This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e933ddc  Adding helper functions, and documentation to fn_api_runner 
classes. (#8793)
e933ddc is described below

commit e933ddca967f897b4256e9689d7699fc9031df62
Author: Pablo <pabl...@users.noreply.github.com>
AuthorDate: Thu Jun 13 12:34:25 2019 -0700

    Adding helper functions, and documentation to fn_api_runner classes. (#8793)
---
 .../runners/portability/fn_api_runner.py           | 464 +++++++++++++--------
 .../runners/portability/fn_api_runner_test.py      |  31 ++
 .../apache_beam/runners/worker/bundle_processor.py |   4 +-
 3 files changed, 326 insertions(+), 173 deletions(-)

diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 1fbb0ed..66aedfc 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -173,6 +173,10 @@ class _GroupingBuffer(object):
         windowed_key_values = lambda key, values: [
             globally_window((key, values))]
       else:
+        # TODO(pabloem, BEAM-7514): Trigger driver needs access to the clock
+        #   note that this only comes through if windowing is default - but 
what
+        #   about having multiple firings on the global window.
+        #   May need to revise.
         trigger_driver = trigger.create_trigger_driver(self._windowing, True)
         windowed_key_values = trigger_driver.process_entire_key
       coder_impl = self._post_grouped_coder.get_impl()
@@ -291,7 +295,10 @@ class FnApiRunner(runner.PipelineRunner):
     return self._latest_run_result
 
   def run_via_runner_api(self, pipeline_proto):
-    return self.run_stages(*self.create_stages(pipeline_proto))
+    stage_context, stages = self.create_stages(pipeline_proto)
+    # TODO(pabloem, BEAM-7514): Create a watermark manager (that has access to
+    #   the teststream (if any), and all the stages).
+    return self.run_stages(stage_context, stages)
 
   @contextlib.contextmanager
   def maybe_profile(self):
@@ -353,7 +360,7 @@ class FnApiRunner(runner.PipelineRunner):
         use_state_iterables=self._use_state_iterables)
 
   def run_stages(self, stage_context, stages):
-    """Run all of the stages.
+    """Run a list of topologically-sorted stages in batch mode.
 
     Args:
       stage_context (fn_api_runner_transforms.TransformContext)
@@ -382,6 +389,156 @@ class FnApiRunner(runner.PipelineRunner):
     return RunnerResult(
         runner.PipelineState.DONE, monitoring_infos_by_stage, metrics_by_stage)
 
+  def _store_side_inputs_in_state(self,
+                                  controller,
+                                  context,
+                                  pipeline_components,
+                                  data_side_input,
+                                  pcoll_buffers,
+                                  safe_coders):
+    for (transform_id, tag), (buffer_id, si) in data_side_input.items():
+      _, pcoll_id = split_buffer_id(buffer_id)
+      value_coder = context.coders[safe_coders[
+          pipeline_components.pcollections[pcoll_id].coder_id]]
+      elements_by_window = _WindowGroupingBuffer(si, value_coder)
+      for element_data in pcoll_buffers[buffer_id]:
+        elements_by_window.append(element_data)
+      for key, window, elements_data in elements_by_window.encoded_items():
+        state_key = beam_fn_api_pb2.StateKey(
+            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
+                ptransform_id=transform_id,
+                side_input_id=tag,
+                window=window,
+                key=key))
+        controller.state.blocking_append(state_key, elements_data)
+
+  def _run_bundle_multiple_times_for_testing(self,
+                                             controller,
+                                             process_bundle_descriptor,
+                                             data_input,
+                                             data_output,
+                                             get_input_coder_callable):
+    for k in range(self._bundle_repeat):
+      try:
+        controller.state.checkpoint()
+        BundleManager(
+            controller, lambda pcoll_id: [], get_input_coder_callable,
+            process_bundle_descriptor, self._progress_frequency, k
+        ).process_bundle(data_input, data_output)
+      finally:
+        controller.state.restore()
+
+  def _collect_written_timers_and_add_to_deferred_inputs(self,
+                                                         context,
+                                                         pipeline_components,
+                                                         stage,
+                                                         get_buffer_callable,
+                                                         deferred_inputs):
+
+    for transform_id, timer_writes in stage.timer_pcollections:
+
+      # Queue any set timers as new inputs.
+      windowed_timer_coder_impl = context.coders[
+          pipeline_components.pcollections[timer_writes].coder_id].get_impl()
+      written_timers = get_buffer_callable(
+          create_buffer_id(timer_writes, kind='timers'))
+      if written_timers:
+        # Keep only the "last" timer set per key and window.
+        timers_by_key_and_window = {}
+        for elements_data in written_timers:
+          input_stream = create_InputStream(elements_data)
+          while input_stream.size() > 0:
+            windowed_key_timer = windowed_timer_coder_impl.decode_from_stream(
+                input_stream, True)
+            key, _ = windowed_key_timer.value
+            # TODO: Explode and merge windows.
+            assert len(windowed_key_timer.windows) == 1
+            timers_by_key_and_window[
+                key, windowed_key_timer.windows[0]] = windowed_key_timer
+        out = create_OutputStream()
+        for windowed_key_timer in timers_by_key_and_window.values():
+          windowed_timer_coder_impl.encode_to_stream(
+              windowed_key_timer, out, True)
+        deferred_inputs[transform_id] = [out.get()]
+        written_timers[:] = []
+
+  def _add_residuals_and_channel_splits_to_deferred_inputs(
+      self, splits, get_input_coder_callable,
+      input_for_callable, last_sent, deferred_inputs):
+    prev_stops = {}
+    for split in splits:
+      for delayed_application in split.residual_roots:
+        deferred_inputs[
+            input_for_callable(
+                delayed_application.application.ptransform_id,
+                delayed_application.application.input_id)
+        ].append(delayed_application.application.element)
+      for channel_split in split.channel_splits:
+        coder_impl = get_input_coder_callable(channel_split.ptransform_id)
+        # TODO(SDF): This requires determanistic ordering of buffer iteration.
+        # TODO(SDF): The return split is in terms of indices.  Ideally,
+        # a runner could map these back to actual positions to effectively
+        # describe the two "halves" of the now-split range.  Even if we have
+        # to buffer each element we send (or at the very least a bit of
+        # metadata, like position, about each of them) this should be doable
+        # if they're already in memory and we are bounding the buffer size
+        # (e.g. to 10mb plus whatever is eagerly read from the SDK).  In the
+        # case of non-split-points, we can either immediately replay the
+        # "non-split-position" elements or record them as we do the other
+        # delayed applications.
+
+        # Decode and recode to split the encoded buffer by element index.
+        all_elements = list(coder_impl.decode_all(b''.join(last_sent[
+            channel_split.ptransform_id])))
+        residual_elements = all_elements[
+            channel_split.first_residual_element : prev_stops.get(
+                channel_split.ptransform_id, len(all_elements)) + 1]
+        if residual_elements:
+          deferred_inputs[channel_split.ptransform_id].append(
+              coder_impl.encode_all(residual_elements))
+        prev_stops[
+            channel_split.ptransform_id] = channel_split.last_primary_element
+
+  @staticmethod
+  def _extract_stage_data_endpoints(
+      stage, pipeline_components, data_api_service_descriptor, pcoll_buffers):
+    # Returns maps of transform names to PCollection identifiers.
+    # Also mutates IO stages to point to the data ApiServiceDescriptor.
+    data_input = {}
+    data_side_input = {}
+    data_output = {}
+    for transform in stage.transforms:
+      if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
+                                bundle_processor.DATA_OUTPUT_URN):
+        pcoll_id = transform.spec.payload
+        if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
+          target = transform.unique_name, only_element(transform.outputs)
+          if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER:
+            data_input[target] = [ENCODED_IMPULSE_VALUE]
+          else:
+            data_input[target] = pcoll_buffers[pcoll_id]
+          coder_id = pipeline_components.pcollections[
+              only_element(transform.outputs.values())].coder_id
+        elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
+          target = transform.unique_name, only_element(transform.inputs)
+          data_output[target] = pcoll_id
+          coder_id = pipeline_components.pcollections[
+              only_element(transform.inputs.values())].coder_id
+        else:
+          raise NotImplementedError
+        data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
+        if data_api_service_descriptor:
+          data_spec.api_service_descriptor.url = (
+              data_api_service_descriptor.url)
+        transform.spec.payload = data_spec.SerializeToString()
+      elif transform.spec.urn in fn_api_runner_transforms.PAR_DO_URNS:
+        payload = proto_utils.parse_Bytes(
+            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
+        for tag, si in payload.side_inputs.items():
+          data_side_input[transform.unique_name, tag] = (
+              create_buffer_id(transform.inputs[tag]), si.access_pattern)
+    return data_input, data_side_input, data_output
+
   def _run_stage(self,
                  worker_handler_factory,
                  pipeline_components,
@@ -435,22 +592,14 @@ class FnApiRunner(runner.PipelineRunner):
       process_bundle_descriptor.state_api_service_descriptor.url = (
           controller.state_api_service_descriptor().url)
 
-    # Store the required side inputs into state.
-    for (transform_id, tag), (buffer_id, si) in data_side_input.items():
-      _, pcoll_id = split_buffer_id(buffer_id)
-      value_coder = context.coders[safe_coders[
-          pipeline_components.pcollections[pcoll_id].coder_id]]
-      elements_by_window = _WindowGroupingBuffer(si, value_coder)
-      for element_data in pcoll_buffers[buffer_id]:
-        elements_by_window.append(element_data)
-      for key, window, elements_data in elements_by_window.encoded_items():
-        state_key = beam_fn_api_pb2.StateKey(
-            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
-                ptransform_id=transform_id,
-                side_input_id=tag,
-                window=window,
-                key=key))
-        controller.state.blocking_append(state_key, elements_data)
+    # Store the required side inputs into state so it is accessible for the
+    # worker when it runs this bundle.
+    self._store_side_inputs_in_state(controller,
+                                     context,
+                                     pipeline_components,
+                                     data_side_input,
+                                     pcoll_buffers,
+                                     safe_coders)
 
     def get_buffer(buffer_id):
       """Returns the buffer for a given (operation_type, PCollection ID).
@@ -493,20 +642,17 @@ class FnApiRunner(runner.PipelineRunner):
           ).coder_id
       ]].get_impl()
 
-    for k in range(self._bundle_repeat):
-      try:
-        controller.state.checkpoint()
-        BundleManager(
-            controller, lambda pcoll_id: [], get_input_coder_impl,
-            process_bundle_descriptor, self._progress_frequency, k
-        ).process_bundle(data_input, data_output)
-      finally:
-        controller.state.restore()
+    self._run_bundle_multiple_times_for_testing(controller,
+                                                process_bundle_descriptor,
+                                                data_input,
+                                                data_output,
+                                                get_input_coder_impl)
 
-    result, splits = BundleManager(
+    bundle_manager = BundleManager(
         controller, get_buffer, get_input_coder_impl, 
process_bundle_descriptor,
-        self._progress_frequency).process_bundle(
-            data_input, data_output)
+        self._progress_frequency)
+
+    result, splits = bundle_manager.process_bundle(data_input, data_output)
 
     def input_for(ptransform_id, input_id):
       input_pcoll = process_bundle_descriptor.transforms[
@@ -523,32 +669,9 @@ class FnApiRunner(runner.PipelineRunner):
 
     while True:
       deferred_inputs = collections.defaultdict(list)
-      for transform_id, timer_writes in stage.timer_pcollections:
-
-        # Queue any set timers as new inputs.
-        windowed_timer_coder_impl = context.coders[
-            pipeline_components.pcollections[timer_writes].coder_id].get_impl()
-        written_timers = get_buffer(
-            create_buffer_id(timer_writes, kind='timers'))
-        if written_timers:
-          # Keep only the "last" timer set per key and window.
-          timers_by_key_and_window = {}
-          for elements_data in written_timers:
-            input_stream = create_InputStream(elements_data)
-            while input_stream.size() > 0:
-              windowed_key_timer = 
windowed_timer_coder_impl.decode_from_stream(
-                  input_stream, True)
-              key, _ = windowed_key_timer.value
-              # TODO: Explode and merge windows.
-              assert len(windowed_key_timer.windows) == 1
-              timers_by_key_and_window[
-                  key, windowed_key_timer.windows[0]] = windowed_key_timer
-          out = create_OutputStream()
-          for windowed_key_timer in timers_by_key_and_window.values():
-            windowed_timer_coder_impl.encode_to_stream(
-                windowed_key_timer, out, True)
-          deferred_inputs[transform_id] = [out.get()]
-          written_timers[:] = []
+
+      self._collect_written_timers_and_add_to_deferred_inputs(
+          context, pipeline_components, stage, get_buffer, deferred_inputs)
 
       # Queue any process-initiated delayed bundle applications.
       for delayed_application in last_result.process_bundle.residual_roots:
@@ -559,39 +682,8 @@ class FnApiRunner(runner.PipelineRunner):
         ].append(delayed_application.application.element)
 
       # Queue any runner-initiated delayed bundle applications.
-      prev_stops = {}
-      for split in splits:
-        for delayed_application in split.residual_roots:
-          deferred_inputs[
-              input_for(
-                  delayed_application.application.ptransform_id,
-                  delayed_application.application.input_id)
-          ].append(delayed_application.application.element)
-        for channel_split in split.channel_splits:
-          coder_impl = get_input_coder_impl(channel_split.ptransform_id)
-          # TODO(SDF): This requires determanistic ordering of buffer 
iteration.
-          # TODO(SDF): The return split is in terms of indices.  Ideally,
-          # a runner could map these back to actual positions to effectively
-          # describe the two "halves" of the now-split range.  Even if we have
-          # to buffer each element we send (or at the very least a bit of
-          # metadata, like position, about each of them) this should be doable
-          # if they're already in memory and we are bounding the buffer size
-          # (e.g. to 10mb plus whatever is eagerly read from the SDK).  In the
-          # case of non-split-points, we can either immediately replay the
-          # "non-split-position" elements or record them as we do the other
-          # delayed applications.
-
-          # Decode and recode to split the encoded buffer by element index.
-          all_elements = list(coder_impl.decode_all(b''.join(last_sent[
-              channel_split.ptransform_id])))
-          residual_elements = all_elements[
-              channel_split.first_residual_element : prev_stops.get(
-                  channel_split.ptransform_id, len(all_elements)) + 1]
-          if residual_elements:
-            deferred_inputs[channel_split.ptransform_id].append(
-                coder_impl.encode_all(residual_elements))
-          prev_stops[
-              channel_split.ptransform_id] = channel_split.last_primary_element
+      self._add_residuals_and_channel_splits_to_deferred_inputs(
+          splits, get_input_coder_impl, input_for, last_sent, deferred_inputs)
 
       if deferred_inputs:
         # The worker will be waiting on these inputs as well.
@@ -1221,6 +1313,17 @@ def split_manager(stage_name, split_manager):
 
 
 class BundleManager(object):
+  """Manages the execution of a bundle from the runner-side.
+
+  This class receives a bundle descriptor, and performs the following tasks:
+  - Registration of the bundle with the worker.
+  - Splitting of the bundle
+  - Setting up any other bundle requirements (e.g. side inputs).
+  - Submitting the bundle to worker for execution
+  - Passing bundle input data to the worker
+  - Collecting bundle output data from the worker
+  - Finalizing the bundle.
+  """
 
   _uid_counter = 0
 
@@ -1244,20 +1347,17 @@ class BundleManager(object):
     self._registered = skip_registration
     self._progress_frequency = progress_frequency
 
-  def process_bundle(self, inputs, expected_outputs):
-    """Submit a bundle for processing by the SDK.
-
-    Args:
-      inputs (dict): A map from PCollection id (expressed as a tuple
-        (PTransform, output), to a buffer with elements in the PCollection.
-      expected_outputs (dict): A map from PCollection id (expressed as a tuple
-        (PTransform, output), to TODO: What does this map to?
-    """
-    # Unique id for the instruction processing this bundle.
-    BundleManager._uid_counter += 1
-    process_bundle_id = 'bundle_%s' % BundleManager._uid_counter
-
-    # Register the bundle descriptor, if needed.
+  def _send_input_to_worker(self,
+                            process_bundle_id,
+                            read_transform_id,
+                            byte_streams):
+    data_out = self._controller.data_plane_handler.output_stream(
+        process_bundle_id, read_transform_id)
+    for byte_stream in byte_streams:
+      data_out.write(byte_stream)
+    data_out.close()
+
+  def _register_bundle_descriptor(self):
     if self._registered:
       registration_future = None
     else:
@@ -1268,6 +1368,10 @@ class BundleManager(object):
           process_bundle_registration)
       self._registered = True
 
+    return registration_future
+
+  def _select_split_manager(self):
+    """TODO(pabloem) WHAT DOES THIS DO"""
     unique_names = set(
         t.unique_name for t in self._bundle_descriptor.transforms.values())
     for stage_name, candidate in reversed(_split_managers):
@@ -1278,88 +1382,103 @@ class BundleManager(object):
     else:
       split_manager = None
 
-    if not split_manager:
-      # Write all the input data to the channel immediately.
-      for transform_id, elements in inputs.items():
-        data_out = self._controller.data_plane_handler.output_stream(
-            process_bundle_id, transform_id)
-        for element_data in elements:
-          data_out.write(element_data)
-        data_out.close()
+    return split_manager
 
+  def _generate_splits_for_testing(self,
+                                   split_manager,
+                                   inputs,
+                                   process_bundle_id):
     split_results = []
+    read_transform_id, buffer_data = only_element(inputs.items())
 
-    # Actually start the bundle.
+    byte_stream = b''.join(buffer_data)
+    num_elements = len(list(
+        self._get_input_coder_impl(read_transform_id).decode_all(byte_stream)))
+
+    # Start the split manager in case it wants to set any breakpoints.
+    split_manager_generator = split_manager(num_elements)
+    try:
+      split_fraction = next(split_manager_generator)
+      done = False
+    except StopIteration:
+      done = True
+
+    # Send all the data.
+    self._send_input_to_worker(
+        process_bundle_id, read_transform_id, [byte_stream])
+
+    # Execute the requested splits.
+    while not done:
+      if split_fraction is None:
+        split_result = None
+      else:
+        split_request = beam_fn_api_pb2.InstructionRequest(
+            process_bundle_split=
+            beam_fn_api_pb2.ProcessBundleSplitRequest(
+                instruction_reference=process_bundle_id,
+                desired_splits={
+                    read_transform_id:
+                    beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit(
+                        fraction_of_remainder=split_fraction,
+                        estimated_input_elements=num_elements)
+                }))
+        split_response = self._controller.control_handler.push(
+            split_request).get()
+        for t in (0.05, 0.1, 0.2):
+          waiting = ('Instruction not running', 'not yet scheduled')
+          if any(msg in split_response.error for msg in waiting):
+            time.sleep(t)
+            split_response = self._controller.control_handler.push(
+                split_request).get()
+        if 'Unknown process bundle' in split_response.error:
+          # It may have finished too fast.
+          split_result = None
+        elif split_response.error:
+          raise RuntimeError(split_response.error)
+        else:
+          split_result = split_response.process_bundle_split
+          split_results.append(split_result)
+      try:
+        split_fraction = split_manager_generator.send(split_result)
+      except StopIteration:
+        break
+    return split_results
+
+  def process_bundle(self, inputs, expected_outputs):
+    # Unique id for the instruction processing this bundle.
+    BundleManager._uid_counter += 1
+    process_bundle_id = 'bundle_%s' % BundleManager._uid_counter
+
+    # Register the bundle descriptor, if needed - noop if already registered.
+    registration_future = self._register_bundle_descriptor()
+
+    split_manager = self._select_split_manager()
+    if not split_manager:
+      # If there is no split_manager, write all input data to the channel.
+      for transform_id, elements in inputs.items():
+        self._send_input_to_worker(
+            process_bundle_id, transform_id, elements)
+
+    # Check that the bundle was successfully registered.
     if registration_future and registration_future.get().error:
       raise RuntimeError(registration_future.get().error)
 
-    process_bundle_request = beam_fn_api_pb2.InstructionRequest(
+    # Actually start the bundle.
+    process_bundle_req = beam_fn_api_pb2.InstructionRequest(
         instruction_id=process_bundle_id,
         process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
             process_bundle_descriptor_reference=self._bundle_descriptor.id))
-    result_future = self._controller.control_handler.push(
-        process_bundle_request)
+    result_future = self._controller.control_handler.push(process_bundle_req)
 
+    split_results = []
     with ProgressRequester(
         self._controller, process_bundle_id, self._progress_frequency):
+
       if split_manager:
-        read_transform_id, buffer_data = only_element(inputs.items())
-        num_elements = len(list(
-            self._get_input_coder_impl(read_transform_id).decode_all(
-                b''.join(buffer_data))))
-
-        # Start the split manager in case it wants to set any breakpoints.
-        split_manager_generator = split_manager(num_elements)
-        try:
-          split_fraction = next(split_manager_generator)
-          done = False
-        except StopIteration:
-          done = True
-
-        # Send all the data.
-        data_out = self._controller.data_plane_handler.output_stream(
-            process_bundle_id, read_transform_id)
-        data_out.write(b''.join(buffer_data))
-        data_out.close()
-
-        # Execute the requested splits.
-        while not done:
-          if split_fraction is None:
-            split_result = None
-          else:
-            split_request = beam_fn_api_pb2.InstructionRequest(
-                process_bundle_split=
-                beam_fn_api_pb2.ProcessBundleSplitRequest(
-                    instruction_reference=process_bundle_id,
-                    desired_splits={
-                        read_transform_id:
-                        beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit(
-                            fraction_of_remainder=split_fraction,
-                            estimated_input_elements=num_elements)
-                    }))
-            split_response = self._controller.control_handler.push(
-                split_request).get()
-            for t in (0.05, 0.1, 0.2):
-              waiting = ('Instruction not running', 'not yet scheduled')
-              if any(msg in split_response.error for msg in waiting):
-                time.sleep(t)
-                split_response = self._controller.control_handler.push(
-                    split_request).get()
-            if 'Unknown process bundle' in split_response.error:
-              # It may have finished too fast.
-              split_result = None
-            elif split_response.error:
-              raise RuntimeError(split_response.error)
-            else:
-              split_result = split_response.process_bundle_split
-              split_results.append(split_result)
-          try:
-            split_fraction = split_manager_generator.send(split_result)
-          except StopIteration:
-            break
+        split_results = self._generate_splits_for_testing(split_manager, 
inputs,
+                                                          process_bundle_id)
 
       # Gather all output data.
-      logging.debug('Gather all output data from %s.', expected_outputs)
       for output in self._controller.data_plane_handler.input_elements(
           process_bundle_id,
           expected_outputs.keys(),
@@ -1388,6 +1507,11 @@ class BundleManager(object):
 
 
 class ProgressRequester(threading.Thread):
+  """ Thread that asks SDK Worker for progress reports with a certain 
frequency.
+
+  A callback can be passed to call with progress updates.
+  """
+
   def __init__(self, controller, instruction_id, frequency, callback=None):
     super(ProgressRequester, self).__init__()
     self._controller = controller
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index f303722..e03864c 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -50,6 +50,7 @@ from apache_beam.runners.portability import fn_api_runner
 from apache_beam.runners.worker import data_plane
 from apache_beam.runners.worker import statesampler
 from apache_beam.testing.synthetic_pipeline import SyntheticSDFAsSource
+from apache_beam.testing.test_stream import TestStream
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 from apache_beam.transforms import userstate
@@ -306,6 +307,36 @@ class FnApiRunnerTest(unittest.TestCase):
       assert_that(p | beam.Create(inputs) | beam.ParDo(AddIndex()),
                   equal_to(expected))
 
+  @unittest.skip('TestStream not yet supported')
+  def test_teststream_pardo_timers(self):
+    timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
+
+    class TimerDoFn(beam.DoFn):
+      def process(self, element, timer=beam.DoFn.TimerParam(timer_spec)):
+        unused_key, ts = element
+        timer.set(ts)
+        timer.set(2 * ts)
+
+      @userstate.on_timer(timer_spec)
+      def process_timer(self):
+        yield 'fired'
+
+    ts = (TestStream()
+          .add_elements([('k1', 10)])  # Set timer for 20
+          .advance_watermark_to(100)
+          .add_elements([('k2', 100)])  # Set timer for 200
+          .advance_watermark_to(1000))
+
+    with self.create_pipeline() as p:
+      _ = (
+          p
+          | ts
+          | beam.ParDo(TimerDoFn())
+          | beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts)))
+
+      #expected = [('fired', ts) for ts in (20, 200)]
+      #assert_that(actual, equal_to(expected))
+
   def test_pardo_timers(self):
     timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
 
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index ffc7379..8d788fc 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -484,9 +484,7 @@ def only_element(iterable):
 
 
 class BundleProcessor(object):
-  """A class for processing bundles of elements.
-
-  """
+  """ A class for processing bundles of elements. """
 
   def __init__(
       self, process_bundle_descriptor, state_handler, data_channel_factory):

Reply via email to