This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 1e050c9  [BEAM-3837] More complete test for try-split.
     new 75a0350  Merge pull request #7801 [BEAM-3837] More complete test for 
try-split.
1e050c9 is described below

commit 1e050c9d9c545424794b98e1b15562448812a07f
Author: Robert Bradshaw <rober...@google.com>
AuthorDate: Thu Feb 7 13:38:25 2019 +0100

    [BEAM-3837] More complete test for try-split.
    
    Also re-worked try-split computation for clarity.
---
 sdks/python/apache_beam/coders/coder_impl.py       |  11 +
 .../runners/portability/fn_api_runner.py           | 227 ++++++++++--------
 .../runners/portability/fn_api_runner_test.py      | 260 +++++++++++++++++++++
 .../portability/fn_api_runner_transforms.py        |   2 +-
 .../apache_beam/runners/worker/bundle_processor.py |  47 ++--
 .../apache_beam/runners/worker/data_plane.py       |   3 +
 .../apache_beam/runners/worker/sdk_worker.py       |  13 +-
 7 files changed, 445 insertions(+), 118 deletions(-)

diff --git a/sdks/python/apache_beam/coders/coder_impl.py 
b/sdks/python/apache_beam/coders/coder_impl.py
index a811e3a..643e270 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -95,6 +95,17 @@ class CoderImpl(object):
     """Decodes an object to an unnested string."""
     raise NotImplementedError
 
+  def encode_all(self, values):
+    out = create_OutputStream()
+    for value in values:
+      self.encode_to_stream(value, out, True)
+    return out.get()
+
+  def decode_all(self, encoded):
+    input_stream = create_InputStream(encoded)
+    while input_stream.size() > 0:
+      yield self.decode_from_stream(input_stream, True)
+
   def encode_nested(self, value):
     out = create_OutputStream()
     self.encode_to_stream(value, out, True)
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 5f8fa3b..6890af6 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -486,17 +486,25 @@ class FnApiRunner(runner.PipelineRunner):
         raise NotImplementedError(buffer_id)
       return pcoll_buffers[buffer_id]
 
+    def get_input_coder_impl(transform_id):
+      return context.coders[safe_coders[
+          beam_fn_api_pb2.RemoteGrpcPort.FromString(
+              process_bundle_descriptor.transforms[transform_id].spec.payload
+          ).coder_id
+      ]].get_impl()
+
     for k in range(self._bundle_repeat):
       try:
         controller.state.checkpoint()
         BundleManager(
-            controller, lambda pcoll_id: [], process_bundle_descriptor,
-            self._progress_frequency, k).process_bundle(data_input, 
data_output)
+            controller, lambda pcoll_id: [], get_input_coder_impl,
+            process_bundle_descriptor, self._progress_frequency, k
+        ).process_bundle(data_input, data_output)
       finally:
         controller.state.restore()
 
     result, splits = BundleManager(
-        controller, get_buffer, process_bundle_descriptor,
+        controller, get_buffer, get_input_coder_impl, 
process_bundle_descriptor,
         self._progress_frequency).process_bundle(
             data_input, data_output)
 
@@ -511,6 +519,8 @@ class FnApiRunner(runner.PipelineRunner):
           'No IO transform feeds %s' % ptransform_id)
 
     last_result = result
+    last_sent = data_input
+
     while True:
       deferred_inputs = collections.defaultdict(list)
       for transform_id, timer_writes in stage.timer_pcollections:
@@ -549,7 +559,7 @@ class FnApiRunner(runner.PipelineRunner):
         ].append(delayed_application.application.element)
 
       # Queue any runner-initiated delayed bundle applications.
-      prev_stops = collections.defaultdict(lambda: float('inf'))
+      prev_stops = {}
       for split in splits:
         for delayed_application in split.residual_roots:
           deferred_inputs[
@@ -558,11 +568,7 @@ class FnApiRunner(runner.PipelineRunner):
                   delayed_application.application.input_id)
           ].append(delayed_application.application.element)
         for channel_split in split.channel_splits:
-          transform = process_bundle_descriptor.transforms[
-              channel_split.ptransform_id]
-          coder_id = beam_fn_api_pb2.RemoteGrpcPort.FromString(
-              transform.spec.payload).coder_id
-          coder_impl = context.coders[safe_coders[coder_id]].get_impl()
+          coder_impl = get_input_coder_impl(channel_split.ptransform_id)
           # TODO(SDF): This requires determanistic ordering of buffer 
iteration.
           # TODO(SDF): The return split is in terms of indices.  Ideally,
           # a runner could map these back to actual positions to effectively
@@ -576,22 +582,15 @@ class FnApiRunner(runner.PipelineRunner):
           # delayed applications.
 
           # Decode and recode to split the encoded buffer by element index.
-          buffer = data_input[
-              channel_split.ptransform_id, channel_split.input_id]
-          input_stream = create_InputStream(''.join(buffer))
-          output_stream = create_OutputStream()
-          index = 0
-          prev_stop = prev_stops[channel_split.ptransform_id]
-          while input_stream.size() > 0:
-            if index > prev_stop:
-              break
-            element = coder_impl.decode_from_stream(input_stream, True)
-            if index >= channel_split.first_residual_element:
-              coder_impl.encode_to_stream(element, output_stream, True)
-            index += 1
-          deferred_inputs[
-              channel_split.ptransform_id, channel_split.input_id].append(
-                  output_stream.get())
+          all_elements = list(coder_impl.decode_all(b''.join(last_sent[
+              channel_split.ptransform_id, channel_split.input_id])))
+          residual_elements = all_elements[
+              channel_split.first_residual_element : prev_stops.get(
+                  channel_split.ptransform_id, len(all_elements)) + 1]
+          if residual_elements:
+            deferred_inputs[
+                channel_split.ptransform_id, channel_split.input_id].append(
+                    coder_impl.encode_all(residual_elements))
           prev_stops[
               channel_split.ptransform_id] = channel_split.last_primary_element
 
@@ -604,9 +603,11 @@ class FnApiRunner(runner.PipelineRunner):
         last_result, splits = BundleManager(
             controller,
             get_buffer,
+            get_input_coder_impl,
             process_bundle_descriptor,
             self._progress_frequency,
             True).process_bundle(deferred_inputs, data_output)
+        last_sent = deferred_inputs
       else:
         break
 
@@ -1008,6 +1009,7 @@ class EmbeddedGrpcWorkerHandler(GrpcWorkerHandler):
         self.control_address, worker_count=self._num_threads)
     self.worker_thread = threading.Thread(
         name='run_worker', target=self.worker.run)
+    self.worker_thread.daemon = True
     self.worker_thread.start()
 
   def stop_worker(self):
@@ -1108,7 +1110,10 @@ class WorkerHandlerManager(object):
 
   def close_all(self):
     for controller in set(self._cached_handlers.values()):
-      controller.close()
+      try:
+        controller.close()
+      except Exception:
+        logging.info("Error closing controller %s" % controller, exc_info=True)
     self._cached_handlers = {}
 
 
@@ -1119,20 +1124,41 @@ class ExtendedProvisionInfo(object):
     self.artifact_staging_dir = artifact_staging_dir
 
 
+_split_managers = []
+
+
+@contextlib.contextmanager
+def split_manager(stage_name, split_manager):
+  """Registers a split manager to control the flow of elements to a given 
stage.
+
+  Used for testing.
+
+  A split manager should be a coroutine yielding desired split fractions,
+  receiving the corresponding split results. Currently, only one input is
+  supported.
+  """
+  try:
+    _split_managers.append((stage_name, split_manager))
+    yield
+  finally:
+    _split_managers.pop()
+
+
 class BundleManager(object):
 
   _uid_counter = 0
 
   def __init__(
-      self, controller, get_buffer, bundle_descriptor, progress_frequency=None,
-      skip_registration=False):
+      self, controller, get_buffer, get_input_coder_impl, bundle_descriptor,
+      progress_frequency=None, skip_registration=False):
     self._controller = controller
     self._get_buffer = get_buffer
+    self._get_input_coder_impl = get_input_coder_impl
     self._bundle_descriptor = bundle_descriptor
     self._registered = skip_registration
     self._progress_frequency = progress_frequency
 
-  def process_bundle(self, inputs, expected_outputs, test_splits=False):
+  def process_bundle(self, inputs, expected_outputs):
     # Unique id for the instruction processing this bundle.
     BundleManager._uid_counter += 1
     process_bundle_id = 'bundle_%s' % BundleManager._uid_counter
@@ -1148,25 +1174,27 @@ class BundleManager(object):
           process_bundle_registration)
       self._registered = True
 
-    # Write all the input data to the channel.
-    for (transform_id, name), elements in inputs.items():
-      data_out = self._controller.data_plane_handler.output_stream(
-          process_bundle_id, beam_fn_api_pb2.Target(
-              primitive_transform_reference=transform_id, name=name))
-      for element_data in elements:
-        data_out.write(element_data)
-      data_out.close()
-
-    # TODO(robertwb): Control this via a pipeline option.
-    if test_splits:
-      # Inject some splits.
-      random_splitter = BundleSplitter(
-          self._controller,
-          process_bundle_id,
-          self._bundle_descriptor.transforms.keys())
-      random_splitter.start()
+    unique_names = set(
+        t.unique_name for t in self._bundle_descriptor.transforms.values())
+    for stage_name, candidate in reversed(_split_managers):
+      if (stage_name in unique_names
+          or (stage_name + '/Process') in unique_names):
+        split_manager = candidate
+        break
     else:
-      random_splitter = None
+      split_manager = None
+
+    if not split_manager:
+      # Write all the input data to the channel immediately.
+      for (transform_id, name), elements in inputs.items():
+        data_out = self._controller.data_plane_handler.output_stream(
+            process_bundle_id, beam_fn_api_pb2.Target(
+                primitive_transform_reference=transform_id, name=name))
+        for element_data in elements:
+          data_out.write(element_data)
+        data_out.close()
+
+    split_results = []
 
     # Actually start the bundle.
     if registration_future and registration_future.get().error:
@@ -1179,6 +1207,64 @@ class BundleManager(object):
 
     with ProgressRequester(
         self._controller, process_bundle_id, self._progress_frequency):
+      if split_manager:
+        (read_transform_id, name), buffer_data = only_element(inputs.items())
+        num_elements = len(list(
+            self._get_input_coder_impl(read_transform_id).decode_all(
+                b''.join(buffer_data))))
+
+        # Start the split manager in case it wants to set any breakpoints.
+        split_manager_generator = split_manager(num_elements)
+        try:
+          split_fraction = next(split_manager_generator)
+          done = False
+        except StopIteration:
+          done = True
+
+        # Send all the data.
+        data_out = self._controller.data_plane_handler.output_stream(
+            process_bundle_id,
+            beam_fn_api_pb2.Target(
+                primitive_transform_reference=read_transform_id, name=name))
+        data_out.write(b''.join(buffer_data))
+        data_out.close()
+
+        # Execute the requested splits.
+        while not done:
+          if split_fraction is None:
+            split_result = None
+          else:
+            split_request = beam_fn_api_pb2.InstructionRequest(
+                process_bundle_split=
+                beam_fn_api_pb2.ProcessBundleSplitRequest(
+                    instruction_reference=process_bundle_id,
+                    desired_splits={
+                        read_transform_id:
+                        beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit(
+                            fraction_of_remainder=split_fraction,
+                            estimated_input_elements=num_elements)
+                    }))
+            split_response = self._controller.control_handler.push(
+                split_request).get()
+            for t in (0.05, 0.1, 0.2):
+              waiting = ('Instruction not running', 'not yet scheduled')
+              if any(msg in split_response.error for msg in waiting):
+                time.sleep(t)
+                split_response = self._controller.control_handler.push(
+                    split_request).get()
+            if 'Unknown process bundle' in split_response.error:
+              # It may have finished too fast.
+              split_result = None
+            elif split_response.error:
+              raise RuntimeError(split_response.error)
+            else:
+              split_result = split_response.process_bundle_split
+              split_results.append(split_result)
+          try:
+            split_fraction = split_manager_generator.send(split_result)
+          except StopIteration:
+            break
+
       # Gather all output data.
       expected_targets = [
           beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
@@ -1198,12 +1284,6 @@ class BundleManager(object):
       logging.debug('Wait for the bundle to finish.')
       result = result_future.get()
 
-    if random_splitter:
-      random_splitter.stop()
-      split_results = random_splitter.split_results()
-    else:
-      split_results = []
-
     if result.error:
       raise RuntimeError(result.error)
 
@@ -1248,47 +1328,6 @@ class ProgressRequester(threading.Thread):
     self._done = True
 
 
-class BundleSplitter(threading.Thread):
-  def __init__(self, controller, instruction_id, split_transforms,
-               frequency=.03, split_fractions=(.5, .25, 0)):
-    super(BundleSplitter, self).__init__()
-    self._controller = controller
-    self._instruction_id = instruction_id
-    self._split_transforms = split_transforms
-    self._split_fractions = split_fractions
-    self._frequency = frequency
-    self._results = []
-    self._done = False
-
-  def run(self):
-    for fraction in self._split_fractions:
-      if self._done:
-        return
-      split_result = self._controller.control_handler.push(
-          beam_fn_api_pb2.InstructionRequest(
-              process_bundle_split=beam_fn_api_pb2.ProcessBundleSplitRequest(
-                  instruction_reference=self._instruction_id,
-                  desired_splits={
-                      transform_id:
-                      beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit(
-                          fraction_of_remainder=fraction)
-                      for transform_id in self._split_transforms}))).get()
-      if split_result.error:
-        logging.info('Unable to split at %s: %s' % (
-            fraction, split_result.error))
-      elif split_result.process_bundle_split:
-        self._results.append(split_result.process_bundle_split)
-      time.sleep(self._frequency)
-
-  def split_results(self):
-    self.stop()
-    self.join()
-    return self._results
-
-  def stop(self):
-    self._done = True
-
-
 class ControlFuture(object):
   def __init__(self, instruction_id, response=None):
     self.instruction_id = instruction_id
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 6c4cad9..aadf4a8 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -17,13 +17,17 @@
 from __future__ import absolute_import
 from __future__ import print_function
 
+import collections
 import logging
 import os
+import random
 import sys
 import tempfile
+import threading
 import time
 import traceback
 import unittest
+import uuid
 from builtins import range
 
 from tenacity import retry
@@ -752,6 +756,262 @@ class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest):
         runner=fn_api_runner.FnApiRunner(bundle_repeat=3))
 
 
+class FnApiRunnerSplitTest(unittest.TestCase):
+
+  def create_pipeline(self):
+    # Must be GRPC so we can send data and split requests concurrent
+    # to the bundle process request.
+    return beam.Pipeline(
+        runner=fn_api_runner.FnApiRunner(
+            default_environment=beam_runner_api_pb2.Environment(
+                urn=python_urns.EMBEDDED_PYTHON_GRPC)))
+
+  def test_checkpoint(self):
+    # This split manager will get re-invoked on each smaller split,
+    # so N times for N elements.
+    element_counter = ElementCounter()
+
+    def split_manager(num_elements):
+      # Send at least one element so it can make forward progress.
+      element_counter.reset()
+      breakpoint = element_counter.set_breakpoint(1)
+      # Cede control back to the runner so data can be sent.
+      yield
+      breakpoint.wait()
+      # Split as close to current as possible.
+      split_result = yield 0.0
+      # Verify we split at exactly the first element.
+      self.verify_channel_split(split_result, 0, 1)
+      # Continue processing.
+      breakpoint.clear()
+
+    self.run_split_pipeline(split_manager, list('abc'), element_counter)
+
+  def test_split_half(self):
+    total_num_elements = 25
+    seen_bundle_sizes = []
+    element_counter = ElementCounter()
+
+    def split_manager(num_elements):
+      seen_bundle_sizes.append(num_elements)
+      if num_elements == total_num_elements:
+        element_counter.reset()
+        breakpoint = element_counter.set_breakpoint(5)
+        yield
+        breakpoint.wait()
+        # Split the remainder (20, then 10, elements) in half.
+        split1 = yield 0.5
+        self.verify_channel_split(split1, 14, 15)  # remainder is 15 to end
+        split2 = yield 0.5
+        self.verify_channel_split(split2, 9, 10)   # remainder is 10 to end
+        breakpoint.clear()
+
+    self.run_split_pipeline(
+        split_manager, range(total_num_elements), element_counter)
+    self.assertEqual([25, 15], seen_bundle_sizes)
+
+  def run_split_pipeline(self, split_manager, elements, element_counter=None):
+    with fn_api_runner.split_manager('Identity', split_manager):
+      with self.create_pipeline() as p:
+        res = (p
+               | beam.Create(elements)
+               | beam.Reshuffle()
+               | 'Identity' >> beam.Map(lambda x: x)
+               | beam.Map(lambda x: element_counter.increment() or x))
+        assert_that(res, equal_to(elements))
+
+  def test_nosplit_sdf(self):
+    def split_manager(num_elements):
+      yield
+
+    elements = [1, 2, 3]
+    expected_groups = [[(e, k) for k in range(e)] for e in elements]
+    self.run_sdf_split_pipeline(
+        split_manager, elements, ElementCounter(), expected_groups)
+
+  def test_checkpoint_sdf(self):
+    element_counter = ElementCounter()
+
+    def split_manager(num_elements):
+      element_counter.reset()
+      breakpoint = element_counter.set_breakpoint(1)
+      yield
+      breakpoint.wait()
+      yield 0
+      breakpoint.clear()
+
+    # Everything should be perfectly split.
+    elements = [2, 3]
+    expected_groups = [[(2, 0)], [(2, 1)], [(3, 0)], [(3, 1)], [(3, 2)]]
+    self.run_sdf_split_pipeline(
+        split_manager, elements, element_counter, expected_groups)
+
+  def test_split_half_sdf(self):
+
+    element_counter = ElementCounter()
+    is_first_bundle = [True]  # emulate nonlocal for Python 2
+
+    def split_manager(num_elements):
+      if is_first_bundle:
+        del is_first_bundle[:]
+        breakpoint = element_counter.set_breakpoint(1)
+        yield
+        breakpoint.wait()
+        split1 = yield 0.5
+        split2 = yield 0.5
+        split3 = yield 0.5
+        self.verify_channel_split(split1, 0, 1)
+        self.verify_channel_split(split2, -1, 1)
+        self.verify_channel_split(split3, -1, 1)
+        breakpoint.clear()
+
+    elements = [4, 4]
+    expected_groups = [
+        [(4, 0)],
+        [(4, 1)],
+        [(4, 2), (4, 3)],
+        [(4, 0), (4, 1), (4, 2), (4, 3)]]
+
+    self.run_sdf_split_pipeline(
+        split_manager, elements, element_counter, expected_groups)
+
+  def test_split_crazy_sdf(self, seed=None):
+    if seed is None:
+      seed = random.randrange(1 << 20)
+    r = random.Random(seed)
+    element_counter = ElementCounter()
+
+    def split_manager(num_elements):
+      element_counter.reset()
+      wait_for = r.randrange(num_elements)
+      breakpoint = element_counter.set_breakpoint(wait_for)
+      yield
+      breakpoint.wait()
+      yield r.random()
+      yield r.random()
+      breakpoint.clear()
+
+    try:
+      elements = [r.randrange(5, 10) for _ in range(5)]
+      self.run_sdf_split_pipeline(split_manager, elements, element_counter)
+    except Exception:
+      logging.error('test_split_crazy_sdf.seed = %s', seed)
+      raise
+
+  def run_sdf_split_pipeline(
+      self, split_manager, elements, element_counter, expected_groups=None):
+    # Define an SDF that for each input x produces [(x, k) for k in range(x)].
+
+    class EnumerateProvider(beam.transforms.core.RestrictionProvider):
+      def initial_restriction(self, element):
+        return (0, element)
+
+      def create_tracker(self, restriction):
+        return restriction_trackers.OffsetRestrictionTracker(
+            *restriction)
+
+      def split(self, element, restriction):
+        # Don't do any initial splitting to simplify test.
+        return [restriction]
+
+    class EnumerateSdf(beam.DoFn):
+      def process(self, element, restriction_tracker=EnumerateProvider()):
+        to_emit = []
+        for k in range(*restriction_tracker.current_restriction()):
+          if restriction_tracker.try_claim(k):
+            to_emit.append((element, k))
+            element_counter.increment()
+          else:
+            break
+        # Emitting in batches for tighter testing.
+        yield to_emit
+
+    expected = [(e, k) for e in elements for k in range(e)]
+
+    with fn_api_runner.split_manager('SDF', split_manager):
+      with self.create_pipeline() as p:
+        grouped = (
+            p
+            | beam.Create(elements)
+            | 'SDF' >> beam.ParDo(EnumerateSdf()))
+        flat = grouped | beam.FlatMap(lambda x: x)
+        assert_that(flat, equal_to(expected))
+        if expected_groups:
+          assert_that(grouped, equal_to(expected_groups), label='CheckGrouped')
+
+  def verify_channel_split(self, split_result, last_primary, first_residual):
+    self.assertEqual(1, len(split_result.channel_splits), split_result)
+    channel_split, = split_result.channel_splits
+    self.assertEqual(last_primary, channel_split.last_primary_element)
+    self.assertEqual(first_residual, channel_split.first_residual_element)
+    # There should be a primary and residual application for each element
+    # not covered above.
+    self.assertEqual(
+        first_residual - last_primary - 1,
+        len(split_result.primary_roots),
+        split_result.primary_roots)
+    self.assertEqual(
+        first_residual - last_primary - 1,
+        len(split_result.residual_roots),
+        split_result.residual_roots)
+
+
+class ElementCounter(object):
+  """Used to wait until a certain number of elements are seen."""
+
+  def __init__(self):
+    self._cv = threading.Condition()
+    self.reset()
+
+  def reset(self):
+    with self._cv:
+      self._breakpoints = collections.defaultdict(list)
+      self._count = 0
+
+  def increment(self):
+    with self._cv:
+      self._count += 1
+      self._cv.notify_all()
+      breakpoints = list(self._breakpoints[self._count])
+    for breakpoint in breakpoints:
+      breakpoint.wait()
+
+  def set_breakpoint(self, value):
+    with self._cv:
+      event = threading.Event()
+      self._breakpoints[value].append(event)
+
+    class Breakpoint(object):
+      @staticmethod
+      def wait(timeout=10):
+        with self._cv:
+          start = time.time()
+          while self._count < value:
+            elapsed = time.time() - start
+            if elapsed > timeout:
+              raise RuntimeError('Timed out waiting for %s' % value)
+            self._cv.wait(timeout - elapsed)
+
+      @staticmethod
+      def clear():
+        event.set()
+
+    return Breakpoint()
+
+  def __reduce__(self):
+    # Ensure we get the same element back through a pickling round-trip.
+    name = uuid.uuid4().hex
+    _pickled_element_counters[name] = self
+    return _unpickle_element_counter, (name,)
+
+
+_pickled_element_counters = {}
+
+
+def _unpickle_element_counter(name):
+  return _pickled_element_counters[name]
+
+
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()
diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
index 2482987..8667a8e 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
@@ -287,9 +287,9 @@ class TransformContext(object):
     self.components = components
     self.known_runner_urns = known_runner_urns
     self.use_state_iterables = use_state_iterables
-    self.safe_coders = {}
     self.bytes_coder_id = self.add_or_get_coder_id(
         coders.BytesCoder().to_runner_api(None), 'bytes_coder')
+    self.safe_coders = {self.bytes_coder_id: self.bytes_coder_id}
 
   def add_or_get_coder_id(self, coder_proto, coder_prefix='coder'):
     for coder_id, coder in self.components.coders.items():
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index db2d790..0e1782d 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -138,28 +138,39 @@ class DataInputOperation(RunnerIOOperation):
           input_stream, True)
       self.output(decoded_value)
 
-  def try_split(self, fraction_of_remainder, total_buffer_size=None):
+  def try_split(self, fraction_of_remainder, total_buffer_size):
     with self.splitting_lock:
-      # If total_buffer_size is not provided, pick something.
-      if not total_buffer_size:
-        total_buffer_size = self.index + 2
+      if total_buffer_size < self.index + 1:
+        total_buffer_size = self.index + 1
       elif self.stop and total_buffer_size > self.stop:
         total_buffer_size = self.stop
-      # Compute, as a fraction, how much further to go.
-      # TODO(SDF): Take into account progress of current element.
-      stop_offset = (total_buffer_size - self.index) * fraction_of_remainder
-      # If it's less than a whole element, try splitting the current element.
-      if int(stop_offset) == 0:
-        split = self.receivers[0].try_split(stop_offset)
-        if split:
-          element_primary, element_residual = split
-          self.stop = self.index + 1
-          return self.stop - 2, element_primary, element_residual, self.stop
-
+      if self.index == -1:
+        # We are "finished" with the (non-existent) previous element.
+        current_element_progress = 1
+      else:
+        # TODO(SDF): Get actual progress of current element.
+        current_element_progress = 0.5
+      # Now figure out where to split.
+      # The units here (except for keep_of_element_remainder) are all in
+      # terms of number of (possibly fractional) elements.
+      remainder = total_buffer_size - self.index - current_element_progress
+      keep = remainder * fraction_of_remainder
+      if current_element_progress < 1:
+        keep_of_element_remainder = keep / (1 - current_element_progress)
+        # If it's less than what's left of the current element,
+        # try splitting at the current element.
+        if keep_of_element_remainder < 1:
+          split = self.receivers[0].try_split(keep_of_element_remainder)
+          if split:
+            element_primary, element_residual = split
+            self.stop = self.index + 1
+            return self.index - 1, element_primary, element_residual, self.stop
       # Otherwise, split at the closest element boundary.
-      desired_stop = max(int(stop_offset), 1) + self.index
-      if desired_stop < self.stop:
-        self.stop = desired_stop
+      # pylint: disable=round-builtin
+      stop_index = (
+          self.index + max(1, int(round(current_element_progress + keep))))
+      if stop_index < self.stop:
+        self.stop = stop_index
         return self.stop - 1, None, None, self.stop
 
 
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py 
b/sdks/python/apache_beam/runners/worker/data_plane.py
index fc8f9cc..276ca19 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -211,6 +211,8 @@ class _GrpcDataChannel(DataChannel):
         try:
           data = received.get(timeout=1)
         except queue.Empty:
+          if self._closed:
+            raise RuntimeError('Channel closed prematurely.')
           if abort_callback():
             return
           if self._exc_info:
@@ -275,6 +277,7 @@ class _GrpcDataChannel(DataChannel):
         self._exc_info = sys.exc_info()
         raise
     finally:
+      self._closed = True
       self._reads_finished.set()
 
   def _start_reader(self, elements_iterator):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 1528d23..cb981ec 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -312,11 +312,14 @@ class SdkWorker(object):
 
   def process_bundle_split(self, request, instruction_id):
     processor = 
self.active_bundle_processors.get(request.instruction_reference)
-    if not processor:
-      raise ValueError('Instruction not running: %s' % instruction_id)
-    return beam_fn_api_pb2.InstructionResponse(
-        instruction_id=instruction_id,
-        process_bundle_split=processor.try_split(request))
+    if processor:
+      return beam_fn_api_pb2.InstructionResponse(
+          instruction_id=instruction_id,
+          process_bundle_split=processor.try_split(request))
+    else:
+      return beam_fn_api_pb2.InstructionResponse(
+          instruction_id=instruction_id,
+          error='Instruction not running: %s' % instruction_id)
 
   def process_bundle_progress(self, request, instruction_id):
     # It is an error to get progress for a not-in-flight bundle.

Reply via email to