[ 
https://issues.apache.org/jira/browse/BEAM-6186?focusedWorklogId=176942&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-176942
 ]

ASF GitHub Bot logged work on BEAM-6186:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 19/Dec/18 11:40
            Start Date: 19/Dec/18 11:40
    Worklog Time Spent: 10m 
      Work Description: robertwb closed pull request #7281: [BEAM-6186] Finish 
moving optimization phases.
URL: https://github.com/apache/beam/pull/7281
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 677b0e1af33e..2bc70c1c2504 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -54,7 +54,9 @@
 from apache_beam.runners import pipeline_context
 from apache_beam.runners import runner
 from apache_beam.runners.portability import fn_api_runner_transforms
+from apache_beam.runners.portability.fn_api_runner_transforms import 
create_buffer_id
 from apache_beam.runners.portability.fn_api_runner_transforms import 
only_element
+from apache_beam.runners.portability.fn_api_runner_transforms import 
split_buffer_id
 from apache_beam.runners.portability.fn_api_runner_transforms import 
unique_name
 from apache_beam.runners.worker import bundle_processor
 from apache_beam.runners.worker import data_plane
@@ -71,8 +73,6 @@
     beam.coders.coders.GlobalWindowCoder()).get_impl().encode_nested(
         beam.transforms.window.GlobalWindows.windowed_value(b''))
 
-IMPULSE_BUFFER = b'impulse'
-
 
 class BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer):
 
@@ -299,547 +299,9 @@ def maybe_profile(self):
 
   def create_stages(self, pipeline_proto):
 
-    # Some helper functions.
-    # TODO(BEAM-6186): Move these to fn_api_runner_transforms.
-
-    Stage = fn_api_runner_transforms.Stage
-    union = fn_api_runner_transforms.union
-
-    def add_or_get_coder_id(coder_proto):
-      return fn_api_runner_transforms.TransformContext(
-          pipeline_components).add_or_get_coder_id(coder_proto)
-
-    def windowed_coder_id(coder_id, window_coder_id):
-      proto = beam_runner_api_pb2.Coder(
-          spec=beam_runner_api_pb2.SdkFunctionSpec(
-              spec=beam_runner_api_pb2.FunctionSpec(
-                  urn=common_urns.coders.WINDOWED_VALUE.urn)),
-          component_coder_ids=[coder_id, window_coder_id])
-      return add_or_get_coder_id(proto)
-
-    _with_state_iterables_cache = {}
-
-    def with_state_iterables(coder_id):
-      if coder_id not in _with_state_iterables_cache:
-        _with_state_iterables_cache[
-            coder_id] = create_with_state_iterables(coder_id)
-      return _with_state_iterables_cache[coder_id]
-
-    def create_with_state_iterables(coder_id):
-      coder = pipeline_components.coders[coder_id]
-      if coder.spec.spec.urn == common_urns.coders.ITERABLE.urn:
-        new_coder_id = unique_name(pipeline_components.coders, 'coder')
-        new_coder = pipeline_components.coders[new_coder_id]
-        new_coder.CopyFrom(coder)
-        new_coder.spec.spec.urn = common_urns.coders.STATE_BACKED_ITERABLE.urn
-        new_coder.spec.spec.payload = b'1'
-        return new_coder_id
-      else:
-        new_component_ids = [
-            with_state_iterables(c) for c in coder.component_coder_ids]
-        if new_component_ids == coder.component_coder_ids:
-          return coder_id
-        else:
-          new_coder_id = unique_name(pipeline_components.coders, 'coder')
-          new_coder = pipeline_components.coders[new_coder_id]
-          new_coder.CopyFrom(coder)
-          new_coder.component_coder_ids[:] = new_component_ids
-          return new_coder_id
-
-    safe_coders = {}
-
-    def length_prefix_unknown_coders(pcoll, pipeline_components):
-      """Length prefixes coder for the given PCollection.
-
-      Updates pipeline_components to have a length prefixed coder for
-      every component coder within the PCollection that is not understood
-      natively by the runner. Also populates the safe_coders map with
-      a corresponding runner side coder which is also length prefixed but
-      compatible for the runner to instantiate.
-      """
-      good_coder_urns = set(
-          value.urn for value in common_urns.coders.__dict__.values())
-      coders = pipeline_components.coders
-
-      for coder_id, coder_proto in coders.items():
-        if coder_proto.spec.spec.urn == common_urns.coders.BYTES.urn:
-          bytes_coder_id = coder_id
-          break
-      else:
-        bytes_coder_id = unique_name(coders, 'bytes_coder')
-        pipeline_components.coders[bytes_coder_id].CopyFrom(
-            beam.coders.BytesCoder().to_runner_api(None))
-      coder_substitutions = {}
-
-      def wrap_unknown_coders(coder_id, with_bytes):
-        if (coder_id, with_bytes) not in coder_substitutions:
-          wrapped_coder_id = None
-          coder_proto = coders[coder_id]
-          if coder_proto.spec.spec.urn == common_urns.coders.LENGTH_PREFIX.urn:
-            coder_substitutions[coder_id, with_bytes] = (
-                bytes_coder_id if with_bytes else coder_id)
-          elif coder_proto.spec.spec.urn in good_coder_urns:
-            wrapped_components = [wrap_unknown_coders(c, with_bytes)
-                                  for c in coder_proto.component_coder_ids]
-            if wrapped_components == list(coder_proto.component_coder_ids):
-              # Use as is.
-              coder_substitutions[coder_id, with_bytes] = coder_id
-            else:
-              wrapped_coder_id = unique_name(
-                  coders,
-                  coder_id + ("_bytes" if with_bytes else "_len_prefix"))
-              coders[wrapped_coder_id].CopyFrom(coder_proto)
-              coders[wrapped_coder_id].component_coder_ids[:] = [
-                  wrap_unknown_coders(c, with_bytes)
-                  for c in coder_proto.component_coder_ids]
-              coder_substitutions[coder_id, with_bytes] = wrapped_coder_id
-          else:
-            # Not a known coder.
-            if with_bytes:
-              coder_substitutions[coder_id, with_bytes] = bytes_coder_id
-            else:
-              wrapped_coder_id = unique_name(coders, coder_id +  "_len_prefix")
-              len_prefix_coder_proto = beam_runner_api_pb2.Coder(
-                  spec=beam_runner_api_pb2.SdkFunctionSpec(
-                      spec=beam_runner_api_pb2.FunctionSpec(
-                          urn=common_urns.coders.LENGTH_PREFIX.urn)),
-                  component_coder_ids=[coder_id])
-              coders[wrapped_coder_id].CopyFrom(len_prefix_coder_proto)
-              coder_substitutions[coder_id, with_bytes] = wrapped_coder_id
-          # This operation is idempotent.
-          if wrapped_coder_id:
-            coder_substitutions[wrapped_coder_id, with_bytes] = 
wrapped_coder_id
-        return coder_substitutions[coder_id, with_bytes]
-
-      new_coder_id = wrap_unknown_coders(pcoll.coder_id, False)
-      safe_coders[new_coder_id] = wrap_unknown_coders(pcoll.coder_id, True)
-      pcoll.coder_id = new_coder_id
-
-    # Now define the "optimization" phases.
-
-    def impulse_to_input(stages):
-      bytes_coder_id = add_or_get_coder_id(
-          beam.coders.BytesCoder().to_runner_api(None))
-
-      for stage in stages:
-        # First map Reads, if any, to Impulse + triggered read op.
-        for transform in list(stage.transforms):
-          if transform.spec.urn == common_urns.deprecated_primitives.READ.urn:
-            read_pc = only_element(transform.outputs.values())
-            read_pc_proto = pipeline_components.pcollections[read_pc]
-            impulse_pc = unique_name(
-                pipeline_components.pcollections, 'Impulse')
-            pipeline_components.pcollections[impulse_pc].CopyFrom(
-                beam_runner_api_pb2.PCollection(
-                    unique_name=impulse_pc,
-                    coder_id=bytes_coder_id,
-                    windowing_strategy_id=read_pc_proto.windowing_strategy_id,
-                    is_bounded=read_pc_proto.is_bounded))
-            stage.transforms.remove(transform)
-            # TODO(robertwb): If this goes multi-process before fn-api
-            # read is default, expand into split + reshuffle + read.
-            stage.transforms.append(
-                beam_runner_api_pb2.PTransform(
-                    unique_name=transform.unique_name + '/Impulse',
-                    spec=beam_runner_api_pb2.FunctionSpec(
-                        urn=common_urns.primitives.IMPULSE.urn),
-                    outputs={'out': impulse_pc}))
-            stage.transforms.append(
-                beam_runner_api_pb2.PTransform(
-                    unique_name=transform.unique_name,
-                    spec=beam_runner_api_pb2.FunctionSpec(
-                        urn=python_urns.IMPULSE_READ_TRANSFORM,
-                        payload=transform.spec.payload),
-                    inputs={'in': impulse_pc},
-                    outputs={'out': read_pc}))
-
-        # Now map impulses to inputs.
-        for transform in list(stage.transforms):
-          if transform.spec.urn == common_urns.primitives.IMPULSE.urn:
-            stage.transforms.remove(transform)
-            stage.transforms.append(
-                beam_runner_api_pb2.PTransform(
-                    unique_name=transform.unique_name,
-                    spec=beam_runner_api_pb2.FunctionSpec(
-                        urn=bundle_processor.DATA_INPUT_URN,
-                        payload=IMPULSE_BUFFER),
-                    outputs=transform.outputs))
-
-        yield stage
-
-    def lift_combiners(stages):
-      return fn_api_runner_transforms.lift_combiners(
-          stages,
-          fn_api_runner_transforms.TransformContext(pipeline_components))
-
-    def expand_gbk(stages):
-      """Transforms each GBK into a write followed by a read.
-      """
-      for stage in stages:
-        assert len(stage.transforms) == 1
-        transform = stage.transforms[0]
-        if transform.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
-          for pcoll_id in transform.inputs.values():
-            length_prefix_unknown_coders(
-                pipeline_components.pcollections[pcoll_id], 
pipeline_components)
-          for pcoll_id in transform.outputs.values():
-            if self._use_state_iterables:
-              pipeline_components.pcollections[
-                  pcoll_id].coder_id = with_state_iterables(
-                      pipeline_components.pcollections[pcoll_id].coder_id)
-            length_prefix_unknown_coders(
-                pipeline_components.pcollections[pcoll_id], 
pipeline_components)
-
-          # This is used later to correlate the read and write.
-          transform_id = stage.name
-          if transform != pipeline_components.transforms.get(transform_id):
-            transform_id = unique_name(
-                pipeline_components.transforms, stage.name)
-            pipeline_components.transforms[transform_id].CopyFrom(transform)
-          grouping_buffer = create_buffer_id(transform_id, kind='group')
-          gbk_write = Stage(
-              transform.unique_name + '/Write',
-              [beam_runner_api_pb2.PTransform(
-                  unique_name=transform.unique_name + '/Write',
-                  inputs=transform.inputs,
-                  spec=beam_runner_api_pb2.FunctionSpec(
-                      urn=bundle_processor.DATA_OUTPUT_URN,
-                      payload=grouping_buffer))],
-              downstream_side_inputs=frozenset(),
-              must_follow=stage.must_follow)
-          yield gbk_write
-
-          yield Stage(
-              transform.unique_name + '/Read',
-              [beam_runner_api_pb2.PTransform(
-                  unique_name=transform.unique_name + '/Read',
-                  outputs=transform.outputs,
-                  spec=beam_runner_api_pb2.FunctionSpec(
-                      urn=bundle_processor.DATA_INPUT_URN,
-                      payload=grouping_buffer))],
-              downstream_side_inputs=stage.downstream_side_inputs,
-              must_follow=union(frozenset([gbk_write]), stage.must_follow))
-        else:
-          yield stage
-
-    def sink_flattens(stages):
-      """Sink flattens and remove them from the graph.
-
-      A flatten that cannot be sunk/fused away becomes multiple writes (to the
-      same logical sink) followed by a read.
-      """
-      # TODO(robertwb): Actually attempt to sink rather than always 
materialize.
-      # TODO(robertwb): Possibly fuse this into one of the stages.
-      pcollections = pipeline_components.pcollections
-      for stage in stages:
-        assert len(stage.transforms) == 1
-        transform = stage.transforms[0]
-        if transform.spec.urn == common_urns.primitives.FLATTEN.urn:
-          # This is used later to correlate the read and writes.
-          buffer_id = create_buffer_id(transform.unique_name)
-          output_pcoll_id, = list(transform.outputs.values())
-          output_coder_id = pcollections[output_pcoll_id].coder_id
-          flatten_writes = []
-          for local_in, pcoll_in in transform.inputs.items():
-
-            if pcollections[pcoll_in].coder_id != output_coder_id:
-              # Flatten inputs must all be written with the same coder as is
-              # used to read them.
-              pcollections[pcoll_in].coder_id = output_coder_id
-              transcoded_pcollection = (
-                  transform.unique_name + '/Transcode/' + local_in + '/out')
-              yield Stage(
-                  transform.unique_name + '/Transcode/' + local_in,
-                  [beam_runner_api_pb2.PTransform(
-                      unique_name=
-                      transform.unique_name + '/Transcode/' + local_in,
-                      inputs={local_in: pcoll_in},
-                      outputs={'out': transcoded_pcollection},
-                      spec=beam_runner_api_pb2.FunctionSpec(
-                          urn=bundle_processor.IDENTITY_DOFN_URN))],
-                  downstream_side_inputs=frozenset(),
-                  must_follow=stage.must_follow)
-              pcollections[transcoded_pcollection].CopyFrom(
-                  pcollections[pcoll_in])
-              pcollections[transcoded_pcollection].coder_id = output_coder_id
-            else:
-              transcoded_pcollection = pcoll_in
-
-            flatten_write = Stage(
-                transform.unique_name + '/Write/' + local_in,
-                [beam_runner_api_pb2.PTransform(
-                    unique_name=transform.unique_name + '/Write/' + local_in,
-                    inputs={local_in: transcoded_pcollection},
-                    spec=beam_runner_api_pb2.FunctionSpec(
-                        urn=bundle_processor.DATA_OUTPUT_URN,
-                        payload=buffer_id))],
-                downstream_side_inputs=frozenset(),
-                must_follow=stage.must_follow)
-            flatten_writes.append(flatten_write)
-            yield flatten_write
-
-          yield Stage(
-              transform.unique_name + '/Read',
-              [beam_runner_api_pb2.PTransform(
-                  unique_name=transform.unique_name + '/Read',
-                  outputs=transform.outputs,
-                  spec=beam_runner_api_pb2.FunctionSpec(
-                      urn=bundle_processor.DATA_INPUT_URN,
-                      payload=buffer_id))],
-              downstream_side_inputs=stage.downstream_side_inputs,
-              must_follow=union(frozenset(flatten_writes), stage.must_follow))
-
-        else:
-          yield stage
-
-    def annotate_downstream_side_inputs(stages):
-      """Annotate each stage with fusion-prohibiting information.
-
-      Each stage is annotated with the (transitive) set of pcollections that
-      depend on this stage that are also used later in the pipeline as a
-      side input.
-
-      While theoretically this could result in O(n^2) annotations, the size of
-      each set is bounded by the number of side inputs (typically much smaller
-      than the number of total nodes) and the number of *distinct* side-input
-      sets is also generally small (and shared due to the use of union
-      defined above).
-
-      This representation is also amenable to simple recomputation on fusion.
-      """
-      consumers = collections.defaultdict(list)
-      all_side_inputs = set()
-      for stage in stages:
-        for transform in stage.transforms:
-          for input in transform.inputs.values():
-            consumers[input].append(stage)
-        for si in stage.side_inputs():
-          all_side_inputs.add(si)
-      all_side_inputs = frozenset(all_side_inputs)
-
-      downstream_side_inputs_by_stage = {}
-
-      def compute_downstream_side_inputs(stage):
-        if stage not in downstream_side_inputs_by_stage:
-          downstream_side_inputs = frozenset()
-          for transform in stage.transforms:
-            for output in transform.outputs.values():
-              if output in all_side_inputs:
-                downstream_side_inputs = union(
-                    downstream_side_inputs, frozenset([output]))
-              for consumer in consumers[output]:
-                downstream_side_inputs = union(
-                    downstream_side_inputs,
-                    compute_downstream_side_inputs(consumer))
-          downstream_side_inputs_by_stage[stage] = downstream_side_inputs
-        return downstream_side_inputs_by_stage[stage]
-
-      for stage in stages:
-        stage.downstream_side_inputs = compute_downstream_side_inputs(stage)
-      return stages
-
-    def fix_side_input_pcoll_coders(stages):
-      """Length prefix side input PCollection coders.
-      """
-      for stage in stages:
-        for si in stage.side_inputs():
-          length_prefix_unknown_coders(
-              pipeline_components.pcollections[si], pipeline_components)
-      return stages
-
-    def greedily_fuse(stages):
-      """Places transforms sharing an edge in the same stage, whenever 
possible.
-      """
-      producers_by_pcoll = {}
-      consumers_by_pcoll = collections.defaultdict(list)
-
-      # Used to always reference the correct stage as the producer and
-      # consumer maps are not updated when stages are fused away.
-      replacements = {}
-
-      def replacement(s):
-        old_ss = []
-        while s in replacements:
-          old_ss.append(s)
-          s = replacements[s]
-        for old_s in old_ss[:-1]:
-          replacements[old_s] = s
-        return s
-
-      def fuse(producer, consumer):
-        fused = producer.fuse(consumer)
-        replacements[producer] = fused
-        replacements[consumer] = fused
-
-      # First record the producers and consumers of each PCollection.
-      for stage in stages:
-        for transform in stage.transforms:
-          for input in transform.inputs.values():
-            consumers_by_pcoll[input].append(stage)
-          for output in transform.outputs.values():
-            producers_by_pcoll[output] = stage
-
-      logging.debug('consumers\n%s', consumers_by_pcoll)
-      logging.debug('producers\n%s', producers_by_pcoll)
-
-      # Now try to fuse away all pcollections.
-      for pcoll, producer in producers_by_pcoll.items():
-        write_pcoll = None
-        for consumer in consumers_by_pcoll[pcoll]:
-          producer = replacement(producer)
-          consumer = replacement(consumer)
-          # Update consumer.must_follow set, as it's used in can_fuse.
-          consumer.must_follow = frozenset(
-              replacement(s) for s in consumer.must_follow)
-          if producer.can_fuse(consumer):
-            fuse(producer, consumer)
-          else:
-            # If we can't fuse, do a read + write.
-            buffer_id = create_buffer_id(pcoll)
-            if write_pcoll is None:
-              write_pcoll = Stage(
-                  pcoll + '/Write',
-                  [beam_runner_api_pb2.PTransform(
-                      unique_name=pcoll + '/Write',
-                      inputs={'in': pcoll},
-                      spec=beam_runner_api_pb2.FunctionSpec(
-                          urn=bundle_processor.DATA_OUTPUT_URN,
-                          payload=buffer_id))])
-              fuse(producer, write_pcoll)
-            if consumer.has_as_main_input(pcoll):
-              read_pcoll = Stage(
-                  pcoll + '/Read',
-                  [beam_runner_api_pb2.PTransform(
-                      unique_name=pcoll + '/Read',
-                      outputs={'out': pcoll},
-                      spec=beam_runner_api_pb2.FunctionSpec(
-                          urn=bundle_processor.DATA_INPUT_URN,
-                          payload=buffer_id))],
-                  must_follow=frozenset([write_pcoll]))
-              fuse(read_pcoll, consumer)
-            else:
-              consumer.must_follow = union(
-                  consumer.must_follow, frozenset([write_pcoll]))
-
-      # Everything that was originally a stage or a replacement, but wasn't
-      # replaced, should be in the final graph.
-      final_stages = frozenset(stages).union(list(replacements.values()))\
-          .difference(list(replacements))
-
-      for stage in final_stages:
-        # Update all references to their final values before throwing
-        # the replacement data away.
-        stage.must_follow = frozenset(replacement(s) for s in 
stage.must_follow)
-        # Two reads of the same stage may have been fused.  This is unneeded.
-        stage.deduplicate_read()
-      return final_stages
-
-    def inject_timer_pcollections(stages):
-      for stage in stages:
-        for transform in list(stage.transforms):
-          if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
-            payload = proto_utils.parse_Bytes(
-                transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
-            for tag, spec in payload.timer_specs.items():
-              if len(transform.inputs) > 1:
-                raise NotImplementedError('Timers and side inputs.')
-              input_pcoll = pipeline_components.pcollections[
-                  next(iter(transform.inputs.values()))]
-              # Create the appropriate coder for the timer PCollection.
-              key_coder_id = input_pcoll.coder_id
-              if (pipeline_components.coders[key_coder_id].spec.spec.urn
-                  == common_urns.coders.KV.urn):
-                key_coder_id = pipeline_components.coders[
-                    key_coder_id].component_coder_ids[0]
-              key_timer_coder_id = add_or_get_coder_id(
-                  beam_runner_api_pb2.Coder(
-                      spec=beam_runner_api_pb2.SdkFunctionSpec(
-                          spec=beam_runner_api_pb2.FunctionSpec(
-                              urn=common_urns.coders.KV.urn)),
-                      component_coder_ids=[key_coder_id, spec.timer_coder_id]))
-              # Inject the read and write pcollections.
-              timer_read_pcoll = unique_name(
-                  pipeline_components.pcollections,
-                  '%s_timers_to_read_%s' % (transform.unique_name, tag))
-              timer_write_pcoll = unique_name(
-                  pipeline_components.pcollections,
-                  '%s_timers_to_write_%s' % (transform.unique_name, tag))
-              pipeline_components.pcollections[timer_read_pcoll].CopyFrom(
-                  beam_runner_api_pb2.PCollection(
-                      unique_name=timer_read_pcoll,
-                      coder_id=key_timer_coder_id,
-                      windowing_strategy_id=input_pcoll.windowing_strategy_id,
-                      is_bounded=input_pcoll.is_bounded))
-              pipeline_components.pcollections[timer_write_pcoll].CopyFrom(
-                  beam_runner_api_pb2.PCollection(
-                      unique_name=timer_write_pcoll,
-                      coder_id=key_timer_coder_id,
-                      windowing_strategy_id=input_pcoll.windowing_strategy_id,
-                      is_bounded=input_pcoll.is_bounded))
-              stage.transforms.append(
-                  beam_runner_api_pb2.PTransform(
-                      unique_name=timer_read_pcoll + '/Read',
-                      outputs={'out': timer_read_pcoll},
-                      spec=beam_runner_api_pb2.FunctionSpec(
-                          urn=bundle_processor.DATA_INPUT_URN,
-                          payload=create_buffer_id(
-                              timer_read_pcoll, kind='timers'))))
-              stage.transforms.append(
-                  beam_runner_api_pb2.PTransform(
-                      unique_name=timer_write_pcoll + '/Write',
-                      inputs={'in': timer_write_pcoll},
-                      spec=beam_runner_api_pb2.FunctionSpec(
-                          urn=bundle_processor.DATA_OUTPUT_URN,
-                          payload=create_buffer_id(
-                              timer_write_pcoll, kind='timers'))))
-              assert tag not in transform.inputs
-              transform.inputs[tag] = timer_read_pcoll
-              assert tag not in transform.outputs
-              transform.outputs[tag] = timer_write_pcoll
-              stage.timer_pcollections.append(
-                  (timer_read_pcoll + '/Read', timer_write_pcoll))
-        yield stage
-
-    def sort_stages(stages):
-      """Order stages suitable for sequential execution.
-      """
-      seen = set()
-      ordered = []
-
-      def process(stage):
-        if stage not in seen:
-          seen.add(stage)
-          for prev in stage.must_follow:
-            process(prev)
-          ordered.append(stage)
-      for stage in stages:
-        process(stage)
-      return ordered
-
-    def window_pcollection_coders(stages):
-      # Some SDK workers require windowed coders for their PCollections.
-      # TODO(BEAM-4150): Consistently use unwindowed coders everywhere.
-      for pcoll in pipeline_components.pcollections.values():
-        if (pipeline_components.coders[pcoll.coder_id].spec.spec.urn
-            != common_urns.coders.WINDOWED_VALUE.urn):
-          original_coder_id = pcoll.coder_id
-          pcoll.coder_id = windowed_coder_id(
-              pcoll.coder_id,
-              pipeline_components.windowing_strategies[
-                  pcoll.windowing_strategy_id].window_coder_id)
-          if (original_coder_id in safe_coders
-              and pcoll.coder_id not in safe_coders):
-            # TODO: This assumes the window coder is safe.
-            safe_coders[pcoll.coder_id] = windowed_coder_id(
-                safe_coders[original_coder_id],
-                pipeline_components.windowing_strategies[
-                    pcoll.windowing_strategy_id].window_coder_id)
-
-      return stages
-
-    # Now actually apply the operations.
-
-    pipeline_components = copy.deepcopy(pipeline_proto.components)
+    pipeline_context = fn_api_runner_transforms.TransformContext(
+        copy.deepcopy(pipeline_proto.components),
+        use_state_iterables=self._use_state_iterables)
 
     # Initial set of stages are singleton leaf transforms.
     stages = list(fn_api_runner_transforms.leaf_transform_stages(
@@ -848,16 +310,23 @@ def window_pcollection_coders(stages):
 
     # Apply each phase in order.
     for phase in [
-        annotate_downstream_side_inputs, fix_side_input_pcoll_coders,
-        lift_combiners, expand_gbk, sink_flattens, greedily_fuse,
-        impulse_to_input, inject_timer_pcollections, sort_stages,
-        window_pcollection_coders]:
+        fn_api_runner_transforms.annotate_downstream_side_inputs,
+        fn_api_runner_transforms.fix_side_input_pcoll_coders,
+        fn_api_runner_transforms.lift_combiners,
+        fn_api_runner_transforms.expand_gbk,
+        fn_api_runner_transforms.sink_flattens,
+        fn_api_runner_transforms.greedily_fuse,
+        fn_api_runner_transforms.read_to_impulse,
+        fn_api_runner_transforms.impulse_to_input,
+        fn_api_runner_transforms.inject_timer_pcollections,
+        fn_api_runner_transforms.sort_stages,
+        fn_api_runner_transforms.window_pcollection_coders]:
       logging.info('%s %s %s', '=' * 20, phase, '=' * 20)
-      stages = list(phase(stages))
+      stages = list(phase(stages, pipeline_context))
       logging.debug('Stages: %s', [str(s) for s in stages])
 
     # Return the (possibly mutated) context and ordered set of stages.
-    return pipeline_components, stages, safe_coders
+    return pipeline_context.components, stages, pipeline_context.safe_coders
 
   def run_stages(self, pipeline_components, stages, safe_coders):
     worker_handler_manager = WorkerHandlerManager(
@@ -896,7 +365,7 @@ def iterable_state_write(values, element_coder_impl):
       out = create_OutputStream()
       for element in values:
         element_coder_impl.encode_to_stream(element, out, True)
-      controller.state_handler.blocking_append(
+      controller.state.blocking_append(
           beam_fn_api_pb2.StateKey(
               runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
           out.get())
@@ -919,7 +388,7 @@ def extract_endpoints(stage):
           pcoll_id = transform.spec.payload
           if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
             target = transform.unique_name, only_element(transform.outputs)
-            if pcoll_id == IMPULSE_BUFFER:
+            if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER:
               data_input[target] = [ENCODED_IMPULSE_VALUE]
             else:
               data_input[target] = pcoll_buffers[pcoll_id]
@@ -1610,11 +1079,3 @@ def monitoring_metrics(self):
       self._monitoring_metrics = FnApiMetrics(
           self._monitoring_infos_by_stage, user_metrics_only=False)
     return self._monitoring_metrics
-
-
-def create_buffer_id(name, kind='materialize'):
-  return ('%s:%s' % (kind, name)).encode('utf-8')
-
-
-def split_buffer_id(buffer_id):
-  return buffer_id.decode('utf-8').split(':', 1)
diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
index 0a252bacf67c..f84ce8e1cbb4 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
@@ -20,10 +20,14 @@
 from __future__ import absolute_import
 from __future__ import print_function
 
+import collections
 import functools
+import logging
 from builtins import object
 
+from apache_beam import coders
 from apache_beam.portability import common_urns
+from apache_beam.portability import python_urns
 from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.runners.worker import bundle_processor
 from apache_beam.utils import proto_utils
@@ -35,6 +39,8 @@
     [common_urns.primitives.GROUP_BY_KEY.urn,
      common_urns.composites.COMBINE_PER_KEY.urn])
 
+IMPULSE_BUFFER = b'impulse'
+
 
 class Stage(object):
   """A set of Transforms that can be sent to the worker for processing."""
@@ -157,18 +163,122 @@ def deduplicate_read(self):
     self.transforms = new_transforms
 
 
+def memoize_on_instance(f):
+  missing = object()
+
+  def wrapper(self, *args):
+    try:
+      cache = getattr(self, '_cache_%s' % f.__name__)
+    except AttributeError:
+      cache = {}
+      setattr(self, '_cache_%s' % f.__name__, cache)
+    result = cache.get(args, missing)
+    if result is missing:
+      result = cache[args] = f(self, *args)
+    return result
+
+  return wrapper
+
+
 class TransformContext(object):
-  def __init__(self, components):
+
+  _KNOWN_CODER_URNS = set(
+      value.urn for value in common_urns.coders.__dict__.values())
+
+  def __init__(self, components, use_state_iterables=False):
     self.components = components
+    self.use_state_iterables = use_state_iterables
+    self.safe_coders = {}
+    self.bytes_coder_id = self.add_or_get_coder_id(
+        coders.BytesCoder().to_runner_api(None), 'bytes_coder')
 
-  def add_or_get_coder_id(self, coder_proto):
+  def add_or_get_coder_id(self, coder_proto, coder_prefix='coder'):
     for coder_id, coder in self.components.coders.items():
       if coder == coder_proto:
         return coder_id
-    new_coder_id = unique_name(self.components.coders, 'coder')
+    new_coder_id = unique_name(self.components.coders, coder_prefix)
     self.components.coders[new_coder_id].CopyFrom(coder_proto)
     return new_coder_id
 
+  @memoize_on_instance
+  def with_state_iterables(self, coder_id):
+    coder = self.components.coders[coder_id]
+    if coder.spec.spec.urn == common_urns.coders.ITERABLE.urn:
+      new_coder_id = unique_name(
+          self.components.coders, coder_id + '_state_backed')
+      new_coder = self.components.coders[new_coder_id]
+      new_coder.CopyFrom(coder)
+      new_coder.spec.spec.urn = common_urns.coders.STATE_BACKED_ITERABLE.urn
+      new_coder.spec.spec.payload = b'1'
+      new_coder.component_coder_ids[0] = self.with_state_iterables(
+          coder.component_coder_ids[0])
+      return new_coder_id
+    else:
+      new_component_ids = [
+          self.with_state_iterables(c) for c in coder.component_coder_ids]
+      if new_component_ids == coder.component_coder_ids:
+        return coder_id
+      else:
+        new_coder_id = unique_name(
+            self.components.coders, coder_id + '_state_backed')
+        self.components.coders[new_coder_id].CopyFrom(
+            beam_runner_api_pb2.Coder(
+                spec=coder.spec,
+                component_coder_ids=new_component_ids))
+        return new_coder_id
+
+  @memoize_on_instance
+  def length_prefixed_coder(self, coder_id):
+    if coder_id in self.safe_coders:
+      return coder_id
+    length_prefixed_id, safe_id = self.length_prefixed_and_safe_coder(coder_id)
+    self.safe_coders[length_prefixed_id] = safe_id
+    return length_prefixed_id
+
+  @memoize_on_instance
+  def length_prefixed_and_safe_coder(self, coder_id):
+    coder = self.components.coders[coder_id]
+    if coder.spec.spec.urn == common_urns.coders.LENGTH_PREFIX.urn:
+      return coder_id, self.bytes_coder_id
+    elif coder.spec.spec.urn in self._KNOWN_CODER_URNS:
+      new_component_ids = [
+          self.length_prefixed_coder(c) for c in coder.component_coder_ids]
+      if new_component_ids == coder.component_coder_ids:
+        new_coder_id = coder_id
+      else:
+        new_coder_id = unique_name(
+            self.components.coders, coder_id + '_length_prefixed')
+        self.components.coders[new_coder_id].CopyFrom(
+            beam_runner_api_pb2.Coder(
+                spec=coder.spec,
+                component_coder_ids=new_component_ids))
+      safe_component_ids = [self.safe_coders[c] for c in new_component_ids]
+      if safe_component_ids == coder.component_coder_ids:
+        safe_coder_id = coder_id
+      else:
+        safe_coder_id = unique_name(
+            self.components.coders, coder_id + '_safe')
+        self.components.coders[safe_coder_id].CopyFrom(
+            beam_runner_api_pb2.Coder(
+                spec=coder.spec,
+                component_coder_ids=safe_component_ids))
+      return new_coder_id, safe_coder_id
+    else:
+      new_coder_id = unique_name(
+          self.components.coders, coder_id + '_length_prefixed')
+      self.components.coders[new_coder_id].CopyFrom(
+          beam_runner_api_pb2.Coder(
+              spec=beam_runner_api_pb2.SdkFunctionSpec(
+                  spec=beam_runner_api_pb2.FunctionSpec(
+                      urn=common_urns.coders.LENGTH_PREFIX.urn)),
+              component_coder_ids=[coder_id]))
+      return new_coder_id, self.bytes_coder_id
+
+  def length_prefix_pcoll_coders(self, pcoll_id):
+    self.components.pcollections[pcoll_id].coder_id = (
+        self.length_prefixed_coder(
+            self.components.pcollections[pcoll_id].coder_id))
+
 
 def leaf_transform_stages(
     root_ids, components, parent=None, known_composites=KNOWN_COMPOSITES):
@@ -222,6 +332,62 @@ def add_parent(child, parent):
   return new_proto
 
 
+def annotate_downstream_side_inputs(stages, pipeline_context):
+  """Annotate each stage with fusion-prohibiting information.
+
+  Each stage is annotated with the (transitive) set of pcollections that
+  depend on this stage that are also used later in the pipeline as a
+  side input.
+
+  While theoretically this could result in O(n^2) annotations, the size of
+  each set is bounded by the number of side inputs (typically much smaller
+  than the number of total nodes) and the number of *distinct* side-input
+  sets is also generally small (and shared due to the use of union
+  defined above).
+
+  This representation is also amenable to simple recomputation on fusion.
+  """
+  consumers = collections.defaultdict(list)
+  all_side_inputs = set()
+  for stage in stages:
+    for transform in stage.transforms:
+      for input in transform.inputs.values():
+        consumers[input].append(stage)
+    for si in stage.side_inputs():
+      all_side_inputs.add(si)
+  all_side_inputs = frozenset(all_side_inputs)
+
+  downstream_side_inputs_by_stage = {}
+
+  def compute_downstream_side_inputs(stage):
+    if stage not in downstream_side_inputs_by_stage:
+      downstream_side_inputs = frozenset()
+      for transform in stage.transforms:
+        for output in transform.outputs.values():
+          if output in all_side_inputs:
+            downstream_side_inputs = union(
+                downstream_side_inputs, frozenset([output]))
+          for consumer in consumers[output]:
+            downstream_side_inputs = union(
+                downstream_side_inputs,
+                compute_downstream_side_inputs(consumer))
+      downstream_side_inputs_by_stage[stage] = downstream_side_inputs
+    return downstream_side_inputs_by_stage[stage]
+
+  for stage in stages:
+    stage.downstream_side_inputs = compute_downstream_side_inputs(stage)
+  return stages
+
+
+def fix_side_input_pcoll_coders(stages, pipeline_context):
+  """Length prefix side input PCollection coders.
+  """
+  for stage in stages:
+    for si in stage.side_inputs():
+      pipeline_context.length_prefix_pcoll_coders(si)
+  return stages
+
+
 def lift_combiners(stages, context):
   """Expands CombinePerKey into pre- and post-grouping stages.
 
@@ -353,6 +519,393 @@ def make_stage(base_stage, transform):
       yield stage
 
 
+def expand_gbk(stages, pipeline_context):
+  """Transforms each GBK into a write followed by a read.
+  """
+  for stage in stages:
+    assert len(stage.transforms) == 1
+    transform = stage.transforms[0]
+    if transform.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
+      for pcoll_id in transform.inputs.values():
+        pipeline_context.length_prefix_pcoll_coders(pcoll_id)
+      for pcoll_id in transform.outputs.values():
+        if pipeline_context.use_state_iterables:
+          pipeline_context.components.pcollections[
+              pcoll_id].coder_id = pipeline_context.with_state_iterables(
+                  pipeline_context.components.pcollections[pcoll_id].coder_id)
+        pipeline_context.length_prefix_pcoll_coders(pcoll_id)
+
+      # This is used later to correlate the read and write.
+      transform_id = stage.name
+      if transform != pipeline_context.components.transforms.get(transform_id):
+        transform_id = unique_name(
+            pipeline_context.components.transforms, stage.name)
+        
pipeline_context.components.transforms[transform_id].CopyFrom(transform)
+      grouping_buffer = create_buffer_id(transform_id, kind='group')
+      gbk_write = Stage(
+          transform.unique_name + '/Write',
+          [beam_runner_api_pb2.PTransform(
+              unique_name=transform.unique_name + '/Write',
+              inputs=transform.inputs,
+              spec=beam_runner_api_pb2.FunctionSpec(
+                  urn=bundle_processor.DATA_OUTPUT_URN,
+                  payload=grouping_buffer))],
+          downstream_side_inputs=frozenset(),
+          must_follow=stage.must_follow)
+      yield gbk_write
+
+      yield Stage(
+          transform.unique_name + '/Read',
+          [beam_runner_api_pb2.PTransform(
+              unique_name=transform.unique_name + '/Read',
+              outputs=transform.outputs,
+              spec=beam_runner_api_pb2.FunctionSpec(
+                  urn=bundle_processor.DATA_INPUT_URN,
+                  payload=grouping_buffer))],
+          downstream_side_inputs=stage.downstream_side_inputs,
+          must_follow=union(frozenset([gbk_write]), stage.must_follow))
+    else:
+      yield stage
+
+
+def sink_flattens(stages, pipeline_context):
+  """Sink flattens and remove them from the graph.
+
+  A flatten that cannot be sunk/fused away becomes multiple writes (to the
+  same logical sink) followed by a read.
+  """
+  # TODO(robertwb): Actually attempt to sink rather than always materialize.
+  # TODO(robertwb): Possibly fuse this into one of the stages.
+  pcollections = pipeline_context.components.pcollections
+  for stage in stages:
+    assert len(stage.transforms) == 1
+    transform = stage.transforms[0]
+    if transform.spec.urn == common_urns.primitives.FLATTEN.urn:
+      # This is used later to correlate the read and writes.
+      buffer_id = create_buffer_id(transform.unique_name)
+      output_pcoll_id, = list(transform.outputs.values())
+      output_coder_id = pcollections[output_pcoll_id].coder_id
+      flatten_writes = []
+      for local_in, pcoll_in in transform.inputs.items():
+
+        if pcollections[pcoll_in].coder_id != output_coder_id:
+          # Flatten inputs must all be written with the same coder as is
+          # used to read them.
+          pcollections[pcoll_in].coder_id = output_coder_id
+          transcoded_pcollection = (
+              transform.unique_name + '/Transcode/' + local_in + '/out')
+          yield Stage(
+              transform.unique_name + '/Transcode/' + local_in,
+              [beam_runner_api_pb2.PTransform(
+                  unique_name=
+                  transform.unique_name + '/Transcode/' + local_in,
+                  inputs={local_in: pcoll_in},
+                  outputs={'out': transcoded_pcollection},
+                  spec=beam_runner_api_pb2.FunctionSpec(
+                      urn=bundle_processor.IDENTITY_DOFN_URN))],
+              downstream_side_inputs=frozenset(),
+              must_follow=stage.must_follow)
+          pcollections[transcoded_pcollection].CopyFrom(
+              pcollections[pcoll_in])
+          pcollections[transcoded_pcollection].coder_id = output_coder_id
+        else:
+          transcoded_pcollection = pcoll_in
+
+        flatten_write = Stage(
+            transform.unique_name + '/Write/' + local_in,
+            [beam_runner_api_pb2.PTransform(
+                unique_name=transform.unique_name + '/Write/' + local_in,
+                inputs={local_in: transcoded_pcollection},
+                spec=beam_runner_api_pb2.FunctionSpec(
+                    urn=bundle_processor.DATA_OUTPUT_URN,
+                    payload=buffer_id))],
+            downstream_side_inputs=frozenset(),
+            must_follow=stage.must_follow)
+        flatten_writes.append(flatten_write)
+        yield flatten_write
+
+      yield Stage(
+          transform.unique_name + '/Read',
+          [beam_runner_api_pb2.PTransform(
+              unique_name=transform.unique_name + '/Read',
+              outputs=transform.outputs,
+              spec=beam_runner_api_pb2.FunctionSpec(
+                  urn=bundle_processor.DATA_INPUT_URN,
+                  payload=buffer_id))],
+          downstream_side_inputs=stage.downstream_side_inputs,
+          must_follow=union(frozenset(flatten_writes), stage.must_follow))
+
+    else:
+      yield stage
+
+
+def greedily_fuse(stages, pipeline_context):
+  """Places transforms sharing an edge in the same stage, whenever possible.
+  """
+  producers_by_pcoll = {}
+  consumers_by_pcoll = collections.defaultdict(list)
+
+  # Used to always reference the correct stage as the producer and
+  # consumer maps are not updated when stages are fused away.
+  replacements = {}
+
+  def replacement(s):
+    old_ss = []
+    while s in replacements:
+      old_ss.append(s)
+      s = replacements[s]
+    for old_s in old_ss[:-1]:
+      replacements[old_s] = s
+    return s
+
+  def fuse(producer, consumer):
+    fused = producer.fuse(consumer)
+    replacements[producer] = fused
+    replacements[consumer] = fused
+
+  # First record the producers and consumers of each PCollection.
+  for stage in stages:
+    for transform in stage.transforms:
+      for input in transform.inputs.values():
+        consumers_by_pcoll[input].append(stage)
+      for output in transform.outputs.values():
+        producers_by_pcoll[output] = stage
+
+  logging.debug('consumers\n%s', consumers_by_pcoll)
+  logging.debug('producers\n%s', producers_by_pcoll)
+
+  # Now try to fuse away all pcollections.
+  for pcoll, producer in producers_by_pcoll.items():
+    write_pcoll = None
+    for consumer in consumers_by_pcoll[pcoll]:
+      producer = replacement(producer)
+      consumer = replacement(consumer)
+      # Update consumer.must_follow set, as it's used in can_fuse.
+      consumer.must_follow = frozenset(
+          replacement(s) for s in consumer.must_follow)
+      if producer.can_fuse(consumer):
+        fuse(producer, consumer)
+      else:
+        # If we can't fuse, do a read + write.
+        buffer_id = create_buffer_id(pcoll)
+        if write_pcoll is None:
+          write_pcoll = Stage(
+              pcoll + '/Write',
+              [beam_runner_api_pb2.PTransform(
+                  unique_name=pcoll + '/Write',
+                  inputs={'in': pcoll},
+                  spec=beam_runner_api_pb2.FunctionSpec(
+                      urn=bundle_processor.DATA_OUTPUT_URN,
+                      payload=buffer_id))])
+          fuse(producer, write_pcoll)
+        if consumer.has_as_main_input(pcoll):
+          read_pcoll = Stage(
+              pcoll + '/Read',
+              [beam_runner_api_pb2.PTransform(
+                  unique_name=pcoll + '/Read',
+                  outputs={'out': pcoll},
+                  spec=beam_runner_api_pb2.FunctionSpec(
+                      urn=bundle_processor.DATA_INPUT_URN,
+                      payload=buffer_id))],
+              must_follow=frozenset([write_pcoll]))
+          fuse(read_pcoll, consumer)
+        else:
+          consumer.must_follow = union(
+              consumer.must_follow, frozenset([write_pcoll]))
+
+  # Everything that was originally a stage or a replacement, but wasn't
+  # replaced, should be in the final graph.
+  final_stages = frozenset(stages).union(list(replacements.values()))\
+      .difference(list(replacements))
+
+  for stage in final_stages:
+    # Update all references to their final values before throwing
+    # the replacement data away.
+    stage.must_follow = frozenset(replacement(s) for s in stage.must_follow)
+    # Two reads of the same stage may have been fused.  This is unneeded.
+    stage.deduplicate_read()
+  return final_stages
+
+
+def read_to_impulse(stages, pipeline_context):
+  """Translates Read operations into Impulse operations."""
+  for stage in stages:
+    # First map Reads, if any, to Impulse + triggered read op.
+    for transform in list(stage.transforms):
+      if transform.spec.urn == common_urns.deprecated_primitives.READ.urn:
+        read_pc = only_element(transform.outputs.values())
+        read_pc_proto = pipeline_context.components.pcollections[read_pc]
+        impulse_pc = unique_name(
+            pipeline_context.components.pcollections, 'Impulse')
+        pipeline_context.components.pcollections[impulse_pc].CopyFrom(
+            beam_runner_api_pb2.PCollection(
+                unique_name=impulse_pc,
+                coder_id=pipeline_context.bytes_coder_id,
+                windowing_strategy_id=read_pc_proto.windowing_strategy_id,
+                is_bounded=read_pc_proto.is_bounded))
+        stage.transforms.remove(transform)
+        # TODO(robertwb): If this goes multi-process before fn-api
+        # read is default, expand into split + reshuffle + read.
+        stage.transforms.append(
+            beam_runner_api_pb2.PTransform(
+                unique_name=transform.unique_name + '/Impulse',
+                spec=beam_runner_api_pb2.FunctionSpec(
+                    urn=common_urns.primitives.IMPULSE.urn),
+                outputs={'out': impulse_pc}))
+        stage.transforms.append(
+            beam_runner_api_pb2.PTransform(
+                unique_name=transform.unique_name,
+                spec=beam_runner_api_pb2.FunctionSpec(
+                    urn=python_urns.IMPULSE_READ_TRANSFORM,
+                    payload=transform.spec.payload),
+                inputs={'in': impulse_pc},
+                outputs={'out': read_pc}))
+
+    yield stage
+
+
+def impulse_to_input(stages, pipeline_context):
+  """Translates Impulse operations into GRPC reads."""
+  for stage in stages:
+    for transform in list(stage.transforms):
+      if transform.spec.urn == common_urns.primitives.IMPULSE.urn:
+        stage.transforms.remove(transform)
+        stage.transforms.append(
+            beam_runner_api_pb2.PTransform(
+                unique_name=transform.unique_name,
+                spec=beam_runner_api_pb2.FunctionSpec(
+                    urn=bundle_processor.DATA_INPUT_URN,
+                    payload=IMPULSE_BUFFER),
+                outputs=transform.outputs))
+    yield stage
+
+
+def inject_timer_pcollections(stages, pipeline_context):
+  """Create PCollections for fired timers and to-be-set timers.
+
+  At execution time, fired timers and timers-to-set are represented as
+  PCollections that are managed by the runner.  This phase adds the
+  necissary collections, with their read and writes, to any stages using
+  timers.
+  """
+  for stage in stages:
+    for transform in list(stage.transforms):
+      if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
+        payload = proto_utils.parse_Bytes(
+            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
+        for tag, spec in payload.timer_specs.items():
+          if len(transform.inputs) > 1:
+            raise NotImplementedError('Timers and side inputs.')
+          input_pcoll = pipeline_context.components.pcollections[
+              next(iter(transform.inputs.values()))]
+          # Create the appropriate coder for the timer PCollection.
+          key_coder_id = input_pcoll.coder_id
+          if (pipeline_context.components.coders[key_coder_id].spec.spec.urn
+              == common_urns.coders.KV.urn):
+            key_coder_id = pipeline_context.components.coders[
+                key_coder_id].component_coder_ids[0]
+          key_timer_coder_id = pipeline_context.add_or_get_coder_id(
+              beam_runner_api_pb2.Coder(
+                  spec=beam_runner_api_pb2.SdkFunctionSpec(
+                      spec=beam_runner_api_pb2.FunctionSpec(
+                          urn=common_urns.coders.KV.urn)),
+                  component_coder_ids=[key_coder_id, spec.timer_coder_id]))
+          # Inject the read and write pcollections.
+          timer_read_pcoll = unique_name(
+              pipeline_context.components.pcollections,
+              '%s_timers_to_read_%s' % (transform.unique_name, tag))
+          timer_write_pcoll = unique_name(
+              pipeline_context.components.pcollections,
+              '%s_timers_to_write_%s' % (transform.unique_name, tag))
+          pipeline_context.components.pcollections[timer_read_pcoll].CopyFrom(
+              beam_runner_api_pb2.PCollection(
+                  unique_name=timer_read_pcoll,
+                  coder_id=key_timer_coder_id,
+                  windowing_strategy_id=input_pcoll.windowing_strategy_id,
+                  is_bounded=input_pcoll.is_bounded))
+          pipeline_context.components.pcollections[timer_write_pcoll].CopyFrom(
+              beam_runner_api_pb2.PCollection(
+                  unique_name=timer_write_pcoll,
+                  coder_id=key_timer_coder_id,
+                  windowing_strategy_id=input_pcoll.windowing_strategy_id,
+                  is_bounded=input_pcoll.is_bounded))
+          stage.transforms.append(
+              beam_runner_api_pb2.PTransform(
+                  unique_name=timer_read_pcoll + '/Read',
+                  outputs={'out': timer_read_pcoll},
+                  spec=beam_runner_api_pb2.FunctionSpec(
+                      urn=bundle_processor.DATA_INPUT_URN,
+                      payload=create_buffer_id(
+                          timer_read_pcoll, kind='timers'))))
+          stage.transforms.append(
+              beam_runner_api_pb2.PTransform(
+                  unique_name=timer_write_pcoll + '/Write',
+                  inputs={'in': timer_write_pcoll},
+                  spec=beam_runner_api_pb2.FunctionSpec(
+                      urn=bundle_processor.DATA_OUTPUT_URN,
+                      payload=create_buffer_id(
+                          timer_write_pcoll, kind='timers'))))
+          assert tag not in transform.inputs
+          transform.inputs[tag] = timer_read_pcoll
+          assert tag not in transform.outputs
+          transform.outputs[tag] = timer_write_pcoll
+          stage.timer_pcollections.append(
+              (timer_read_pcoll + '/Read', timer_write_pcoll))
+    yield stage
+
+
+def sort_stages(stages, pipeline_context):
+  """Order stages suitable for sequential execution.
+  """
+  seen = set()
+  ordered = []
+
+  def process(stage):
+    if stage not in seen:
+      seen.add(stage)
+      for prev in stage.must_follow:
+        process(prev)
+      ordered.append(stage)
+  for stage in stages:
+    process(stage)
+  return ordered
+
+
+def window_pcollection_coders(stages, pipeline_context):
+  """Wrap all PCollection coders as windowed value coders.
+
+  This is required as some SDK workers require windowed coders for their
+  PCollections.
+  TODO(BEAM-4150): Consistently use unwindowed coders everywhere.
+  """
+  def windowed_coder_id(coder_id, window_coder_id):
+    proto = beam_runner_api_pb2.Coder(
+        spec=beam_runner_api_pb2.SdkFunctionSpec(
+            spec=beam_runner_api_pb2.FunctionSpec(
+                urn=common_urns.coders.WINDOWED_VALUE.urn)),
+        component_coder_ids=[coder_id, window_coder_id])
+    return pipeline_context.add_or_get_coder_id(
+        proto, coder_id + '_windowed')
+
+  for pcoll in pipeline_context.components.pcollections.values():
+    if (pipeline_context.components.coders[pcoll.coder_id].spec.spec.urn
+        != common_urns.coders.WINDOWED_VALUE.urn):
+      original_coder_id = pcoll.coder_id
+      pcoll.coder_id = windowed_coder_id(
+          pcoll.coder_id,
+          pipeline_context.components.windowing_strategies[
+              pcoll.windowing_strategy_id].window_coder_id)
+      if (original_coder_id in pipeline_context.safe_coders
+          and pcoll.coder_id not in pipeline_context.safe_coders):
+        # TODO: This assumes the window coder is safe.
+        pipeline_context.safe_coders[pcoll.coder_id] = windowed_coder_id(
+            pipeline_context.safe_coders[original_coder_id],
+            pipeline_context.components.windowing_strategies[
+                pcoll.windowing_strategy_id].window_coder_id)
+
+  return stages
+
+
 def union(a, b):
   # Minimize the number of distinct sets.
   if not a or a == b:
@@ -385,3 +938,11 @@ def unique_name(existing, prefix):
 def only_element(iterable):
   element, = iterable
   return element
+
+
+def create_buffer_id(name, kind='materialize'):
+  return ('%s:%s' % (kind, name)).encode('utf-8')
+
+
+def split_buffer_id(buffer_id):
+  return buffer_id.decode('utf-8').split(':', 1)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Issue Time Tracking
-------------------

    Worklog Id:     (was: 176942)
    Time Spent: 2.5h  (was: 2h 20m)

> Cleanup FnApiRunner optimization phases.
> ----------------------------------------
>
>                 Key: BEAM-6186
>                 URL: https://issues.apache.org/jira/browse/BEAM-6186
>             Project: Beam
>          Issue Type: Improvement
>          Components: sdk-py-core
>            Reporter: Robert Bradshaw
>            Assignee: Ahmet Altay
>            Priority: Minor
>          Time Spent: 2.5h
>  Remaining Estimate: 0h
>
> They are currently expressed as functions with closure. It would be good to 
> pull them out with explicit dependencies both to better be able to follow the 
> code, and also be able to test and reuse them.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to