This is an automated email from the ASF dual-hosted git repository. pabloem pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new e933ddc Adding helper functions, and documentation to fn_api_runner classes. (#8793) e933ddc is described below commit e933ddca967f897b4256e9689d7699fc9031df62 Author: Pablo <pabl...@users.noreply.github.com> AuthorDate: Thu Jun 13 12:34:25 2019 -0700 Adding helper functions, and documentation to fn_api_runner classes. (#8793) --- .../runners/portability/fn_api_runner.py | 464 +++++++++++++-------- .../runners/portability/fn_api_runner_test.py | 31 ++ .../apache_beam/runners/worker/bundle_processor.py | 4 +- 3 files changed, 326 insertions(+), 173 deletions(-) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index 1fbb0ed..66aedfc 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -173,6 +173,10 @@ class _GroupingBuffer(object): windowed_key_values = lambda key, values: [ globally_window((key, values))] else: + # TODO(pabloem, BEAM-7514): Trigger driver needs access to the clock + # note that this only comes through if windowing is default - but what + # about having multiple firings on the global window. + # May need to revise. trigger_driver = trigger.create_trigger_driver(self._windowing, True) windowed_key_values = trigger_driver.process_entire_key coder_impl = self._post_grouped_coder.get_impl() @@ -291,7 +295,10 @@ class FnApiRunner(runner.PipelineRunner): return self._latest_run_result def run_via_runner_api(self, pipeline_proto): - return self.run_stages(*self.create_stages(pipeline_proto)) + stage_context, stages = self.create_stages(pipeline_proto) + # TODO(pabloem, BEAM-7514): Create a watermark manager (that has access to + # the teststream (if any), and all the stages). + return self.run_stages(stage_context, stages) @contextlib.contextmanager def maybe_profile(self): @@ -353,7 +360,7 @@ class FnApiRunner(runner.PipelineRunner): use_state_iterables=self._use_state_iterables) def run_stages(self, stage_context, stages): - """Run all of the stages. + """Run a list of topologically-sorted stages in batch mode. Args: stage_context (fn_api_runner_transforms.TransformContext) @@ -382,6 +389,156 @@ class FnApiRunner(runner.PipelineRunner): return RunnerResult( runner.PipelineState.DONE, monitoring_infos_by_stage, metrics_by_stage) + def _store_side_inputs_in_state(self, + controller, + context, + pipeline_components, + data_side_input, + pcoll_buffers, + safe_coders): + for (transform_id, tag), (buffer_id, si) in data_side_input.items(): + _, pcoll_id = split_buffer_id(buffer_id) + value_coder = context.coders[safe_coders[ + pipeline_components.pcollections[pcoll_id].coder_id]] + elements_by_window = _WindowGroupingBuffer(si, value_coder) + for element_data in pcoll_buffers[buffer_id]: + elements_by_window.append(element_data) + for key, window, elements_data in elements_by_window.encoded_items(): + state_key = beam_fn_api_pb2.StateKey( + multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( + ptransform_id=transform_id, + side_input_id=tag, + window=window, + key=key)) + controller.state.blocking_append(state_key, elements_data) + + def _run_bundle_multiple_times_for_testing(self, + controller, + process_bundle_descriptor, + data_input, + data_output, + get_input_coder_callable): + for k in range(self._bundle_repeat): + try: + controller.state.checkpoint() + BundleManager( + controller, lambda pcoll_id: [], get_input_coder_callable, + process_bundle_descriptor, self._progress_frequency, k + ).process_bundle(data_input, data_output) + finally: + controller.state.restore() + + def _collect_written_timers_and_add_to_deferred_inputs(self, + context, + pipeline_components, + stage, + get_buffer_callable, + deferred_inputs): + + for transform_id, timer_writes in stage.timer_pcollections: + + # Queue any set timers as new inputs. + windowed_timer_coder_impl = context.coders[ + pipeline_components.pcollections[timer_writes].coder_id].get_impl() + written_timers = get_buffer_callable( + create_buffer_id(timer_writes, kind='timers')) + if written_timers: + # Keep only the "last" timer set per key and window. + timers_by_key_and_window = {} + for elements_data in written_timers: + input_stream = create_InputStream(elements_data) + while input_stream.size() > 0: + windowed_key_timer = windowed_timer_coder_impl.decode_from_stream( + input_stream, True) + key, _ = windowed_key_timer.value + # TODO: Explode and merge windows. + assert len(windowed_key_timer.windows) == 1 + timers_by_key_and_window[ + key, windowed_key_timer.windows[0]] = windowed_key_timer + out = create_OutputStream() + for windowed_key_timer in timers_by_key_and_window.values(): + windowed_timer_coder_impl.encode_to_stream( + windowed_key_timer, out, True) + deferred_inputs[transform_id] = [out.get()] + written_timers[:] = [] + + def _add_residuals_and_channel_splits_to_deferred_inputs( + self, splits, get_input_coder_callable, + input_for_callable, last_sent, deferred_inputs): + prev_stops = {} + for split in splits: + for delayed_application in split.residual_roots: + deferred_inputs[ + input_for_callable( + delayed_application.application.ptransform_id, + delayed_application.application.input_id) + ].append(delayed_application.application.element) + for channel_split in split.channel_splits: + coder_impl = get_input_coder_callable(channel_split.ptransform_id) + # TODO(SDF): This requires determanistic ordering of buffer iteration. + # TODO(SDF): The return split is in terms of indices. Ideally, + # a runner could map these back to actual positions to effectively + # describe the two "halves" of the now-split range. Even if we have + # to buffer each element we send (or at the very least a bit of + # metadata, like position, about each of them) this should be doable + # if they're already in memory and we are bounding the buffer size + # (e.g. to 10mb plus whatever is eagerly read from the SDK). In the + # case of non-split-points, we can either immediately replay the + # "non-split-position" elements or record them as we do the other + # delayed applications. + + # Decode and recode to split the encoded buffer by element index. + all_elements = list(coder_impl.decode_all(b''.join(last_sent[ + channel_split.ptransform_id]))) + residual_elements = all_elements[ + channel_split.first_residual_element : prev_stops.get( + channel_split.ptransform_id, len(all_elements)) + 1] + if residual_elements: + deferred_inputs[channel_split.ptransform_id].append( + coder_impl.encode_all(residual_elements)) + prev_stops[ + channel_split.ptransform_id] = channel_split.last_primary_element + + @staticmethod + def _extract_stage_data_endpoints( + stage, pipeline_components, data_api_service_descriptor, pcoll_buffers): + # Returns maps of transform names to PCollection identifiers. + # Also mutates IO stages to point to the data ApiServiceDescriptor. + data_input = {} + data_side_input = {} + data_output = {} + for transform in stage.transforms: + if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, + bundle_processor.DATA_OUTPUT_URN): + pcoll_id = transform.spec.payload + if transform.spec.urn == bundle_processor.DATA_INPUT_URN: + target = transform.unique_name, only_element(transform.outputs) + if pcoll_id == fn_api_runner_transforms.IMPULSE_BUFFER: + data_input[target] = [ENCODED_IMPULSE_VALUE] + else: + data_input[target] = pcoll_buffers[pcoll_id] + coder_id = pipeline_components.pcollections[ + only_element(transform.outputs.values())].coder_id + elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: + target = transform.unique_name, only_element(transform.inputs) + data_output[target] = pcoll_id + coder_id = pipeline_components.pcollections[ + only_element(transform.inputs.values())].coder_id + else: + raise NotImplementedError + data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) + if data_api_service_descriptor: + data_spec.api_service_descriptor.url = ( + data_api_service_descriptor.url) + transform.spec.payload = data_spec.SerializeToString() + elif transform.spec.urn in fn_api_runner_transforms.PAR_DO_URNS: + payload = proto_utils.parse_Bytes( + transform.spec.payload, beam_runner_api_pb2.ParDoPayload) + for tag, si in payload.side_inputs.items(): + data_side_input[transform.unique_name, tag] = ( + create_buffer_id(transform.inputs[tag]), si.access_pattern) + return data_input, data_side_input, data_output + def _run_stage(self, worker_handler_factory, pipeline_components, @@ -435,22 +592,14 @@ class FnApiRunner(runner.PipelineRunner): process_bundle_descriptor.state_api_service_descriptor.url = ( controller.state_api_service_descriptor().url) - # Store the required side inputs into state. - for (transform_id, tag), (buffer_id, si) in data_side_input.items(): - _, pcoll_id = split_buffer_id(buffer_id) - value_coder = context.coders[safe_coders[ - pipeline_components.pcollections[pcoll_id].coder_id]] - elements_by_window = _WindowGroupingBuffer(si, value_coder) - for element_data in pcoll_buffers[buffer_id]: - elements_by_window.append(element_data) - for key, window, elements_data in elements_by_window.encoded_items(): - state_key = beam_fn_api_pb2.StateKey( - multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( - ptransform_id=transform_id, - side_input_id=tag, - window=window, - key=key)) - controller.state.blocking_append(state_key, elements_data) + # Store the required side inputs into state so it is accessible for the + # worker when it runs this bundle. + self._store_side_inputs_in_state(controller, + context, + pipeline_components, + data_side_input, + pcoll_buffers, + safe_coders) def get_buffer(buffer_id): """Returns the buffer for a given (operation_type, PCollection ID). @@ -493,20 +642,17 @@ class FnApiRunner(runner.PipelineRunner): ).coder_id ]].get_impl() - for k in range(self._bundle_repeat): - try: - controller.state.checkpoint() - BundleManager( - controller, lambda pcoll_id: [], get_input_coder_impl, - process_bundle_descriptor, self._progress_frequency, k - ).process_bundle(data_input, data_output) - finally: - controller.state.restore() + self._run_bundle_multiple_times_for_testing(controller, + process_bundle_descriptor, + data_input, + data_output, + get_input_coder_impl) - result, splits = BundleManager( + bundle_manager = BundleManager( controller, get_buffer, get_input_coder_impl, process_bundle_descriptor, - self._progress_frequency).process_bundle( - data_input, data_output) + self._progress_frequency) + + result, splits = bundle_manager.process_bundle(data_input, data_output) def input_for(ptransform_id, input_id): input_pcoll = process_bundle_descriptor.transforms[ @@ -523,32 +669,9 @@ class FnApiRunner(runner.PipelineRunner): while True: deferred_inputs = collections.defaultdict(list) - for transform_id, timer_writes in stage.timer_pcollections: - - # Queue any set timers as new inputs. - windowed_timer_coder_impl = context.coders[ - pipeline_components.pcollections[timer_writes].coder_id].get_impl() - written_timers = get_buffer( - create_buffer_id(timer_writes, kind='timers')) - if written_timers: - # Keep only the "last" timer set per key and window. - timers_by_key_and_window = {} - for elements_data in written_timers: - input_stream = create_InputStream(elements_data) - while input_stream.size() > 0: - windowed_key_timer = windowed_timer_coder_impl.decode_from_stream( - input_stream, True) - key, _ = windowed_key_timer.value - # TODO: Explode and merge windows. - assert len(windowed_key_timer.windows) == 1 - timers_by_key_and_window[ - key, windowed_key_timer.windows[0]] = windowed_key_timer - out = create_OutputStream() - for windowed_key_timer in timers_by_key_and_window.values(): - windowed_timer_coder_impl.encode_to_stream( - windowed_key_timer, out, True) - deferred_inputs[transform_id] = [out.get()] - written_timers[:] = [] + + self._collect_written_timers_and_add_to_deferred_inputs( + context, pipeline_components, stage, get_buffer, deferred_inputs) # Queue any process-initiated delayed bundle applications. for delayed_application in last_result.process_bundle.residual_roots: @@ -559,39 +682,8 @@ class FnApiRunner(runner.PipelineRunner): ].append(delayed_application.application.element) # Queue any runner-initiated delayed bundle applications. - prev_stops = {} - for split in splits: - for delayed_application in split.residual_roots: - deferred_inputs[ - input_for( - delayed_application.application.ptransform_id, - delayed_application.application.input_id) - ].append(delayed_application.application.element) - for channel_split in split.channel_splits: - coder_impl = get_input_coder_impl(channel_split.ptransform_id) - # TODO(SDF): This requires determanistic ordering of buffer iteration. - # TODO(SDF): The return split is in terms of indices. Ideally, - # a runner could map these back to actual positions to effectively - # describe the two "halves" of the now-split range. Even if we have - # to buffer each element we send (or at the very least a bit of - # metadata, like position, about each of them) this should be doable - # if they're already in memory and we are bounding the buffer size - # (e.g. to 10mb plus whatever is eagerly read from the SDK). In the - # case of non-split-points, we can either immediately replay the - # "non-split-position" elements or record them as we do the other - # delayed applications. - - # Decode and recode to split the encoded buffer by element index. - all_elements = list(coder_impl.decode_all(b''.join(last_sent[ - channel_split.ptransform_id]))) - residual_elements = all_elements[ - channel_split.first_residual_element : prev_stops.get( - channel_split.ptransform_id, len(all_elements)) + 1] - if residual_elements: - deferred_inputs[channel_split.ptransform_id].append( - coder_impl.encode_all(residual_elements)) - prev_stops[ - channel_split.ptransform_id] = channel_split.last_primary_element + self._add_residuals_and_channel_splits_to_deferred_inputs( + splits, get_input_coder_impl, input_for, last_sent, deferred_inputs) if deferred_inputs: # The worker will be waiting on these inputs as well. @@ -1221,6 +1313,17 @@ def split_manager(stage_name, split_manager): class BundleManager(object): + """Manages the execution of a bundle from the runner-side. + + This class receives a bundle descriptor, and performs the following tasks: + - Registration of the bundle with the worker. + - Splitting of the bundle + - Setting up any other bundle requirements (e.g. side inputs). + - Submitting the bundle to worker for execution + - Passing bundle input data to the worker + - Collecting bundle output data from the worker + - Finalizing the bundle. + """ _uid_counter = 0 @@ -1244,20 +1347,17 @@ class BundleManager(object): self._registered = skip_registration self._progress_frequency = progress_frequency - def process_bundle(self, inputs, expected_outputs): - """Submit a bundle for processing by the SDK. - - Args: - inputs (dict): A map from PCollection id (expressed as a tuple - (PTransform, output), to a buffer with elements in the PCollection. - expected_outputs (dict): A map from PCollection id (expressed as a tuple - (PTransform, output), to TODO: What does this map to? - """ - # Unique id for the instruction processing this bundle. - BundleManager._uid_counter += 1 - process_bundle_id = 'bundle_%s' % BundleManager._uid_counter - - # Register the bundle descriptor, if needed. + def _send_input_to_worker(self, + process_bundle_id, + read_transform_id, + byte_streams): + data_out = self._controller.data_plane_handler.output_stream( + process_bundle_id, read_transform_id) + for byte_stream in byte_streams: + data_out.write(byte_stream) + data_out.close() + + def _register_bundle_descriptor(self): if self._registered: registration_future = None else: @@ -1268,6 +1368,10 @@ class BundleManager(object): process_bundle_registration) self._registered = True + return registration_future + + def _select_split_manager(self): + """TODO(pabloem) WHAT DOES THIS DO""" unique_names = set( t.unique_name for t in self._bundle_descriptor.transforms.values()) for stage_name, candidate in reversed(_split_managers): @@ -1278,88 +1382,103 @@ class BundleManager(object): else: split_manager = None - if not split_manager: - # Write all the input data to the channel immediately. - for transform_id, elements in inputs.items(): - data_out = self._controller.data_plane_handler.output_stream( - process_bundle_id, transform_id) - for element_data in elements: - data_out.write(element_data) - data_out.close() + return split_manager + def _generate_splits_for_testing(self, + split_manager, + inputs, + process_bundle_id): split_results = [] + read_transform_id, buffer_data = only_element(inputs.items()) - # Actually start the bundle. + byte_stream = b''.join(buffer_data) + num_elements = len(list( + self._get_input_coder_impl(read_transform_id).decode_all(byte_stream))) + + # Start the split manager in case it wants to set any breakpoints. + split_manager_generator = split_manager(num_elements) + try: + split_fraction = next(split_manager_generator) + done = False + except StopIteration: + done = True + + # Send all the data. + self._send_input_to_worker( + process_bundle_id, read_transform_id, [byte_stream]) + + # Execute the requested splits. + while not done: + if split_fraction is None: + split_result = None + else: + split_request = beam_fn_api_pb2.InstructionRequest( + process_bundle_split= + beam_fn_api_pb2.ProcessBundleSplitRequest( + instruction_reference=process_bundle_id, + desired_splits={ + read_transform_id: + beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit( + fraction_of_remainder=split_fraction, + estimated_input_elements=num_elements) + })) + split_response = self._controller.control_handler.push( + split_request).get() + for t in (0.05, 0.1, 0.2): + waiting = ('Instruction not running', 'not yet scheduled') + if any(msg in split_response.error for msg in waiting): + time.sleep(t) + split_response = self._controller.control_handler.push( + split_request).get() + if 'Unknown process bundle' in split_response.error: + # It may have finished too fast. + split_result = None + elif split_response.error: + raise RuntimeError(split_response.error) + else: + split_result = split_response.process_bundle_split + split_results.append(split_result) + try: + split_fraction = split_manager_generator.send(split_result) + except StopIteration: + break + return split_results + + def process_bundle(self, inputs, expected_outputs): + # Unique id for the instruction processing this bundle. + BundleManager._uid_counter += 1 + process_bundle_id = 'bundle_%s' % BundleManager._uid_counter + + # Register the bundle descriptor, if needed - noop if already registered. + registration_future = self._register_bundle_descriptor() + + split_manager = self._select_split_manager() + if not split_manager: + # If there is no split_manager, write all input data to the channel. + for transform_id, elements in inputs.items(): + self._send_input_to_worker( + process_bundle_id, transform_id, elements) + + # Check that the bundle was successfully registered. if registration_future and registration_future.get().error: raise RuntimeError(registration_future.get().error) - process_bundle_request = beam_fn_api_pb2.InstructionRequest( + # Actually start the bundle. + process_bundle_req = beam_fn_api_pb2.InstructionRequest( instruction_id=process_bundle_id, process_bundle=beam_fn_api_pb2.ProcessBundleRequest( process_bundle_descriptor_reference=self._bundle_descriptor.id)) - result_future = self._controller.control_handler.push( - process_bundle_request) + result_future = self._controller.control_handler.push(process_bundle_req) + split_results = [] with ProgressRequester( self._controller, process_bundle_id, self._progress_frequency): + if split_manager: - read_transform_id, buffer_data = only_element(inputs.items()) - num_elements = len(list( - self._get_input_coder_impl(read_transform_id).decode_all( - b''.join(buffer_data)))) - - # Start the split manager in case it wants to set any breakpoints. - split_manager_generator = split_manager(num_elements) - try: - split_fraction = next(split_manager_generator) - done = False - except StopIteration: - done = True - - # Send all the data. - data_out = self._controller.data_plane_handler.output_stream( - process_bundle_id, read_transform_id) - data_out.write(b''.join(buffer_data)) - data_out.close() - - # Execute the requested splits. - while not done: - if split_fraction is None: - split_result = None - else: - split_request = beam_fn_api_pb2.InstructionRequest( - process_bundle_split= - beam_fn_api_pb2.ProcessBundleSplitRequest( - instruction_reference=process_bundle_id, - desired_splits={ - read_transform_id: - beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit( - fraction_of_remainder=split_fraction, - estimated_input_elements=num_elements) - })) - split_response = self._controller.control_handler.push( - split_request).get() - for t in (0.05, 0.1, 0.2): - waiting = ('Instruction not running', 'not yet scheduled') - if any(msg in split_response.error for msg in waiting): - time.sleep(t) - split_response = self._controller.control_handler.push( - split_request).get() - if 'Unknown process bundle' in split_response.error: - # It may have finished too fast. - split_result = None - elif split_response.error: - raise RuntimeError(split_response.error) - else: - split_result = split_response.process_bundle_split - split_results.append(split_result) - try: - split_fraction = split_manager_generator.send(split_result) - except StopIteration: - break + split_results = self._generate_splits_for_testing(split_manager, inputs, + process_bundle_id) # Gather all output data. - logging.debug('Gather all output data from %s.', expected_outputs) for output in self._controller.data_plane_handler.input_elements( process_bundle_id, expected_outputs.keys(), @@ -1388,6 +1507,11 @@ class BundleManager(object): class ProgressRequester(threading.Thread): + """ Thread that asks SDK Worker for progress reports with a certain frequency. + + A callback can be passed to call with progress updates. + """ + def __init__(self, controller, instruction_id, frequency, callback=None): super(ProgressRequester, self).__init__() self._controller = controller diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index f303722..e03864c 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -50,6 +50,7 @@ from apache_beam.runners.portability import fn_api_runner from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import statesampler from apache_beam.testing.synthetic_pipeline import SyntheticSDFAsSource +from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms import userstate @@ -306,6 +307,36 @@ class FnApiRunnerTest(unittest.TestCase): assert_that(p | beam.Create(inputs) | beam.ParDo(AddIndex()), equal_to(expected)) + @unittest.skip('TestStream not yet supported') + def test_teststream_pardo_timers(self): + timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) + + class TimerDoFn(beam.DoFn): + def process(self, element, timer=beam.DoFn.TimerParam(timer_spec)): + unused_key, ts = element + timer.set(ts) + timer.set(2 * ts) + + @userstate.on_timer(timer_spec) + def process_timer(self): + yield 'fired' + + ts = (TestStream() + .add_elements([('k1', 10)]) # Set timer for 20 + .advance_watermark_to(100) + .add_elements([('k2', 100)]) # Set timer for 200 + .advance_watermark_to(1000)) + + with self.create_pipeline() as p: + _ = ( + p + | ts + | beam.ParDo(TimerDoFn()) + | beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts))) + + #expected = [('fired', ts) for ts in (20, 200)] + #assert_that(actual, equal_to(expected)) + def test_pardo_timers(self): timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index ffc7379..8d788fc 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -484,9 +484,7 @@ def only_element(iterable): class BundleProcessor(object): - """A class for processing bundles of elements. - - """ + """ A class for processing bundles of elements. """ def __init__( self, process_bundle_descriptor, state_handler, data_channel_factory):