diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd index f24b75a720e0..52211e4d8ce8 100644 --- a/sdks/python/apache_beam/runners/worker/operations.pxd +++ b/sdks/python/apache_beam/runners/worker/operations.pxd @@ -117,6 +117,7 @@ cdef class DoOperation(Operation): cdef dict timer_specs cdef public object input_info cdef object fn + cdef object scoped_timer_processing_state cdef class SdfProcessSizedElements(DoOperation): diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index 2b20bebe0940..7668564b6eb3 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -809,7 +809,10 @@ def __init__( self.tagged_receivers = None # type: Optional[_TaggedReceivers] # A mapping of timer tags to the input "PCollections" they come in on. self.input_info = None # type: Optional[OpInputInfo] - + self.scoped_timer_processing_state = self.state_sampler.scoped_state( + self.name_context, + 'process-timers', + metrics_container=self.metrics_container) # See fn_data in dataflow_runner.py # TODO: Store all the items from spec? self.fn, _, _, _, _ = (pickler.loads(self.spec.serialized_fn)) @@ -971,14 +974,15 @@ def add_timer_info(self, timer_family_id, timer_info): self.user_state_context.add_timer_info(timer_family_id, timer_info) def process_timer(self, tag, timer_data): - timer_spec = self.timer_specs[tag] - self.dofn_runner.process_user_timer( - timer_spec, - timer_data.user_key, - timer_data.windows[0], - timer_data.fire_timestamp, - timer_data.paneinfo, - timer_data.dynamic_timer_tag) + with self.scoped_timer_processing_state: + timer_spec = self.timer_specs[tag] + self.dofn_runner.process_user_timer( + timer_spec, + timer_data.user_key, + timer_data.windows[0], + timer_data.fire_timestamp, + timer_data.paneinfo, + timer_data.dynamic_timer_tag) def finish(self): # type: () -> None diff --git a/sdks/python/apache_beam/runners/worker/statesampler_test.py b/sdks/python/apache_beam/runners/worker/statesampler_test.py index c9ea7e8eef97..0d0ce1d2c8dc 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_test.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_test.py @@ -21,17 +21,56 @@ import logging import time import unittest +from unittest import mock +from unittest.mock import Mock +from unittest.mock import patch from tenacity import retry from tenacity import stop_after_attempt +from apache_beam.internal import pickler +from apache_beam.runners import common +from apache_beam.runners.worker import operation_specs +from apache_beam.runners.worker import operations from apache_beam.runners.worker import statesampler +from apache_beam.transforms import core +from apache_beam.transforms import userstate +from apache_beam.transforms.core import GlobalWindows +from apache_beam.transforms.core import Windowing +from apache_beam.transforms.window import GlobalWindow from apache_beam.utils.counters import CounterFactory from apache_beam.utils.counters import CounterName +from apache_beam.utils.windowed_value import PaneInfo _LOGGER = logging.getLogger(__name__) +class TimerDoFn(core.DoFn): + TIMER_SPEC = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) + + def __init__(self, sleep_duration_s=0): + self._sleep_duration_s = sleep_duration_s + + @userstate.on_timer(TIMER_SPEC) + def on_timer_f(self): + if self._sleep_duration_s: + time.sleep(self._sleep_duration_s) + + +class ExceptionTimerDoFn(core.DoFn): + """A DoFn that raises an exception when its timer fires.""" + TIMER_SPEC = userstate.TimerSpec('ts-timer', userstate.TimeDomain.WATERMARK) + + def __init__(self, sleep_duration_s=0): + self._sleep_duration_s = sleep_duration_s + + @userstate.on_timer(TIMER_SPEC) + def on_timer_f(self): + if self._sleep_duration_s: + time.sleep(self._sleep_duration_s) + raise RuntimeError("Test exception from timer") + + class StateSamplerTest(unittest.TestCase): # Due to somewhat non-deterministic nature of state sampling and sleep, @@ -127,6 +166,152 @@ def test_sampler_transition_overhead(self): # debug mode). self.assertLess(overhead_us, 20.0) + @retry(reraise=True, stop=stop_after_attempt(3)) + # Patch the problematic function to return the correct timer spec + @patch('apache_beam.transforms.userstate.get_dofn_specs') + def test_do_operation_process_timer(self, mock_get_dofn_specs): + fn = TimerDoFn() + mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC]) + + if not statesampler.FAST_SAMPLER: + self.skipTest('DoOperation test requires FAST_SAMPLER') + + state_duration_ms = 200 + margin_of_error = 0.75 + + counter_factory = CounterFactory() + sampler = statesampler.StateSampler( + 'test_do_op', counter_factory, sampling_period_ms=1) + + fn_for_spec = TimerDoFn(sleep_duration_s=state_duration_ms / 1000.0) + + spec = operation_specs.WorkerDoFn( + serialized_fn=pickler.dumps( + (fn_for_spec, [], {}, [], Windowing(GlobalWindows()))), + output_tags=[], + input=None, + side_inputs=[], + output_coders=[]) + + mock_user_state_context = mock.MagicMock() + op = operations.DoOperation( + common.NameContext('step1'), + spec, + counter_factory, + sampler, + user_state_context=mock_user_state_context) + + op.setup() + + timer_data = Mock() + timer_data.user_key = None + timer_data.windows = [GlobalWindow()] + timer_data.fire_timestamp = 0 + timer_data.paneinfo = PaneInfo( + is_first=False, + is_last=False, + timing=0, + index=0, + nonspeculative_index=0) + timer_data.dynamic_timer_tag = '' + + sampler.start() + op.process_timer('ts-timer', timer_data=timer_data) + sampler.stop() + sampler.commit_counters() + + expected_name = CounterName( + 'process-timers-msecs', step_name='step1', stage_name='test_do_op') + + found_counter = None + for counter in counter_factory.get_counters(): + if counter.name == expected_name: + found_counter = counter + break + + self.assertIsNotNone( + found_counter, f"Expected counter '{expected_name}' to be created.") + + actual_value = found_counter.value() + logging.info("Actual value %d", actual_value) + self.assertGreater( + actual_value, state_duration_ms * (1.0 - margin_of_error)) + + @retry(reraise=True, stop=stop_after_attempt(3)) + @patch('apache_beam.runners.worker.operations.userstate.get_dofn_specs') + def test_do_operation_process_timer_with_exception(self, mock_get_dofn_specs): + fn = ExceptionTimerDoFn() + mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC]) + + if not statesampler.FAST_SAMPLER: + self.skipTest('DoOperation test requires FAST_SAMPLER') + + state_duration_ms = 200 + margin_of_error = 0.50 + + counter_factory = CounterFactory() + sampler = statesampler.StateSampler( + 'test_do_op_exception', counter_factory, sampling_period_ms=1) + + fn_for_spec = ExceptionTimerDoFn( + sleep_duration_s=state_duration_ms / 1000.0) + + spec = operation_specs.WorkerDoFn( + serialized_fn=pickler.dumps( + (fn_for_spec, [], {}, [], Windowing(GlobalWindows()))), + output_tags=[], + input=None, + side_inputs=[], + output_coders=[]) + + mock_user_state_context = mock.MagicMock() + op = operations.DoOperation( + common.NameContext('step1'), + spec, + counter_factory, + sampler, + user_state_context=mock_user_state_context) + + op.setup() + + timer_data = Mock() + timer_data.user_key = None + timer_data.windows = [GlobalWindow()] + timer_data.fire_timestamp = 0 + timer_data.paneinfo = PaneInfo( + is_first=False, + is_last=False, + timing=0, + index=0, + nonspeculative_index=0) + timer_data.dynamic_timer_tag = '' + + sampler.start() + # Assert that the expected exception is raised + with self.assertRaises(RuntimeError): + op.process_timer('ts-ts-timer', timer_data=timer_data) + sampler.stop() + sampler.commit_counters() + + expected_name = CounterName( + 'process-timers-msecs', + step_name='step1', + stage_name='test_do_op_exception') + + found_counter = None + for counter in counter_factory.get_counters(): + if counter.name == expected_name: + found_counter = counter + break + + self.assertIsNotNone( + found_counter, f"Expected counter '{expected_name}' to be created.") + + actual_value = found_counter.value() + self.assertGreater( + actual_value, state_duration_ms * (1.0 - margin_of_error)) + _LOGGER.info("Exception test finished successfully.") + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)