This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 3ff64c72266 [Python] Implement combiner deferred side inputs (#35601)
3ff64c72266 is described below
commit 3ff64c72266813b296fdbf3241412b84a1d3c0d9
Author: Hai Joey Tran <[email protected]>
AuthorDate: Tue Jul 22 14:59:49 2025 -0400
[Python] Implement combiner deferred side inputs (#35601)
* squash
* revert unnecessary fn_runner change
* improve test
* add a test with streaming
* add streaming with matching window test
* implement streaming support
* add combiner support
* implement args support
* more rigorously test which combinefn methods are called with side inputs
* clean up
* use pack/unpack terminology
* add an explanatory comment
* revert old unneeded changes
* tidy
* use temp dir for json test file
* add combineglobally test
* connect args/kwargs to all combinefn methods after all
* enable streaming for streaming tests
* remove streaming options
* move liftedcombineperkey
* add additional docstring to liftedcombineperkey
* isort
* Update sdks/python/apache_beam/transforms/combiners.py
Co-authored-by: Danny McCormick <[email protected]>
* privatize a couple transforms
---------
Co-authored-by: Danny McCormick <[email protected]>
---
.../apache_beam/runners/direct/direct_runner.py | 2 +-
.../runners/direct/helper_transforms.py | 120 -------------
sdks/python/apache_beam/transforms/combiners.py | 124 ++++++++++++++
.../apache_beam/transforms/combiners_test.py | 186 +++++++++++++++++++++
sdks/python/apache_beam/transforms/core.py | 16 ++
5 files changed, 327 insertions(+), 121 deletions(-)
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py
b/sdks/python/apache_beam/runners/direct/direct_runner.py
index fcc13ae1024..a629c12a058 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -389,7 +389,7 @@ def _get_transform_overrides(pipeline_options):
# Importing following locally to avoid a circular dependency.
from apache_beam.pipeline import PTransformOverride
- from apache_beam.runners.direct.helper_transforms import LiftedCombinePerKey
+ from apache_beam.transforms.combiners import LiftedCombinePerKey
from apache_beam.runners.direct.sdf_direct_runner import
ProcessKeyedElementsViaKeyedWorkItemsOverride
from apache_beam.runners.direct.sdf_direct_runner import
SplittableParDoOverride
diff --git a/sdks/python/apache_beam/runners/direct/helper_transforms.py
b/sdks/python/apache_beam/runners/direct/helper_transforms.py
deleted file mode 100644
index 0e88c021e2f..00000000000
--- a/sdks/python/apache_beam/runners/direct/helper_transforms.py
+++ /dev/null
@@ -1,120 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# pytype: skip-file
-
-import collections
-import itertools
-import typing
-
-import apache_beam as beam
-from apache_beam import typehints
-from apache_beam.internal.util import ArgumentPlaceholder
-from apache_beam.transforms.combiners import _CurriedFn
-from apache_beam.utils.windowed_value import WindowedValue
-
-
-class LiftedCombinePerKey(beam.PTransform):
- """An implementation of CombinePerKey that does mapper-side pre-combining.
- """
- def __init__(self, combine_fn, args, kwargs):
- args_to_check = itertools.chain(args, kwargs.values())
- if isinstance(combine_fn, _CurriedFn):
- args_to_check = itertools.chain(
- args_to_check, combine_fn.args, combine_fn.kwargs.values())
- if any(isinstance(arg, ArgumentPlaceholder) for arg in args_to_check):
- # This isn't implemented in dataflow either...
- raise NotImplementedError('Deferred CombineFn side inputs.')
- self._combine_fn = beam.transforms.combiners.curry_combine_fn(
- combine_fn, args, kwargs)
-
- def expand(self, pcoll):
- return (
- pcoll
- | beam.ParDo(PartialGroupByKeyCombiningValues(self._combine_fn))
- | beam.GroupByKey()
- | beam.ParDo(FinishCombine(self._combine_fn)))
-
-
-class PartialGroupByKeyCombiningValues(beam.DoFn):
- """Aggregates values into a per-key-window cache.
-
- As bundles are in-memory-sized, we don't bother flushing until the very end.
- """
- def __init__(self, combine_fn):
- self._combine_fn = combine_fn
-
- def setup(self):
- self._combine_fn.setup()
-
- def start_bundle(self):
- self._cache = collections.defaultdict(self._combine_fn.create_accumulator)
-
- def process(self, element, window=beam.DoFn.WindowParam):
- k, vi = element
- self._cache[k, window] = self._combine_fn.add_input(
- self._cache[k, window], vi)
-
- def finish_bundle(self):
- for (k, w), va in self._cache.items():
- # We compact the accumulator since a GBK (which necessitates encoding)
- # will follow.
- yield WindowedValue((k, self._combine_fn.compact(va)), w.end, (w, ))
-
- def teardown(self):
- self._combine_fn.teardown()
-
- def default_type_hints(self):
- hints = self._combine_fn.get_type_hints()
- K = typehints.TypeVariable('K')
- if hints.input_types:
- args, kwargs = hints.input_types
- args = (typehints.Tuple[K, args[0]], ) + args[1:]
- hints = hints.with_input_types(*args, **kwargs)
- else:
- hints = hints.with_input_types(typehints.Tuple[K, typing.Any])
- hints = hints.with_output_types(typehints.Tuple[K, typing.Any])
- return hints
-
-
-class FinishCombine(beam.DoFn):
- """Merges partially combined results.
- """
- def __init__(self, combine_fn):
- self._combine_fn = combine_fn
-
- def setup(self):
- self._combine_fn.setup()
-
- def process(self, element):
- k, vs = element
- return [(
- k,
- self._combine_fn.extract_output(
- self._combine_fn.merge_accumulators(vs)))]
-
- def teardown(self):
- self._combine_fn.teardown()
-
- def default_type_hints(self):
- hints = self._combine_fn.get_type_hints()
- K = typehints.TypeVariable('K')
- hints = hints.with_input_types(typehints.Tuple[K, typing.Any])
- if hints.output_types:
- main_output_type = hints.simple_output_type('')
- hints = hints.with_output_types(typehints.Tuple[K, main_output_type])
- return hints
diff --git a/sdks/python/apache_beam/transforms/combiners.py
b/sdks/python/apache_beam/transforms/combiners.py
index 58267ef97ac..6e4647fecef 100644
--- a/sdks/python/apache_beam/transforms/combiners.py
+++ b/sdks/python/apache_beam/transforms/combiners.py
@@ -41,6 +41,7 @@ from apache_beam.typehints import with_input_types
from apache_beam.typehints import with_output_types
from apache_beam.utils.timestamp import Duration
from apache_beam.utils.timestamp import Timestamp
+from apache_beam.utils.windowed_value import WindowedValue
__all__ = [
'Count',
@@ -985,3 +986,126 @@ class LatestCombineFn(core.CombineFn):
def extract_output(self, accumulator):
return accumulator[0]
+
+
+class LiftedCombinePerKey(core.PTransform):
+ """An implementation of CombinePerKey that does mapper-side pre-combining.
+
+ This shouldn't generally be used directly except for use-cases where a
+ runner doesn't support CombinePerKey. This implementation manually implements
+ a CombinePerKey using ParDos, as opposed to runner implementations which may
+ use a more efficient implementation.
+ """
+ def __init__(self, combine_fn, args, kwargs):
+ side_inputs = _pack_side_inputs(args, kwargs)
+ self._side_inputs: dict = side_inputs
+ if not isinstance(combine_fn, core.CombineFn):
+ combine_fn = core.CombineFn.from_callable(combine_fn)
+ self._combine_fn = combine_fn
+
+ def expand(self, pcoll):
+ return (
+ pcoll
+ | core.ParDo(
+ _PartialGroupByKeyCombiningValues(self._combine_fn),
+ **self._side_inputs)
+ | core.GroupByKey()
+ | core.ParDo(_FinishCombine(self._combine_fn), **self._side_inputs))
+
+
+def _pack_side_inputs(side_input_args, side_input_kwargs):
+ if len(side_input_args) >= 10:
+ # If we have more than 10 side inputs, we can't use the
+ # _side_input_arg_{i} as our keys since they won't sort
+ # correctly. Just punt for now, more than 10 args probably
+ # doesn't happen often.
+ raise NotImplementedError
+ side_inputs = {}
+ for i, si in enumerate(side_input_args):
+ side_inputs[f'_side_input_arg_{i}'] = si
+ for k, v in side_input_kwargs.items():
+ side_inputs[k] = v
+ return side_inputs
+
+
+def _unpack_side_inputs(side_inputs):
+ side_input_args = []
+ side_input_kwargs = {}
+ for k, v in sorted(side_inputs.items(), key=lambda x: x[0]):
+ if k.startswith('_side_input_arg_'):
+ side_input_args.append(v)
+ else:
+ side_input_kwargs[k] = v
+ return side_input_args, side_input_kwargs
+
+
+class _PartialGroupByKeyCombiningValues(core.DoFn):
+ """Aggregates values into a per-key-window cache.
+
+ As bundles are in-memory-sized, we don't bother flushing until the very end.
+ """
+ def __init__(self, combine_fn):
+ self._combine_fn = combine_fn
+ self.side_input_args = []
+ self.side_input_kwargs = {}
+
+ def setup(self):
+ self._combine_fn.setup()
+
+ def start_bundle(self):
+ self._cache = dict()
+ self._cached_windowed_side_inputs = {}
+
+ def process(self, element, window=core.DoFn.WindowParam, **side_inputs):
+ k, vi = element
+ side_input_args, side_input_kwargs = _unpack_side_inputs(side_inputs)
+ if (k, window) not in self._cache:
+ self._cache[(k, window)] = self._combine_fn.create_accumulator(
+ *side_input_args, **side_input_kwargs)
+
+ self._cache[k, window] = self._combine_fn.add_input(
+ self._cache[k, window], vi, *side_input_args, **side_input_kwargs)
+ self._cached_windowed_side_inputs[window] = (
+ side_input_args, side_input_kwargs)
+
+ def finish_bundle(self):
+ for (k, w), va in self._cache.items():
+ # We compact the accumulator since a GBK (which necessitates encoding)
+ # will follow.
+ side_input_args, side_input_kwargs = (
+ self._cached_windowed_side_inputs[w])
+ yield WindowedValue((
+ k,
+ self._combine_fn.compact(va, *side_input_args, **side_input_kwargs)),
+ w.end, (w, ))
+
+ def teardown(self):
+ self._combine_fn.teardown()
+
+
+class _FinishCombine(core.DoFn):
+ """Merges partially combined results.
+ """
+ def __init__(self, combine_fn):
+ self._combine_fn = combine_fn
+
+ def setup(self):
+ self._combine_fn.setup()
+
+ def process(self, element, window=core.DoFn.WindowParam, **side_inputs):
+
+ k, vs = element
+ side_input_args, side_input_kwargs = _unpack_side_inputs(side_inputs)
+ return [(
+ k,
+ self._combine_fn.extract_output(
+ self._combine_fn.merge_accumulators(
+ vs, *side_input_args, **side_input_kwargs),
+ *side_input_args,
+ **side_input_kwargs))]
+
+ def teardown(self):
+ try:
+ self._combine_fn.teardown()
+ except AttributeError:
+ pass
diff --git a/sdks/python/apache_beam/transforms/combiners_test.py
b/sdks/python/apache_beam/transforms/combiners_test.py
index a8979239f83..ba9e21f8556 100644
--- a/sdks/python/apache_beam/transforms/combiners_test.py
+++ b/sdks/python/apache_beam/transforms/combiners_test.py
@@ -19,15 +19,20 @@
# pytype: skip-file
import itertools
+import json
+import os
import random
+import tempfile
import time
import unittest
+from pathlib import Path
import hamcrest as hc
import pytest
import apache_beam as beam
import apache_beam.transforms.combiners as combine
+from apache_beam import pvalue
from apache_beam.metrics import Metrics
from apache_beam.metrics import MetricsFilter
from apache_beam.options.pipeline_options import PipelineOptions
@@ -1021,5 +1026,186 @@ class CombineGloballyTest(unittest.TestCase):
| beam.CombineGlobally(sum).without_defaults())
+def get_common_items(sets, excluded_chars=""):
+ # set.intersection() takes multiple sets as separete arguments.
+ # We unpack the `sets` list into multiple arguments with the * operator.
+ # The combine transform might give us an empty list of `sets`,
+ # so we use a list with an empty set as a default value.
+ common = set.intersection(*(sets or [set()]))
+ return common.difference(excluded_chars)
+
+
+class CombinerWithSideInputs(unittest.TestCase):
+ def test_cpk_with_side_input(self):
+ test_cases = [(get_common_items, True),
+ (beam.CombineFn.from_callable(get_common_items), True),
+ (get_common_items, False),
+ (beam.CombineFn.from_callable(get_common_items), False)]
+ for combiner, with_kwarg in test_cases:
+ self._check_combineperkey_with_side_input(combiner, with_kwarg)
+ self._check_combineglobally_with_side_input(combiner, with_kwarg)
+
+ def _check_combineperkey_with_side_input(self, combiner, with_kwarg):
+ with beam.Pipeline() as pipeline:
+ pc = (pipeline | beam.Create(['🍅']))
+ if with_kwarg:
+ cpk = beam.CombinePerKey(
+ combiner, excluded_chars=beam.pvalue.AsSingleton(pc))
+ else:
+ cpk = beam.CombinePerKey(combiner, beam.pvalue.AsSingleton(pc))
+ common_items = (
+ pipeline
+ | 'Create produce' >> beam.Create([
+ {'🍓', '🥕', '🍌', '🍅', '🌶️'},
+ {'🍇', '🥕', '🥝', '🍅', '🥔'},
+ {'🍉', '🥕', '🍆', '🍅', '🍍'},
+ {'🥑', '🥕', '🌽', '🍅', '🥥'},
+ ])
+ | beam.WithKeys(lambda x: None)
+ | cpk)
+ assert_that(common_items, equal_to([(None, {'🥕'})]))
+
+ def _check_combineglobally_with_side_input(self, combiner, with_kwarg):
+ with beam.Pipeline() as pipeline:
+ pc = (pipeline | beam.Create(['🍅']))
+ if with_kwarg:
+ cpk = beam.CombineGlobally(
+ combiner, excluded_chars=beam.pvalue.AsSingleton(pc))
+ else:
+ cpk = beam.CombineGlobally(combiner, beam.pvalue.AsSingleton(pc))
+ common_items = (
+ pipeline
+ | 'Create produce' >> beam.Create([
+ {'🍓', '🥕', '🍌', '🍅', '🌶️'},
+ {'🍇', '🥕', '🥝', '🍅', '🥔'},
+ {'🍉', '🥕', '🍆', '🍅', '🍍'},
+ {'🥑', '🥕', '🌽', '🍅', '🥥'},
+ ])
+ | cpk)
+ assert_that(common_items, equal_to([{'🥕'}]))
+
+ def test_combinefn_methods_with_side_input(self):
+ # Test that the expected combinefn methods are called with the
+ # expected arguments when using side inputs in CombinePerKey.
+ with tempfile.TemporaryDirectory() as tmp_dirname:
+ fname = str(Path(tmp_dirname) / "combinefn_calls.json")
+ with open(fname, "w") as f:
+ json.dump({}, f)
+
+ def set_in_json(key, values):
+ current_json = {}
+ if os.path.exists(fname):
+ with open(fname, "r") as f:
+ current_json = json.load(f)
+ current_json[key] = values
+ with open(fname, "w") as f:
+ json.dump(current_json, f)
+
+ class MyCombiner(beam.CombineFn):
+ def create_accumulator(self, *args, **kwargs):
+ set_in_json("create_accumulator_args", args)
+ set_in_json("create_accumulator_kwargs", kwargs)
+ return args, kwargs
+
+ def add_input(self, accumulator, input, *args, **kwargs):
+ set_in_json("add_input_args", args)
+ set_in_json("add_input_kwargs", kwargs)
+ return accumulator
+
+ def merge_accumulators(self, accumulators, *args, **kwargs):
+ set_in_json("merge_accumulators_args", args)
+ set_in_json("merge_accumulators_kwargs", kwargs)
+ return args, kwargs
+
+ def compact(self, accumulator, *args, **kwargs):
+ set_in_json("compact_args", args)
+ set_in_json("compact_kwargs", kwargs)
+ return accumulator
+
+ def extract_output(self, accumulator, *args, **kwargs):
+ set_in_json("extract_output_args", args)
+ set_in_json("extract_output_kwargs", kwargs)
+ return accumulator
+
+ with beam.Pipeline() as p:
+ static_pos_arg = 0
+ deferred_pos_arg = beam.pvalue.AsSingleton(
+ p | "CreateDeferredSideInput" >> beam.Create([1]))
+ static_kwarg = 2
+ deferred_kwarg = beam.pvalue.AsSingleton(
+ p | "CreateDeferredSideInputKwarg" >> beam.Create([3]))
+ res = (
+ p
+ | "CreateInputs" >> beam.Create([(None, None)])
+ | beam.CombinePerKey(
+ MyCombiner(),
+ static_pos_arg,
+ deferred_pos_arg,
+ static_kwarg=static_kwarg,
+ deferred_kwarg=deferred_kwarg))
+ assert_that(
+ res,
+ equal_to([
+ (None, ((0, 1), {
+ 'static_kwarg': 2, 'deferred_kwarg': 3
+ }))
+ ]))
+
+ # Check that the combinefn was called with the expected arguments
+ with open(fname, "r") as f:
+ data = json.load(f)
+ expected_args = [0, 1]
+ expected_kwargs = {"static_kwarg": 2, "deferred_kwarg": 3}
+ method_names = [
+ "create_accumulator",
+ "compact",
+ "add_input",
+ "merge_accumulators",
+ "extract_output"
+ ]
+ for key in method_names:
+ print(f"Checking {key}")
+ self.assertEqual(data[key + "_args"], expected_args)
+ self.assertEqual(data[key + "_kwargs"], expected_kwargs)
+
+ def test_cpk_with_windows(self):
+ # With global window side input
+ with TestPipeline() as p:
+
+ def sum_with_floor(vals, min_value=0):
+ vals_sum = sum(vals)
+ if vals_sum < min_value:
+ vals_sum += min_value
+ return vals_sum
+
+ res = (
+ p
+ | "CreateInputs" >> beam.Create([1, 2, 100, 101, 102])
+ | beam.Map(lambda x: window.TimestampedValue(('k', x), x))
+ | beam.WindowInto(FixedWindows(99))
+ | beam.CombinePerKey(
+ sum_with_floor,
+ min_value=pvalue.AsSingleton(p | beam.Create([100]))))
+ assert_that(res, equal_to([('k', 103), ('k', 303)]))
+
+ # with matching window side input
+ with TestPipeline() as p:
+ min_value = (
+ p
+ | "CreateMinValue" >> beam.Create([
+ window.TimestampedValue(50, 5),
+ window.TimestampedValue(1000, 100)
+ ])
+ | "WindowSideInputs" >> beam.WindowInto(FixedWindows(99)))
+ res = (
+ p
+ | "CreateInputs" >> beam.Create([1, 2, 100, 101, 102])
+ | beam.Map(lambda x: window.TimestampedValue(('k', x), x))
+ | beam.WindowInto(FixedWindows(99))
+ | beam.CombinePerKey(
+ sum_with_floor, min_value=pvalue.AsSingleton(min_value)))
+ assert_that(res, equal_to([('k', 53), ('k', 1303)]))
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/transforms/core.py
b/sdks/python/apache_beam/transforms/core.py
index c043f768574..6e0170c04ea 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -2927,6 +2927,22 @@ class CombinePerKey(PTransformWithSideInputs):
Returns:
A PObject holding the result of the combine operation.
"""
+ def __new__(cls, *args, **kwargs):
+ def has_side_inputs():
+ return (
+ any(isinstance(arg, pvalue.AsSideInput) for arg in args) or
+ any(isinstance(arg, pvalue.AsSideInput) for arg in kwargs.values()))
+
+ if has_side_inputs():
+ # If the CombineFn has deferred side inputs, the python SDK
+ # doesn't implement it.
+ # Use a ParDo-based CombinePerKey instead.
+ from apache_beam.transforms.combiners import \
+ LiftedCombinePerKey
+ combine_fn, *args = args
+ return LiftedCombinePerKey(combine_fn, args, kwargs)
+ return super(CombinePerKey, cls).__new__(cls)
+
def with_hot_key_fanout(self, fanout):
"""A per-key combine operation like self but with two levels of
aggregation.