This is an automated email from the ASF dual-hosted git repository.
tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 5b8743b9366 [BEAM-36736] Add state sampling for timer processing in
the Python SDK (#36737)
5b8743b9366 is described below
commit 5b8743b9366ddc6e0300aa6e5ce6c2fbe1e4274f
Author: Karthik Talluri <[email protected]>
AuthorDate: Tue Nov 25 15:39:18 2025 -0800
[BEAM-36736] Add state sampling for timer processing in the Python SDK
(#36737)
* [BEAM-36736] Add state sampling for timer processing
* Force CI to rebuild
* Fix error with no state found
* Fix error for Regex test
* Resolve linting error
* Add test case to test full functionality
* Fix suffix issue
* Fix formatting issues using tox -e yapf-check
* Add test cases to test code paths
* Address comments and remove extra test case
* Remove user state context variable
* Adjust state duration for test to avoid flakiness
* Add different tests, remove no op scoped state, and address
formatting/lint issues
* Add patch to deal with CI presubmit errors
* Adjust test case to not use dofn_runner
* Test case failing presubmits, attempting to fix
* Fix mocking for tests and ensure all pass
* Remove extra test and increase retries on the process timer tests to
avoid flakiness
* Remove upper bound restriction and reduce retries
* Remove unused suffix param.
---------
Co-authored-by: tvalentyn <[email protected]>
---
.../apache_beam/runners/worker/operations.pxd | 1 +
.../apache_beam/runners/worker/operations.py | 22 ++-
.../runners/worker/statesampler_test.py | 185 +++++++++++++++++++++
3 files changed, 199 insertions(+), 9 deletions(-)
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd
b/sdks/python/apache_beam/runners/worker/operations.pxd
index f24b75a720e..52211e4d8ce 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -117,6 +117,7 @@ cdef class DoOperation(Operation):
cdef dict timer_specs
cdef public object input_info
cdef object fn
+ cdef object scoped_timer_processing_state
cdef class SdfProcessSizedElements(DoOperation):
diff --git a/sdks/python/apache_beam/runners/worker/operations.py
b/sdks/python/apache_beam/runners/worker/operations.py
index 9f490e4ae44..d0f7cceb558 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -809,7 +809,10 @@ class DoOperation(Operation):
self.tagged_receivers = None # type: Optional[_TaggedReceivers]
# A mapping of timer tags to the input "PCollections" they come in on.
self.input_info = None # type: Optional[OpInputInfo]
-
+ self.scoped_timer_processing_state = self.state_sampler.scoped_state(
+ self.name_context,
+ 'process-timers',
+ metrics_container=self.metrics_container)
# See fn_data in dataflow_runner.py
# TODO: Store all the items from spec?
self.fn, _, _, _, _ = (pickler.loads(self.spec.serialized_fn))
@@ -971,14 +974,15 @@ class DoOperation(Operation):
self.user_state_context.add_timer_info(timer_family_id, timer_info)
def process_timer(self, tag, timer_data):
- timer_spec = self.timer_specs[tag]
- self.dofn_runner.process_user_timer(
- timer_spec,
- timer_data.user_key,
- timer_data.windows[0],
- timer_data.fire_timestamp,
- timer_data.paneinfo,
- timer_data.dynamic_timer_tag)
+ with self.scoped_timer_processing_state:
+ timer_spec = self.timer_specs[tag]
+ self.dofn_runner.process_user_timer(
+ timer_spec,
+ timer_data.user_key,
+ timer_data.windows[0],
+ timer_data.fire_timestamp,
+ timer_data.paneinfo,
+ timer_data.dynamic_timer_tag)
def finish(self):
# type: () -> None
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_test.py
b/sdks/python/apache_beam/runners/worker/statesampler_test.py
index c9ea7e8eef9..0d0ce1d2c8d 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_test.py
+++ b/sdks/python/apache_beam/runners/worker/statesampler_test.py
@@ -21,17 +21,56 @@
import logging
import time
import unittest
+from unittest import mock
+from unittest.mock import Mock
+from unittest.mock import patch
from tenacity import retry
from tenacity import stop_after_attempt
+from apache_beam.internal import pickler
+from apache_beam.runners import common
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import operations
from apache_beam.runners.worker import statesampler
+from apache_beam.transforms import core
+from apache_beam.transforms import userstate
+from apache_beam.transforms.core import GlobalWindows
+from apache_beam.transforms.core import Windowing
+from apache_beam.transforms.window import GlobalWindow
from apache_beam.utils.counters import CounterFactory
from apache_beam.utils.counters import CounterName
+from apache_beam.utils.windowed_value import PaneInfo
_LOGGER = logging.getLogger(__name__)
+class TimerDoFn(core.DoFn):
+ TIMER_SPEC = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
+
+ def __init__(self, sleep_duration_s=0):
+ self._sleep_duration_s = sleep_duration_s
+
+ @userstate.on_timer(TIMER_SPEC)
+ def on_timer_f(self):
+ if self._sleep_duration_s:
+ time.sleep(self._sleep_duration_s)
+
+
+class ExceptionTimerDoFn(core.DoFn):
+ """A DoFn that raises an exception when its timer fires."""
+ TIMER_SPEC = userstate.TimerSpec('ts-timer', userstate.TimeDomain.WATERMARK)
+
+ def __init__(self, sleep_duration_s=0):
+ self._sleep_duration_s = sleep_duration_s
+
+ @userstate.on_timer(TIMER_SPEC)
+ def on_timer_f(self):
+ if self._sleep_duration_s:
+ time.sleep(self._sleep_duration_s)
+ raise RuntimeError("Test exception from timer")
+
+
class StateSamplerTest(unittest.TestCase):
# Due to somewhat non-deterministic nature of state sampling and sleep,
@@ -127,6 +166,152 @@ class StateSamplerTest(unittest.TestCase):
# debug mode).
self.assertLess(overhead_us, 20.0)
+ @retry(reraise=True, stop=stop_after_attempt(3))
+ # Patch the problematic function to return the correct timer spec
+ @patch('apache_beam.transforms.userstate.get_dofn_specs')
+ def test_do_operation_process_timer(self, mock_get_dofn_specs):
+ fn = TimerDoFn()
+ mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC])
+
+ if not statesampler.FAST_SAMPLER:
+ self.skipTest('DoOperation test requires FAST_SAMPLER')
+
+ state_duration_ms = 200
+ margin_of_error = 0.75
+
+ counter_factory = CounterFactory()
+ sampler = statesampler.StateSampler(
+ 'test_do_op', counter_factory, sampling_period_ms=1)
+
+ fn_for_spec = TimerDoFn(sleep_duration_s=state_duration_ms / 1000.0)
+
+ spec = operation_specs.WorkerDoFn(
+ serialized_fn=pickler.dumps(
+ (fn_for_spec, [], {}, [], Windowing(GlobalWindows()))),
+ output_tags=[],
+ input=None,
+ side_inputs=[],
+ output_coders=[])
+
+ mock_user_state_context = mock.MagicMock()
+ op = operations.DoOperation(
+ common.NameContext('step1'),
+ spec,
+ counter_factory,
+ sampler,
+ user_state_context=mock_user_state_context)
+
+ op.setup()
+
+ timer_data = Mock()
+ timer_data.user_key = None
+ timer_data.windows = [GlobalWindow()]
+ timer_data.fire_timestamp = 0
+ timer_data.paneinfo = PaneInfo(
+ is_first=False,
+ is_last=False,
+ timing=0,
+ index=0,
+ nonspeculative_index=0)
+ timer_data.dynamic_timer_tag = ''
+
+ sampler.start()
+ op.process_timer('ts-timer', timer_data=timer_data)
+ sampler.stop()
+ sampler.commit_counters()
+
+ expected_name = CounterName(
+ 'process-timers-msecs', step_name='step1', stage_name='test_do_op')
+
+ found_counter = None
+ for counter in counter_factory.get_counters():
+ if counter.name == expected_name:
+ found_counter = counter
+ break
+
+ self.assertIsNotNone(
+ found_counter, f"Expected counter '{expected_name}' to be created.")
+
+ actual_value = found_counter.value()
+ logging.info("Actual value %d", actual_value)
+ self.assertGreater(
+ actual_value, state_duration_ms * (1.0 - margin_of_error))
+
+ @retry(reraise=True, stop=stop_after_attempt(3))
+ @patch('apache_beam.runners.worker.operations.userstate.get_dofn_specs')
+ def test_do_operation_process_timer_with_exception(self,
mock_get_dofn_specs):
+ fn = ExceptionTimerDoFn()
+ mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC])
+
+ if not statesampler.FAST_SAMPLER:
+ self.skipTest('DoOperation test requires FAST_SAMPLER')
+
+ state_duration_ms = 200
+ margin_of_error = 0.50
+
+ counter_factory = CounterFactory()
+ sampler = statesampler.StateSampler(
+ 'test_do_op_exception', counter_factory, sampling_period_ms=1)
+
+ fn_for_spec = ExceptionTimerDoFn(
+ sleep_duration_s=state_duration_ms / 1000.0)
+
+ spec = operation_specs.WorkerDoFn(
+ serialized_fn=pickler.dumps(
+ (fn_for_spec, [], {}, [], Windowing(GlobalWindows()))),
+ output_tags=[],
+ input=None,
+ side_inputs=[],
+ output_coders=[])
+
+ mock_user_state_context = mock.MagicMock()
+ op = operations.DoOperation(
+ common.NameContext('step1'),
+ spec,
+ counter_factory,
+ sampler,
+ user_state_context=mock_user_state_context)
+
+ op.setup()
+
+ timer_data = Mock()
+ timer_data.user_key = None
+ timer_data.windows = [GlobalWindow()]
+ timer_data.fire_timestamp = 0
+ timer_data.paneinfo = PaneInfo(
+ is_first=False,
+ is_last=False,
+ timing=0,
+ index=0,
+ nonspeculative_index=0)
+ timer_data.dynamic_timer_tag = ''
+
+ sampler.start()
+ # Assert that the expected exception is raised
+ with self.assertRaises(RuntimeError):
+ op.process_timer('ts-ts-timer', timer_data=timer_data)
+ sampler.stop()
+ sampler.commit_counters()
+
+ expected_name = CounterName(
+ 'process-timers-msecs',
+ step_name='step1',
+ stage_name='test_do_op_exception')
+
+ found_counter = None
+ for counter in counter_factory.get_counters():
+ if counter.name == expected_name:
+ found_counter = counter
+ break
+
+ self.assertIsNotNone(
+ found_counter, f"Expected counter '{expected_name}' to be created.")
+
+ actual_value = found_counter.value()
+ self.assertGreater(
+ actual_value, state_duration_ms * (1.0 - margin_of_error))
+ _LOGGER.info("Exception test finished successfully.")
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)