This is an automated email from the ASF dual-hosted git repository. robertwb pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 1e050c9 [BEAM-3837] More complete test for try-split. new 75a0350 Merge pull request #7801 [BEAM-3837] More complete test for try-split. 1e050c9 is described below commit 1e050c9d9c545424794b98e1b15562448812a07f Author: Robert Bradshaw <rober...@google.com> AuthorDate: Thu Feb 7 13:38:25 2019 +0100 [BEAM-3837] More complete test for try-split. Also re-worked try-split computation for clarity. --- sdks/python/apache_beam/coders/coder_impl.py | 11 + .../runners/portability/fn_api_runner.py | 227 ++++++++++-------- .../runners/portability/fn_api_runner_test.py | 260 +++++++++++++++++++++ .../portability/fn_api_runner_transforms.py | 2 +- .../apache_beam/runners/worker/bundle_processor.py | 47 ++-- .../apache_beam/runners/worker/data_plane.py | 3 + .../apache_beam/runners/worker/sdk_worker.py | 13 +- 7 files changed, 445 insertions(+), 118 deletions(-) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index a811e3a..643e270 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -95,6 +95,17 @@ class CoderImpl(object): """Decodes an object to an unnested string.""" raise NotImplementedError + def encode_all(self, values): + out = create_OutputStream() + for value in values: + self.encode_to_stream(value, out, True) + return out.get() + + def decode_all(self, encoded): + input_stream = create_InputStream(encoded) + while input_stream.size() > 0: + yield self.decode_from_stream(input_stream, True) + def encode_nested(self, value): out = create_OutputStream() self.encode_to_stream(value, out, True) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index 5f8fa3b..6890af6 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -486,17 +486,25 @@ class FnApiRunner(runner.PipelineRunner): raise NotImplementedError(buffer_id) return pcoll_buffers[buffer_id] + def get_input_coder_impl(transform_id): + return context.coders[safe_coders[ + beam_fn_api_pb2.RemoteGrpcPort.FromString( + process_bundle_descriptor.transforms[transform_id].spec.payload + ).coder_id + ]].get_impl() + for k in range(self._bundle_repeat): try: controller.state.checkpoint() BundleManager( - controller, lambda pcoll_id: [], process_bundle_descriptor, - self._progress_frequency, k).process_bundle(data_input, data_output) + controller, lambda pcoll_id: [], get_input_coder_impl, + process_bundle_descriptor, self._progress_frequency, k + ).process_bundle(data_input, data_output) finally: controller.state.restore() result, splits = BundleManager( - controller, get_buffer, process_bundle_descriptor, + controller, get_buffer, get_input_coder_impl, process_bundle_descriptor, self._progress_frequency).process_bundle( data_input, data_output) @@ -511,6 +519,8 @@ class FnApiRunner(runner.PipelineRunner): 'No IO transform feeds %s' % ptransform_id) last_result = result + last_sent = data_input + while True: deferred_inputs = collections.defaultdict(list) for transform_id, timer_writes in stage.timer_pcollections: @@ -549,7 +559,7 @@ class FnApiRunner(runner.PipelineRunner): ].append(delayed_application.application.element) # Queue any runner-initiated delayed bundle applications. - prev_stops = collections.defaultdict(lambda: float('inf')) + prev_stops = {} for split in splits: for delayed_application in split.residual_roots: deferred_inputs[ @@ -558,11 +568,7 @@ class FnApiRunner(runner.PipelineRunner): delayed_application.application.input_id) ].append(delayed_application.application.element) for channel_split in split.channel_splits: - transform = process_bundle_descriptor.transforms[ - channel_split.ptransform_id] - coder_id = beam_fn_api_pb2.RemoteGrpcPort.FromString( - transform.spec.payload).coder_id - coder_impl = context.coders[safe_coders[coder_id]].get_impl() + coder_impl = get_input_coder_impl(channel_split.ptransform_id) # TODO(SDF): This requires determanistic ordering of buffer iteration. # TODO(SDF): The return split is in terms of indices. Ideally, # a runner could map these back to actual positions to effectively @@ -576,22 +582,15 @@ class FnApiRunner(runner.PipelineRunner): # delayed applications. # Decode and recode to split the encoded buffer by element index. - buffer = data_input[ - channel_split.ptransform_id, channel_split.input_id] - input_stream = create_InputStream(''.join(buffer)) - output_stream = create_OutputStream() - index = 0 - prev_stop = prev_stops[channel_split.ptransform_id] - while input_stream.size() > 0: - if index > prev_stop: - break - element = coder_impl.decode_from_stream(input_stream, True) - if index >= channel_split.first_residual_element: - coder_impl.encode_to_stream(element, output_stream, True) - index += 1 - deferred_inputs[ - channel_split.ptransform_id, channel_split.input_id].append( - output_stream.get()) + all_elements = list(coder_impl.decode_all(b''.join(last_sent[ + channel_split.ptransform_id, channel_split.input_id]))) + residual_elements = all_elements[ + channel_split.first_residual_element : prev_stops.get( + channel_split.ptransform_id, len(all_elements)) + 1] + if residual_elements: + deferred_inputs[ + channel_split.ptransform_id, channel_split.input_id].append( + coder_impl.encode_all(residual_elements)) prev_stops[ channel_split.ptransform_id] = channel_split.last_primary_element @@ -604,9 +603,11 @@ class FnApiRunner(runner.PipelineRunner): last_result, splits = BundleManager( controller, get_buffer, + get_input_coder_impl, process_bundle_descriptor, self._progress_frequency, True).process_bundle(deferred_inputs, data_output) + last_sent = deferred_inputs else: break @@ -1008,6 +1009,7 @@ class EmbeddedGrpcWorkerHandler(GrpcWorkerHandler): self.control_address, worker_count=self._num_threads) self.worker_thread = threading.Thread( name='run_worker', target=self.worker.run) + self.worker_thread.daemon = True self.worker_thread.start() def stop_worker(self): @@ -1108,7 +1110,10 @@ class WorkerHandlerManager(object): def close_all(self): for controller in set(self._cached_handlers.values()): - controller.close() + try: + controller.close() + except Exception: + logging.info("Error closing controller %s" % controller, exc_info=True) self._cached_handlers = {} @@ -1119,20 +1124,41 @@ class ExtendedProvisionInfo(object): self.artifact_staging_dir = artifact_staging_dir +_split_managers = [] + + +@contextlib.contextmanager +def split_manager(stage_name, split_manager): + """Registers a split manager to control the flow of elements to a given stage. + + Used for testing. + + A split manager should be a coroutine yielding desired split fractions, + receiving the corresponding split results. Currently, only one input is + supported. + """ + try: + _split_managers.append((stage_name, split_manager)) + yield + finally: + _split_managers.pop() + + class BundleManager(object): _uid_counter = 0 def __init__( - self, controller, get_buffer, bundle_descriptor, progress_frequency=None, - skip_registration=False): + self, controller, get_buffer, get_input_coder_impl, bundle_descriptor, + progress_frequency=None, skip_registration=False): self._controller = controller self._get_buffer = get_buffer + self._get_input_coder_impl = get_input_coder_impl self._bundle_descriptor = bundle_descriptor self._registered = skip_registration self._progress_frequency = progress_frequency - def process_bundle(self, inputs, expected_outputs, test_splits=False): + def process_bundle(self, inputs, expected_outputs): # Unique id for the instruction processing this bundle. BundleManager._uid_counter += 1 process_bundle_id = 'bundle_%s' % BundleManager._uid_counter @@ -1148,25 +1174,27 @@ class BundleManager(object): process_bundle_registration) self._registered = True - # Write all the input data to the channel. - for (transform_id, name), elements in inputs.items(): - data_out = self._controller.data_plane_handler.output_stream( - process_bundle_id, beam_fn_api_pb2.Target( - primitive_transform_reference=transform_id, name=name)) - for element_data in elements: - data_out.write(element_data) - data_out.close() - - # TODO(robertwb): Control this via a pipeline option. - if test_splits: - # Inject some splits. - random_splitter = BundleSplitter( - self._controller, - process_bundle_id, - self._bundle_descriptor.transforms.keys()) - random_splitter.start() + unique_names = set( + t.unique_name for t in self._bundle_descriptor.transforms.values()) + for stage_name, candidate in reversed(_split_managers): + if (stage_name in unique_names + or (stage_name + '/Process') in unique_names): + split_manager = candidate + break else: - random_splitter = None + split_manager = None + + if not split_manager: + # Write all the input data to the channel immediately. + for (transform_id, name), elements in inputs.items(): + data_out = self._controller.data_plane_handler.output_stream( + process_bundle_id, beam_fn_api_pb2.Target( + primitive_transform_reference=transform_id, name=name)) + for element_data in elements: + data_out.write(element_data) + data_out.close() + + split_results = [] # Actually start the bundle. if registration_future and registration_future.get().error: @@ -1179,6 +1207,64 @@ class BundleManager(object): with ProgressRequester( self._controller, process_bundle_id, self._progress_frequency): + if split_manager: + (read_transform_id, name), buffer_data = only_element(inputs.items()) + num_elements = len(list( + self._get_input_coder_impl(read_transform_id).decode_all( + b''.join(buffer_data)))) + + # Start the split manager in case it wants to set any breakpoints. + split_manager_generator = split_manager(num_elements) + try: + split_fraction = next(split_manager_generator) + done = False + except StopIteration: + done = True + + # Send all the data. + data_out = self._controller.data_plane_handler.output_stream( + process_bundle_id, + beam_fn_api_pb2.Target( + primitive_transform_reference=read_transform_id, name=name)) + data_out.write(b''.join(buffer_data)) + data_out.close() + + # Execute the requested splits. + while not done: + if split_fraction is None: + split_result = None + else: + split_request = beam_fn_api_pb2.InstructionRequest( + process_bundle_split= + beam_fn_api_pb2.ProcessBundleSplitRequest( + instruction_reference=process_bundle_id, + desired_splits={ + read_transform_id: + beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit( + fraction_of_remainder=split_fraction, + estimated_input_elements=num_elements) + })) + split_response = self._controller.control_handler.push( + split_request).get() + for t in (0.05, 0.1, 0.2): + waiting = ('Instruction not running', 'not yet scheduled') + if any(msg in split_response.error for msg in waiting): + time.sleep(t) + split_response = self._controller.control_handler.push( + split_request).get() + if 'Unknown process bundle' in split_response.error: + # It may have finished too fast. + split_result = None + elif split_response.error: + raise RuntimeError(split_response.error) + else: + split_result = split_response.process_bundle_split + split_results.append(split_result) + try: + split_fraction = split_manager_generator.send(split_result) + except StopIteration: + break + # Gather all output data. expected_targets = [ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id, @@ -1198,12 +1284,6 @@ class BundleManager(object): logging.debug('Wait for the bundle to finish.') result = result_future.get() - if random_splitter: - random_splitter.stop() - split_results = random_splitter.split_results() - else: - split_results = [] - if result.error: raise RuntimeError(result.error) @@ -1248,47 +1328,6 @@ class ProgressRequester(threading.Thread): self._done = True -class BundleSplitter(threading.Thread): - def __init__(self, controller, instruction_id, split_transforms, - frequency=.03, split_fractions=(.5, .25, 0)): - super(BundleSplitter, self).__init__() - self._controller = controller - self._instruction_id = instruction_id - self._split_transforms = split_transforms - self._split_fractions = split_fractions - self._frequency = frequency - self._results = [] - self._done = False - - def run(self): - for fraction in self._split_fractions: - if self._done: - return - split_result = self._controller.control_handler.push( - beam_fn_api_pb2.InstructionRequest( - process_bundle_split=beam_fn_api_pb2.ProcessBundleSplitRequest( - instruction_reference=self._instruction_id, - desired_splits={ - transform_id: - beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit( - fraction_of_remainder=fraction) - for transform_id in self._split_transforms}))).get() - if split_result.error: - logging.info('Unable to split at %s: %s' % ( - fraction, split_result.error)) - elif split_result.process_bundle_split: - self._results.append(split_result.process_bundle_split) - time.sleep(self._frequency) - - def split_results(self): - self.stop() - self.join() - return self._results - - def stop(self): - self._done = True - - class ControlFuture(object): def __init__(self, instruction_id, response=None): self.instruction_id = instruction_id diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index 6c4cad9..aadf4a8 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -17,13 +17,17 @@ from __future__ import absolute_import from __future__ import print_function +import collections import logging import os +import random import sys import tempfile +import threading import time import traceback import unittest +import uuid from builtins import range from tenacity import retry @@ -752,6 +756,262 @@ class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest): runner=fn_api_runner.FnApiRunner(bundle_repeat=3)) +class FnApiRunnerSplitTest(unittest.TestCase): + + def create_pipeline(self): + # Must be GRPC so we can send data and split requests concurrent + # to the bundle process request. + return beam.Pipeline( + runner=fn_api_runner.FnApiRunner( + default_environment=beam_runner_api_pb2.Environment( + urn=python_urns.EMBEDDED_PYTHON_GRPC))) + + def test_checkpoint(self): + # This split manager will get re-invoked on each smaller split, + # so N times for N elements. + element_counter = ElementCounter() + + def split_manager(num_elements): + # Send at least one element so it can make forward progress. + element_counter.reset() + breakpoint = element_counter.set_breakpoint(1) + # Cede control back to the runner so data can be sent. + yield + breakpoint.wait() + # Split as close to current as possible. + split_result = yield 0.0 + # Verify we split at exactly the first element. + self.verify_channel_split(split_result, 0, 1) + # Continue processing. + breakpoint.clear() + + self.run_split_pipeline(split_manager, list('abc'), element_counter) + + def test_split_half(self): + total_num_elements = 25 + seen_bundle_sizes = [] + element_counter = ElementCounter() + + def split_manager(num_elements): + seen_bundle_sizes.append(num_elements) + if num_elements == total_num_elements: + element_counter.reset() + breakpoint = element_counter.set_breakpoint(5) + yield + breakpoint.wait() + # Split the remainder (20, then 10, elements) in half. + split1 = yield 0.5 + self.verify_channel_split(split1, 14, 15) # remainder is 15 to end + split2 = yield 0.5 + self.verify_channel_split(split2, 9, 10) # remainder is 10 to end + breakpoint.clear() + + self.run_split_pipeline( + split_manager, range(total_num_elements), element_counter) + self.assertEqual([25, 15], seen_bundle_sizes) + + def run_split_pipeline(self, split_manager, elements, element_counter=None): + with fn_api_runner.split_manager('Identity', split_manager): + with self.create_pipeline() as p: + res = (p + | beam.Create(elements) + | beam.Reshuffle() + | 'Identity' >> beam.Map(lambda x: x) + | beam.Map(lambda x: element_counter.increment() or x)) + assert_that(res, equal_to(elements)) + + def test_nosplit_sdf(self): + def split_manager(num_elements): + yield + + elements = [1, 2, 3] + expected_groups = [[(e, k) for k in range(e)] for e in elements] + self.run_sdf_split_pipeline( + split_manager, elements, ElementCounter(), expected_groups) + + def test_checkpoint_sdf(self): + element_counter = ElementCounter() + + def split_manager(num_elements): + element_counter.reset() + breakpoint = element_counter.set_breakpoint(1) + yield + breakpoint.wait() + yield 0 + breakpoint.clear() + + # Everything should be perfectly split. + elements = [2, 3] + expected_groups = [[(2, 0)], [(2, 1)], [(3, 0)], [(3, 1)], [(3, 2)]] + self.run_sdf_split_pipeline( + split_manager, elements, element_counter, expected_groups) + + def test_split_half_sdf(self): + + element_counter = ElementCounter() + is_first_bundle = [True] # emulate nonlocal for Python 2 + + def split_manager(num_elements): + if is_first_bundle: + del is_first_bundle[:] + breakpoint = element_counter.set_breakpoint(1) + yield + breakpoint.wait() + split1 = yield 0.5 + split2 = yield 0.5 + split3 = yield 0.5 + self.verify_channel_split(split1, 0, 1) + self.verify_channel_split(split2, -1, 1) + self.verify_channel_split(split3, -1, 1) + breakpoint.clear() + + elements = [4, 4] + expected_groups = [ + [(4, 0)], + [(4, 1)], + [(4, 2), (4, 3)], + [(4, 0), (4, 1), (4, 2), (4, 3)]] + + self.run_sdf_split_pipeline( + split_manager, elements, element_counter, expected_groups) + + def test_split_crazy_sdf(self, seed=None): + if seed is None: + seed = random.randrange(1 << 20) + r = random.Random(seed) + element_counter = ElementCounter() + + def split_manager(num_elements): + element_counter.reset() + wait_for = r.randrange(num_elements) + breakpoint = element_counter.set_breakpoint(wait_for) + yield + breakpoint.wait() + yield r.random() + yield r.random() + breakpoint.clear() + + try: + elements = [r.randrange(5, 10) for _ in range(5)] + self.run_sdf_split_pipeline(split_manager, elements, element_counter) + except Exception: + logging.error('test_split_crazy_sdf.seed = %s', seed) + raise + + def run_sdf_split_pipeline( + self, split_manager, elements, element_counter, expected_groups=None): + # Define an SDF that for each input x produces [(x, k) for k in range(x)]. + + class EnumerateProvider(beam.transforms.core.RestrictionProvider): + def initial_restriction(self, element): + return (0, element) + + def create_tracker(self, restriction): + return restriction_trackers.OffsetRestrictionTracker( + *restriction) + + def split(self, element, restriction): + # Don't do any initial splitting to simplify test. + return [restriction] + + class EnumerateSdf(beam.DoFn): + def process(self, element, restriction_tracker=EnumerateProvider()): + to_emit = [] + for k in range(*restriction_tracker.current_restriction()): + if restriction_tracker.try_claim(k): + to_emit.append((element, k)) + element_counter.increment() + else: + break + # Emitting in batches for tighter testing. + yield to_emit + + expected = [(e, k) for e in elements for k in range(e)] + + with fn_api_runner.split_manager('SDF', split_manager): + with self.create_pipeline() as p: + grouped = ( + p + | beam.Create(elements) + | 'SDF' >> beam.ParDo(EnumerateSdf())) + flat = grouped | beam.FlatMap(lambda x: x) + assert_that(flat, equal_to(expected)) + if expected_groups: + assert_that(grouped, equal_to(expected_groups), label='CheckGrouped') + + def verify_channel_split(self, split_result, last_primary, first_residual): + self.assertEqual(1, len(split_result.channel_splits), split_result) + channel_split, = split_result.channel_splits + self.assertEqual(last_primary, channel_split.last_primary_element) + self.assertEqual(first_residual, channel_split.first_residual_element) + # There should be a primary and residual application for each element + # not covered above. + self.assertEqual( + first_residual - last_primary - 1, + len(split_result.primary_roots), + split_result.primary_roots) + self.assertEqual( + first_residual - last_primary - 1, + len(split_result.residual_roots), + split_result.residual_roots) + + +class ElementCounter(object): + """Used to wait until a certain number of elements are seen.""" + + def __init__(self): + self._cv = threading.Condition() + self.reset() + + def reset(self): + with self._cv: + self._breakpoints = collections.defaultdict(list) + self._count = 0 + + def increment(self): + with self._cv: + self._count += 1 + self._cv.notify_all() + breakpoints = list(self._breakpoints[self._count]) + for breakpoint in breakpoints: + breakpoint.wait() + + def set_breakpoint(self, value): + with self._cv: + event = threading.Event() + self._breakpoints[value].append(event) + + class Breakpoint(object): + @staticmethod + def wait(timeout=10): + with self._cv: + start = time.time() + while self._count < value: + elapsed = time.time() - start + if elapsed > timeout: + raise RuntimeError('Timed out waiting for %s' % value) + self._cv.wait(timeout - elapsed) + + @staticmethod + def clear(): + event.set() + + return Breakpoint() + + def __reduce__(self): + # Ensure we get the same element back through a pickling round-trip. + name = uuid.uuid4().hex + _pickled_element_counters[name] = self + return _unpickle_element_counter, (name,) + + +_pickled_element_counters = {} + + +def _unpickle_element_counter(name): + return _pickled_element_counters[name] + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py index 2482987..8667a8e 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py @@ -287,9 +287,9 @@ class TransformContext(object): self.components = components self.known_runner_urns = known_runner_urns self.use_state_iterables = use_state_iterables - self.safe_coders = {} self.bytes_coder_id = self.add_or_get_coder_id( coders.BytesCoder().to_runner_api(None), 'bytes_coder') + self.safe_coders = {self.bytes_coder_id: self.bytes_coder_id} def add_or_get_coder_id(self, coder_proto, coder_prefix='coder'): for coder_id, coder in self.components.coders.items(): diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index db2d790..0e1782d 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -138,28 +138,39 @@ class DataInputOperation(RunnerIOOperation): input_stream, True) self.output(decoded_value) - def try_split(self, fraction_of_remainder, total_buffer_size=None): + def try_split(self, fraction_of_remainder, total_buffer_size): with self.splitting_lock: - # If total_buffer_size is not provided, pick something. - if not total_buffer_size: - total_buffer_size = self.index + 2 + if total_buffer_size < self.index + 1: + total_buffer_size = self.index + 1 elif self.stop and total_buffer_size > self.stop: total_buffer_size = self.stop - # Compute, as a fraction, how much further to go. - # TODO(SDF): Take into account progress of current element. - stop_offset = (total_buffer_size - self.index) * fraction_of_remainder - # If it's less than a whole element, try splitting the current element. - if int(stop_offset) == 0: - split = self.receivers[0].try_split(stop_offset) - if split: - element_primary, element_residual = split - self.stop = self.index + 1 - return self.stop - 2, element_primary, element_residual, self.stop - + if self.index == -1: + # We are "finished" with the (non-existent) previous element. + current_element_progress = 1 + else: + # TODO(SDF): Get actual progress of current element. + current_element_progress = 0.5 + # Now figure out where to split. + # The units here (except for keep_of_element_remainder) are all in + # terms of number of (possibly fractional) elements. + remainder = total_buffer_size - self.index - current_element_progress + keep = remainder * fraction_of_remainder + if current_element_progress < 1: + keep_of_element_remainder = keep / (1 - current_element_progress) + # If it's less than what's left of the current element, + # try splitting at the current element. + if keep_of_element_remainder < 1: + split = self.receivers[0].try_split(keep_of_element_remainder) + if split: + element_primary, element_residual = split + self.stop = self.index + 1 + return self.index - 1, element_primary, element_residual, self.stop # Otherwise, split at the closest element boundary. - desired_stop = max(int(stop_offset), 1) + self.index - if desired_stop < self.stop: - self.stop = desired_stop + # pylint: disable=round-builtin + stop_index = ( + self.index + max(1, int(round(current_element_progress + keep)))) + if stop_index < self.stop: + self.stop = stop_index return self.stop - 1, None, None, self.stop diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index fc8f9cc..276ca19 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -211,6 +211,8 @@ class _GrpcDataChannel(DataChannel): try: data = received.get(timeout=1) except queue.Empty: + if self._closed: + raise RuntimeError('Channel closed prematurely.') if abort_callback(): return if self._exc_info: @@ -275,6 +277,7 @@ class _GrpcDataChannel(DataChannel): self._exc_info = sys.exc_info() raise finally: + self._closed = True self._reads_finished.set() def _start_reader(self, elements_iterator): diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 1528d23..cb981ec 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -312,11 +312,14 @@ class SdkWorker(object): def process_bundle_split(self, request, instruction_id): processor = self.active_bundle_processors.get(request.instruction_reference) - if not processor: - raise ValueError('Instruction not running: %s' % instruction_id) - return beam_fn_api_pb2.InstructionResponse( - instruction_id=instruction_id, - process_bundle_split=processor.try_split(request)) + if processor: + return beam_fn_api_pb2.InstructionResponse( + instruction_id=instruction_id, + process_bundle_split=processor.try_split(request)) + else: + return beam_fn_api_pb2.InstructionResponse( + instruction_id=instruction_id, + error='Instruction not running: %s' % instruction_id) def process_bundle_progress(self, request, instruction_id): # It is an error to get progress for a not-in-flight bundle.