This is an automated email from the ASF dual-hosted git repository. robertwb pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 63dbbf4 [BEAM-3742] Checkpointing for SDF over FnAPI in Python SDK and FnApiRunner. (#7722) 63dbbf4 is described below commit 63dbbf4affc2082015d35f27f393f1b6d0a0493b Author: Robert Bradshaw <rober...@gmail.com> AuthorDate: Tue Feb 5 22:27:24 2019 +0100 [BEAM-3742] Checkpointing for SDF over FnAPI in Python SDK and FnApiRunner. (#7722) --- sdks/python/apache_beam/io/restriction_trackers.py | 12 +- sdks/python/apache_beam/portability/common_urns.py | 2 + sdks/python/apache_beam/runners/common.pxd | 3 + sdks/python/apache_beam/runners/common.py | 46 ++++++- .../apache_beam/runners/direct/direct_runner.py | 4 - .../runners/direct/sdf_direct_runner_test.py | 3 +- .../runners/portability/flink_runner_test.py | 3 + .../runners/portability/fn_api_runner.py | 40 ++++-- .../runners/portability/fn_api_runner_test.py | 38 ++++++ .../portability/fn_api_runner_transforms.py | 137 +++++++++++++++++++-- .../apache_beam/runners/worker/bundle_processor.py | 126 +++++++++++++++++-- .../apache_beam/runners/worker/operations.pxd | 6 + .../apache_beam/runners/worker/operations.py | 16 ++- .../apache_beam/runners/worker/sdk_worker.py | 3 +- sdks/python/apache_beam/transforms/core.py | 10 ++ 15 files changed, 405 insertions(+), 44 deletions(-) diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py index b9b1f17..e72d508 100644 --- a/sdks/python/apache_beam/io/restriction_trackers.py +++ b/sdks/python/apache_beam/io/restriction_trackers.py @@ -80,8 +80,9 @@ class OffsetRestrictionTracker(RestrictionTracker): self._range = OffsetRange(start_position, stop_position) self._current_position = None self._last_claim_attempt = None + self._deferred_residual = None self._checkpointed = False - self._lock = threading.Lock() + self._lock = threading.RLock() def check_done(self): with self._lock: @@ -139,3 +140,12 @@ class OffsetRestrictionTracker(RestrictionTracker): self._range = OffsetRange(self._range.start, end_position) return residual_range + + def defer_remainder(self, watermark=None): + with self._lock: + self._deferred_watermark = watermark + self._deferred_residual = self.checkpoint() + + def deferred_status(self): + if self._deferred_residual: + return (self._deferred_residual, self._deferred_watermark) diff --git a/sdks/python/apache_beam/portability/common_urns.py b/sdks/python/apache_beam/portability/common_urns.py index 4ee32a7..a000f81 100644 --- a/sdks/python/apache_beam/portability/common_urns.py +++ b/sdks/python/apache_beam/portability/common_urns.py @@ -49,6 +49,8 @@ composites = PropertiesFromEnumType( beam_runner_api_pb2.StandardPTransforms.Composites) combine_components = PropertiesFromEnumType( beam_runner_api_pb2.StandardPTransforms.CombineComponents) +sdf_components = PropertiesFromEnumType( + beam_runner_api_pb2.StandardPTransforms.SplittableParDoComponents) side_inputs = PropertiesFromEnumType( beam_runner_api_pb2.StandardSideInputTypes.Enum) diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd index 49f4c44..b5ab88d 100644 --- a/sdks/python/apache_beam/runners/common.pxd +++ b/sdks/python/apache_beam/runners/common.pxd @@ -82,6 +82,9 @@ cdef class PerWindowInvoker(DoFnInvoker): cdef bint has_windowed_inputs cdef bint cache_globally_windowed_args cdef object process_method + cdef bint is_splittable + cdef object restriction_tracker + cdef WindowedValue current_windowed_value cdef class DoFnRunner(Receiver): diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 99a4bca..3d9b07f 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -211,7 +211,7 @@ class DoFnSignature(object): self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle') self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle') - restriction_provider = self._get_restriction_provider(do_fn) + restriction_provider = self.get_restriction_provider() self.initial_restriction_method = ( MethodWrapper(restriction_provider, 'initial_restriction') if restriction_provider else None) @@ -237,7 +237,7 @@ class DoFnSignature(object): method = timer_spec._attached_callback self.timer_methods[timer_spec] = MethodWrapper(do_fn, method.__name__) - def _get_restriction_provider(self, do_fn): + def get_restriction_provider(self): result = _find_param_with_default(self.process_method, default_as_type=RestrictionProvider) return result[1] if result else None @@ -434,6 +434,9 @@ class PerWindowInvoker(DoFnInvoker): (core.DoFn.WindowParam in default_arg_values) or signature.is_stateful_dofn()) self.user_state_context = user_state_context + self.is_splittable = signature.is_splittable_dofn() + self.restriction_tracker = None + self.current_windowed_value = None # Try to prepare all the arguments that can just be filled in # without any additional work. in the process function. @@ -515,7 +518,16 @@ class PerWindowInvoker(DoFnInvoker): # or if the process accesses the window parameter. We can just call it once # otherwise as none of the arguments are changing + if self.is_splittable and not restriction_tracker: + restriction = self.invoke_initial_restriction(windowed_value.value) + restriction_tracker = self.invoke_create_tracker(restriction) + if restriction_tracker: + if len(windowed_value.windows) > 1 and self.has_windowed_inputs: + # Should never get here due to window explosion in + # the upstream pair-with-restriction. + raise NotImplementedError( + 'SDFs in multiply-windowed values with windowed arguments.') restriction_tracker_param = _find_param_with_default( self.signature.process_method, default_as_type=core.RestrictionProvider)[0] @@ -524,7 +536,17 @@ class PerWindowInvoker(DoFnInvoker): 'A RestrictionTracker %r was provided but DoFn does not have a ' 'RestrictionTrackerParam defined' % restriction_tracker) additional_kwargs[restriction_tracker_param] = restriction_tracker - if self.has_windowed_inputs and len(windowed_value.windows) != 1: + try: + self.current_windowed_value = windowed_value + self.restriction_tracker = restriction_tracker + return self._invoke_per_window( + windowed_value, additional_args, additional_kwargs, + output_processor) + finally: + self.restriction_tracker = None + self.current_windowed_value = windowed_value + + elif self.has_windowed_inputs and len(windowed_value.windows) != 1: for w in windowed_value.windows: self._invoke_per_window( WindowedValue(windowed_value.value, windowed_value.timestamp, (w,)), @@ -602,6 +624,15 @@ class PerWindowInvoker(DoFnInvoker): output_processor.process_outputs( windowed_value, self.process_method(*args_for_process)) + if self.is_splittable: + deferred_status = self.restriction_tracker.deferred_status() + if deferred_status: + deferred_restriction, deferred_watermark = deferred_status + return ( + windowed_value.with_value( + (windowed_value.value, deferred_restriction)), + deferred_watermark) + class DoFnRunner(Receiver): """For internal use only; no backwards-compatibility guarantees. @@ -679,10 +710,17 @@ class DoFnRunner(Receiver): def process(self, windowed_value): try: - self.do_fn_invoker.invoke_process(windowed_value) + return self.do_fn_invoker.invoke_process(windowed_value) except BaseException as exn: self._reraise_augmented(exn) + def process_with_restriction(self, windowed_value): + element, restriction = windowed_value.value + return self.do_fn_invoker.invoke_process( + windowed_value.with_value(element), + restriction_tracker=self.do_fn_invoker.invoke_create_tracker( + restriction)) + def process_user_timer(self, timer_spec, key, window, timestamp): try: self.do_fn_invoker.invoke_user_timer(timer_spec, key, window, timestamp) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 770044c..43e8c7f 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -76,7 +76,6 @@ class SwitchingDirectRunner(PipelineRunner): use_fnapi_runner = False from apache_beam.pipeline import PipelineVisitor - from apache_beam.runners.common import DoFnSignature from apache_beam.runners.dataflow.native_io.iobase import NativeSource from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite from apache_beam.testing.test_stream import TestStream @@ -103,9 +102,6 @@ class SwitchingDirectRunner(PipelineRunner): self.supported_by_fnapi_runner = False if isinstance(transform, beam.ParDo): dofn = transform.dofn - # The FnApiRunner does not support execution of SplittableDoFns. - if DoFnSignature(dofn).is_splittable_dofn(): - self.supported_by_fnapi_runner = False # The FnApiRunner does not support execution of CombineFns with # deferred side inputs. if isinstance(dofn, CombineValuesDoFn): diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py index d5924cb..eae38bc 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py @@ -35,7 +35,6 @@ from apache_beam.pvalue import AsSingleton from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to -from apache_beam.transforms.core import ProcessContinuation from apache_beam.transforms.core import RestrictionProvider from apache_beam.transforms.trigger import AccumulationMode from apache_beam.transforms.window import SlidingWindows @@ -83,7 +82,7 @@ class ReadFiles(DoFn): output_count += 1 if self._resume_count and output_count == self._resume_count: - yield ProcessContinuation() + restriction_tracker.defer_remainder() break pos += len_line diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py index 958bb8e..767203b 100644 --- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py @@ -197,6 +197,9 @@ if __name__ == '__main__': counter_name, line) ) + def test_sdf(self): + raise unittest.SkipTest("BEAM-2939") + # Inherits all other tests. # Run the tests. diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index be8799b..e908a5c 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -322,6 +322,7 @@ class FnApiRunner(runner.PipelineRunner): phases=[fn_api_runner_transforms.annotate_downstream_side_inputs, fn_api_runner_transforms.fix_side_input_pcoll_coders, fn_api_runner_transforms.lift_combiners, + fn_api_runner_transforms.expand_sdf, fn_api_runner_transforms.expand_gbk, fn_api_runner_transforms.sink_flattens, fn_api_runner_transforms.greedily_fuse, @@ -413,7 +414,7 @@ class FnApiRunner(runner.PipelineRunner): data_spec.api_service_descriptor.url = ( data_api_service_descriptor.url) transform.spec.payload = data_spec.SerializeToString() - elif transform.spec.urn == common_urns.primitives.PAR_DO.urn: + elif transform.spec.urn in fn_api_runner_transforms.PAR_DO_URNS: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, si in payload.side_inputs.items(): @@ -496,11 +497,15 @@ class FnApiRunner(runner.PipelineRunner): result = BundleManager( controller, get_buffer, process_bundle_descriptor, - self._progress_frequency).process_bundle(data_input, data_output) + self._progress_frequency).process_bundle( + data_input, data_output) + last_result = result while True: - timer_inputs = {} + deferred_inputs = collections.defaultdict(list) for transform_id, timer_writes in stage.timer_pcollections: + + # Queue any set timers as new inputs. windowed_timer_coder_impl = context.coders[ pipeline_components.pcollections[timer_writes].coder_id].get_impl() written_timers = get_buffer( @@ -522,20 +527,37 @@ class FnApiRunner(runner.PipelineRunner): for windowed_key_timer in timers_by_key_and_window.values(): windowed_timer_coder_impl.encode_to_stream( windowed_key_timer, out, True) - timer_inputs[transform_id, 'out'] = [out.get()] + deferred_inputs[transform_id, 'out'] = [out.get()] written_timers[:] = [] - if timer_inputs: + + # Queue any delayed bundle applications. + for delayed_application in last_result.process_bundle.residual_roots: + # Find the io transform that feeds this transform. + # TODO(SDF): Memoize? + application = delayed_application.application + input_pcoll = process_bundle_descriptor.transforms[ + application.ptransform_id].inputs[application.input_id] + for input_id, proto in process_bundle_descriptor.transforms.items(): + if (proto.spec.urn == bundle_processor.DATA_INPUT_URN + and input_pcoll in proto.outputs.values()): + deferred_inputs[input_id, 'out'].append(application.element) + break + else: + raise RuntimeError( + 'No IO transform feeds %s' % application.ptransform_id) + + if deferred_inputs: # The worker will be waiting on these inputs as well. for other_input in data_input: - if other_input not in timer_inputs: - timer_inputs[other_input] = [] + if other_input not in deferred_inputs: + deferred_inputs[other_input] = [] # TODO(robertwb): merge results - BundleManager( + last_result = BundleManager( controller, get_buffer, process_bundle_descriptor, self._progress_frequency, - True).process_bundle(timer_inputs, data_output) + True).process_bundle(deferred_inputs, data_output) else: break diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index f08c13b..6c4cad9 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -30,6 +30,7 @@ from tenacity import retry from tenacity import stop_after_attempt import apache_beam as beam +from apache_beam.io import restriction_trackers from apache_beam.metrics import monitoring_infos from apache_beam.metrics.execution import MetricKey from apache_beam.metrics.execution import MetricsEnvironment @@ -351,6 +352,43 @@ class FnApiRunnerTest(unittest.TestCase): assert_that(actual, is_buffered_correctly) + def test_sdf(self): + + class ExpandStringsProvider(beam.transforms.core.RestrictionProvider): + def initial_restriction(self, element): + return (0, len(element)) + + def create_tracker(self, restriction): + return restriction_trackers.OffsetRestrictionTracker( + restriction[0], restriction[1]) + + def split(self, element, restriction): + start, end = restriction + middle = (end - start) // 2 + return [(start, middle), (middle, end)] + + class ExpandStringsDoFn(beam.DoFn): + def process(self, element, restriction_tracker=ExpandStringsProvider()): + assert isinstance( + restriction_tracker, + restriction_trackers.OffsetRestrictionTracker), restriction_tracker + for k in range(*restriction_tracker.current_restriction()): + if not restriction_tracker.try_claim(k): + return + yield element[k] + if k % 2 == 1: + restriction_tracker.defer_remainder() + return + + with self.create_pipeline() as p: + data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz'] + actual = ( + p + | beam.Create(data) + | beam.ParDo(ExpandStringsDoFn())) + + assert_that(actual, equal_to(list(''.join(data)))) + def test_group_by_key(self): with self.create_pipeline() as p: res = (p diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py index 21f8fa2..2482987 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py @@ -25,6 +25,8 @@ import functools import logging from builtins import object +from past.builtins import unicode + from apache_beam import coders from apache_beam.portability import common_urns from apache_beam.portability import python_urns @@ -35,9 +37,21 @@ from apache_beam.utils import proto_utils # This module is experimental. No backwards-compatibility guarantees. -KNOWN_COMPOSITES = frozenset( - [common_urns.primitives.GROUP_BY_KEY.urn, - common_urns.composites.COMBINE_PER_KEY.urn]) +KNOWN_COMPOSITES = frozenset([ + common_urns.primitives.GROUP_BY_KEY.urn, + common_urns.composites.COMBINE_PER_KEY.urn]) + +COMBINE_URNS = frozenset([ + common_urns.composites.COMBINE_PER_KEY.urn, + common_urns.combine_components.COMBINE_PGBKCV.urn, + common_urns.combine_components.COMBINE_MERGE_ACCUMULATORS.urn, + common_urns.combine_components.COMBINE_EXTRACT_OUTPUTS.urn]) + +PAR_DO_URNS = frozenset([ + common_urns.primitives.PAR_DO.urn, + common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn, + common_urns.sdf_components.SPLIT_RESTRICTION.urn, + common_urns.sdf_components.PROCESS_ELEMENTS.urn]) IMPULSE_BUFFER = b'impulse' @@ -76,15 +90,11 @@ class Stage(object): @staticmethod def _extract_environment(transform): - if transform.spec.urn == common_urns.primitives.PAR_DO.urn: + if transform.spec.urn in PAR_DO_URNS: pardo_payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) return pardo_payload.do_fn.environment_id - elif transform.spec.urn in ( - common_urns.composites.COMBINE_PER_KEY.urn, - common_urns.combine_components.COMBINE_PGBKCV.urn, - common_urns.combine_components.COMBINE_MERGE_ACCUMULATORS.urn, - common_urns.combine_components.COMBINE_EXTRACT_OUTPUTS.urn): + elif transform.spec.urn in COMBINE_URNS: combine_payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.CombinePayload) return combine_payload.combine_fn.environment_id @@ -137,7 +147,7 @@ class Stage(object): def side_inputs(self): for transform in self.transforms: - if transform.spec.urn == common_urns.primitives.PAR_DO.urn: + if transform.spec.urn in PAR_DO_URNS: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for side_input in payload.side_inputs: @@ -145,7 +155,7 @@ class Stage(object): def has_as_main_input(self, pcoll): for transform in self.transforms: - if transform.spec.urn == common_urns.primitives.PAR_DO.urn: + if transform.spec.urn in PAR_DO_URNS: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) local_side_inputs = payload.side_inputs @@ -681,6 +691,108 @@ def lift_combiners(stages, context): yield stage +def expand_sdf(stages, context): + """Transforms splitable DoFns into pair+split+read.""" + for stage in stages: + assert len(stage.transforms) == 1 + transform = stage.transforms[0] + if transform.spec.urn == common_urns.primitives.PAR_DO.urn: + + pardo_payload = proto_utils.parse_Bytes( + transform.spec.payload, beam_runner_api_pb2.ParDoPayload) + + if pardo_payload.splittable: + + def copy_like(protos, original, suffix='_copy', **kwargs): + if isinstance(original, (str, unicode)): + key = original + original = protos[original] + else: + key = 'component' + new_id = unique_name(protos, key + suffix) + protos[new_id].CopyFrom(original) + proto = protos[new_id] + for name, value in kwargs.items(): + if isinstance(value, dict): + getattr(proto, name).clear() + getattr(proto, name).update(value) + elif isinstance(value, list): + del getattr(proto, name)[:] + getattr(proto, name).extend(value) + elif name == 'urn': + proto.spec.urn = value + else: + setattr(proto, name, value) + return new_id + + def make_stage(base_stage, transform_id, extra_must_follow=()): + transform = context.components.transforms[transform_id] + return Stage( + transform.unique_name, + [transform], + base_stage.downstream_side_inputs, + union(base_stage.must_follow, frozenset(extra_must_follow)), + parent=base_stage, + environment=base_stage.environment) + + main_input_tag = only_element(tag for tag in transform.inputs.keys() + if tag not in pardo_payload.side_inputs) + main_input_id = transform.inputs[main_input_tag] + element_coder_id = context.components.pcollections[ + main_input_id].coder_id + paired_coder_id = context.add_or_get_coder_id( + beam_runner_api_pb2.Coder( + spec=beam_runner_api_pb2.SdkFunctionSpec( + spec=beam_runner_api_pb2.FunctionSpec( + urn=common_urns.coders.KV.urn)), + component_coder_ids=[element_coder_id, + pardo_payload.restriction_coder_id])) + + paired_pcoll_id = copy_like( + context.components.pcollections, + main_input_id, + '_paired', + coder_id=paired_coder_id) + pair_transform_id = copy_like( + context.components.transforms, + transform, + unique_name=transform.unique_name + '/PairWithRestriction', + urn=common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn, + outputs={'out': paired_pcoll_id}) + + split_pcoll_id = copy_like( + context.components.pcollections, + main_input_id, + '_split', + coder_id=paired_coder_id) + split_transform_id = copy_like( + context.components.transforms, + transform, + unique_name=transform.unique_name + '/SplitRestriction', + urn=common_urns.sdf_components.SPLIT_RESTRICTION.urn, + inputs=dict(transform.inputs, **{main_input_tag: paired_pcoll_id}), + outputs={'out': split_pcoll_id}) + + process_transform_id = copy_like( + context.components.transforms, + transform, + unique_name=transform.unique_name + '/Process', + urn=common_urns.sdf_components.PROCESS_ELEMENTS.urn, + inputs=dict(transform.inputs, **{main_input_tag: split_pcoll_id})) + + yield make_stage(stage, pair_transform_id) + split_stage = make_stage(stage, split_transform_id) + yield split_stage + yield make_stage( + stage, process_transform_id, extra_must_follow=[split_stage]) + + else: + yield stage + + else: + yield stage + + def expand_gbk(stages, pipeline_context): """Transforms each GBK into a write followed by a read. """ @@ -861,6 +973,7 @@ def greedily_fuse(stages, pipeline_context): fuse(producer, consumer) else: # If we can't fuse, do a read + write. + pipeline_context.length_prefix_pcoll_coders(pcoll) buffer_id = create_buffer_id(pcoll) if write_pcoll is None: write_pcoll = Stage( @@ -994,7 +1107,7 @@ def inject_timer_pcollections(stages, pipeline_context): """ for stage in stages: for transform in list(stage.transforms): - if transform.spec.urn == common_urns.primitives.PAR_DO.urn: + if transform.spec.urn in PAR_DO_URNS: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload) for tag, spec in payload.timer_specs.items(): diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 7126215..6cffc02 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -31,6 +31,7 @@ from builtins import next from builtins import object from future.utils import itervalues +from google import protobuf import apache_beam as beam from apache_beam import coders @@ -43,6 +44,7 @@ from apache_beam.portability import common_urns from apache_beam.portability import python_urns from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.runners import common from apache_beam.runners import pipeline_context from apache_beam.runners.dataflow import dataflow_runner from apache_beam.runners.worker import operation_specs @@ -474,10 +476,12 @@ class BundleProcessor(object): expected_inputs.append(op) try: + execution_context = ExecutionContext() self.state_sampler.start() # Start all operations. for op in reversed(self.ops.values()): logging.debug('start %s', op) + op.execution_context = execution_context op.start() # Inject inputs from data plane. @@ -499,9 +503,32 @@ class BundleProcessor(object): for op in self.ops.values(): logging.debug('finish %s', op) op.finish() + + return [ + self.delayed_bundle_application(op, residual) + for op, residual in execution_context.delayed_applications] + finally: self.state_sampler.stop_if_still_running() + def delayed_bundle_application(self, op, deferred_remainder): + ptransform_id, main_input_tag, main_input_coder, outputs = op.input_info + # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder. + element_and_restriction, watermark = deferred_remainder + if watermark: + proto_watermark = protobuf.Timestamp() + proto_watermark.FromMicroseconds(watermark.micros) + output_watermarks = {output: proto_watermark for output in outputs} + else: + output_watermarks = None + return beam_fn_api_pb2.DelayedBundleApplication( + application=beam_fn_api_pb2.BundleApplication( + ptransform_id=ptransform_id, + input_id=main_input_tag, + output_watermarks=output_watermarks, + element=main_input_coder.get_impl().encode_nested( + element_and_restriction))) + def metrics(self): # DEPRECATED return beam_fn_api_pb2.Metrics( @@ -553,6 +580,11 @@ class BundleProcessor(object): return monitoring_info +class ExecutionContext(object): + def __init__(self): + self.delayed_applications = [] + + class BeamTransformFactory(object): """Factory for turning transform_protos into executable operations.""" def __init__(self, descriptor, data_channel_factory, counter_factory, @@ -766,6 +798,70 @@ def create(factory, transform_id, transform_proto, serialized_fn, consumers): @BeamTransformFactory.register_urn( + common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn, + beam_runner_api_pb2.ParDoPayload) +def create(*args): + + class CreateRestriction(beam.DoFn): + def __init__(self, fn, restriction_provider): + self.restriction_provider = restriction_provider + + # An unused window is requested to force explosion of multi-window + # WindowedValues. + def process( + self, element, _unused_window=beam.DoFn.WindowParam, *args, **kwargs): + # TODO(SDF): Do we want to allow mutation of the element? + # (E.g. it could be nice to shift bulky description to the portion + # that can be distributed.) + yield element, self.restriction_provider.initial_restriction(element) + + return _create_sdf_operation(CreateRestriction, *args) + + +@BeamTransformFactory.register_urn( + common_urns.sdf_components.SPLIT_RESTRICTION.urn, + beam_runner_api_pb2.ParDoPayload) +def create(*args): + + class SplitRestriction(beam.DoFn): + def __init__(self, fn, restriction_provider): + self.restriction_provider = restriction_provider + + def process(self, element_restriction, *args, **kwargs): + element, restriction = element_restriction + for part in self.restriction_provider.split(element, restriction): + yield element, part + + return _create_sdf_operation(SplitRestriction, *args) + + +@BeamTransformFactory.register_urn( + common_urns.sdf_components.PROCESS_ELEMENTS.urn, + beam_runner_api_pb2.ParDoPayload) +def create(factory, transform_id, transform_proto, parameter, consumers): + assert parameter.do_fn.spec.urn == python_urns.PICKLED_DOFN_INFO + serialized_fn = parameter.do_fn.spec.payload + return _create_pardo_operation( + factory, transform_id, transform_proto, consumers, + serialized_fn, parameter, + operation_cls=operations.SdfProcessElements) + + +def _create_sdf_operation( + proxy_dofn, + factory, transform_id, transform_proto, parameter, consumers): + + dofn_data = pickler.loads(parameter.do_fn.spec.payload) + dofn = dofn_data[0] + restriction_provider = common.DoFnSignature(dofn).get_restriction_provider() + serialized_fn = pickler.dumps( + (proxy_dofn(dofn, restriction_provider),) + dofn_data[1:]) + return _create_pardo_operation( + factory, transform_id, transform_proto, consumers, + serialized_fn, parameter) + + +@BeamTransformFactory.register_urn( common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload) def create(factory, transform_id, transform_proto, parameter, consumers): assert parameter.do_fn.spec.urn == python_urns.PICKLED_DOFN_INFO @@ -777,7 +873,7 @@ def create(factory, transform_id, transform_proto, parameter, consumers): def _create_pardo_operation( factory, transform_id, transform_proto, consumers, - serialized_fn, pardo_proto=None): + serialized_fn, pardo_proto=None, operation_cls=operations.DoOperation): if pardo_proto and pardo_proto.side_inputs: input_tags_to_coders = factory.get_input_coders(transform_proto) @@ -824,7 +920,8 @@ def _create_pardo_operation( factory.descriptor.pcollections[pcoll_id].windowing_strategy_id) serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,)) - if pardo_proto and (pardo_proto.timer_specs or pardo_proto.state_specs): + if pardo_proto and (pardo_proto.timer_specs or pardo_proto.state_specs + or pardo_proto.splittable): main_input_coder = None timer_inputs = {} for tag, pcoll_id in transform_proto.inputs.items(): @@ -835,15 +932,19 @@ def _create_pardo_operation( else: # Must be the main input assert main_input_coder is None + main_input_tag = tag main_input_coder = factory.get_windowed_coder(pcoll_id) assert main_input_coder is not None - user_state_context = FnApiUserStateContext( - factory.state_handler, - transform_id, - main_input_coder.key_coder(), - main_input_coder.window_coder, - timer_specs=pardo_proto.timer_specs) + if pardo_proto.timer_specs or pardo_proto.state_specs: + user_state_context = FnApiUserStateContext( + factory.state_handler, + transform_id, + main_input_coder.key_coder(), + main_input_coder.window_coder, + timer_specs=pardo_proto.timer_specs) + else: + user_state_context = None else: user_state_context = None timer_inputs = None @@ -856,8 +957,8 @@ def _create_pardo_operation( side_inputs=None, # Fn API uses proto definitions and the Fn State API output_coders=[output_coders[tag] for tag in output_tags]) - return factory.augment_oldstyle_op( - operations.DoOperation( + result = factory.augment_oldstyle_op( + operation_cls( transform_proto.unique_name, spec, factory.counter_factory, @@ -868,6 +969,11 @@ def _create_pardo_operation( transform_proto.unique_name, consumers, output_tags) + if pardo_proto and pardo_proto.splittable: + result.input_info = ( + transform_id, main_input_tag, main_input_coder, + transform_proto.outputs.keys()) + return result def _create_simple_pardo_operation( diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd index cb3477a..10c3c41 100644 --- a/sdks/python/apache_beam/runners/worker/operations.pxd +++ b/sdks/python/apache_beam/runners/worker/operations.pxd @@ -45,6 +45,7 @@ cdef class Operation(object): cdef object consumers cdef readonly counter_factory cdef public metrics_container + cdef public execution_context # Public for access by Fn harness operations. # TODO(robertwb): Cythonize FnHarness. cdef public list receivers @@ -89,6 +90,11 @@ cdef class DoOperation(Operation): cdef object user_state_context cdef public dict timer_inputs cdef dict timer_specs + cdef public object input_info + + +cdef class SdfProcessElements(DoOperation): + pass cdef class CombineOperation(Operation): diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index 9ab4fcf..a6e0c31 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -129,6 +129,7 @@ class Operation(object): self.spec = spec self.counter_factory = counter_factory + self.execution_context = None self.consumers = collections.defaultdict(list) # These are overwritten in the legacy harness. @@ -496,7 +497,10 @@ class DoOperation(Operation): def process(self, o): with self.scoped_process_state: - self.dofn_receiver.receive(o) + delayed_application = self.dofn_receiver.receive(o) + if delayed_application: + self.execution_context.delayed_applications.append( + (self, delayed_application)) def process_timer(self, tag, windowed_timer): key, timer_data = windowed_timer.value @@ -540,6 +544,16 @@ class DoOperation(Operation): return infos +class SdfProcessElements(DoOperation): + + def process(self, o): + with self.scoped_process_state: + delayed_application = self.dofn_runner.process_with_restriction(o) + if delayed_application: + self.execution_context.delayed_applications.append( + (self, delayed_application)) + + class DoFnRunnerReceiver(Receiver): def __init__(self, dofn_runner): diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 4eb22a8..6067181 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -272,10 +272,11 @@ class SdkWorker(object): instruction_id, request.process_bundle_descriptor_reference) as bundle_processor: with self.maybe_profile(instruction_id): - bundle_processor.process_bundle(instruction_id) + delayed_applications = bundle_processor.process_bundle(instruction_id) return beam_fn_api_pb2.InstructionResponse( instruction_id=instruction_id, process_bundle=beam_fn_api_pb2.ProcessBundleResponse( + residual_roots=delayed_applications, metrics=bundle_processor.metrics(), monitoring_infos=bundle_processor.monitoring_infos())) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 09091ed..27b67cd 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1029,6 +1029,14 @@ class ParDo(PTransformWithSideInputs): "expected instance of ParDo, but got %s" % self.__class__ picked_pardo_fn_data = pickler.dumps(self._pardo_fn_data()) state_specs, timer_specs = userstate.get_dofn_specs(self.fn) + from apache_beam.runners.common import DoFnSignature + is_splittable = DoFnSignature(self.fn).is_splittable_dofn() + if is_splittable: + restriction_coder = ( + DoFnSignature(self.fn).get_restriction_provider().restriction_coder()) + restriction_coder_id = context.coders.get_id(restriction_coder) + else: + restriction_coder_id = None return ( common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload( @@ -1037,6 +1045,8 @@ class ParDo(PTransformWithSideInputs): spec=beam_runner_api_pb2.FunctionSpec( urn=python_urns.PICKLED_DOFN_INFO, payload=picked_pardo_fn_data)), + splittable=is_splittable, + restriction_coder_id=restriction_coder_id, state_specs={spec.name: spec.to_runner_api(context) for spec in state_specs}, timer_specs={spec.name: spec.to_runner_api(context)