11"""Mixin for event detection algorithms that work similar to Rampp et al."""
22
3+ import warnings
34from typing import Any , Callable , Optional , Union
45
56import numpy as np
67import pandas as pd
78from joblib import Memory
89from numpy .linalg import norm
9- from typing_extensions import Self
10+ from typing_extensions import Literal , Self
1011
1112from gaitmap .utils ._algo_helper import invert_result_dictionary , set_params_from_dict
1213from gaitmap .utils ._types import _Hashable
2223)
2324from gaitmap .utils .exceptions import ValidationError
2425from gaitmap .utils .stride_list_conversion import (
25- _segmented_stride_list_to_min_vel_single_sensor ,
26+ _stride_list_to_min_vel_single_sensor ,
2627 enforce_stride_list_consistency ,
2728)
2829
@@ -33,21 +34,24 @@ class _EventDetectionMixin:
3334 detect_only : Optional [tuple [str , ...]]
3435
3536 min_vel_event_list_ : Optional [Union [pd .DataFrame , dict [str , pd .DataFrame ]]]
36- segmented_event_list_ : Optional [Union [pd .DataFrame , dict [str , pd .DataFrame ]]]
37+ annotated_original_event_list_ : Optional [Union [pd .DataFrame , dict [str , pd .DataFrame ]]]
3738
3839 data : SensorData
3940 sampling_rate_hz : float
4041 stride_list : pd .DataFrame
42+ input_stride_type : Literal ["segmented" , "ic" ]
4143
4244 def __init__ (
4345 self ,
4446 memory : Optional [Memory ] = None ,
4547 enforce_consistency : bool = True ,
4648 detect_only : Optional [tuple [str , ...]] = None ,
47- ) -> None :
49+ input_stride_type : Literal ["segmented" , "ic" ] = "segmented" ,
50+ ):
4851 self .memory = memory
4952 self .enforce_consistency = enforce_consistency
5053 self .detect_only = detect_only
54+ self .input_stride_type = input_stride_type
5155
5256 def detect (self , data : SensorData , stride_list : StrideList , * , sampling_rate_hz : float ) -> Self :
5357 """Find gait events in data within strides provided by stride_list.
@@ -121,50 +125,62 @@ def _detect_single_dataset(
121125 # find events in all segments
122126 event_detection_func = self ._select_all_event_detection_method ()
123127 event_detection_func = memory .cache (event_detection_func )
124- ic , tc , min_vel = event_detection_func (gyr , acc , stride_list , events = events , ** detect_kwargs )
128+ ic , tc , min_vel = event_detection_func (
129+ gyr , acc , stride_list , events = events , input_stride_type = self .input_stride_type , ** detect_kwargs
130+ )
125131
126132 # build first dict / df based on segment start and end
127- segmented_event_list = {
133+ annotated_original_event_list = {
128134 "s_id" : stride_list .index ,
129135 "start" : stride_list ["start" ],
130136 "end" : stride_list ["end" ],
131137 }
132138 for event , event_list in zip (("ic" , "tc" , "min_vel" ), (ic , tc , min_vel )):
133139 if event in events :
134- segmented_event_list [event ] = event_list
135-
136- segmented_event_list = pd .DataFrame (segmented_event_list ).set_index ("s_id" )
137-
140+ annotated_original_event_list [event ] = event_list
141+ annotated_original_event_list = pd .DataFrame (annotated_original_event_list ).set_index ("s_id" )
138142 if self .enforce_consistency :
139143 # check for consistency, remove inconsistent strides
140- segmented_event_list , _ = enforce_stride_list_consistency (
141- segmented_event_list , stride_type = "segmented" , check_stride_list = False
144+ annotated_original_event_list , _ = enforce_stride_list_consistency (
145+ annotated_original_event_list , input_stride_type = self . input_stride_type , check_stride_list = False
142146 )
143147
144148 if "min_vel" not in events or self .enforce_consistency is False :
145149 # do not set min_vel_event_list_ if consistency is not enforced as it would be completely scrambled
146150 # and can not be used for anything anyway
147- return {"segmented_event_list " : segmented_event_list }
151+ return {"annotated_original_event_list " : annotated_original_event_list }
148152
149153 # convert to min_vel event list
150- min_vel_event_list , _ = _segmented_stride_list_to_min_vel_single_sensor (
151- segmented_event_list , target_stride_type = "min_vel"
154+ min_vel_event_list , _ = _stride_list_to_min_vel_single_sensor (
155+ annotated_original_event_list , source_stride_type = self . input_stride_type , target_stride_type = "min_vel"
152156 )
153157
154158 output_order = [c for c in ["start" , "end" , "ic" , "tc" , "min_vel" , "pre_ic" ] if c in min_vel_event_list .columns ]
155159
156160 # We enforce consistency again here, as a valid segmented stride list does not necessarily result in a valid
157161 # min_vel stride list
158162 min_vel_event_list , _ = enforce_stride_list_consistency (
159- min_vel_event_list [output_order ], stride_type = "min_vel" , check_stride_list = False
163+ min_vel_event_list [output_order ], input_stride_type = "min_vel" , check_stride_list = False
160164 )
161165
162- return {"min_vel_event_list" : min_vel_event_list , "segmented_event_list" : segmented_event_list }
166+ return {
167+ "min_vel_event_list" : min_vel_event_list ,
168+ "annotated_original_event_list" : annotated_original_event_list ,
169+ }
170+
171+ @property
172+ def segmented_event_list_ (self ) -> Optional [Union [pd .DataFrame , dict [str , pd .DataFrame ]]]:
173+ warnings .warn (
174+ "`segmented_event_list_` is deprecated and will be removed in a future version. "
175+ "Use `annotated_original_event_list_` instead." ,
176+ DeprecationWarning ,
177+ )
178+ return self .annotated_original_event_list_
163179
164180 def _select_all_event_detection_method (self ) -> Callable :
165181 """Select the function to calculate the all events.
166182
167- This is separate method to make it easy to overwrite by a subclass.
183+ This is a separate method to make it easy to overwrite by a subclass.
168184 """
169185 raise NotImplementedError ()
170186
0 commit comments