yifanmai commented on a change in pull request #12787: URL: https://github.com/apache/beam/pull/12787#discussion_r487204174
########## File path: sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py ########## @@ -704,6 +705,71 @@ def fix_side_input_pcoll_coders(stages, pipeline_context): return stages +def _group_stages_by_key(stages, get_stage_key): + grouped_stages = collections.defaultdict(list) + stages_with_none_key = [] + for stage in stages: + stage_key = get_stage_key(stage) + if stage_key is None: + stages_with_none_key.append(stage) + else: + grouped_stages[stage_key].append(stage) + return (grouped_stages, stages_with_none_key) + + +def eliminate_common_key_with_none(stages, context): + # type: (Iterable[Stage], TransformContext) -> Iterable[Stage] + + """Runs common subexpression elimination for sibling KeyWithNone stages. + + If multiple KeyWithNone stages share a common input, then all but one stages + will be eliminated along with their output PCollections. Transforms that + originally read input from the output PCollection of the eliminated + KeyWithNone stages will be remapped to read input from the output PCollection + of the remaining KeyWithNone stage. + """ + + # Partition stages by whether they are eligible for common KeyWithNone + # elimination, and group eligible KeyWithNone stages by parent and + # environment. + def get_stage_key(stage): + if len(stage.transforms) == 1: + transform = only_transform(stage.transforms) + if (transform.spec.urn == common_urns.primitives.PAR_DO.urn and + len(transform.inputs) == 1 and len(transform.outputs) == 1): + pardo_payload = proto_utils.parse_Bytes( + transform.spec.payload, beam_runner_api_pb2.ParDoPayload) + if pardo_payload.do_fn.urn == python_urns.KEY_WITH_NONE_DOFN: + return (only_element(transform.inputs.values()), stage.environment) + return None + + grouped_eligible_stages, ineligible_stages = _group_stages_by_key( + stages, get_stage_key) + + # Eliminate stages and build the PCollection remapping dictionary. + pcoll_id_remap = {} + remaining_stages = [] + for sibling_stages in grouped_eligible_stages.values(): + output_pcoll_ids = [ + only_element(stage.transforms[0].outputs.values()) + for stage in sibling_stages + ] + for to_delete_pcoll_id in output_pcoll_ids[1:]: + pcoll_id_remap[to_delete_pcoll_id] = output_pcoll_ids[0] + del context.components.pcollections[to_delete_pcoll_id] + remaining_stages.append(sibling_stages[0]) + + # Yield stages while remapping input PCollections if needed. + stages_to_yield = itertools.chain(ineligible_stages, remaining_stages) + for stage in stages_to_yield: + for transform in stage.transforms: + for input_key in list(transform.inputs): Review comment: Done. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org