This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 578d6f6  Add max buffering duration for GroupIntoBatches in Python SDK
     new 5620419  Merge pull request #13144 from [BEAM-10475] Add max buffering 
duration option for GroupIntoBatches transform in Python
578d6f6 is described below

commit 578d6f6816311d3a649608a5ec33d40d174d7e7b
Author: sychen <syc...@google.com>
AuthorDate: Mon Oct 19 16:19:49 2020 -0700

    Add max buffering duration for GroupIntoBatches in Python SDK
---
 .../apache_beam/runners/direct/direct_runner.py    |   7 ++
 sdks/python/apache_beam/transforms/util.py         |  76 ++++++++----
 sdks/python/apache_beam/transforms/util_test.py    | 132 +++++++++++++++------
 3 files changed, 157 insertions(+), 58 deletions(-)

diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py 
b/sdks/python/apache_beam/runners/direct/direct_runner.py
index 8f221aa..914cebd 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -46,11 +46,13 @@ from apache_beam.runners.direct.clock import TestClock
 from apache_beam.runners.runner import PipelineResult
 from apache_beam.runners.runner import PipelineRunner
 from apache_beam.runners.runner import PipelineState
+from apache_beam.transforms import userstate
 from apache_beam.transforms.core import CombinePerKey
 from apache_beam.transforms.core import CombineValuesDoFn
 from apache_beam.transforms.core import DoFn
 from apache_beam.transforms.core import ParDo
 from apache_beam.transforms.ptransform import PTransform
+from apache_beam.transforms.timeutil import TimeDomain
 from apache_beam.typehints import trivial_inference
 
 # Note that the BundleBasedDirectRunner and SwitchingDirectRunner names are
@@ -107,6 +109,11 @@ class SwitchingDirectRunner(PipelineRunner):
             if any(isinstance(arg, ArgumentPlaceholder)
                    for arg in args_to_check):
               self.supported_by_fnapi_runner = False
+          if userstate.is_stateful_dofn(dofn):
+            _, timer_specs = userstate.get_dofn_specs(dofn)
+            for timer in timer_specs:
+              if timer.time_domain == TimeDomain.REAL_TIME:
+                self.supported_by_fnapi_runner = False
 
     # Check whether all transforms used in the pipeline are supported by the
     # FnApiRunner, and the pipeline was not meant to be run as streaming.
diff --git a/sdks/python/apache_beam/transforms/util.py 
b/sdks/python/apache_beam/transforms/util.py
index e6662f0..12f9c8a 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -751,24 +751,42 @@ class GroupIntoBatches(PTransform):
   GroupIntoBatches is experimental. Its use case will depend on the runner if
   it has support of States and Timers.
   """
-  def __init__(self, batch_size):
+  def __init__(
+      self, batch_size, max_buffering_duration_secs=None, clock=time.time):
     """Create a new GroupIntoBatches with batch size.
 
     Arguments:
       batch_size: (required) How many elements should be in a batch
+      max_buffering_duration_secs: (optional) How long in seconds at most an
+        incomplete batch of elements is allowed to be buffered in the states.
+        The duration must be a positive second duration and should be given as
+        an int or float.
+      clock: (optional) an alternative to time.time (mostly for testing)
     """
     self.batch_size = batch_size
 
+    if max_buffering_duration_secs is not None:
+      assert max_buffering_duration_secs > 0, (
+          'max buffering duration should be a positive value')
+    self.max_buffering_duration_secs = max_buffering_duration_secs
+    self.clock = clock
+
   def expand(self, pcoll):
     input_coder = coders.registry.get_coder(pcoll)
     return pcoll | ParDo(
-        _pardo_group_into_batches(self.batch_size, input_coder))
+        _pardo_group_into_batches(
+            input_coder,
+            self.batch_size,
+            self.max_buffering_duration_secs,
+            self.clock))
 
 
-def _pardo_group_into_batches(batch_size, input_coder):
+def _pardo_group_into_batches(
+    input_coder, batch_size, max_buffering_duration_secs, clock=time.time):
   ELEMENT_STATE = BagStateSpec('values', input_coder)
   COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn())
-  EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
+  WINDOW_TIMER = TimerSpec('window_end', TimeDomain.WATERMARK)
+  BUFFERING_TIMER = TimerSpec('buffering_end', TimeDomain.REAL_TIME)
 
   class _GroupIntoBatchesDoFn(DoFn):
     def process(
@@ -777,33 +795,47 @@ def _pardo_group_into_batches(batch_size, input_coder):
         window=DoFn.WindowParam,
         element_state=DoFn.StateParam(ELEMENT_STATE),
         count_state=DoFn.StateParam(COUNT_STATE),
-        expiry_timer=DoFn.TimerParam(EXPIRY_TIMER)):
+        window_timer=DoFn.TimerParam(WINDOW_TIMER),
+        buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):
       # Allowed lateness not supported in Python SDK
       # 
https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data
-      expiry_timer.set(window.end)
+      window_timer.set(window.end)
       element_state.add(element)
       count_state.add(1)
       count = count_state.read()
+      if count == 1 and max_buffering_duration_secs is not None:
+        # This is the first element in batch. Start counting buffering time if 
a
+        # limit was set.
+        buffering_timer.set(clock() + max_buffering_duration_secs)
       if count >= batch_size:
-        batch = [element for element in element_state.read()]
-        key, _ = batch[0]
-        batch_values = [v for (k, v) in batch]
-        yield (key, batch_values)
-        element_state.clear()
-        count_state.clear()
-
-    @on_timer(EXPIRY_TIMER)
-    def expiry(
+        return self.flush_batch(element_state, count_state, buffering_timer)
+
+    @on_timer(WINDOW_TIMER)
+    def on_window_timer(
+        self,
+        element_state=DoFn.StateParam(ELEMENT_STATE),
+        count_state=DoFn.StateParam(COUNT_STATE),
+        buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):
+      return self.flush_batch(element_state, count_state, buffering_timer)
+
+    @on_timer(BUFFERING_TIMER)
+    def on_buffering_timer(
         self,
         element_state=DoFn.StateParam(ELEMENT_STATE),
-        count_state=DoFn.StateParam(COUNT_STATE)):
+        count_state=DoFn.StateParam(COUNT_STATE),
+        buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):
+      return self.flush_batch(element_state, count_state, buffering_timer)
+
+    def flush_batch(self, element_state, count_state, buffering_timer):
       batch = [element for element in element_state.read()]
-      if batch:
-        key, _ = batch[0]
-        batch_values = [v for (k, v) in batch]
-        yield (key, batch_values)
-        element_state.clear()
-        count_state.clear()
+      if not batch:
+        return
+      key, _ = batch[0]
+      batch_values = [v for (k, v) in batch]
+      element_state.clear()
+      count_state.clear()
+      buffering_timer.clear()
+      yield key, batch_values
 
   return _GroupIntoBatchesDoFn()
 
diff --git a/sdks/python/apache_beam/transforms/util_test.py 
b/sdks/python/apache_beam/transforms/util_test.py
index f22ea2c..cbca2a1 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -22,7 +22,6 @@
 from __future__ import absolute_import
 from __future__ import division
 
-import itertools
 import logging
 import math
 import random
@@ -38,6 +37,8 @@ import future.tests.base  # pylint: disable=unused-import
 from nose.plugins.attrib import attr
 
 import apache_beam as beam
+from apache_beam import GroupByKey
+from apache_beam import Map
 from apache_beam import WindowInto
 from apache_beam.coders import coders
 from apache_beam.options.pipeline_options import PipelineOptions
@@ -48,8 +49,12 @@ from apache_beam.testing.util import TestWindowedValue
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import contains_in_any_order
 from apache_beam.testing.util import equal_to
+from apache_beam.transforms import trigger
 from apache_beam.transforms import util
 from apache_beam.transforms import window
+from apache_beam.transforms.core import FlatMapTuple
+from apache_beam.transforms.trigger import AfterCount
+from apache_beam.transforms.trigger import Repeatedly
 from apache_beam.transforms.window import FixedWindows
 from apache_beam.transforms.window import GlobalWindow
 from apache_beam.transforms.window import GlobalWindows
@@ -67,8 +72,8 @@ warnings.filterwarnings(
 
 
 class FakeClock(object):
-  def __init__(self):
-    self._now = time.time()
+  def __init__(self, now=time.time()):
+    self._now = now
 
   def __call__(self):
     return self._now
@@ -110,10 +115,16 @@ class BatchElementsTest(unittest.TestCase):
           | util.BatchElements(
               min_batch_size=5, max_batch_size=10, clock=FakeClock())
           | beam.Map(len))
-      assert_that(res, equal_to([
-          5, 5, 10, 10,  # elements in [0, 30)
-          10, 7,         # elements in [30, 47)
-      ]))
+      assert_that(
+          res,
+          equal_to([
+              5,
+              5,
+              10,
+              10,  # elements in [0, 30)
+              10,
+              7,  # elements in [30, 47)
+          ]))
 
   def test_target_duration(self):
     clock = FakeClock()
@@ -659,45 +670,94 @@ class GroupIntoBatchesTest(unittest.TestCase):
                       GroupIntoBatchesTest.BATCH_SIZE))
           ]))
 
-  @unittest.skip('BEAM-8748')
-  def test_in_streaming_mode(self):
-    timestamp_interval = 1
-    offset = itertools.count(0)
-    start_time = timestamp.Timestamp(0)
+  def test_buffering_timer_in_fixed_window_streaming(self):
     window_duration = 6
-    test_stream = (
-        TestStream().advance_watermark_to(start_time).add_elements([
-            TimestampedValue(x, next(offset) * timestamp_interval)
-            for x in GroupIntoBatchesTest._create_test_data()
-        ]).advance_watermark_to(start_time +
-                                (window_duration - 1)).advance_watermark_to(
-                                    start_time + (window_duration + 1)).
-        advance_watermark_to(
-            start_time +
-            GroupIntoBatchesTest.NUM_ELEMENTS).advance_watermark_to_infinity())
+    max_buffering_duration_secs = 100
+
+    start_time = timestamp.Timestamp(0)
+    test_stream = TestStream().add_elements(
+        [TimestampedValue(value, start_time + i)
+         for i, value in enumerate(GroupIntoBatchesTest._create_test_data())]) 
\
+      .advance_watermark_to(
+        start_time + GroupIntoBatchesTest.NUM_ELEMENTS + 1) \
+      .advance_processing_time(100) \
+      .advance_watermark_to_infinity()
+
     with TestPipeline(options=StandardOptions(streaming=True)) as pipeline:
-      # window duration is 6 and batch size is 5, so output batch size
-      # should be 5 (flush because of batchSize reached)
+      # To trigger the processing time timer, use a fake clock with start time
+      # being Timestamp(0).
+      fake_clock = FakeClock(now=start_time)
+
+      num_elements_per_batch = (
+          pipeline | test_stream
+          | "fixed window" >> WindowInto(FixedWindows(window_duration))
+          | util.GroupIntoBatches(
+              GroupIntoBatchesTest.BATCH_SIZE,
+              max_buffering_duration_secs,
+              fake_clock)
+          | "count elements in batch" >> Map(lambda x: (None, len(x[1])))
+          | "global window" >> WindowInto(GlobalWindows())
+          | GroupByKey()
+          | FlatMapTuple(lambda k, vs: vs))
+
+      # Window duration is 6 and batch size is 5, so output batch size
+      # should be 5 (flush because of batch size reached).
       expected_0 = 5
-      # there is only one element left in the window so batch size
-      # should be 1 (flush because of end of window reached)
+      # There is only one element left in the window so batch size
+      # should be 1 (flush because of end of window reached).
       expected_1 = 1
-      # collection is 10 elements, there is only 4 left, so batch size
-      # should be 4 (flush because end of collection reached)
+      # Collection has 10 elements, there are only 4 left, so batch size should
+      # be 4 (flush because of max buffering duration reached).
       expected_2 = 4
-
-      collection = pipeline | test_stream \
-                   | WindowInto(FixedWindows(window_duration)) \
-                   | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE)
-      num_elements_in_batches = collection | beam.Map(len)
       assert_that(
-          num_elements_in_batches,
-          equal_to([expected_0, expected_1, expected_2]))
+          num_elements_per_batch,
+          equal_to([expected_0, expected_1, expected_2]),
+          "assert2")
+
+  def test_buffering_timer_in_global_window_streaming(self):
+    max_buffering_duration_secs = 42
+
+    start_time = timestamp.Timestamp(0)
+    test_stream = TestStream().advance_watermark_to(start_time)
+    for i, value in enumerate(GroupIntoBatchesTest._create_test_data()):
+      test_stream.add_elements(
+          [TimestampedValue(value, start_time + i)]) \
+        .advance_processing_time(5)
+    test_stream.advance_watermark_to(
+        start_time + GroupIntoBatchesTest.NUM_ELEMENTS + 1) \
+      .advance_watermark_to_infinity()
+
+    with TestPipeline(options=StandardOptions(streaming=True)) as pipeline:
+      # Set a batch size larger than the total number of elements.
+      # Since we're in a global window, we would have been waiting
+      # for all the elements to arrive without the buffering time limit.
+      batch_size = GroupIntoBatchesTest.NUM_ELEMENTS * 2
+
+      # To trigger the processing time timer, use a fake clock with start time
+      # being Timestamp(0). Since the fake clock never really advances during
+      # the pipeline execution, meaning that the timer is always set to the 
same
+      # value, the timer will be fired on every element after the first firing.
+      fake_clock = FakeClock(now=start_time)
+
+      num_elements_per_batch = (
+          pipeline | test_stream
+          | WindowInto(
+              GlobalWindows(),
+              trigger=Repeatedly(AfterCount(1)),
+              accumulation_mode=trigger.AccumulationMode.DISCARDING)
+          | util.GroupIntoBatches(
+              batch_size, max_buffering_duration_secs, fake_clock)
+          | 'count elements in batch' >> Map(lambda x: (None, len(x[1])))
+          | GroupByKey()
+          | FlatMapTuple(lambda k, vs: vs))
+
+      # We will flush twice when the max buffering duration is reached and when
+      # the global window ends.
+      assert_that(num_elements_per_batch, equal_to([9, 1]))
 
 
 class ToStringTest(unittest.TestCase):
   def test_tostring_elements(self):
-
     with TestPipeline() as p:
       result = (p | beam.Create([1, 1, 2, 3]) | util.ToString.Element())
       assert_that(result, equal_to(["1", "1", "2", "3"]))

Reply via email to