yifanmai commented on a change in pull request #12185:
URL: https://github.com/apache/beam/pull/12185#discussion_r470309725



##########
File path: 
sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
##########
@@ -690,6 +692,196 @@ def fix_side_input_pcoll_coders(stages, pipeline_context):
   return stages
 
 
+def pack_combiners(stages, context):
+  # type: (Iterable[Stage], TransformContext) -> Iterator[Stage]
+  """Packs sibling CombinePerKey stages into a single CombinePerKey.
+
+  If CombinePerKey stages have a common input, one input each, and one output
+  each, pack the stages into a single stage that runs all CombinePerKeys and
+  outputs resulting tuples to a new PCollection. A subsequent stage unpacks
+  tuples from this PCollection and sends them to the original output
+  PCollections.
+  """
+
+  class _UnpackFn(core.DoFn):
+    """A DoFn that unpacks a packed to multiple tagged outputs.
+
+    Example:
+      tags = (T1, T2, ...)
+      input = (K, (V1, V2, ...))
+      output = TaggedOutput(T1, (K, V1)), TaggedOutput(T2, (K, V1)), ...
+    """
+
+    def __init__(self, tags):
+      self._tags = tags
+
+    def process(self, element):
+      key, values = element
+      return [
+          core.pvalue.TaggedOutput(tag, (key, value))
+          for tag, value in zip(self._tags, values)
+      ]
+
+  def _get_fallback_coder_id():
+    return context.add_or_get_coder_id(
+        coders.registry.get_coder(object).to_runner_api(None))
+
+  def _get_component_coder_id_from_kv_coder(coder, index):
+    assert index < 2
+    if coder.spec.urn == common_urns.coders.KV.urn and len(
+        coder.component_coder_ids) == 2:
+      return coder.component_coder_ids[index]
+    return _get_fallback_coder_id()
+
+  def _get_key_coder_id_from_kv_coder(coder):
+    return _get_component_coder_id_from_kv_coder(coder, 0)
+
+  def _get_value_coder_id_from_kv_coder(coder):
+    return _get_component_coder_id_from_kv_coder(coder, 1)
+
+  def _try_fuse_stages(a, b):
+    if a.can_fuse(b, context):
+      return a.fuse(b)
+    else:
+      raise ValueError
+
+  # Group stages by parent, yielding ineligible stages.
+  combine_stages_by_input_pcoll_id = collections.defaultdict(list)
+  for stage in stages:
+    transform = only_transform(stage.transforms)
+    if transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn and 
len(
+        transform.inputs) == 1 and len(transform.outputs) == 1:
+      input_pcoll_id = only_element(transform.inputs.values())
+      combine_stages_by_input_pcoll_id[input_pcoll_id].append(stage)
+    else:
+      yield stage
+
+  for input_pcoll_id, packable_stages in 
combine_stages_by_input_pcoll_id.items(
+  ):
+    # Yield stage and continue if it has no siblings.
+    if len(packable_stages) == 1:
+      yield packable_stages[0]
+      continue
+
+    transforms = [only_transform(stage.transforms) for stage in 
packable_stages]
+    combine_payloads = [
+        proto_utils.parse_Bytes(transform.spec.payload,
+                                beam_runner_api_pb2.CombinePayload)
+        for transform in transforms
+    ]
+
+    # Yield stages and continue if they cannot be packed.
+    try:
+      # Fused stage is used as template and is not yielded.
+      fused_stage = functools.reduce(_try_fuse_stages, packable_stages)
+      merged_transform_environment_id = functools.reduce(
+          Stage._merge_environments,
+          [transform.environment_id or None for transform in transforms])
+      # Combiner packing only supports Python CombineFns.
+      for combine_payload in combine_payloads:
+        if combine_payload.combine_fn.urn != python_urns.PICKLED_COMBINE_FN:
+          raise ValueError('Combiner packing only supports Python CombineFns')
+    except ValueError:
+      for stage in packable_stages:
+        yield stage
+      continue

Review comment:
       Moved most of the checks to the earlier part, and cleaned up the logic 
for this part (which we still need, because we need to make sure the whole 
group is fuseable).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to