This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 63dbbf4  [BEAM-3742] Checkpointing for SDF over FnAPI in Python SDK 
and FnApiRunner. (#7722)
63dbbf4 is described below

commit 63dbbf4affc2082015d35f27f393f1b6d0a0493b
Author: Robert Bradshaw <rober...@gmail.com>
AuthorDate: Tue Feb 5 22:27:24 2019 +0100

    [BEAM-3742] Checkpointing for SDF over FnAPI in Python SDK and FnApiRunner. 
(#7722)
---
 sdks/python/apache_beam/io/restriction_trackers.py |  12 +-
 sdks/python/apache_beam/portability/common_urns.py |   2 +
 sdks/python/apache_beam/runners/common.pxd         |   3 +
 sdks/python/apache_beam/runners/common.py          |  46 ++++++-
 .../apache_beam/runners/direct/direct_runner.py    |   4 -
 .../runners/direct/sdf_direct_runner_test.py       |   3 +-
 .../runners/portability/flink_runner_test.py       |   3 +
 .../runners/portability/fn_api_runner.py           |  40 ++++--
 .../runners/portability/fn_api_runner_test.py      |  38 ++++++
 .../portability/fn_api_runner_transforms.py        | 137 +++++++++++++++++++--
 .../apache_beam/runners/worker/bundle_processor.py | 126 +++++++++++++++++--
 .../apache_beam/runners/worker/operations.pxd      |   6 +
 .../apache_beam/runners/worker/operations.py       |  16 ++-
 .../apache_beam/runners/worker/sdk_worker.py       |   3 +-
 sdks/python/apache_beam/transforms/core.py         |  10 ++
 15 files changed, 405 insertions(+), 44 deletions(-)

diff --git a/sdks/python/apache_beam/io/restriction_trackers.py 
b/sdks/python/apache_beam/io/restriction_trackers.py
index b9b1f17..e72d508 100644
--- a/sdks/python/apache_beam/io/restriction_trackers.py
+++ b/sdks/python/apache_beam/io/restriction_trackers.py
@@ -80,8 +80,9 @@ class OffsetRestrictionTracker(RestrictionTracker):
     self._range = OffsetRange(start_position, stop_position)
     self._current_position = None
     self._last_claim_attempt = None
+    self._deferred_residual = None
     self._checkpointed = False
-    self._lock = threading.Lock()
+    self._lock = threading.RLock()
 
   def check_done(self):
     with self._lock:
@@ -139,3 +140,12 @@ class OffsetRestrictionTracker(RestrictionTracker):
 
       self._range = OffsetRange(self._range.start, end_position)
       return residual_range
+
+  def defer_remainder(self, watermark=None):
+    with self._lock:
+      self._deferred_watermark = watermark
+      self._deferred_residual = self.checkpoint()
+
+  def deferred_status(self):
+    if self._deferred_residual:
+      return (self._deferred_residual, self._deferred_watermark)
diff --git a/sdks/python/apache_beam/portability/common_urns.py 
b/sdks/python/apache_beam/portability/common_urns.py
index 4ee32a7..a000f81 100644
--- a/sdks/python/apache_beam/portability/common_urns.py
+++ b/sdks/python/apache_beam/portability/common_urns.py
@@ -49,6 +49,8 @@ composites = PropertiesFromEnumType(
     beam_runner_api_pb2.StandardPTransforms.Composites)
 combine_components = PropertiesFromEnumType(
     beam_runner_api_pb2.StandardPTransforms.CombineComponents)
+sdf_components = PropertiesFromEnumType(
+    beam_runner_api_pb2.StandardPTransforms.SplittableParDoComponents)
 
 side_inputs = PropertiesFromEnumType(
     beam_runner_api_pb2.StandardSideInputTypes.Enum)
diff --git a/sdks/python/apache_beam/runners/common.pxd 
b/sdks/python/apache_beam/runners/common.pxd
index 49f4c44..b5ab88d 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -82,6 +82,9 @@ cdef class PerWindowInvoker(DoFnInvoker):
   cdef bint has_windowed_inputs
   cdef bint cache_globally_windowed_args
   cdef object process_method
+  cdef bint is_splittable
+  cdef object restriction_tracker
+  cdef WindowedValue current_windowed_value
 
 
 cdef class DoFnRunner(Receiver):
diff --git a/sdks/python/apache_beam/runners/common.py 
b/sdks/python/apache_beam/runners/common.py
index 99a4bca..3d9b07f 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -211,7 +211,7 @@ class DoFnSignature(object):
     self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle')
     self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle')
 
-    restriction_provider = self._get_restriction_provider(do_fn)
+    restriction_provider = self.get_restriction_provider()
     self.initial_restriction_method = (
         MethodWrapper(restriction_provider, 'initial_restriction')
         if restriction_provider else None)
@@ -237,7 +237,7 @@ class DoFnSignature(object):
         method = timer_spec._attached_callback
         self.timer_methods[timer_spec] = MethodWrapper(do_fn, method.__name__)
 
-  def _get_restriction_provider(self, do_fn):
+  def get_restriction_provider(self):
     result = _find_param_with_default(self.process_method,
                                       default_as_type=RestrictionProvider)
     return result[1] if result else None
@@ -434,6 +434,9 @@ class PerWindowInvoker(DoFnInvoker):
         (core.DoFn.WindowParam in default_arg_values) or
         signature.is_stateful_dofn())
     self.user_state_context = user_state_context
+    self.is_splittable = signature.is_splittable_dofn()
+    self.restriction_tracker = None
+    self.current_windowed_value = None
 
     # Try to prepare all the arguments that can just be filled in
     # without any additional work. in the process function.
@@ -515,7 +518,16 @@ class PerWindowInvoker(DoFnInvoker):
     # or if the process accesses the window parameter. We can just call it once
     # otherwise as none of the arguments are changing
 
+    if self.is_splittable and not restriction_tracker:
+      restriction = self.invoke_initial_restriction(windowed_value.value)
+      restriction_tracker = self.invoke_create_tracker(restriction)
+
     if restriction_tracker:
+      if len(windowed_value.windows) > 1 and self.has_windowed_inputs:
+        # Should never get here due to window explosion in
+        # the upstream pair-with-restriction.
+        raise NotImplementedError(
+            'SDFs in multiply-windowed values with windowed arguments.')
       restriction_tracker_param = _find_param_with_default(
           self.signature.process_method,
           default_as_type=core.RestrictionProvider)[0]
@@ -524,7 +536,17 @@ class PerWindowInvoker(DoFnInvoker):
             'A RestrictionTracker %r was provided but DoFn does not have a '
             'RestrictionTrackerParam defined' % restriction_tracker)
       additional_kwargs[restriction_tracker_param] = restriction_tracker
-    if self.has_windowed_inputs and len(windowed_value.windows) != 1:
+      try:
+        self.current_windowed_value = windowed_value
+        self.restriction_tracker = restriction_tracker
+        return self._invoke_per_window(
+            windowed_value, additional_args, additional_kwargs,
+            output_processor)
+      finally:
+        self.restriction_tracker = None
+        self.current_windowed_value = windowed_value
+
+    elif self.has_windowed_inputs and len(windowed_value.windows) != 1:
       for w in windowed_value.windows:
         self._invoke_per_window(
             WindowedValue(windowed_value.value, windowed_value.timestamp, 
(w,)),
@@ -602,6 +624,15 @@ class PerWindowInvoker(DoFnInvoker):
       output_processor.process_outputs(
           windowed_value, self.process_method(*args_for_process))
 
+    if self.is_splittable:
+      deferred_status = self.restriction_tracker.deferred_status()
+      if deferred_status:
+        deferred_restriction, deferred_watermark = deferred_status
+        return (
+            windowed_value.with_value(
+                (windowed_value.value, deferred_restriction)),
+            deferred_watermark)
+
 
 class DoFnRunner(Receiver):
   """For internal use only; no backwards-compatibility guarantees.
@@ -679,10 +710,17 @@ class DoFnRunner(Receiver):
 
   def process(self, windowed_value):
     try:
-      self.do_fn_invoker.invoke_process(windowed_value)
+      return self.do_fn_invoker.invoke_process(windowed_value)
     except BaseException as exn:
       self._reraise_augmented(exn)
 
+  def process_with_restriction(self, windowed_value):
+    element, restriction = windowed_value.value
+    return self.do_fn_invoker.invoke_process(
+        windowed_value.with_value(element),
+        restriction_tracker=self.do_fn_invoker.invoke_create_tracker(
+            restriction))
+
   def process_user_timer(self, timer_spec, key, window, timestamp):
     try:
       self.do_fn_invoker.invoke_user_timer(timer_spec, key, window, timestamp)
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py 
b/sdks/python/apache_beam/runners/direct/direct_runner.py
index 770044c..43e8c7f 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -76,7 +76,6 @@ class SwitchingDirectRunner(PipelineRunner):
       use_fnapi_runner = False
 
     from apache_beam.pipeline import PipelineVisitor
-    from apache_beam.runners.common import DoFnSignature
     from apache_beam.runners.dataflow.native_io.iobase import NativeSource
     from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite
     from apache_beam.testing.test_stream import TestStream
@@ -103,9 +102,6 @@ class SwitchingDirectRunner(PipelineRunner):
           self.supported_by_fnapi_runner = False
         if isinstance(transform, beam.ParDo):
           dofn = transform.dofn
-          # The FnApiRunner does not support execution of SplittableDoFns.
-          if DoFnSignature(dofn).is_splittable_dofn():
-            self.supported_by_fnapi_runner = False
           # The FnApiRunner does not support execution of CombineFns with
           # deferred side inputs.
           if isinstance(dofn, CombineValuesDoFn):
diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py 
b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
index d5924cb..eae38bc 100644
--- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
+++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
@@ -35,7 +35,6 @@ from apache_beam.pvalue import AsSingleton
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
-from apache_beam.transforms.core import ProcessContinuation
 from apache_beam.transforms.core import RestrictionProvider
 from apache_beam.transforms.trigger import AccumulationMode
 from apache_beam.transforms.window import SlidingWindows
@@ -83,7 +82,7 @@ class ReadFiles(DoFn):
         output_count += 1
 
         if self._resume_count and output_count == self._resume_count:
-          yield ProcessContinuation()
+          restriction_tracker.defer_remainder()
           break
 
         pos += len_line
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py 
b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index 958bb8e..767203b 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -197,6 +197,9 @@ if __name__ == '__main__':
                 counter_name, line)
         )
 
+    def test_sdf(self):
+      raise unittest.SkipTest("BEAM-2939")
+
     # Inherits all other tests.
 
   # Run the tests.
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index be8799b..e908a5c 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -322,6 +322,7 @@ class FnApiRunner(runner.PipelineRunner):
         phases=[fn_api_runner_transforms.annotate_downstream_side_inputs,
                 fn_api_runner_transforms.fix_side_input_pcoll_coders,
                 fn_api_runner_transforms.lift_combiners,
+                fn_api_runner_transforms.expand_sdf,
                 fn_api_runner_transforms.expand_gbk,
                 fn_api_runner_transforms.sink_flattens,
                 fn_api_runner_transforms.greedily_fuse,
@@ -413,7 +414,7 @@ class FnApiRunner(runner.PipelineRunner):
             data_spec.api_service_descriptor.url = (
                 data_api_service_descriptor.url)
           transform.spec.payload = data_spec.SerializeToString()
-        elif transform.spec.urn == common_urns.primitives.PAR_DO.urn:
+        elif transform.spec.urn in fn_api_runner_transforms.PAR_DO_URNS:
           payload = proto_utils.parse_Bytes(
               transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
           for tag, si in payload.side_inputs.items():
@@ -496,11 +497,15 @@ class FnApiRunner(runner.PipelineRunner):
 
     result = BundleManager(
         controller, get_buffer, process_bundle_descriptor,
-        self._progress_frequency).process_bundle(data_input, data_output)
+        self._progress_frequency).process_bundle(
+            data_input, data_output)
 
+    last_result = result
     while True:
-      timer_inputs = {}
+      deferred_inputs = collections.defaultdict(list)
       for transform_id, timer_writes in stage.timer_pcollections:
+
+        # Queue any set timers as new inputs.
         windowed_timer_coder_impl = context.coders[
             pipeline_components.pcollections[timer_writes].coder_id].get_impl()
         written_timers = get_buffer(
@@ -522,20 +527,37 @@ class FnApiRunner(runner.PipelineRunner):
           for windowed_key_timer in timers_by_key_and_window.values():
             windowed_timer_coder_impl.encode_to_stream(
                 windowed_key_timer, out, True)
-          timer_inputs[transform_id, 'out'] = [out.get()]
+          deferred_inputs[transform_id, 'out'] = [out.get()]
           written_timers[:] = []
-      if timer_inputs:
+
+      # Queue any delayed bundle applications.
+      for delayed_application in last_result.process_bundle.residual_roots:
+        # Find the io transform that feeds this transform.
+        # TODO(SDF): Memoize?
+        application = delayed_application.application
+        input_pcoll = process_bundle_descriptor.transforms[
+            application.ptransform_id].inputs[application.input_id]
+        for input_id, proto in process_bundle_descriptor.transforms.items():
+          if (proto.spec.urn == bundle_processor.DATA_INPUT_URN
+              and input_pcoll in proto.outputs.values()):
+            deferred_inputs[input_id, 'out'].append(application.element)
+            break
+        else:
+          raise RuntimeError(
+              'No IO transform feeds %s' % application.ptransform_id)
+
+      if deferred_inputs:
         # The worker will be waiting on these inputs as well.
         for other_input in data_input:
-          if other_input not in timer_inputs:
-            timer_inputs[other_input] = []
+          if other_input not in deferred_inputs:
+            deferred_inputs[other_input] = []
         # TODO(robertwb): merge results
-        BundleManager(
+        last_result = BundleManager(
             controller,
             get_buffer,
             process_bundle_descriptor,
             self._progress_frequency,
-            True).process_bundle(timer_inputs, data_output)
+            True).process_bundle(deferred_inputs, data_output)
       else:
         break
 
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index f08c13b..6c4cad9 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -30,6 +30,7 @@ from tenacity import retry
 from tenacity import stop_after_attempt
 
 import apache_beam as beam
+from apache_beam.io import restriction_trackers
 from apache_beam.metrics import monitoring_infos
 from apache_beam.metrics.execution import MetricKey
 from apache_beam.metrics.execution import MetricsEnvironment
@@ -351,6 +352,43 @@ class FnApiRunnerTest(unittest.TestCase):
 
       assert_that(actual, is_buffered_correctly)
 
+  def test_sdf(self):
+
+    class ExpandStringsProvider(beam.transforms.core.RestrictionProvider):
+      def initial_restriction(self, element):
+        return (0, len(element))
+
+      def create_tracker(self, restriction):
+        return restriction_trackers.OffsetRestrictionTracker(
+            restriction[0], restriction[1])
+
+      def split(self, element, restriction):
+        start, end = restriction
+        middle = (end - start) // 2
+        return [(start, middle), (middle, end)]
+
+    class ExpandStringsDoFn(beam.DoFn):
+      def process(self, element, restriction_tracker=ExpandStringsProvider()):
+        assert isinstance(
+            restriction_tracker,
+            restriction_trackers.OffsetRestrictionTracker), restriction_tracker
+        for k in range(*restriction_tracker.current_restriction()):
+          if not restriction_tracker.try_claim(k):
+            return
+          yield element[k]
+          if k % 2 == 1:
+            restriction_tracker.defer_remainder()
+            return
+
+    with self.create_pipeline() as p:
+      data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
+      actual = (
+          p
+          | beam.Create(data)
+          | beam.ParDo(ExpandStringsDoFn()))
+
+      assert_that(actual, equal_to(list(''.join(data))))
+
   def test_group_by_key(self):
     with self.create_pipeline() as p:
       res = (p
diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
index 21f8fa2..2482987 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
@@ -25,6 +25,8 @@ import functools
 import logging
 from builtins import object
 
+from past.builtins import unicode
+
 from apache_beam import coders
 from apache_beam.portability import common_urns
 from apache_beam.portability import python_urns
@@ -35,9 +37,21 @@ from apache_beam.utils import proto_utils
 # This module is experimental. No backwards-compatibility guarantees.
 
 
-KNOWN_COMPOSITES = frozenset(
-    [common_urns.primitives.GROUP_BY_KEY.urn,
-     common_urns.composites.COMBINE_PER_KEY.urn])
+KNOWN_COMPOSITES = frozenset([
+    common_urns.primitives.GROUP_BY_KEY.urn,
+    common_urns.composites.COMBINE_PER_KEY.urn])
+
+COMBINE_URNS = frozenset([
+    common_urns.composites.COMBINE_PER_KEY.urn,
+    common_urns.combine_components.COMBINE_PGBKCV.urn,
+    common_urns.combine_components.COMBINE_MERGE_ACCUMULATORS.urn,
+    common_urns.combine_components.COMBINE_EXTRACT_OUTPUTS.urn])
+
+PAR_DO_URNS = frozenset([
+    common_urns.primitives.PAR_DO.urn,
+    common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
+    common_urns.sdf_components.SPLIT_RESTRICTION.urn,
+    common_urns.sdf_components.PROCESS_ELEMENTS.urn])
 
 IMPULSE_BUFFER = b'impulse'
 
@@ -76,15 +90,11 @@ class Stage(object):
 
   @staticmethod
   def _extract_environment(transform):
-    if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
+    if transform.spec.urn in PAR_DO_URNS:
       pardo_payload = proto_utils.parse_Bytes(
           transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
       return pardo_payload.do_fn.environment_id
-    elif transform.spec.urn in (
-        common_urns.composites.COMBINE_PER_KEY.urn,
-        common_urns.combine_components.COMBINE_PGBKCV.urn,
-        common_urns.combine_components.COMBINE_MERGE_ACCUMULATORS.urn,
-        common_urns.combine_components.COMBINE_EXTRACT_OUTPUTS.urn):
+    elif transform.spec.urn in COMBINE_URNS:
       combine_payload = proto_utils.parse_Bytes(
           transform.spec.payload, beam_runner_api_pb2.CombinePayload)
       return combine_payload.combine_fn.environment_id
@@ -137,7 +147,7 @@ class Stage(object):
 
   def side_inputs(self):
     for transform in self.transforms:
-      if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
+      if transform.spec.urn in PAR_DO_URNS:
         payload = proto_utils.parse_Bytes(
             transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
         for side_input in payload.side_inputs:
@@ -145,7 +155,7 @@ class Stage(object):
 
   def has_as_main_input(self, pcoll):
     for transform in self.transforms:
-      if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
+      if transform.spec.urn in PAR_DO_URNS:
         payload = proto_utils.parse_Bytes(
             transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
         local_side_inputs = payload.side_inputs
@@ -681,6 +691,108 @@ def lift_combiners(stages, context):
       yield stage
 
 
+def expand_sdf(stages, context):
+  """Transforms splitable DoFns into pair+split+read."""
+  for stage in stages:
+    assert len(stage.transforms) == 1
+    transform = stage.transforms[0]
+    if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
+
+      pardo_payload = proto_utils.parse_Bytes(
+          transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
+
+      if pardo_payload.splittable:
+
+        def copy_like(protos, original, suffix='_copy', **kwargs):
+          if isinstance(original, (str, unicode)):
+            key = original
+            original = protos[original]
+          else:
+            key = 'component'
+          new_id = unique_name(protos, key + suffix)
+          protos[new_id].CopyFrom(original)
+          proto = protos[new_id]
+          for name, value in kwargs.items():
+            if isinstance(value, dict):
+              getattr(proto, name).clear()
+              getattr(proto, name).update(value)
+            elif isinstance(value, list):
+              del getattr(proto, name)[:]
+              getattr(proto, name).extend(value)
+            elif name == 'urn':
+              proto.spec.urn = value
+            else:
+              setattr(proto, name, value)
+          return new_id
+
+        def make_stage(base_stage, transform_id, extra_must_follow=()):
+          transform = context.components.transforms[transform_id]
+          return Stage(
+              transform.unique_name,
+              [transform],
+              base_stage.downstream_side_inputs,
+              union(base_stage.must_follow, frozenset(extra_must_follow)),
+              parent=base_stage,
+              environment=base_stage.environment)
+
+        main_input_tag = only_element(tag for tag in transform.inputs.keys()
+                                      if tag not in pardo_payload.side_inputs)
+        main_input_id = transform.inputs[main_input_tag]
+        element_coder_id = context.components.pcollections[
+            main_input_id].coder_id
+        paired_coder_id = context.add_or_get_coder_id(
+            beam_runner_api_pb2.Coder(
+                spec=beam_runner_api_pb2.SdkFunctionSpec(
+                    spec=beam_runner_api_pb2.FunctionSpec(
+                        urn=common_urns.coders.KV.urn)),
+                component_coder_ids=[element_coder_id,
+                                     pardo_payload.restriction_coder_id]))
+
+        paired_pcoll_id = copy_like(
+            context.components.pcollections,
+            main_input_id,
+            '_paired',
+            coder_id=paired_coder_id)
+        pair_transform_id = copy_like(
+            context.components.transforms,
+            transform,
+            unique_name=transform.unique_name + '/PairWithRestriction',
+            urn=common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
+            outputs={'out': paired_pcoll_id})
+
+        split_pcoll_id = copy_like(
+            context.components.pcollections,
+            main_input_id,
+            '_split',
+            coder_id=paired_coder_id)
+        split_transform_id = copy_like(
+            context.components.transforms,
+            transform,
+            unique_name=transform.unique_name + '/SplitRestriction',
+            urn=common_urns.sdf_components.SPLIT_RESTRICTION.urn,
+            inputs=dict(transform.inputs, **{main_input_tag: paired_pcoll_id}),
+            outputs={'out': split_pcoll_id})
+
+        process_transform_id = copy_like(
+            context.components.transforms,
+            transform,
+            unique_name=transform.unique_name + '/Process',
+            urn=common_urns.sdf_components.PROCESS_ELEMENTS.urn,
+            inputs=dict(transform.inputs, **{main_input_tag: split_pcoll_id}))
+
+        yield make_stage(stage, pair_transform_id)
+        split_stage = make_stage(stage, split_transform_id)
+        yield split_stage
+        yield make_stage(
+            stage, process_transform_id, extra_must_follow=[split_stage])
+
+      else:
+        yield stage
+
+    else:
+      yield stage
+
+
 def expand_gbk(stages, pipeline_context):
   """Transforms each GBK into a write followed by a read.
   """
@@ -861,6 +973,7 @@ def greedily_fuse(stages, pipeline_context):
         fuse(producer, consumer)
       else:
         # If we can't fuse, do a read + write.
+        pipeline_context.length_prefix_pcoll_coders(pcoll)
         buffer_id = create_buffer_id(pcoll)
         if write_pcoll is None:
           write_pcoll = Stage(
@@ -994,7 +1107,7 @@ def inject_timer_pcollections(stages, pipeline_context):
   """
   for stage in stages:
     for transform in list(stage.transforms):
-      if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
+      if transform.spec.urn in PAR_DO_URNS:
         payload = proto_utils.parse_Bytes(
             transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
         for tag, spec in payload.timer_specs.items():
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 7126215..6cffc02 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -31,6 +31,7 @@ from builtins import next
 from builtins import object
 
 from future.utils import itervalues
+from google import protobuf
 
 import apache_beam as beam
 from apache_beam import coders
@@ -43,6 +44,7 @@ from apache_beam.portability import common_urns
 from apache_beam.portability import python_urns
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.runners import common
 from apache_beam.runners import pipeline_context
 from apache_beam.runners.dataflow import dataflow_runner
 from apache_beam.runners.worker import operation_specs
@@ -474,10 +476,12 @@ class BundleProcessor(object):
         expected_inputs.append(op)
 
     try:
+      execution_context = ExecutionContext()
       self.state_sampler.start()
       # Start all operations.
       for op in reversed(self.ops.values()):
         logging.debug('start %s', op)
+        op.execution_context = execution_context
         op.start()
 
       # Inject inputs from data plane.
@@ -499,9 +503,32 @@ class BundleProcessor(object):
       for op in self.ops.values():
         logging.debug('finish %s', op)
         op.finish()
+
+      return [
+          self.delayed_bundle_application(op, residual)
+          for op, residual in execution_context.delayed_applications]
+
     finally:
       self.state_sampler.stop_if_still_running()
 
+  def delayed_bundle_application(self, op, deferred_remainder):
+    ptransform_id, main_input_tag, main_input_coder, outputs = op.input_info
+    # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder.
+    element_and_restriction, watermark = deferred_remainder
+    if watermark:
+      proto_watermark = protobuf.Timestamp()
+      proto_watermark.FromMicroseconds(watermark.micros)
+      output_watermarks = {output: proto_watermark for output in outputs}
+    else:
+      output_watermarks = None
+    return beam_fn_api_pb2.DelayedBundleApplication(
+        application=beam_fn_api_pb2.BundleApplication(
+            ptransform_id=ptransform_id,
+            input_id=main_input_tag,
+            output_watermarks=output_watermarks,
+            element=main_input_coder.get_impl().encode_nested(
+                element_and_restriction)))
+
   def metrics(self):
     # DEPRECATED
     return beam_fn_api_pb2.Metrics(
@@ -553,6 +580,11 @@ class BundleProcessor(object):
     return monitoring_info
 
 
+class ExecutionContext(object):
+  def __init__(self):
+    self.delayed_applications = []
+
+
 class BeamTransformFactory(object):
   """Factory for turning transform_protos into executable operations."""
   def __init__(self, descriptor, data_channel_factory, counter_factory,
@@ -766,6 +798,70 @@ def create(factory, transform_id, transform_proto, 
serialized_fn, consumers):
 
 
 @BeamTransformFactory.register_urn(
+    common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
+    beam_runner_api_pb2.ParDoPayload)
+def create(*args):
+
+  class CreateRestriction(beam.DoFn):
+    def __init__(self, fn, restriction_provider):
+      self.restriction_provider = restriction_provider
+
+    # An unused window is requested to force explosion of multi-window
+    # WindowedValues.
+    def process(
+        self, element, _unused_window=beam.DoFn.WindowParam, *args, **kwargs):
+      # TODO(SDF): Do we want to allow mutation of the element?
+      # (E.g. it could be nice to shift bulky description to the portion
+      # that can be distributed.)
+      yield element, self.restriction_provider.initial_restriction(element)
+
+  return _create_sdf_operation(CreateRestriction, *args)
+
+
+@BeamTransformFactory.register_urn(
+    common_urns.sdf_components.SPLIT_RESTRICTION.urn,
+    beam_runner_api_pb2.ParDoPayload)
+def create(*args):
+
+  class SplitRestriction(beam.DoFn):
+    def __init__(self, fn, restriction_provider):
+      self.restriction_provider = restriction_provider
+
+    def process(self, element_restriction, *args, **kwargs):
+      element, restriction = element_restriction
+      for part in self.restriction_provider.split(element, restriction):
+        yield element, part
+
+  return _create_sdf_operation(SplitRestriction, *args)
+
+
+@BeamTransformFactory.register_urn(
+    common_urns.sdf_components.PROCESS_ELEMENTS.urn,
+    beam_runner_api_pb2.ParDoPayload)
+def create(factory, transform_id, transform_proto, parameter, consumers):
+  assert parameter.do_fn.spec.urn == python_urns.PICKLED_DOFN_INFO
+  serialized_fn = parameter.do_fn.spec.payload
+  return _create_pardo_operation(
+      factory, transform_id, transform_proto, consumers,
+      serialized_fn, parameter,
+      operation_cls=operations.SdfProcessElements)
+
+
+def _create_sdf_operation(
+    proxy_dofn,
+    factory, transform_id, transform_proto, parameter, consumers):
+
+  dofn_data = pickler.loads(parameter.do_fn.spec.payload)
+  dofn = dofn_data[0]
+  restriction_provider = common.DoFnSignature(dofn).get_restriction_provider()
+  serialized_fn = pickler.dumps(
+      (proxy_dofn(dofn, restriction_provider),) + dofn_data[1:])
+  return _create_pardo_operation(
+      factory, transform_id, transform_proto, consumers,
+      serialized_fn, parameter)
+
+
+@BeamTransformFactory.register_urn(
     common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload)
 def create(factory, transform_id, transform_proto, parameter, consumers):
   assert parameter.do_fn.spec.urn == python_urns.PICKLED_DOFN_INFO
@@ -777,7 +873,7 @@ def create(factory, transform_id, transform_proto, 
parameter, consumers):
 
 def _create_pardo_operation(
     factory, transform_id, transform_proto, consumers,
-    serialized_fn, pardo_proto=None):
+    serialized_fn, pardo_proto=None, operation_cls=operations.DoOperation):
 
   if pardo_proto and pardo_proto.side_inputs:
     input_tags_to_coders = factory.get_input_coders(transform_proto)
@@ -824,7 +920,8 @@ def _create_pardo_operation(
         factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
     serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))
 
-  if pardo_proto and (pardo_proto.timer_specs or pardo_proto.state_specs):
+  if pardo_proto and (pardo_proto.timer_specs or pardo_proto.state_specs
+                      or pardo_proto.splittable):
     main_input_coder = None
     timer_inputs = {}
     for tag, pcoll_id in transform_proto.inputs.items():
@@ -835,15 +932,19 @@ def _create_pardo_operation(
       else:
         # Must be the main input
         assert main_input_coder is None
+        main_input_tag = tag
         main_input_coder = factory.get_windowed_coder(pcoll_id)
     assert main_input_coder is not None
 
-    user_state_context = FnApiUserStateContext(
-        factory.state_handler,
-        transform_id,
-        main_input_coder.key_coder(),
-        main_input_coder.window_coder,
-        timer_specs=pardo_proto.timer_specs)
+    if pardo_proto.timer_specs or pardo_proto.state_specs:
+      user_state_context = FnApiUserStateContext(
+          factory.state_handler,
+          transform_id,
+          main_input_coder.key_coder(),
+          main_input_coder.window_coder,
+          timer_specs=pardo_proto.timer_specs)
+    else:
+      user_state_context = None
   else:
     user_state_context = None
     timer_inputs = None
@@ -856,8 +957,8 @@ def _create_pardo_operation(
       side_inputs=None,  # Fn API uses proto definitions and the Fn State API
       output_coders=[output_coders[tag] for tag in output_tags])
 
-  return factory.augment_oldstyle_op(
-      operations.DoOperation(
+  result = factory.augment_oldstyle_op(
+      operation_cls(
           transform_proto.unique_name,
           spec,
           factory.counter_factory,
@@ -868,6 +969,11 @@ def _create_pardo_operation(
       transform_proto.unique_name,
       consumers,
       output_tags)
+  if pardo_proto and pardo_proto.splittable:
+    result.input_info = (
+        transform_id, main_input_tag, main_input_coder,
+        transform_proto.outputs.keys())
+  return result
 
 
 def _create_simple_pardo_operation(
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd 
b/sdks/python/apache_beam/runners/worker/operations.pxd
index cb3477a..10c3c41 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -45,6 +45,7 @@ cdef class Operation(object):
   cdef object consumers
   cdef readonly counter_factory
   cdef public metrics_container
+  cdef public execution_context
   # Public for access by Fn harness operations.
   # TODO(robertwb): Cythonize FnHarness.
   cdef public list receivers
@@ -89,6 +90,11 @@ cdef class DoOperation(Operation):
   cdef object user_state_context
   cdef public dict timer_inputs
   cdef dict timer_specs
+  cdef public object input_info
+
+
+cdef class SdfProcessElements(DoOperation):
+  pass
 
 
 cdef class CombineOperation(Operation):
diff --git a/sdks/python/apache_beam/runners/worker/operations.py 
b/sdks/python/apache_beam/runners/worker/operations.py
index 9ab4fcf..a6e0c31 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -129,6 +129,7 @@ class Operation(object):
 
     self.spec = spec
     self.counter_factory = counter_factory
+    self.execution_context = None
     self.consumers = collections.defaultdict(list)
 
     # These are overwritten in the legacy harness.
@@ -496,7 +497,10 @@ class DoOperation(Operation):
 
   def process(self, o):
     with self.scoped_process_state:
-      self.dofn_receiver.receive(o)
+      delayed_application = self.dofn_receiver.receive(o)
+      if delayed_application:
+        self.execution_context.delayed_applications.append(
+            (self, delayed_application))
 
   def process_timer(self, tag, windowed_timer):
     key, timer_data = windowed_timer.value
@@ -540,6 +544,16 @@ class DoOperation(Operation):
     return infos
 
 
+class SdfProcessElements(DoOperation):
+
+  def process(self, o):
+    with self.scoped_process_state:
+      delayed_application = self.dofn_runner.process_with_restriction(o)
+      if delayed_application:
+        self.execution_context.delayed_applications.append(
+            (self, delayed_application))
+
+
 class DoFnRunnerReceiver(Receiver):
 
   def __init__(self, dofn_runner):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 4eb22a8..6067181 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -272,10 +272,11 @@ class SdkWorker(object):
         instruction_id,
         request.process_bundle_descriptor_reference) as bundle_processor:
       with self.maybe_profile(instruction_id):
-        bundle_processor.process_bundle(instruction_id)
+        delayed_applications = bundle_processor.process_bundle(instruction_id)
       return beam_fn_api_pb2.InstructionResponse(
           instruction_id=instruction_id,
           process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
+              residual_roots=delayed_applications,
               metrics=bundle_processor.metrics(),
               monitoring_infos=bundle_processor.monitoring_infos()))
 
diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 09091ed..27b67cd 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -1029,6 +1029,14 @@ class ParDo(PTransformWithSideInputs):
         "expected instance of ParDo, but got %s" % self.__class__
     picked_pardo_fn_data = pickler.dumps(self._pardo_fn_data())
     state_specs, timer_specs = userstate.get_dofn_specs(self.fn)
+    from apache_beam.runners.common import DoFnSignature
+    is_splittable = DoFnSignature(self.fn).is_splittable_dofn()
+    if is_splittable:
+      restriction_coder = (
+          
DoFnSignature(self.fn).get_restriction_provider().restriction_coder())
+      restriction_coder_id = context.coders.get_id(restriction_coder)
+    else:
+      restriction_coder_id = None
     return (
         common_urns.primitives.PAR_DO.urn,
         beam_runner_api_pb2.ParDoPayload(
@@ -1037,6 +1045,8 @@ class ParDo(PTransformWithSideInputs):
                 spec=beam_runner_api_pb2.FunctionSpec(
                     urn=python_urns.PICKLED_DOFN_INFO,
                     payload=picked_pardo_fn_data)),
+            splittable=is_splittable,
+            restriction_coder_id=restriction_coder_id,
             state_specs={spec.name: spec.to_runner_api(context)
                          for spec in state_specs},
             timer_specs={spec.name: spec.to_runner_api(context)

Reply via email to