This is an automated email from the ASF dual-hosted git repository. boyuanz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 578d6f6 Add max buffering duration for GroupIntoBatches in Python SDK new 5620419 Merge pull request #13144 from [BEAM-10475] Add max buffering duration option for GroupIntoBatches transform in Python 578d6f6 is described below commit 578d6f6816311d3a649608a5ec33d40d174d7e7b Author: sychen <syc...@google.com> AuthorDate: Mon Oct 19 16:19:49 2020 -0700 Add max buffering duration for GroupIntoBatches in Python SDK --- .../apache_beam/runners/direct/direct_runner.py | 7 ++ sdks/python/apache_beam/transforms/util.py | 76 ++++++++---- sdks/python/apache_beam/transforms/util_test.py | 132 +++++++++++++++------ 3 files changed, 157 insertions(+), 58 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 8f221aa..914cebd 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -46,11 +46,13 @@ from apache_beam.runners.direct.clock import TestClock from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineRunner from apache_beam.runners.runner import PipelineState +from apache_beam.transforms import userstate from apache_beam.transforms.core import CombinePerKey from apache_beam.transforms.core import CombineValuesDoFn from apache_beam.transforms.core import DoFn from apache_beam.transforms.core import ParDo from apache_beam.transforms.ptransform import PTransform +from apache_beam.transforms.timeutil import TimeDomain from apache_beam.typehints import trivial_inference # Note that the BundleBasedDirectRunner and SwitchingDirectRunner names are @@ -107,6 +109,11 @@ class SwitchingDirectRunner(PipelineRunner): if any(isinstance(arg, ArgumentPlaceholder) for arg in args_to_check): self.supported_by_fnapi_runner = False + if userstate.is_stateful_dofn(dofn): + _, timer_specs = userstate.get_dofn_specs(dofn) + for timer in timer_specs: + if timer.time_domain == TimeDomain.REAL_TIME: + self.supported_by_fnapi_runner = False # Check whether all transforms used in the pipeline are supported by the # FnApiRunner, and the pipeline was not meant to be run as streaming. diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index e6662f0..12f9c8a 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -751,24 +751,42 @@ class GroupIntoBatches(PTransform): GroupIntoBatches is experimental. Its use case will depend on the runner if it has support of States and Timers. """ - def __init__(self, batch_size): + def __init__( + self, batch_size, max_buffering_duration_secs=None, clock=time.time): """Create a new GroupIntoBatches with batch size. Arguments: batch_size: (required) How many elements should be in a batch + max_buffering_duration_secs: (optional) How long in seconds at most an + incomplete batch of elements is allowed to be buffered in the states. + The duration must be a positive second duration and should be given as + an int or float. + clock: (optional) an alternative to time.time (mostly for testing) """ self.batch_size = batch_size + if max_buffering_duration_secs is not None: + assert max_buffering_duration_secs > 0, ( + 'max buffering duration should be a positive value') + self.max_buffering_duration_secs = max_buffering_duration_secs + self.clock = clock + def expand(self, pcoll): input_coder = coders.registry.get_coder(pcoll) return pcoll | ParDo( - _pardo_group_into_batches(self.batch_size, input_coder)) + _pardo_group_into_batches( + input_coder, + self.batch_size, + self.max_buffering_duration_secs, + self.clock)) -def _pardo_group_into_batches(batch_size, input_coder): +def _pardo_group_into_batches( + input_coder, batch_size, max_buffering_duration_secs, clock=time.time): ELEMENT_STATE = BagStateSpec('values', input_coder) COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn()) - EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK) + WINDOW_TIMER = TimerSpec('window_end', TimeDomain.WATERMARK) + BUFFERING_TIMER = TimerSpec('buffering_end', TimeDomain.REAL_TIME) class _GroupIntoBatchesDoFn(DoFn): def process( @@ -777,33 +795,47 @@ def _pardo_group_into_batches(batch_size, input_coder): window=DoFn.WindowParam, element_state=DoFn.StateParam(ELEMENT_STATE), count_state=DoFn.StateParam(COUNT_STATE), - expiry_timer=DoFn.TimerParam(EXPIRY_TIMER)): + window_timer=DoFn.TimerParam(WINDOW_TIMER), + buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): # Allowed lateness not supported in Python SDK # https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data - expiry_timer.set(window.end) + window_timer.set(window.end) element_state.add(element) count_state.add(1) count = count_state.read() + if count == 1 and max_buffering_duration_secs is not None: + # This is the first element in batch. Start counting buffering time if a + # limit was set. + buffering_timer.set(clock() + max_buffering_duration_secs) if count >= batch_size: - batch = [element for element in element_state.read()] - key, _ = batch[0] - batch_values = [v for (k, v) in batch] - yield (key, batch_values) - element_state.clear() - count_state.clear() - - @on_timer(EXPIRY_TIMER) - def expiry( + return self.flush_batch(element_state, count_state, buffering_timer) + + @on_timer(WINDOW_TIMER) + def on_window_timer( + self, + element_state=DoFn.StateParam(ELEMENT_STATE), + count_state=DoFn.StateParam(COUNT_STATE), + buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): + return self.flush_batch(element_state, count_state, buffering_timer) + + @on_timer(BUFFERING_TIMER) + def on_buffering_timer( self, element_state=DoFn.StateParam(ELEMENT_STATE), - count_state=DoFn.StateParam(COUNT_STATE)): + count_state=DoFn.StateParam(COUNT_STATE), + buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)): + return self.flush_batch(element_state, count_state, buffering_timer) + + def flush_batch(self, element_state, count_state, buffering_timer): batch = [element for element in element_state.read()] - if batch: - key, _ = batch[0] - batch_values = [v for (k, v) in batch] - yield (key, batch_values) - element_state.clear() - count_state.clear() + if not batch: + return + key, _ = batch[0] + batch_values = [v for (k, v) in batch] + element_state.clear() + count_state.clear() + buffering_timer.clear() + yield key, batch_values return _GroupIntoBatchesDoFn() diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index f22ea2c..cbca2a1 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -22,7 +22,6 @@ from __future__ import absolute_import from __future__ import division -import itertools import logging import math import random @@ -38,6 +37,8 @@ import future.tests.base # pylint: disable=unused-import from nose.plugins.attrib import attr import apache_beam as beam +from apache_beam import GroupByKey +from apache_beam import Map from apache_beam import WindowInto from apache_beam.coders import coders from apache_beam.options.pipeline_options import PipelineOptions @@ -48,8 +49,12 @@ from apache_beam.testing.util import TestWindowedValue from apache_beam.testing.util import assert_that from apache_beam.testing.util import contains_in_any_order from apache_beam.testing.util import equal_to +from apache_beam.transforms import trigger from apache_beam.transforms import util from apache_beam.transforms import window +from apache_beam.transforms.core import FlatMapTuple +from apache_beam.transforms.trigger import AfterCount +from apache_beam.transforms.trigger import Repeatedly from apache_beam.transforms.window import FixedWindows from apache_beam.transforms.window import GlobalWindow from apache_beam.transforms.window import GlobalWindows @@ -67,8 +72,8 @@ warnings.filterwarnings( class FakeClock(object): - def __init__(self): - self._now = time.time() + def __init__(self, now=time.time()): + self._now = now def __call__(self): return self._now @@ -110,10 +115,16 @@ class BatchElementsTest(unittest.TestCase): | util.BatchElements( min_batch_size=5, max_batch_size=10, clock=FakeClock()) | beam.Map(len)) - assert_that(res, equal_to([ - 5, 5, 10, 10, # elements in [0, 30) - 10, 7, # elements in [30, 47) - ])) + assert_that( + res, + equal_to([ + 5, + 5, + 10, + 10, # elements in [0, 30) + 10, + 7, # elements in [30, 47) + ])) def test_target_duration(self): clock = FakeClock() @@ -659,45 +670,94 @@ class GroupIntoBatchesTest(unittest.TestCase): GroupIntoBatchesTest.BATCH_SIZE)) ])) - @unittest.skip('BEAM-8748') - def test_in_streaming_mode(self): - timestamp_interval = 1 - offset = itertools.count(0) - start_time = timestamp.Timestamp(0) + def test_buffering_timer_in_fixed_window_streaming(self): window_duration = 6 - test_stream = ( - TestStream().advance_watermark_to(start_time).add_elements([ - TimestampedValue(x, next(offset) * timestamp_interval) - for x in GroupIntoBatchesTest._create_test_data() - ]).advance_watermark_to(start_time + - (window_duration - 1)).advance_watermark_to( - start_time + (window_duration + 1)). - advance_watermark_to( - start_time + - GroupIntoBatchesTest.NUM_ELEMENTS).advance_watermark_to_infinity()) + max_buffering_duration_secs = 100 + + start_time = timestamp.Timestamp(0) + test_stream = TestStream().add_elements( + [TimestampedValue(value, start_time + i) + for i, value in enumerate(GroupIntoBatchesTest._create_test_data())]) \ + .advance_watermark_to( + start_time + GroupIntoBatchesTest.NUM_ELEMENTS + 1) \ + .advance_processing_time(100) \ + .advance_watermark_to_infinity() + with TestPipeline(options=StandardOptions(streaming=True)) as pipeline: - # window duration is 6 and batch size is 5, so output batch size - # should be 5 (flush because of batchSize reached) + # To trigger the processing time timer, use a fake clock with start time + # being Timestamp(0). + fake_clock = FakeClock(now=start_time) + + num_elements_per_batch = ( + pipeline | test_stream + | "fixed window" >> WindowInto(FixedWindows(window_duration)) + | util.GroupIntoBatches( + GroupIntoBatchesTest.BATCH_SIZE, + max_buffering_duration_secs, + fake_clock) + | "count elements in batch" >> Map(lambda x: (None, len(x[1]))) + | "global window" >> WindowInto(GlobalWindows()) + | GroupByKey() + | FlatMapTuple(lambda k, vs: vs)) + + # Window duration is 6 and batch size is 5, so output batch size + # should be 5 (flush because of batch size reached). expected_0 = 5 - # there is only one element left in the window so batch size - # should be 1 (flush because of end of window reached) + # There is only one element left in the window so batch size + # should be 1 (flush because of end of window reached). expected_1 = 1 - # collection is 10 elements, there is only 4 left, so batch size - # should be 4 (flush because end of collection reached) + # Collection has 10 elements, there are only 4 left, so batch size should + # be 4 (flush because of max buffering duration reached). expected_2 = 4 - - collection = pipeline | test_stream \ - | WindowInto(FixedWindows(window_duration)) \ - | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE) - num_elements_in_batches = collection | beam.Map(len) assert_that( - num_elements_in_batches, - equal_to([expected_0, expected_1, expected_2])) + num_elements_per_batch, + equal_to([expected_0, expected_1, expected_2]), + "assert2") + + def test_buffering_timer_in_global_window_streaming(self): + max_buffering_duration_secs = 42 + + start_time = timestamp.Timestamp(0) + test_stream = TestStream().advance_watermark_to(start_time) + for i, value in enumerate(GroupIntoBatchesTest._create_test_data()): + test_stream.add_elements( + [TimestampedValue(value, start_time + i)]) \ + .advance_processing_time(5) + test_stream.advance_watermark_to( + start_time + GroupIntoBatchesTest.NUM_ELEMENTS + 1) \ + .advance_watermark_to_infinity() + + with TestPipeline(options=StandardOptions(streaming=True)) as pipeline: + # Set a batch size larger than the total number of elements. + # Since we're in a global window, we would have been waiting + # for all the elements to arrive without the buffering time limit. + batch_size = GroupIntoBatchesTest.NUM_ELEMENTS * 2 + + # To trigger the processing time timer, use a fake clock with start time + # being Timestamp(0). Since the fake clock never really advances during + # the pipeline execution, meaning that the timer is always set to the same + # value, the timer will be fired on every element after the first firing. + fake_clock = FakeClock(now=start_time) + + num_elements_per_batch = ( + pipeline | test_stream + | WindowInto( + GlobalWindows(), + trigger=Repeatedly(AfterCount(1)), + accumulation_mode=trigger.AccumulationMode.DISCARDING) + | util.GroupIntoBatches( + batch_size, max_buffering_duration_secs, fake_clock) + | 'count elements in batch' >> Map(lambda x: (None, len(x[1]))) + | GroupByKey() + | FlatMapTuple(lambda k, vs: vs)) + + # We will flush twice when the max buffering duration is reached and when + # the global window ends. + assert_that(num_elements_per_batch, equal_to([9, 1])) class ToStringTest(unittest.TestCase): def test_tostring_elements(self): - with TestPipeline() as p: result = (p | beam.Create([1, 1, 2, 3]) | util.ToString.Element()) assert_that(result, equal_to(["1", "1", "2", "3"]))