This is an automated email from the ASF dual-hosted git repository. robertwb pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new bfc37eb [BEAM-3837] Handle BundleSplitRequests in Python SDK Harness. new f45d674 Merge pull request #7759 [BEAM-3837] Handle BundleSplitRequests in Python SDK Harness. bfc37eb is described below commit bfc37ebd6e6940858b789750d07f4e4b654a6e99 Author: Robert Bradshaw <rober...@google.com> AuthorDate: Mon Jan 28 16:07:21 2019 +0100 [BEAM-3837] Handle BundleSplitRequests in Python SDK Harness. --- .../fn-execution/src/main/proto/beam_fn_api.proto | 57 ++++++++ sdks/python/apache_beam/io/restriction_trackers.py | 19 ++- sdks/python/apache_beam/runners/common.py | 17 +++ .../runners/portability/fn_api_runner.py | 144 ++++++++++++++++++--- .../apache_beam/runners/worker/bundle_processor.py | 70 +++++++++- .../apache_beam/runners/worker/operations.pxd | 4 + .../apache_beam/runners/worker/operations.py | 56 +++++++- .../apache_beam/runners/worker/sdk_worker.py | 14 ++ 8 files changed, 355 insertions(+), 26 deletions(-) diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto index d681bc1..9a1c86a 100644 --- a/model/fn-execution/src/main/proto/beam_fn_api.proto +++ b/model/fn-execution/src/main/proto/beam_fn_api.proto @@ -746,6 +746,31 @@ message ProcessBundleSplitRequest { // If the backlog is unspecified for a PTransform, the runner would like // the SDK to process all data received for that PTransform. map<string, bytes> backlog_remaining = 2; + + // A message specifying the desired split for a single transform. + message DesiredSplit { + // (Required) The fraction of known work remaining in this bundle + // for this transform that should be kept by the SDK after this split. + // + // Set to 0 to "checkpoint" as soon as possible (keeping as little work as + // possible and returning the remainder). + float fraction_of_remainder = 1; + + // (Required for GrpcRead operations) Number of total elements expected + // to be sent to this GrpcRead operation, required to correctly account + // for unreceived data when determining where to split. + int64 estimated_input_elements = 2; + + // TODO(SDF): Allow providing weights rather than sizes. + // TODO(SDF): Allow specifying allowed/preferred split points. + } + + // (Required) Specifies the desired split for each transform. + // + // Currently only splits at GRPC read operations are supported. + // This may, of course, limit the amount of work downstream operations + // receive. + map<string, DesiredSplit> desired_splits = 3; } // Represents a partition of the bundle: a "primary" and @@ -765,8 +790,40 @@ message ProcessBundleSplitResponse { // have to be executed in a separate bundle (e.g. in parallel on a different // worker, or after the current bundle completes, etc.) repeated DelayedBundleApplication residual_roots = 2; + + // Represents contiguous portions of the data channel that are either + // entirely processed or entirely unprocessed and belong to the primary + // or residual respectively. + // + // This affords both a more efficient representation over the FnAPI + // (if the bundle is large) and often a more efficient representation + // on the runner side (e.g. if the set of elements can be represented + // as some range in an underlying dataset). + message ChannelSplit { + // (Required) The grpc read transform reading this channel. + string ptransform_id = 1; + + // (Required) Name of the transform's input to which to pass the element. + string input_id = 2; + + // The last element of the input channel that should be entirely considered + // part of the primary, identified by its absolute index in the (ordered) + // channel. + int32 last_primary_element = 3; + + // The first element of the input channel that should be entirely considered + // part of the residual, identified by its absolute index in the (ordered) + // channel. + int32 first_residual_element = 4; + } + + // Partitions of input data channels into primary and residual elements, + // if any. Should not include any elements represented in the bundle + // applications roots above. + repeated ChannelSplit channel_splits = 3; } + message FinalizeBundleRequest { // (Required) A reference to a completed process bundle request with the given // instruction id. diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py index e72d508..c165d95 100644 --- a/sdks/python/apache_beam/io/restriction_trackers.py +++ b/sdks/python/apache_beam/io/restriction_trackers.py @@ -79,6 +79,7 @@ class OffsetRestrictionTracker(RestrictionTracker): def __init__(self, start_position, stop_position): self._range = OffsetRange(start_position, stop_position) self._current_position = None + self._current_watermark = None self._last_claim_attempt = None self._deferred_residual = None self._checkpointed = False @@ -98,6 +99,9 @@ class OffsetRestrictionTracker(RestrictionTracker): with self._lock: return (self._range.start, self._range.stop) + def current_watermark(self): + return self._current_watermark + def start_position(self): with self._lock: return self._range.start @@ -127,6 +131,19 @@ class OffsetRestrictionTracker(RestrictionTracker): return False + def try_split(self, fraction): + with self._lock: + if not self._checkpointed: + if self._current_position is None: + cur = self._range.start - 1 + else: + cur = self._current_position + split_point = cur + int(max(1, (self._range.stop - cur) * fraction)) + if split_point < self._range.stop: + prev_stop, self._range.stop = self._range.stop, split_point + return (self._range.start, split_point), (split_point, prev_stop) + + # TODO(SDF): Replace all calls with try_claim(0). def checkpoint(self): with self._lock: # If self._current_position is 'None' no records have been claimed so @@ -143,7 +160,7 @@ class OffsetRestrictionTracker(RestrictionTracker): def defer_remainder(self, watermark=None): with self._lock: - self._deferred_watermark = watermark + self._deferred_watermark = watermark or self._current_watermark self._deferred_residual = self.checkpoint() def deferred_status(self): diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 3d9b07f..efdb59f 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -633,6 +633,20 @@ class PerWindowInvoker(DoFnInvoker): (windowed_value.value, deferred_restriction)), deferred_watermark) + def try_split(self, fraction): + restriction_tracker = self.restriction_tracker + current_windowed_value = self.current_windowed_value + if restriction_tracker and current_windowed_value: + split = restriction_tracker.try_split(fraction) + if split: + primary, residual = split + element = self.current_windowed_value.value + return ( + (self.current_windowed_value.with_value((element, primary)), + None), + (self.current_windowed_value.with_value((element, residual)), + restriction_tracker.current_watermark())) + class DoFnRunner(Receiver): """For internal use only; no backwards-compatibility guarantees. @@ -721,6 +735,9 @@ class DoFnRunner(Receiver): restriction_tracker=self.do_fn_invoker.invoke_create_tracker( restriction)) + def try_split(self, fraction): + return self.do_fn_invoker.try_split(fraction) + def process_user_timer(self, timer_spec, key, window, timestamp): try: self.do_fn_invoker.invoke_user_timer(timer_spec, key, window, timestamp) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index e908a5c..5f8fa3b 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -495,11 +495,21 @@ class FnApiRunner(runner.PipelineRunner): finally: controller.state.restore() - result = BundleManager( + result, splits = BundleManager( controller, get_buffer, process_bundle_descriptor, self._progress_frequency).process_bundle( data_input, data_output) + def input_for(ptransform_id, input_id): + input_pcoll = process_bundle_descriptor.transforms[ + ptransform_id].inputs[input_id] + for read_id, proto in process_bundle_descriptor.transforms.items(): + if (proto.spec.urn == bundle_processor.DATA_INPUT_URN + and input_pcoll in proto.outputs.values()): + return read_id, 'out' + raise RuntimeError( + 'No IO transform feeds %s' % ptransform_id) + last_result = result while True: deferred_inputs = collections.defaultdict(list) @@ -530,21 +540,60 @@ class FnApiRunner(runner.PipelineRunner): deferred_inputs[transform_id, 'out'] = [out.get()] written_timers[:] = [] - # Queue any delayed bundle applications. + # Queue any process-initiated delayed bundle applications. for delayed_application in last_result.process_bundle.residual_roots: - # Find the io transform that feeds this transform. - # TODO(SDF): Memoize? - application = delayed_application.application - input_pcoll = process_bundle_descriptor.transforms[ - application.ptransform_id].inputs[application.input_id] - for input_id, proto in process_bundle_descriptor.transforms.items(): - if (proto.spec.urn == bundle_processor.DATA_INPUT_URN - and input_pcoll in proto.outputs.values()): - deferred_inputs[input_id, 'out'].append(application.element) - break - else: - raise RuntimeError( - 'No IO transform feeds %s' % application.ptransform_id) + deferred_inputs[ + input_for( + delayed_application.application.ptransform_id, + delayed_application.application.input_id) + ].append(delayed_application.application.element) + + # Queue any runner-initiated delayed bundle applications. + prev_stops = collections.defaultdict(lambda: float('inf')) + for split in splits: + for delayed_application in split.residual_roots: + deferred_inputs[ + input_for( + delayed_application.application.ptransform_id, + delayed_application.application.input_id) + ].append(delayed_application.application.element) + for channel_split in split.channel_splits: + transform = process_bundle_descriptor.transforms[ + channel_split.ptransform_id] + coder_id = beam_fn_api_pb2.RemoteGrpcPort.FromString( + transform.spec.payload).coder_id + coder_impl = context.coders[safe_coders[coder_id]].get_impl() + # TODO(SDF): This requires determanistic ordering of buffer iteration. + # TODO(SDF): The return split is in terms of indices. Ideally, + # a runner could map these back to actual positions to effectively + # describe the two "halves" of the now-split range. Even if we have + # to buffer each element we send (or at the very least a bit of + # metadata, like position, about each of them) this should be doable + # if they're already in memory and we are bounding the buffer size + # (e.g. to 10mb plus whatever is eagerly read from the SDK). In the + # case of non-split-points, we can either immediately replay the + # "non-split-position" elements or record them as we do the other + # delayed applications. + + # Decode and recode to split the encoded buffer by element index. + buffer = data_input[ + channel_split.ptransform_id, channel_split.input_id] + input_stream = create_InputStream(''.join(buffer)) + output_stream = create_OutputStream() + index = 0 + prev_stop = prev_stops[channel_split.ptransform_id] + while input_stream.size() > 0: + if index > prev_stop: + break + element = coder_impl.decode_from_stream(input_stream, True) + if index >= channel_split.first_residual_element: + coder_impl.encode_to_stream(element, output_stream, True) + index += 1 + deferred_inputs[ + channel_split.ptransform_id, channel_split.input_id].append( + output_stream.get()) + prev_stops[ + channel_split.ptransform_id] = channel_split.last_primary_element if deferred_inputs: # The worker will be waiting on these inputs as well. @@ -552,7 +601,7 @@ class FnApiRunner(runner.PipelineRunner): if other_input not in deferred_inputs: deferred_inputs[other_input] = [] # TODO(robertwb): merge results - last_result = BundleManager( + last_result, splits = BundleManager( controller, get_buffer, process_bundle_descriptor, @@ -1083,7 +1132,7 @@ class BundleManager(object): self._registered = skip_registration self._progress_frequency = progress_frequency - def process_bundle(self, inputs, expected_outputs): + def process_bundle(self, inputs, expected_outputs, test_splits=False): # Unique id for the instruction processing this bundle. BundleManager._uid_counter += 1 process_bundle_id = 'bundle_%s' % BundleManager._uid_counter @@ -1108,6 +1157,17 @@ class BundleManager(object): data_out.write(element_data) data_out.close() + # TODO(robertwb): Control this via a pipeline option. + if test_splits: + # Inject some splits. + random_splitter = BundleSplitter( + self._controller, + process_bundle_id, + self._bundle_descriptor.transforms.keys()) + random_splitter.start() + else: + random_splitter = None + # Actually start the bundle. if registration_future and registration_future.get().error: raise RuntimeError(registration_future.get().error) @@ -1138,9 +1198,16 @@ class BundleManager(object): logging.debug('Wait for the bundle to finish.') result = result_future.get() + if random_splitter: + random_splitter.stop() + split_results = random_splitter.split_results() + else: + split_results = [] + if result.error: raise RuntimeError(result.error) - return result + + return result, split_results class ProgressRequester(threading.Thread): @@ -1181,6 +1248,47 @@ class ProgressRequester(threading.Thread): self._done = True +class BundleSplitter(threading.Thread): + def __init__(self, controller, instruction_id, split_transforms, + frequency=.03, split_fractions=(.5, .25, 0)): + super(BundleSplitter, self).__init__() + self._controller = controller + self._instruction_id = instruction_id + self._split_transforms = split_transforms + self._split_fractions = split_fractions + self._frequency = frequency + self._results = [] + self._done = False + + def run(self): + for fraction in self._split_fractions: + if self._done: + return + split_result = self._controller.control_handler.push( + beam_fn_api_pb2.InstructionRequest( + process_bundle_split=beam_fn_api_pb2.ProcessBundleSplitRequest( + instruction_reference=self._instruction_id, + desired_splits={ + transform_id: + beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit( + fraction_of_remainder=fraction) + for transform_id in self._split_transforms}))).get() + if split_result.error: + logging.info('Unable to split at %s: %s' % ( + fraction, split_result.error)) + elif split_result.process_bundle_split: + self._results.append(split_result.process_bundle_split) + time.sleep(self._frequency) + + def split_results(self): + self.stop() + self.join() + return self._results + + def stop(self): + self._done = True + + class ControlFuture(object): def __init__(self, instruction_id, response=None): self.instruction_id = instruction_id diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 6cffc02..db2d790 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -27,6 +27,7 @@ import json import logging import random import re +import threading from builtins import next from builtins import object @@ -113,9 +114,15 @@ class DataInputOperation(RunnerIOOperation): windowed_coder, target=input_target, data_channel=data_channel) # We must do this manually as we don't have a spec or spec.output_coders. self.receivers = [ - operations.ConsumerSet( + operations.ConsumerSet.create( self.counter_factory, self.name_context.step_name, 0, next(iter(itervalues(consumers))), self.windowed_coder)] + self.splitting_lock = threading.Lock() + + def start(self): + super(DataInputOperation, self).start() + self.index = -1 + self.stop = float('inf') def process(self, windowed_value): self.output(windowed_value) @@ -123,10 +130,38 @@ class DataInputOperation(RunnerIOOperation): def process_encoded(self, encoded_windowed_values): input_stream = coder_impl.create_InputStream(encoded_windowed_values) while input_stream.size() > 0: + with self.splitting_lock: + if self.index == self.stop - 1: + return + self.index += 1 decoded_value = self.windowed_coder_impl.decode_from_stream( input_stream, True) self.output(decoded_value) + def try_split(self, fraction_of_remainder, total_buffer_size=None): + with self.splitting_lock: + # If total_buffer_size is not provided, pick something. + if not total_buffer_size: + total_buffer_size = self.index + 2 + elif self.stop and total_buffer_size > self.stop: + total_buffer_size = self.stop + # Compute, as a fraction, how much further to go. + # TODO(SDF): Take into account progress of current element. + stop_offset = (total_buffer_size - self.index) * fraction_of_remainder + # If it's less than a whole element, try splitting the current element. + if int(stop_offset) == 0: + split = self.receivers[0].try_split(stop_offset) + if split: + element_primary, element_residual = split + self.stop = self.index + 1 + return self.stop - 2, element_primary, element_residual, self.stop + + # Otherwise, split at the closest element boundary. + desired_stop = max(int(stop_offset), 1) + self.index + if desired_stop < self.stop: + self.stop = desired_stop + return self.stop - 1, None, None, self.stop + class _StateBackedIterable(object): def __init__(self, state_handler, state_key, coder_or_impl): @@ -413,6 +448,7 @@ class BundleProcessor(object): self.ops = self.create_execution_tree(self.process_bundle_descriptor) for op in self.ops.values(): op.setup() + self.splitting_lock = threading.Lock() def create_execution_tree(self, descriptor): @@ -509,8 +545,40 @@ class BundleProcessor(object): for op, residual in execution_context.delayed_applications] finally: + # Ensure any in-flight split attempts complete. + with self.splitting_lock: + pass self.state_sampler.stop_if_still_running() + def try_split(self, bundle_split_request): + split_response = beam_fn_api_pb2.ProcessBundleSplitResponse() + with self.splitting_lock: + for op in self.ops.values(): + if isinstance(op, DataInputOperation): + desired_split = bundle_split_request.desired_splits.get( + op.target.primitive_transform_reference) + if desired_split: + split = op.try_split(desired_split.fraction_of_remainder, + desired_split.estimated_input_elements) + if split: + (primary_end, element_primary, element_residual, residual_start, + ) = split + if element_primary: + split_response.primary_roots.add().CopyFrom( + self.delayed_bundle_application( + *element_primary).application) + if element_residual: + split_response.residual_roots.add().CopyFrom( + self.delayed_bundle_application(*element_residual)) + split_response.channel_splits.extend([ + beam_fn_api_pb2.ProcessBundleSplitResponse.ChannelSplit( + ptransform_id=op.target.primitive_transform_reference, + input_id=op.target.name, + last_primary_element=primary_end, + first_residual_element=residual_start)]) + + return split_response + def delayed_bundle_application(self, op, deferred_remainder): ptransform_id, main_input_tag, main_input_coder, outputs = op.input_info # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder. diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd index 10c3c41..9f0c015 100644 --- a/sdks/python/apache_beam/runners/worker/operations.pxd +++ b/sdks/python/apache_beam/runners/worker/operations.pxd @@ -38,6 +38,10 @@ cdef class ConsumerSet(Receiver): cpdef update_counters_finish(self) +cdef class SingletonConsumerSet(ConsumerSet): + cdef Operation consumer + + cdef class Operation(object): cdef readonly name_context cdef readonly operation_name diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index a6e0c31..c7c767f 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -73,6 +73,14 @@ class ConsumerSet(Receiver): the other edge. ConsumerSet are attached to the outputting Operation. """ + @staticmethod + def create(counter_factory, step_name, output_index, consumers, coder): + if len(consumers) == 1: + return SingletonConsumerSet( + counter_factory, step_name, output_index, consumers, coder) + else: + return ConsumerSet( + counter_factory, step_name, output_index, consumers, coder) def __init__( self, counter_factory, step_name, output_index, consumers, coder): @@ -90,6 +98,14 @@ class ConsumerSet(Receiver): cython.cast(Operation, consumer).process(windowed_value) self.update_counters_finish() + def try_split(self, fraction_of_remainder): + # TODO(SDF): Consider supporting splitting each consumer individually. + # This would never come up in the existing SDF expansion, but might + # be useful to support fused SDF nodes. + # This would require dedicated delivery of the split results to each + # of the consumers separately. + return None + def update_counters_start(self, windowed_value): self.opcounter.update_from(windowed_value) @@ -102,6 +118,23 @@ class ConsumerSet(Receiver): len(self.consumers)) +class SingletonConsumerSet(ConsumerSet): + def __init__( + self, counter_factory, step_name, output_index, consumers, coder): + assert len(consumers) == 1 + super(SingletonConsumerSet, self).__init__( + counter_factory, step_name, output_index, consumers, coder) + self.consumer = consumers[0] + + def receive(self, windowed_value): + self.update_counters_start(windowed_value) + self.consumer.process(windowed_value) + self.update_counters_finish() + + def try_split(self, fraction_of_remainder): + return self.consumer.try_split(fraction_of_remainder) + + class Operation(object): """An operation representing the live version of a work item specification. @@ -157,11 +190,13 @@ class Operation(object): # top-level operation, should have output_coders #TODO(pabloem): Define better what step name is used here. if getattr(self.spec, 'output_coders', None): - self.receivers = [ConsumerSet(self.counter_factory, - self.name_context.logging_name(), - i, - self.consumers[i], coder) - for i, coder in enumerate(self.spec.output_coders)] + self.receivers = [ + ConsumerSet.create( + self.counter_factory, + self.name_context.logging_name(), + i, + self.consumers[i], coder) + for i, coder in enumerate(self.spec.output_coders)] self.setup_done = True def start(self): @@ -174,6 +209,9 @@ class Operation(object): """Process element in operation.""" pass + def try_split(self, fraction_of_remainder): + return None + def finish(self): """Finish operation.""" pass @@ -327,7 +365,7 @@ class ImpulseReadOperation(Operation): name_context, None, counter_factory, state_sampler) self.source = source self.receivers = [ - ConsumerSet( + ConsumerSet.create( self.counter_factory, self.name_context.step_name, 0, next(iter(consumers.values())), output_coder)] @@ -553,6 +591,12 @@ class SdfProcessElements(DoOperation): self.execution_context.delayed_applications.append( (self, delayed_application)) + def try_split(self, fraction_of_remainder): + split = self.dofn_runner.try_split(fraction_of_remainder) + if split: + primary, residual = split + return (self, primary), (self, residual) + class DoFnRunnerReceiver(Receiver): diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 6067181..1528d23 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -197,7 +197,13 @@ class SdkHarness(object): logging.debug( "Currently using %s threads." % len(self._process_thread_pool._threads)) + def _request_process_bundle_split(self, request): + self._request_process_bundle_action(request) + def _request_process_bundle_progress(self, request): + self._request_process_bundle_action(request) + + def _request_process_bundle_action(self, request): def task(): instruction_reference = getattr( @@ -304,6 +310,14 @@ class SdkWorker(object): processor.reset() self.cached_bundle_processors[bundle_descriptor_id].append(processor) + def process_bundle_split(self, request, instruction_id): + processor = self.active_bundle_processors.get(request.instruction_reference) + if not processor: + raise ValueError('Instruction not running: %s' % instruction_id) + return beam_fn_api_pb2.InstructionResponse( + instruction_id=instruction_id, + process_bundle_split=processor.try_split(request)) + def process_bundle_progress(self, request, instruction_id): # It is an error to get progress for a not-in-flight bundle. processor = self.active_bundle_processors.get(request.instruction_reference)