This is an automated email from the ASF dual-hosted git repository.
altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new ad88aa8 [Beam-6696] GroupIntoBatches transform for Python SDK (#8914)
ad88aa8 is described below
commit ad88aa83977a99e2b1e602f944d62c966ac24e40
Author: Raheel Khan <[email protected]>
AuthorDate: Sat Jun 22 01:55:29 2019 +0500
[Beam-6696] GroupIntoBatches transform for Python SDK (#8914)
GroupIntoBatches transform in the Python SDK
---
sdks/python/apache_beam/transforms/util.py | 77 +++++++++++++++++++++++-
sdks/python/apache_beam/transforms/util_test.py | 78 +++++++++++++++++++++++++
2 files changed, 154 insertions(+), 1 deletion(-)
diff --git a/sdks/python/apache_beam/transforms/util.py
b/sdks/python/apache_beam/transforms/util.py
index 4388f6a..dd5817d 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -25,16 +25,19 @@ import collections
import contextlib
import random
import time
+import warnings
from builtins import object
from builtins import range
from builtins import zip
from future.utils import itervalues
+from apache_beam import coders
from apache_beam import typehints
from apache_beam.metrics import Metrics
from apache_beam.portability import common_urns
from apache_beam.transforms import window
+from apache_beam.transforms.combiners import CountCombineFn
from apache_beam.transforms.core import CombinePerKey
from apache_beam.transforms.core import DoFn
from apache_beam.transforms.core import FlatMap
@@ -45,13 +48,19 @@ from apache_beam.transforms.core import ParDo
from apache_beam.transforms.core import Windowing
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.ptransform import ptransform_fn
+from apache_beam.transforms.timeutil import TimeDomain
from apache_beam.transforms.trigger import AccumulationMode
from apache_beam.transforms.trigger import AfterCount
+from apache_beam.transforms.userstate import BagStateSpec
+from apache_beam.transforms.userstate import CombiningValueStateSpec
+from apache_beam.transforms.userstate import TimerSpec
+from apache_beam.transforms.userstate import on_timer
from apache_beam.transforms.window import NonMergingWindowFn
from apache_beam.transforms.window import TimestampCombiner
from apache_beam.transforms.window import TimestampedValue
from apache_beam.utils import windowed_value
from apache_beam.utils.annotations import deprecated
+from apache_beam.utils.annotations import experimental
__all__ = [
'BatchElements',
@@ -64,7 +73,8 @@ __all__ = [
'Reshuffle',
'ToString',
'Values',
- 'WithKeys'
+ 'WithKeys',
+ 'GroupIntoBatches'
]
K = typehints.TypeVariable('K')
@@ -671,6 +681,71 @@ def WithKeys(pcoll, k):
return pcoll | Map(lambda v: (k, v))
+@experimental()
[email protected]_input_types(typehints.KV[K, V])
+class GroupIntoBatches(PTransform):
+ """PTransform that batches the input into desired batch size. Elements are
+ buffered until they are equal to batch size provided in the argument at which
+ point they are output to the output Pcollection.
+
+ Windows are preserved (batches will contain elements from the same window)
+
+ GroupIntoBatches is experimental. Its use case will depend on the runner if
+ it has support of States and Timers.
+ """
+
+ def __init__(self, batch_size):
+ """Create a new GroupIntoBatches with batch size.
+
+ Arguments:
+ batch_size: (required) How many elements should be in a batch
+ """
+ warnings.warn('Use of GroupIntoBatches transform requires State/Timer '
+ 'support from the runner')
+ self.batch_size = batch_size
+
+ def expand(self, pcoll):
+ input_coder = coders.registry.get_coder(pcoll)
+ return pcoll | ParDo(_pardo_group_into_batches(
+ self.batch_size, input_coder))
+
+
+def _pardo_group_into_batches(batch_size, input_coder):
+ ELEMENT_STATE = BagStateSpec('values', input_coder)
+ COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn())
+ EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
+
+ class _GroupIntoBatchesDoFn(DoFn):
+
+ def process(self, element,
+ window=DoFn.WindowParam,
+ element_state=DoFn.StateParam(ELEMENT_STATE),
+ count_state=DoFn.StateParam(COUNT_STATE),
+ expiry_timer=DoFn.TimerParam(EXPIRY_TIMER)):
+ # Allowed lateness not supported in Python SDK
+ #
https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data
+ expiry_timer.set(window.end)
+ element_state.add(element)
+ count_state.add(1)
+ count = count_state.read()
+ if count >= batch_size:
+ batch = [element for element in element_state.read()]
+ yield batch
+ element_state.clear()
+ count_state.clear()
+
+ @on_timer(EXPIRY_TIMER)
+ def expiry(self, element_state=DoFn.StateParam(ELEMENT_STATE),
+ count_state=DoFn.StateParam(COUNT_STATE)):
+ batch = [element for element in element_state.read()]
+ if batch:
+ yield batch
+ element_state.clear()
+ count_state.clear()
+
+ return _GroupIntoBatchesDoFn()
+
+
class ToString(object):
"""
PTransform for converting a PCollection element, KV or PCollection Iterable
diff --git a/sdks/python/apache_beam/transforms/util_test.py
b/sdks/python/apache_beam/transforms/util_test.py
index b655f66..ae952f6 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -20,7 +20,9 @@
from __future__ import absolute_import
from __future__ import division
+import itertools
import logging
+import math
import random
import time
import unittest
@@ -28,16 +30,19 @@ from builtins import object
from builtins import range
import apache_beam as beam
+from apache_beam import WindowInto
from apache_beam.coders import coders
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.util import TestWindowedValue
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import contains_in_any_order
from apache_beam.testing.util import equal_to
from apache_beam.transforms import util
from apache_beam.transforms import window
+from apache_beam.transforms.window import FixedWindows
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import GlobalWindows
from apache_beam.transforms.window import IntervalWindow
@@ -432,6 +437,79 @@ class WithKeysTest(unittest.TestCase):
assert_that(with_keys, equal_to([(1, 1), (4, 2), (9, 3)]))
+class GroupIntoBatchesTest(unittest.TestCase):
+ NUM_ELEMENTS = 10
+ BATCH_SIZE = 5
+
+ @staticmethod
+ def _create_test_data():
+ scientists = [
+ "Einstein",
+ "Darwin",
+ "Copernicus",
+ "Pasteur",
+ "Curie",
+ "Faraday",
+ "Newton",
+ "Bohr",
+ "Galilei",
+ "Maxwell"
+ ]
+
+ data = []
+ for i in range(GroupIntoBatchesTest.NUM_ELEMENTS):
+ index = i % len(scientists)
+ data.append(("key", scientists[index]))
+ return data
+
+ def test_in_global_window(self):
+ pipeline = TestPipeline()
+ collection = pipeline \
+ | beam.Create(GroupIntoBatchesTest._create_test_data()) \
+ | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE)
+ num_batches = collection | beam.combiners.Count.Globally()
+ assert_that(num_batches,
+ equal_to([int(math.ceil(GroupIntoBatchesTest.NUM_ELEMENTS /
+ GroupIntoBatchesTest.BATCH_SIZE))]))
+ pipeline.run()
+
+ def test_in_streaming_mode(self):
+ timestamp_interval = 1
+ offset = itertools.count(0)
+ start_time = timestamp.Timestamp(0)
+ window_duration = 6
+ test_stream = (TestStream()
+ .advance_watermark_to(start_time)
+ .add_elements(
+ [TimestampedValue(x, next(offset) * timestamp_interval)
+ for x in GroupIntoBatchesTest._create_test_data()])
+ .advance_watermark_to(start_time + (window_duration - 1))
+ .advance_watermark_to(start_time + (window_duration + 1))
+ .advance_watermark_to(start_time +
+ GroupIntoBatchesTest.NUM_ELEMENTS)
+ .advance_watermark_to_infinity())
+ pipeline = TestPipeline()
+ # window duration is 6 and batch size is 5, so output batch size should be
+ # 5 (flush because of batchSize reached)
+ expected_0 = 5
+ # there is only one element left in the window so batch size should be 1
+ # (flush because of end of window reached)
+ expected_1 = 1
+ # collection is 10 elements, there is only 4 left, so batch size should be
+ # 4 (flush because end of collection reached)
+ expected_2 = 4
+
+ collection = pipeline | test_stream \
+ | WindowInto(FixedWindows(window_duration)) \
+ | util.GroupIntoBatches(GroupIntoBatchesTest.BATCH_SIZE)
+ num_elements_in_batches = collection | beam.Map(len)
+
+ result = pipeline.run()
+ result.wait_until_finish()
+ assert_that(num_elements_in_batches,
+ equal_to([expected_0, expected_1, expected_2]))
+
+
class ToStringTest(unittest.TestCase):
def test_tostring_elements(self):