damccorm commented on code in PR #35601: URL: https://github.com/apache/beam/pull/35601#discussion_r2222622732
########## sdks/python/apache_beam/transforms/combiners.py: ########## @@ -985,3 +986,126 @@ def merge_accumulators(self, accumulators): def extract_output(self, accumulator): return accumulator[0] + + +class LiftedCombinePerKey(core.PTransform): Review Comment: Should this be part of `__all__` for imports? ########## sdks/python/apache_beam/transforms/core.py: ########## @@ -2927,6 +2927,22 @@ class CombinePerKey(PTransformWithSideInputs): Returns: A PObject holding the result of the combine operation. """ + def __new__(cls, *args, **kwargs): + def has_side_inputs(): + return ( + any(isinstance(arg, pvalue.AsSideInput) for arg in args) or + any(isinstance(arg, pvalue.AsSideInput) for arg in kwargs.values())) + + if has_side_inputs(): + # If the CombineFn has deferred side inputs, the python SDK + # doesn't implement it. + # Use a ParDo-based CombinePerKey instead. + from apache_beam.transforms.combiners import \ + LiftedCombinePerKey + combine_fn, *args = args + return LiftedCombinePerKey(combine_fn, args, kwargs) Review Comment: I'd probably vote to leave it as is. We could warn and have a kwarg to silence the warning (or something similar), but I tend to agree that this is a rare case; while the perf implication is non-obvious, adding side inputs basically always comes with some perf cost ########## sdks/python/apache_beam/transforms/combiners.py: ########## @@ -985,3 +986,126 @@ def merge_accumulators(self, accumulators): def extract_output(self, accumulator): return accumulator[0] + + +class LiftedCombinePerKey(core.PTransform): + """An implementation of CombinePerKey that does mapper-side pre-combining. + + This shouldn't generally be used directly except for use-cases where a + runner doesn't support CombinePerKey. This implementation manually implements + a CombinePerKey using ParDos, as opposed to runner implementations which may + use a more efficient implementation. + """ + def __init__(self, combine_fn, args, kwargs): + side_inputs = _pack_side_inputs(args, kwargs) + self._side_inputs: dict = side_inputs + if not isinstance(combine_fn, core.CombineFn): + combine_fn = core.CombineFn.from_callable(combine_fn) + self._combine_fn = combine_fn + + def expand(self, pcoll): + return ( + pcoll + | core.ParDo( + PartialGroupByKeyCombiningValues(self._combine_fn), + **self._side_inputs) + | core.GroupByKey() + | core.ParDo(FinishCombine(self._combine_fn), **self._side_inputs)) + + +def _pack_side_inputs(side_input_args, side_input_kwargs): + if len(side_input_args) >= 10: + # If we have more than 10 side inputs, we can't use the + # _side_input_arg_{i} as our keys since they won't sort + # correctly. Just punt for now, more than 10 args probably + # doesn't happen often. + raise NotImplementedError + side_inputs = {} + for i, si in enumerate(side_input_args): + side_inputs[f'_side_input_arg_{i}'] = si + for k, v in side_input_kwargs.items(): + side_inputs[k] = v + return side_inputs + + +def _unpack_side_inputs(side_inputs): + side_input_args = [] + side_input_kwargs = {} + for k, v in sorted(side_inputs.items(), key=lambda x: x[0]): + if k.startswith('_side_input_arg_'): + side_input_args.append(v) + else: + side_input_kwargs[k] = v + return side_input_args, side_input_kwargs + + +class PartialGroupByKeyCombiningValues(core.DoFn): Review Comment: ```suggestion class _PartialGroupByKeyCombiningValues(core.DoFn): ``` We never expect people to use this (or `FinishCombine`) directly, right? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@beam.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org