jrmccluskey commented on code in PR #37532:
URL: https://github.com/apache/beam/pull/37532#discussion_r3094244519


##########
sdks/python/apache_beam/transforms/util_test.py:
##########
@@ -1026,6 +1026,384 @@ def test_stateful_grows_to_max_batch(self):
       assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50]))
 
 
+class SortAndBatchElementsTest(unittest.TestCase):
+  """Tests for SortAndBatchElements transform."""
+  def test_elements_are_sorted_by_size(self):
+    """Test that elements are sorted by size within batches."""
+    with TestPipeline() as p:
+      # Create elements with varying sizes
+      data = ['aaaaa', 'bb', 'cccc', 'a', 'ddd']
+      res = (
+          p
+          | beam.Create(data, reshuffle=False)
+          | util.SortAndBatchElements(
+              min_batch_size=1, max_batch_size=5, max_batch_weight=100))
+
+      def check_sorted(batch):
+        lengths = [len(s) for s in batch]
+        assert lengths == sorted(lengths), (

Review Comment:
   I would recommend listing the expected data here just for clarity, even if 
just in a comment block



##########
sdks/python/apache_beam/transforms/util_test.py:
##########
@@ -1026,6 +1026,384 @@ def test_stateful_grows_to_max_batch(self):
       assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50]))
 
 
+class SortAndBatchElementsTest(unittest.TestCase):
+  """Tests for SortAndBatchElements transform."""
+  def test_elements_are_sorted_by_size(self):
+    """Test that elements are sorted by size within batches."""
+    with TestPipeline() as p:
+      # Create elements with varying sizes
+      data = ['aaaaa', 'bb', 'cccc', 'a', 'ddd']
+      res = (
+          p
+          | beam.Create(data, reshuffle=False)
+          | util.SortAndBatchElements(
+              min_batch_size=1, max_batch_size=5, max_batch_weight=100))
+
+      def check_sorted(batch):
+        lengths = [len(s) for s in batch]
+        assert lengths == sorted(lengths), (
+            f'Batch not sorted by size: {lengths}')
+        return batch
+
+      _ = res | beam.Map(check_sorted)
+
+  def test_batch_respects_max_batch_size(self):
+    """Test that batches do not exceed max_batch_size."""
+    with TestPipeline() as p:
+      res = (
+          p
+          | beam.Create(['a'] * 10, reshuffle=False)
+          | util.SortAndBatchElements(
+              min_batch_size=1, max_batch_size=3, max_batch_weight=100)
+          | beam.Map(len))
+      assert_that(res, equal_to([3, 3, 3, 1]))
+
+  def test_batch_respects_max_batch_weight(self):
+    """Test that batches do not exceed max_batch_weight."""
+    with TestPipeline() as p:
+      # Each element has size 5, max_batch_weight is 12
+      # So we can fit at most 2 elements per batch
+      data = ['aaaaa', 'bbbbb', 'ccccc', 'ddddd']
+      res = (
+          p
+          | beam.Create(data, reshuffle=False)
+          | util.SortAndBatchElements(
+              min_batch_size=1, max_batch_size=10, max_batch_weight=12)
+          | beam.Map(len))
+      assert_that(res, equal_to([2, 2]))
+
+  def test_default_element_size_fn_with_strings(self):
+    """Test default element_size_fn works with strings."""
+    with TestPipeline() as p:
+      data = ['a', 'bbb', 'cc']
+      res = (
+          p
+          | beam.Create(data, reshuffle=False)
+          | util.SortAndBatchElements(
+              min_batch_size=1, max_batch_size=3, max_batch_weight=100)
+          | beam.FlatMap(lambda batch: [len(s) for s in batch]))
+      # Elements should be sorted by length: 'a'(1), 'cc'(2), 'bbb'(3)
+      assert_that(res, equal_to([1, 2, 3]))
+
+  def test_default_element_size_fn_with_integers(self):
+    """Test default element_size_fn falls back to 1 for integers."""
+    with TestPipeline() as p:
+      data = [10, 20, 30, 40, 50]
+      res = (
+          p
+          | beam.Create(data, reshuffle=False)
+          | util.SortAndBatchElements(
+              min_batch_size=1, max_batch_size=3, max_batch_weight=100)
+          | beam.Map(len))
+      # With size=1 for all, should batch by max_batch_size
+      assert_that(res, equal_to([3, 2]))
+
+  def test_custom_element_size_fn(self):
+    """Test using a custom element_size_fn."""
+    with TestPipeline() as p:
+      data = [{'text': 'a'}, {'text': 'bbb'}, {'text': 'cc'}]
+      res = (
+          p
+          | beam.Create(data, reshuffle=False)
+          | util.SortAndBatchElements(
+              min_batch_size=1,
+              max_batch_size=3,
+              max_batch_weight=100,
+              element_size_fn=lambda x: len(x['text']))
+          | beam.FlatMap(lambda batch: [len(e['text']) for e in batch]))
+      # Should be sorted by text length
+      assert_that(res, equal_to([1, 2, 3]))
+
+  def test_empty_input(self):
+    """Test with empty input produces no output."""
+    with TestPipeline() as p:
+      res = (
+          p
+          | beam.Create([], reshuffle=False)
+          | util.SortAndBatchElements(
+              min_batch_size=1, max_batch_size=10, max_batch_weight=100)
+          | beam.Map(len))
+      assert_that(res, equal_to([]))
+
+  def test_single_element(self):
+    """Test with a single element."""
+    with TestPipeline() as p:
+      res = (
+          p
+          | beam.Create(['hello'], reshuffle=False)
+          | util.SortAndBatchElements(
+              min_batch_size=1, max_batch_size=10, max_batch_weight=100))
+      assert_that(res, equal_to([['hello']]))
+
+  def test_windowed_batches(self):
+    """Test that windowed elements are batched per window."""
+    with TestPipeline('FnApiRunner') as p:
+      res = (
+          p
+          | beam.Create(range(1, 8), reshuffle=False)
+          | beam.Map(lambda t: window.TimestampedValue('a' * t, t))
+          | beam.WindowInto(window.FixedWindows(3))
+          | util.SortAndBatchElements(
+              min_batch_size=1, max_batch_size=10, max_batch_weight=100)
+          | beam.Map(lambda batch: ''.join(batch)))
+      # FixedWindows(3) with default offset 0 produces:
+      # Window [0, 3): elements at t=1,2 with sizes 1,2
+      # Window [3, 6): elements at t=3,4,5 with sizes 3,4,5
+      # Window [6, 9): elements at t=6,7 with sizes 6,7
+      assert_that(
+          res,
+          equal_to([
+              'a' * (1 + 2),  # Window [0, 3)
+              'a' * (3 + 4 + 5),  # Window [3, 6)
+              'a' * (6 + 7),  # Window [6, 9)
+          ]))
+
+  def test_validation_min_batch_size(self):
+    """Test that min_batch_size validation raises ValueError."""
+    with self.assertRaises(ValueError) as cm:
+      util.SortAndBatchElements(
+          min_batch_size=0, max_batch_size=10, max_batch_weight=100)
+    self.assertIn('min_batch_size must be >= 1', str(cm.exception))
+
+  def test_validation_max_batch_size(self):
+    """Test that max_batch_size < min_batch_size raises ValueError."""
+    with self.assertRaises(ValueError) as cm:
+      util.SortAndBatchElements(
+          min_batch_size=10, max_batch_size=5, max_batch_weight=100)
+    self.assertIn('max_batch_size', str(cm.exception))
+    self.assertIn('min_batch_size', str(cm.exception))
+
+  def test_validation_max_batch_weight(self):
+    """Test that max_batch_weight validation raises ValueError."""
+    with self.assertRaises(ValueError) as cm:
+      util.SortAndBatchElements(
+          min_batch_size=1, max_batch_size=10, max_batch_weight=0)
+    self.assertIn('max_batch_weight must be >= 1', str(cm.exception))
+
+  def test_validation_element_size_fn_callable(self):
+    """Test that a non-callable element_size_fn raises TypeError."""
+    with self.assertRaises(TypeError) as cm:
+      util.SortAndBatchElements(
+          min_batch_size=1,
+          max_batch_size=10,
+          max_batch_weight=100,
+          element_size_fn=123)
+    self.assertIn('element_size_fn must be callable', str(cm.exception))
+
+  def test_batch_timestamps(self):
+    """Test that batches have correct timestamps."""
+    with TestPipeline('FnApiRunner') as p:
+      res = (
+          p
+          | beam.Create(['a', 'bb', 'ccc'], reshuffle=False)
+          | util.SortAndBatchElements(
+              min_batch_size=1, max_batch_size=10, max_batch_weight=100)
+          |
+          beam.Map(lambda batch, ts=beam.DoFn.TimestampParam: (len(batch), 
ts)))
+      assert_that(res, equal_to([(3, GlobalWindow().max_timestamp())]))
+
+  def test_padding_efficiency_improvement(self):

Review Comment:
   I'm not sold on this test in particular, since I think there's a bit of an 
incongruity as far as the batching approaches being used. Using BatchElements 
with a default element weighting of 1 compared to weighting based on length 
does create a favorable outcome for this test, but putting the approaches on 
the same weighing function actually produces a better padding overhead in the 
case of traditional BatchElements (albeit creating five batches of 1, which by 
the definition here has a padding overhead of 0.) 
   
   I don't disagree that sorting and batching has benefits, but I don't think 
we necessarily need a unit test to prove it. 



##########
sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py:
##########
@@ -0,0 +1,652 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Benchmark: BatchElements vs SortAndBatchElements (weight-based splitting).

Review Comment:
   I generally like having something like this for proof-of-concept work and 
helping users pick the best options for their data. I don't love that the 
benchmark here doesn't actually use Beam or the BatchElements / 
SortAndBatchElements implementations directly, but considering that those 
implementations are generally pretty static and don't change often I'm okay 
including this code in this way.



##########
sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py:
##########
@@ -0,0 +1,652 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Benchmark: BatchElements vs SortAndBatchElements (weight-based splitting).
+
+Compares two batching strategies for variable-length inference workloads:
+
+- Baseline (BatchElements): fixed-count chunking, ignores element sizes.
+- Stateless (SortAndBatchElements): within each bundle, sorts elements
+  by size, then splits batches using max_batch_weight so that each batch
+  has a bounded total weight.  The improvement comes from sorting
+  combined with weight-based splitting: sorting clusters similar-sized
+  elements together, and the weight constraint then produces tighter
+  batches.  Sorting alone with fixed count-based boundaries yields ~0%
+  gain (verified by strict-control ablation).
+
+Padding ratio::
+
+  padding_ratio = sum(max_len_in_batch * batch_size) / sum(actual_lengths)
+  Lower is better.  1.0 = no padding waste.
+
+Methodology:
+
+- N=20 independent trials per condition (3 warmup trials excluded).
+- Same input corpus (seed=42) for A/B comparison.
+- Percentile method: linear interpolation between adjacent ranks
+  (equivalent to numpy.percentile with method='linear').
+  For N=20 trials: P50 interpolates ranks 10-11 (0-indexed 9-10),
+  P95 interpolates ranks 19-20 (0-indexed 18-19),
+  P99 interpolates near rank 20 (0-indexed 18.81).
+- Reports median [IQR] and P95 for each metric.
+- Inference model: latency = batch_size * (max_seq_len / 50)^1.5 ms
+  (simulates transformer-like scaling).
+
+Run::
+
+  python3 -m apache_beam.testing.benchmarks.sort_and_batch_benchmark
+"""
+
+import math
+import random
+import statistics
+import time
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple

Review Comment:
   Swap to the collections.abc equivalent for Callable and Sequence, use the 
native built-ins for dict, list, and tuple



##########
sdks/python/apache_beam/transforms/util.py:
##########
@@ -1319,6 +1322,283 @@ def expand(self, pcoll):
               self._batch_size_estimator, self._element_size_fn))
 
 
+class _SortAndBatchElementsDoFn(DoFn):
+  """DoFn that buffers, sorts by element size, and batches elements.
+
+  This DoFn is used internally by ``SortAndBatchElements`` for
+  PCollections with the default (global) window. It accumulates all
+  elements in the current bundle, sorts them by size in ascending order,
+  and emits optimally-sized batches on ``finish_bundle``.
+
+  Args:
+    min_batch_size: The minimum number of elements per batch. Must be >= 1.
+    max_batch_size: The maximum number of elements per batch.
+        Must be >= ``min_batch_size``.
+    max_batch_weight: The maximum total weight of elements in a batch,
+        where weight is computed by ``element_size_fn``. Must be >= 1.
+    element_size_fn: A callable mapping an element to its integer
+        size/weight.
+  """
+  def __init__(
+      self,
+      min_batch_size: int,
+      max_batch_size: int,
+      max_batch_weight: int,
+      element_size_fn: Callable[[Any], int]):
+    self._min_batch_size = min_batch_size
+    self._max_batch_size = max_batch_size
+    self._max_batch_weight = max_batch_weight
+    self._element_size_fn = element_size_fn or self._default_element_size
+    self._has_warned_type_error = False
+    self._buffer = []
+
+  def _default_element_size(self, element):
+    try:
+      return len(element)
+    except TypeError:
+      if not self._has_warned_type_error:
+        _LOGGER.warning(
+            'Element of type %s does not support len(). Falling back to '
+            'size 1. Consider providing a custom element_size_fn to '
+            'SortAndBatchElements for meaningful size-based batching.',
+            type(element).__name__)
+        self._has_warned_type_error = True
+      return 1
+
+  def start_bundle(self):
+    self._buffer = []
+
+  def process(self, element):
+    self._buffer.append(element)
+
+  def finish_bundle(self):
+    if not self._buffer:
+      return
+
+    # Sort elements by size (ascending) for optimal batching
+    # Elements of similar sizes will be grouped together
+    sorted_elements = sorted(self._buffer, key=self._element_size_fn)
+
+    batch = []
+    batch_weight = 0
+
+    for element in sorted_elements:
+      element_size = self._element_size_fn(element)
+
+      # Check if adding this element would exceed limits
+      would_exceed_count = len(batch) >= self._max_batch_size
+      would_exceed_weight = (
+          batch_weight + element_size >= self._max_batch_weight and batch)
+
+      if would_exceed_count or would_exceed_weight:
+        # Emit current batch
+        yield window.GlobalWindows.windowed_value_at_end_of_window(batch)
+        batch = []
+        batch_weight = 0
+
+      batch.append(element)
+      batch_weight += element_size
+
+    # Emit remaining elements
+    if batch:
+      yield window.GlobalWindows.windowed_value_at_end_of_window(batch)
+
+    self._buffer = None
+
+
+class _WindowAwareSortAndBatchElementsDoFn(DoFn):
+  """DoFn that buffers, sorts by element size, and batches elements per window.
+
+  This DoFn is used internally by ``SortAndBatchElements`` for
+  PCollections with non-default (e.g. fixed, sliding, or session) windows.
+  Elements are buffered per window and each window is flushed independently.
+  To prevent unbounded memory growth, when the number of live windows
+  exceeds ``_MAX_LIVE_WINDOWS`` the largest window buffer is flushed early.
+
+  Args:
+    min_batch_size: The minimum number of elements per batch. Must be >= 1.
+    max_batch_size: The maximum number of elements per batch.
+        Must be >= ``min_batch_size``.
+    max_batch_weight: The maximum total weight of elements in a batch,
+        where weight is computed by ``element_size_fn``. Must be >= 1.
+    element_size_fn: A callable mapping an element to its integer
+        size/weight.
+  """
+
+  _MAX_LIVE_WINDOWS = 10

Review Comment:
   Is 10 an arbitrary number, or were there any experiences while testing that 
led you to this number? Would it be worth making this configurable as a kwarg 
in the DoFn?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to