This is an automated email from the ASF dual-hosted git repository.

chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a0831e0d4b1 Add AftersynchronizedProcessing Time as continuation 
trigger (#36285)
a0831e0d4b1 is described below

commit a0831e0d4b1a36f8dd1d9c16ef388c02c6620e1a
Author: Tarun Annapareddy <tannapare...@google.com>
AuthorDate: Wed Oct 1 18:44:45 2025 -0700

    Add AftersynchronizedProcessing Time as continuation trigger (#36285)
    
    * Add AftersynchronizedProcessing Time as continuation trigger
    
    * fix trailing space
    
    * fix trailing space
    
    * fix formatting
---
 sdks/python/apache_beam/transforms/core.py         |  12 +++
 .../apache_beam/transforms/ptransform_test.py      |  19 ++++
 sdks/python/apache_beam/transforms/trigger.py      | 119 ++++++++++++++++++++-
 sdks/python/apache_beam/transforms/trigger_test.py |  50 +++++++++
 4 files changed, 196 insertions(+), 4 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 2304faf478f..cbd78d8222e 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -3341,6 +3341,18 @@ class GroupByKey(PTransform):
       return typehints.KV[
           key_type, typehints.WindowedValue[value_type]]  # type: ignore[misc]
 
+  def get_windowing(self, inputs):
+    # Switch to the continuation trigger associated with the current trigger.
+    windowing = inputs[0].windowing
+    triggerfn = windowing.triggerfn.get_continuation_trigger()
+    return Windowing(
+        windowfn=windowing.windowfn,
+        triggerfn=triggerfn,
+        accumulation_mode=windowing.accumulation_mode,
+        timestamp_combiner=windowing.timestamp_combiner,
+        allowed_lateness=windowing.allowed_lateness,
+        environment_id=windowing.environment_id)
+
   def expand(self, pcoll):
     from apache_beam.transforms.trigger import DataLossReason
     from apache_beam.transforms.trigger import DefaultTrigger
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py 
b/sdks/python/apache_beam/transforms/ptransform_test.py
index 3df33bcd8be..ea736dceddb 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -47,6 +47,7 @@ from apache_beam.io.iobase import Read
 from apache_beam.metrics import Metrics
 from apache_beam.metrics.metric import MetricsFilter
 from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.options.pipeline_options import StreamingOptions
 from apache_beam.options.pipeline_options import TypeOptions
 from apache_beam.portability import common_urns
@@ -61,6 +62,9 @@ from apache_beam.transforms import window
 from apache_beam.transforms.display import DisplayData
 from apache_beam.transforms.display import DisplayDataItem
 from apache_beam.transforms.ptransform import PTransform
+from apache_beam.transforms.trigger import AccumulationMode
+from apache_beam.transforms.trigger import AfterProcessingTime
+from apache_beam.transforms.trigger import _AfterSynchronizedProcessingTime
 from apache_beam.transforms.window import TimestampedValue
 from apache_beam.typehints import with_input_types
 from apache_beam.typehints import with_output_types
@@ -510,6 +514,21 @@ class PTransformTest(unittest.TestCase):
       with TestPipeline(options=test_options) as pipeline:
         pipeline | TestStream() | beam.GroupByKey()
 
+  def test_group_by_key_trigger(self):
+    options = PipelineOptions(['--allow_unsafe_triggers'])
+    options.view_as(StandardOptions).streaming = True
+    with TestPipeline(runner='BundleBasedDirectRunner',
+                      options=options) as pipeline:
+      pcoll = pipeline | 'Start' >> beam.Create([(0, 0)])
+      triggered = pcoll | 'Trigger' >> beam.WindowInto(
+          window.GlobalWindows(),
+          trigger=AfterProcessingTime(1),
+          accumulation_mode=AccumulationMode.DISCARDING)
+      output = triggered | 'Gbk' >> beam.GroupByKey()
+      self.assertTrue(
+          isinstance(
+              output.windowing.triggerfn, _AfterSynchronizedProcessingTime))
+
   def test_group_by_key_unsafe_trigger(self):
     test_options = PipelineOptions()
     test_options.view_as(TypeOptions).allow_unsafe_triggers = False
diff --git a/sdks/python/apache_beam/transforms/trigger.py 
b/sdks/python/apache_beam/transforms/trigger.py
index 7d573a58e3f..cc9922dd158 100644
--- a/sdks/python/apache_beam/transforms/trigger.py
+++ b/sdks/python/apache_beam/transforms/trigger.py
@@ -304,7 +304,7 @@ class TriggerFn(metaclass=ABCMeta):
         'after_each': AfterEach,
         'after_end_of_window': AfterWatermark,
         'after_processing_time': AfterProcessingTime,
-        # after_processing_time, after_synchronized_processing_time
+        'after_synchronized_processing_time': _AfterSynchronizedProcessingTime,
         'always': Always,
         'default': DefaultTrigger,
         'element_count': AfterCount,
@@ -317,6 +317,17 @@ class TriggerFn(metaclass=ABCMeta):
   def to_runner_api(self, unused_context):
     pass
 
+  @abstractmethod
+  def get_continuation_trigger(self):
+    """Returns:
+        Trigger to use after a GroupBy to preserve the intention of this
+        trigger. Specifically, triggers that are time based and intended
+        to provide speculative results should continue providing speculative
+        results. Triggers that fire once (or multiple times) should
+        continue firing once (or multiple times).
+    """
+    pass
+
 
 class DefaultTrigger(TriggerFn):
   """Semantically Repeatedly(AfterWatermark()), but more optimized."""
@@ -366,6 +377,9 @@ class DefaultTrigger(TriggerFn):
   def has_ontime_pane(self):
     return True
 
+  def get_continuation_trigger(self):
+    return self
+
 
 class AfterProcessingTime(TriggerFn):
   """Fire exactly once after a specified delay from processing time."""
@@ -421,6 +435,11 @@ class AfterProcessingTime(TriggerFn):
   def has_ontime_pane(self):
     return False
 
+  def get_continuation_trigger(self):
+    # The continuation of an AfterProcessingTime trigger is an
+    # _AfterSynchronizedProcessingTime trigger.
+    return _AfterSynchronizedProcessingTime()
+
 
 class Always(TriggerFn):
   """Repeatedly invoke the given trigger, never finishing."""
@@ -466,6 +485,9 @@ class Always(TriggerFn):
     return beam_runner_api_pb2.Trigger(
         always=beam_runner_api_pb2.Trigger.Always())
 
+  def get_continuation_trigger(self):
+    return self
+
 
 class _Never(TriggerFn):
   """A trigger that never fires.
@@ -518,6 +540,9 @@ class _Never(TriggerFn):
     return beam_runner_api_pb2.Trigger(
         never=beam_runner_api_pb2.Trigger.Never())
 
+  def get_continuation_trigger(self):
+    return self
+
 
 class AfterWatermark(TriggerFn):
   """Fire exactly once when the watermark passes the end of the window.
@@ -531,9 +556,19 @@ class AfterWatermark(TriggerFn):
   LATE_TAG = _CombiningValueStateTag('is_late', any)
 
   def __init__(self, early=None, late=None):
-    # TODO(zhoufek): Maybe don't wrap early/late if they are already Repeatedly
-    self.early = Repeatedly(early) if early else None
-    self.late = Repeatedly(late) if late else None
+    self.early = self._wrap_if_not_repeatedly(early)
+    self.late = self._wrap_if_not_repeatedly(late)
+
+  @staticmethod
+  def _wrap_if_not_repeatedly(trigger):
+    if trigger and not isinstance(trigger, Repeatedly):
+      return Repeatedly(trigger)
+    return trigger
+
+  def get_continuation_trigger(self):
+    return AfterWatermark(
+        self.early.get_continuation_trigger() if self.early else None,
+        self.late.get_continuation_trigger() if self.late else None)
 
   def __repr__(self):
     qualifiers = []
@@ -692,6 +727,9 @@ class AfterCount(TriggerFn):
   def has_ontime_pane(self):
     return False
 
+  def get_continuation_trigger(self):
+    return AfterCount(1)
+
 
 class Repeatedly(TriggerFn):
   """Repeatedly invoke the given trigger, never finishing."""
@@ -741,6 +779,9 @@ class Repeatedly(TriggerFn):
   def has_ontime_pane(self):
     return self.underlying.has_ontime_pane()
 
+  def get_continuation_trigger(self):
+    return Repeatedly(self.underlying.get_continuation_trigger())
+
 
 class _ParallelTriggerFn(TriggerFn, metaclass=ABCMeta):
   def __init__(self, *triggers):
@@ -831,6 +872,12 @@ class _ParallelTriggerFn(TriggerFn, metaclass=ABCMeta):
   def has_ontime_pane(self):
     return any(t.has_ontime_pane() for t in self.triggers)
 
+  def get_continuation_trigger(self):
+    return self.__class__(
+        *(
+            subtrigger.get_continuation_trigger()
+            for subtrigger in self.triggers))
+
 
 class AfterAny(_ParallelTriggerFn):
   """Fires when any subtrigger fires.
@@ -933,6 +980,13 @@ class AfterEach(TriggerFn):
   def has_ontime_pane(self):
     return any(t.has_ontime_pane() for t in self.triggers)
 
+  def get_continuation_trigger(self):
+    return Repeatedly(
+        AfterAny(
+            *(
+                subtrigger.get_continuation_trigger()
+                for subtrigger in self.triggers)))
+
 
 class OrFinally(AfterAny):
   @staticmethod
@@ -1643,3 +1697,60 @@ class InMemoryUnmergedState(UnmergedState):
     state_str = '\n'.join(
         '%s: %s' % (key, dict(state)) for key, state in self.state.items())
     return 'timers: %s\nstate: %s' % (dict(self.timers), state_str)
+
+
+class _AfterSynchronizedProcessingTime(TriggerFn):
+  """A "runner's-discretion" trigger downstream of a GroupByKey
+  with AfterProcessingTime trigger.
+
+  In runners that directly execute this
+  Python code, the trigger currently always fires,
+  but this behavior is neither guaranteed nor
+  required by runners, regardless of whether they
+  execute triggers via Python.
+
+  _AfterSynchronizedProcessingTime is experimental
+  and internal-only. No backwards compatibility
+  guarantees.
+  """
+  def __init__(self):
+    pass
+
+  def __repr__(self):
+    return '_AfterSynchronizedProcessingTime()'
+
+  def __eq__(self, other):
+    return type(self) == type(other)
+
+  def __hash__(self):
+    return hash(type(self))
+
+  def on_element(self, _element, _window, _context):
+    pass
+
+  def on_merge(self, _to_be_merged, _merge_result, _context):
+    pass
+
+  def should_fire(self, _time_domain, _timestamp, _window, _context):
+    return True
+
+  def on_fire(self, _timestamp, _window, _context):
+    return False
+
+  def reset(self, _window, _context):
+    pass
+
+  @staticmethod
+  def from_runner_api(_proto, _context):
+    return _AfterSynchronizedProcessingTime()
+
+  def to_runner_api(self, _context):
+    return beam_runner_api_pb2.Trigger(
+        after_synchronized_processing_time=beam_runner_api_pb2.Trigger.
+        AfterSynchronizedProcessingTime())
+
+  def has_ontime_pane(self):
+    return False
+
+  def get_continuation_trigger(self):
+    return self
diff --git a/sdks/python/apache_beam/transforms/trigger_test.py 
b/sdks/python/apache_beam/transforms/trigger_test.py
index b9a8cdc594b..9f9b7fe51a9 100644
--- a/sdks/python/apache_beam/transforms/trigger_test.py
+++ b/sdks/python/apache_beam/transforms/trigger_test.py
@@ -554,6 +554,56 @@ class RunnerApiTest(unittest.TestCase):
           TriggerFn.from_runner_api(trigger_fn.to_runner_api(context), 
context))
 
 
+class ContinuationTriggerTest(unittest.TestCase):
+  def test_after_all(self):
+    self.assertEqual(
+        AfterAll(AfterCount(2), AfterCount(5)).get_continuation_trigger(),
+        AfterAll(AfterCount(1), AfterCount(1)))
+
+  def test_after_any(self):
+    self.assertEqual(
+        AfterAny(AfterCount(2), AfterCount(5)).get_continuation_trigger(),
+        AfterAny(AfterCount(1), AfterCount(1)))
+
+  def test_after_count(self):
+    self.assertEqual(AfterCount(1).get_continuation_trigger(), AfterCount(1))
+    self.assertEqual(AfterCount(100).get_continuation_trigger(), AfterCount(1))
+
+  def test_after_each(self):
+    self.assertEqual(
+        AfterEach(AfterCount(2), AfterCount(5)).get_continuation_trigger(),
+        Repeatedly(AfterAny(AfterCount(1), AfterCount(1))))
+
+  def test_after_processing_time(self):
+    from apache_beam.transforms.trigger import _AfterSynchronizedProcessingTime
+    self.assertEqual(
+        AfterProcessingTime(10).get_continuation_trigger(),
+        _AfterSynchronizedProcessingTime())
+
+  def test_after_watermark(self):
+    self.assertEqual(
+        AfterWatermark().get_continuation_trigger(), AfterWatermark())
+    self.assertEqual(
+        AfterWatermark(early=AfterCount(10),
+                       late=AfterCount(20)).get_continuation_trigger(),
+        AfterWatermark(early=AfterCount(1), late=AfterCount(1)))
+
+  def test_always(self):
+    self.assertEqual(Always().get_continuation_trigger(), Always())
+
+  def test_default(self):
+    self.assertEqual(
+        DefaultTrigger().get_continuation_trigger(), DefaultTrigger())
+
+  def test_never(self):
+    self.assertEqual(_Never().get_continuation_trigger(), _Never())
+
+  def test_repeatedly(self):
+    self.assertEqual(
+        Repeatedly(AfterCount(10)).get_continuation_trigger(),
+        Repeatedly(AfterCount(1)))
+
+
 class TriggerPipelineTest(unittest.TestCase):
   def test_after_processing_time(self):
     test_options = PipelineOptions(

Reply via email to