3535from apache_beam .typehints .typehints import TupleConstraint
3636from apache_beam .utils .timestamp import MAX_TIMESTAMP
3737from apache_beam .utils .timestamp import MIN_TIMESTAMP
38+ from apache_beam .utils .timestamp import Duration
3839from apache_beam .utils .timestamp import DurationTypes # pylint: disable=unused-import
3940from apache_beam .utils .timestamp import Timestamp
4041from apache_beam .utils .timestamp import TimestampTypes # pylint: disable=unused-import
@@ -89,7 +90,7 @@ def __init__(
8990 self ,
9091 duration : DurationTypes ,
9192 slide_interval : DurationTypes ,
92- offset : TimestampTypes ,
93+ offset : DurationTypes ,
9394 allowed_lateness : DurationTypes ,
9495 default_start_value ,
9596 fill_start_if_missing : bool ,
@@ -200,11 +201,23 @@ def process(
200201
201202 timer_started = timer_state .read ()
202203 if not timer_started :
204+ offset_duration = Duration .of (self .offset )
205+ slide_duration = Duration .of (self .slide_interval )
206+ duration_duration = Duration .of (self .duration )
207+
208+ # Align the timestamp with the windowing scheme.
209+ aligned_micros = (timestamp - offset_duration ).micros
210+
211+ # Calculate the start of the last window that could contain this timestamp
212+ last_window_start_aligned_micros = (
213+ (aligned_micros // slide_duration .micros ) * slide_duration .micros )
214+
215+ last_window_start = Timestamp (
216+ micros = last_window_start_aligned_micros ) + offset_duration
217+ n = (duration_duration .micros - 1 ) // slide_duration .micros
203218 # Calculate the start of the first sliding window.
204- first_slide_start = int (
205- (timestamp .micros / 1e6 - self .offset ) //
206- self .slide_interval ) * self .slide_interval + self .offset
207- first_slide_start_ts = Timestamp .of (first_slide_start )
219+ first_slide_start_ts = last_window_start - Duration (
220+ micros = n * slide_duration .micros )
208221
209222 # Set the initial timer to fire at the end of the first window plus
210223 # allowed lateness.
@@ -256,14 +269,16 @@ def _get_windowed_values_from_state(
256269 if not windowed_values :
257270 # If the window is empty, use the last value.
258271 last_value = last_value_state .read ()
259- windowed_values .append (last_value )
272+ value_to_insert = (window_start_ts , last_value [1 ])
273+ windowed_values .append (value_to_insert )
260274 else :
261275 first_timestamp = windowed_values [0 ][0 ]
262276 last_value = last_value_state .read ()
263277 if first_timestamp > window_start_ts and last_value :
264278 # Prepend the last value if there's a gap between the first element
265279 # in the window and the start of the window.
266- windowed_values = [last_value ] + windowed_values
280+ value_to_insert = (window_start_ts , last_value [1 ])
281+ windowed_values = [value_to_insert ] + windowed_values
267282
268283 # Find the last element before the beginning of the next window to update
269284 # last_value_state.
@@ -334,8 +349,7 @@ def on_timer(
334349 windowed_values = self ._get_windowed_values_from_state (
335350 buffer_state , late_start_ts , late_end_ts , last_value_state )
336351 yield TimestampedValue (
337- ((key , late_start_ts , late_end_ts ), [v [1 ]
338- for v in windowed_values ]),
352+ (key , ((late_start_ts , late_end_ts ), windowed_values )),
339353 late_end_ts - 1 )
340354 late_start_ts += self .slide_interval
341355
@@ -347,8 +361,7 @@ def on_timer(
347361 windowed_values = self ._get_windowed_values_from_state (
348362 buffer_state , window_start_ts , window_end_ts , last_value_state )
349363 yield TimestampedValue (
350- ((key , window_start_ts , window_end_ts ), [v [1 ]
351- for v in windowed_values ]),
364+ (key , ((window_start_ts , window_end_ts ), windowed_values )),
352365 window_end_ts - 1 )
353366
354367 # Post-emit actions for the current window:
@@ -532,7 +545,7 @@ def __init__(
532545 self ,
533546 duration : DurationTypes ,
534547 slide_interval : Optional [DurationTypes ] = None ,
535- offset : TimestampTypes = 0 ,
548+ offset : DurationTypes = 0 ,
536549 allowed_lateness : DurationTypes = 0 ,
537550 default_start_value = None ,
538551 fill_start_if_missing : bool = False ,
@@ -617,7 +630,7 @@ def expand(self, input):
617630 self .stop_timestamp )))
618631
619632 if isinstance (input .element_type , TupleConstraint ):
620- ret = keyed_output | beam . MapTuple ( lambda x , y : ( x [ 0 ], y ))
633+ ret = keyed_output
621634 else :
622635 # Remove the default key if the input PCollection was originally unkeyed.
623636 ret = keyed_output | beam .Values ()
0 commit comments