This is an automated email from the ASF dual-hosted git repository. chamikara pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new a0831e0d4b1 Add AftersynchronizedProcessing Time as continuation trigger (#36285) a0831e0d4b1 is described below commit a0831e0d4b1a36f8dd1d9c16ef388c02c6620e1a Author: Tarun Annapareddy <tannapare...@google.com> AuthorDate: Wed Oct 1 18:44:45 2025 -0700 Add AftersynchronizedProcessing Time as continuation trigger (#36285) * Add AftersynchronizedProcessing Time as continuation trigger * fix trailing space * fix trailing space * fix formatting --- sdks/python/apache_beam/transforms/core.py | 12 +++ .../apache_beam/transforms/ptransform_test.py | 19 ++++ sdks/python/apache_beam/transforms/trigger.py | 119 ++++++++++++++++++++- sdks/python/apache_beam/transforms/trigger_test.py | 50 +++++++++ 4 files changed, 196 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 2304faf478f..cbd78d8222e 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -3341,6 +3341,18 @@ class GroupByKey(PTransform): return typehints.KV[ key_type, typehints.WindowedValue[value_type]] # type: ignore[misc] + def get_windowing(self, inputs): + # Switch to the continuation trigger associated with the current trigger. + windowing = inputs[0].windowing + triggerfn = windowing.triggerfn.get_continuation_trigger() + return Windowing( + windowfn=windowing.windowfn, + triggerfn=triggerfn, + accumulation_mode=windowing.accumulation_mode, + timestamp_combiner=windowing.timestamp_combiner, + allowed_lateness=windowing.allowed_lateness, + environment_id=windowing.environment_id) + def expand(self, pcoll): from apache_beam.transforms.trigger import DataLossReason from apache_beam.transforms.trigger import DefaultTrigger diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index 3df33bcd8be..ea736dceddb 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -47,6 +47,7 @@ from apache_beam.io.iobase import Read from apache_beam.metrics import Metrics from apache_beam.metrics.metric import MetricsFilter from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.pipeline_options import StreamingOptions from apache_beam.options.pipeline_options import TypeOptions from apache_beam.portability import common_urns @@ -61,6 +62,9 @@ from apache_beam.transforms import window from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display import DisplayDataItem from apache_beam.transforms.ptransform import PTransform +from apache_beam.transforms.trigger import AccumulationMode +from apache_beam.transforms.trigger import AfterProcessingTime +from apache_beam.transforms.trigger import _AfterSynchronizedProcessingTime from apache_beam.transforms.window import TimestampedValue from apache_beam.typehints import with_input_types from apache_beam.typehints import with_output_types @@ -510,6 +514,21 @@ class PTransformTest(unittest.TestCase): with TestPipeline(options=test_options) as pipeline: pipeline | TestStream() | beam.GroupByKey() + def test_group_by_key_trigger(self): + options = PipelineOptions(['--allow_unsafe_triggers']) + options.view_as(StandardOptions).streaming = True + with TestPipeline(runner='BundleBasedDirectRunner', + options=options) as pipeline: + pcoll = pipeline | 'Start' >> beam.Create([(0, 0)]) + triggered = pcoll | 'Trigger' >> beam.WindowInto( + window.GlobalWindows(), + trigger=AfterProcessingTime(1), + accumulation_mode=AccumulationMode.DISCARDING) + output = triggered | 'Gbk' >> beam.GroupByKey() + self.assertTrue( + isinstance( + output.windowing.triggerfn, _AfterSynchronizedProcessingTime)) + def test_group_by_key_unsafe_trigger(self): test_options = PipelineOptions() test_options.view_as(TypeOptions).allow_unsafe_triggers = False diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 7d573a58e3f..cc9922dd158 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -304,7 +304,7 @@ class TriggerFn(metaclass=ABCMeta): 'after_each': AfterEach, 'after_end_of_window': AfterWatermark, 'after_processing_time': AfterProcessingTime, - # after_processing_time, after_synchronized_processing_time + 'after_synchronized_processing_time': _AfterSynchronizedProcessingTime, 'always': Always, 'default': DefaultTrigger, 'element_count': AfterCount, @@ -317,6 +317,17 @@ class TriggerFn(metaclass=ABCMeta): def to_runner_api(self, unused_context): pass + @abstractmethod + def get_continuation_trigger(self): + """Returns: + Trigger to use after a GroupBy to preserve the intention of this + trigger. Specifically, triggers that are time based and intended + to provide speculative results should continue providing speculative + results. Triggers that fire once (or multiple times) should + continue firing once (or multiple times). + """ + pass + class DefaultTrigger(TriggerFn): """Semantically Repeatedly(AfterWatermark()), but more optimized.""" @@ -366,6 +377,9 @@ class DefaultTrigger(TriggerFn): def has_ontime_pane(self): return True + def get_continuation_trigger(self): + return self + class AfterProcessingTime(TriggerFn): """Fire exactly once after a specified delay from processing time.""" @@ -421,6 +435,11 @@ class AfterProcessingTime(TriggerFn): def has_ontime_pane(self): return False + def get_continuation_trigger(self): + # The continuation of an AfterProcessingTime trigger is an + # _AfterSynchronizedProcessingTime trigger. + return _AfterSynchronizedProcessingTime() + class Always(TriggerFn): """Repeatedly invoke the given trigger, never finishing.""" @@ -466,6 +485,9 @@ class Always(TriggerFn): return beam_runner_api_pb2.Trigger( always=beam_runner_api_pb2.Trigger.Always()) + def get_continuation_trigger(self): + return self + class _Never(TriggerFn): """A trigger that never fires. @@ -518,6 +540,9 @@ class _Never(TriggerFn): return beam_runner_api_pb2.Trigger( never=beam_runner_api_pb2.Trigger.Never()) + def get_continuation_trigger(self): + return self + class AfterWatermark(TriggerFn): """Fire exactly once when the watermark passes the end of the window. @@ -531,9 +556,19 @@ class AfterWatermark(TriggerFn): LATE_TAG = _CombiningValueStateTag('is_late', any) def __init__(self, early=None, late=None): - # TODO(zhoufek): Maybe don't wrap early/late if they are already Repeatedly - self.early = Repeatedly(early) if early else None - self.late = Repeatedly(late) if late else None + self.early = self._wrap_if_not_repeatedly(early) + self.late = self._wrap_if_not_repeatedly(late) + + @staticmethod + def _wrap_if_not_repeatedly(trigger): + if trigger and not isinstance(trigger, Repeatedly): + return Repeatedly(trigger) + return trigger + + def get_continuation_trigger(self): + return AfterWatermark( + self.early.get_continuation_trigger() if self.early else None, + self.late.get_continuation_trigger() if self.late else None) def __repr__(self): qualifiers = [] @@ -692,6 +727,9 @@ class AfterCount(TriggerFn): def has_ontime_pane(self): return False + def get_continuation_trigger(self): + return AfterCount(1) + class Repeatedly(TriggerFn): """Repeatedly invoke the given trigger, never finishing.""" @@ -741,6 +779,9 @@ class Repeatedly(TriggerFn): def has_ontime_pane(self): return self.underlying.has_ontime_pane() + def get_continuation_trigger(self): + return Repeatedly(self.underlying.get_continuation_trigger()) + class _ParallelTriggerFn(TriggerFn, metaclass=ABCMeta): def __init__(self, *triggers): @@ -831,6 +872,12 @@ class _ParallelTriggerFn(TriggerFn, metaclass=ABCMeta): def has_ontime_pane(self): return any(t.has_ontime_pane() for t in self.triggers) + def get_continuation_trigger(self): + return self.__class__( + *( + subtrigger.get_continuation_trigger() + for subtrigger in self.triggers)) + class AfterAny(_ParallelTriggerFn): """Fires when any subtrigger fires. @@ -933,6 +980,13 @@ class AfterEach(TriggerFn): def has_ontime_pane(self): return any(t.has_ontime_pane() for t in self.triggers) + def get_continuation_trigger(self): + return Repeatedly( + AfterAny( + *( + subtrigger.get_continuation_trigger() + for subtrigger in self.triggers))) + class OrFinally(AfterAny): @staticmethod @@ -1643,3 +1697,60 @@ class InMemoryUnmergedState(UnmergedState): state_str = '\n'.join( '%s: %s' % (key, dict(state)) for key, state in self.state.items()) return 'timers: %s\nstate: %s' % (dict(self.timers), state_str) + + +class _AfterSynchronizedProcessingTime(TriggerFn): + """A "runner's-discretion" trigger downstream of a GroupByKey + with AfterProcessingTime trigger. + + In runners that directly execute this + Python code, the trigger currently always fires, + but this behavior is neither guaranteed nor + required by runners, regardless of whether they + execute triggers via Python. + + _AfterSynchronizedProcessingTime is experimental + and internal-only. No backwards compatibility + guarantees. + """ + def __init__(self): + pass + + def __repr__(self): + return '_AfterSynchronizedProcessingTime()' + + def __eq__(self, other): + return type(self) == type(other) + + def __hash__(self): + return hash(type(self)) + + def on_element(self, _element, _window, _context): + pass + + def on_merge(self, _to_be_merged, _merge_result, _context): + pass + + def should_fire(self, _time_domain, _timestamp, _window, _context): + return True + + def on_fire(self, _timestamp, _window, _context): + return False + + def reset(self, _window, _context): + pass + + @staticmethod + def from_runner_api(_proto, _context): + return _AfterSynchronizedProcessingTime() + + def to_runner_api(self, _context): + return beam_runner_api_pb2.Trigger( + after_synchronized_processing_time=beam_runner_api_pb2.Trigger. + AfterSynchronizedProcessingTime()) + + def has_ontime_pane(self): + return False + + def get_continuation_trigger(self): + return self diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py index b9a8cdc594b..9f9b7fe51a9 100644 --- a/sdks/python/apache_beam/transforms/trigger_test.py +++ b/sdks/python/apache_beam/transforms/trigger_test.py @@ -554,6 +554,56 @@ class RunnerApiTest(unittest.TestCase): TriggerFn.from_runner_api(trigger_fn.to_runner_api(context), context)) +class ContinuationTriggerTest(unittest.TestCase): + def test_after_all(self): + self.assertEqual( + AfterAll(AfterCount(2), AfterCount(5)).get_continuation_trigger(), + AfterAll(AfterCount(1), AfterCount(1))) + + def test_after_any(self): + self.assertEqual( + AfterAny(AfterCount(2), AfterCount(5)).get_continuation_trigger(), + AfterAny(AfterCount(1), AfterCount(1))) + + def test_after_count(self): + self.assertEqual(AfterCount(1).get_continuation_trigger(), AfterCount(1)) + self.assertEqual(AfterCount(100).get_continuation_trigger(), AfterCount(1)) + + def test_after_each(self): + self.assertEqual( + AfterEach(AfterCount(2), AfterCount(5)).get_continuation_trigger(), + Repeatedly(AfterAny(AfterCount(1), AfterCount(1)))) + + def test_after_processing_time(self): + from apache_beam.transforms.trigger import _AfterSynchronizedProcessingTime + self.assertEqual( + AfterProcessingTime(10).get_continuation_trigger(), + _AfterSynchronizedProcessingTime()) + + def test_after_watermark(self): + self.assertEqual( + AfterWatermark().get_continuation_trigger(), AfterWatermark()) + self.assertEqual( + AfterWatermark(early=AfterCount(10), + late=AfterCount(20)).get_continuation_trigger(), + AfterWatermark(early=AfterCount(1), late=AfterCount(1))) + + def test_always(self): + self.assertEqual(Always().get_continuation_trigger(), Always()) + + def test_default(self): + self.assertEqual( + DefaultTrigger().get_continuation_trigger(), DefaultTrigger()) + + def test_never(self): + self.assertEqual(_Never().get_continuation_trigger(), _Never()) + + def test_repeatedly(self): + self.assertEqual( + Repeatedly(AfterCount(10)).get_continuation_trigger(), + Repeatedly(AfterCount(1))) + + class TriggerPipelineTest(unittest.TestCase): def test_after_processing_time(self): test_options = PipelineOptions(