This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 2d53926542f Avoid oversizing batch sizes with size estimation function 
(#31228)
2d53926542f is described below

commit 2d53926542f82a8b955eb541f13475f9bef091a7
Author: Danny McCormick <dannymccorm...@google.com>
AuthorDate: Fri May 10 10:12:17 2024 -0400

    Avoid oversizing batch sizes with size estimation function (#31228)
    
    * Avoid oversizing batch sizes with size estimation function
    
    * lint
---
 sdks/python/apache_beam/transforms/util.py      | 18 ++++++++------
 sdks/python/apache_beam/transforms/util_test.py | 33 ++++++++++++++++++++++---
 2 files changed, 40 insertions(+), 11 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/util.py 
b/sdks/python/apache_beam/transforms/util.py
index edf79b7c798..750d98f0789 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -579,14 +579,15 @@ class _GlobalWindowsBatchingDoFn(DoFn):
     self._batch_size_estimator.ignore_next_timing()
 
   def process(self, element):
-    self._batch.append(element)
-    self._running_batch_size += self._element_size_fn(element)
-    if self._running_batch_size >= self._target_batch_size:
+    element_size = self._element_size_fn(element)
+    if self._running_batch_size + element_size > self._target_batch_size:
       with self._batch_size_estimator.record_time(self._running_batch_size):
         yield window.GlobalWindows.windowed_value_at_end_of_window(self._batch)
       self._batch = []
       self._running_batch_size = 0
       self._target_batch_size = self._batch_size_estimator.next_batch_size()
+    self._batch.append(element)
+    self._running_batch_size += element_size
 
   def finish_bundle(self):
     if self._batch:
@@ -621,15 +622,18 @@ class _WindowAwareBatchingDoFn(DoFn):
 
   def process(self, element, window=DoFn.WindowParam):
     batch = self._batches[window]
-    batch.elements.append(element)
-    batch.size += self._element_size_fn(element)
-    if batch.size >= self._target_batch_size:
+    element_size = self._element_size_fn(element)
+    if batch.size + element_size > self._target_batch_size:
       with self._batch_size_estimator.record_time(batch.size):
         yield windowed_value.WindowedValue(
             batch.elements, window.max_timestamp(), (window, ))
       del self._batches[window]
       self._target_batch_size = self._batch_size_estimator.next_batch_size()
-    elif len(self._batches) > self._MAX_LIVE_WINDOWS:
+
+    self._batches[window].elements.append(element)
+    self._batches[window].size += element_size
+
+    if len(self._batches) > self._MAX_LIVE_WINDOWS:
       window, batch = max(
           self._batches.items(),
           key=lambda window_batch: window_batch[1].size)
diff --git a/sdks/python/apache_beam/transforms/util_test.py 
b/sdks/python/apache_beam/transforms/util_test.py
index 53898d57998..74d9f438a5d 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -299,15 +299,40 @@ class BatchElementsTest(unittest.TestCase):
       res = (
           p
           | beam.Create([
-              'a', 'a', 'aaaaaaaaaa',  # First batch.
-              'aaaaaa', 'aaaaa',       # Second batch.
-              'a', 'aaaaaaa', 'a', 'a' # Third batch.
+              'a', 'a',                # First batch.
+              'aaaaaaaaaa',            # Second batch.
+              'aaaaa', 'aaaaa',        # Third batch.
+              'a', 'aaaaaaa', 'a', 'a' # Fourth batch.
               ], reshuffle=False)
           | util.BatchElements(
               min_batch_size=10, max_batch_size=10, element_size_fn=len)
           | beam.Map(lambda batch: ''.join(batch))
           | beam.Map(len))
-      assert_that(res, equal_to([12, 11, 10]))
+      assert_that(res, equal_to([2, 10, 10, 10]))
+
+  def test_sized_windowed_batches(self):
+    # Assumes a single bundle, in order...
+    with TestPipeline() as p:
+      res = (
+          p
+          | beam.Create(range(1, 8), reshuffle=False)
+          | beam.Map(lambda t: window.TimestampedValue('a' * t, t))
+          | beam.WindowInto(window.FixedWindows(3))
+          | util.BatchElements(
+              min_batch_size=11,
+              max_batch_size=11,
+              element_size_fn=len,
+              clock=FakeClock())
+          | beam.Map(lambda batch: ''.join(batch)))
+      assert_that(
+          res,
+          equal_to([
+              'a' * (1+2), # Elements in [1, 3)
+              'a' * (3+4), # Elements in [3, 6)
+              'a' * 5,
+              'a' * 6, # Elements in [6, 9)
+              'a' * 7,
+          ]))
 
   def test_target_duration(self):
     clock = FakeClock()

Reply via email to