[ https://issues.apache.org/jira/browse/BEAM-5521?focusedWorklogId=155803&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-155803 ]
ASF GitHub Bot logged work on BEAM-5521: ---------------------------------------- Author: ASF GitHub Bot Created on: 18/Oct/18 09:45 Start Date: 18/Oct/18 09:45 Worklog Time Spent: 10m Work Description: robertwb closed pull request #6717: [BEAM-5521] Re-use bundle processors across bundles. URL: https://github.com/apache/beam/pull/6717 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py index b177d2014d7..1836a82eb0c 100644 --- a/sdks/python/apache_beam/metrics/cells.py +++ b/sdks/python/apache_beam/metrics/cells.py @@ -147,6 +147,10 @@ def __init__(self, *args): super(CounterCell, self).__init__(*args) self.value = CounterAggregator.identity_element() + def reset(self): + self.commit = CellCommitState() + self.value = CounterAggregator.identity_element() + def combine(self, other): result = CounterCell() result.inc(self.value + other.value) @@ -189,6 +193,10 @@ def __init__(self, *args): super(DistributionCell, self).__init__(*args) self.data = DistributionAggregator.identity_element() + def reset(self): + self.commit = CellCommitState() + self.data = DistributionAggregator.identity_element() + def combine(self, other): result = DistributionCell() result.data = self.data.combine(other.data) @@ -230,6 +238,10 @@ def __init__(self, *args): super(GaugeCell, self).__init__(*args) self.data = GaugeAggregator.identity_element() + def reset(self): + self.commit = CellCommitState() + self.data = GaugeAggregator.identity_element() + def combine(self, other): result = GaugeCell() result.data = self.data.combine(other.data) diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py index 2d771394c23..5818ac2efa8 100644 --- a/sdks/python/apache_beam/metrics/execution.py +++ b/sdks/python/apache_beam/metrics/execution.py @@ -238,6 +238,14 @@ def to_runner_api_monitoring_infos(self, transform_id): )) return {monitoring_infos.to_key(mi) : mi for mi in all_user_metrics} + def reset(self): + for counter in self.counters.values(): + counter.reset() + for distribution in self.distributions.values(): + distribution.reset() + for gauge in self.gauges.values(): + gauge.reset() + class MetricUpdates(object): """Contains updates for several metrics. diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index e4737b4ad09..4fd0aace539 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -151,7 +151,6 @@ def __init__(self, state_handler, transform_id, tag, side_input_data, coder): self._element_coder = coder.wrapped_value_coder self._target_window_coder = coder.window_coder # TODO(robertwb): Limit the cache size. - # TODO(robertwb): Cross-bundle caching respecting cache tokens. self._cache = {} def __getitem__(self, window): @@ -205,6 +204,10 @@ def is_globally_windowed(self): return (self._side_input_data.window_mapping_fn == sideinputs._global_window_mapping_fn) + def reset(self): + # TODO(BEAM-5428): Cross-bundle caching respecting cache tokens. + self._cache = {} + class CombiningValueRuntimeState(userstate.RuntimeState): def __init__(self, underlying_bag_state, combinefn): @@ -310,6 +313,10 @@ def get_state(self, state_spec, key, window): else: raise NotImplementedError(state_spec) + def reset(self): + # TODO(BEAM-5428): Implement cross-bundle state caching. + pass + def memoize(func): cache = {} @@ -342,6 +349,8 @@ def __init__( 'fnapi-step-%s' % self.process_bundle_descriptor.id, self.counter_factory) self.ops = self.create_execution_tree(self.process_bundle_descriptor) + for op in self.ops.values(): + op.setup() def create_execution_tree(self, descriptor): @@ -385,6 +394,13 @@ def topological_height(transform_id): for transform_id in sorted( descriptor.transforms, key=topological_height, reverse=True)]) + def reset(self): + self.counter_factory.reset() + self.state_sampler.reset() + # Side input caches. + for op in self.ops.values(): + op.reset() + def process_bundle(self, instruction_id): expected_inputs = [] for op in self.ops.values(): diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd index 318ea516810..22134cd69a0 100644 --- a/sdks/python/apache_beam/runners/worker/operations.pxd +++ b/sdks/python/apache_beam/runners/worker/operations.pxd @@ -49,6 +49,8 @@ cdef class Operation(object): # TODO(robertwb): Cythonize FnHarness. cdef public list receivers cdef readonly bint debug_logging_enabled + # For legacy workers. + cdef bint setup_done cdef public step_name # initialized lazily diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index 61dd632aafe..20dec1d69e9 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -143,19 +143,30 @@ def __init__(self, name_context, spec, counter_factory, state_sampler): # TODO(ccy): the '-abort' state can be added when the abort is supported in # Operations. self.receivers = [] + # Legacy workers cannot call setup() until after setting additional state + # on the operation. + self.setup_done = False + + def setup(self): + with self.scoped_start_state: + self.debug_logging_enabled = logging.getLogger().isEnabledFor( + logging.DEBUG) + # Everything except WorkerSideInputSource, which is not a + # top-level operation, should have output_coders + #TODO(pabloem): Define better what step name is used here. + if getattr(self.spec, 'output_coders', None): + self.receivers = [ConsumerSet(self.counter_factory, + self.name_context.logging_name(), + i, + self.consumers[i], coder) + for i, coder in enumerate(self.spec.output_coders)] + self.setup_done = True def start(self): """Start operation.""" - self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG) - # Everything except WorkerSideInputSource, which is not a - # top-level operation, should have output_coders - #TODO(pabloem): Define better what step name is used here. - if getattr(self.spec, 'output_coders', None): - self.receivers = [ConsumerSet(self.counter_factory, - self.name_context.logging_name(), - i, - self.consumers[i], coder) - for i, coder in enumerate(self.spec.output_coders)] + if not self.setup_done: + # For legacy workers. + self.setup() def process(self, o): """Process element in operation.""" @@ -165,6 +176,9 @@ def finish(self): """Finish operation.""" pass + def reset(self): + self.metrics_container.reset() + def output(self, windowed_value, output_index=0): cython.cast(Receiver, self.receivers[output_index]).receive(windowed_value) @@ -422,9 +436,9 @@ def _read_side_inputs(self, tags_and_types): yield apache_sideinputs.SideInputMap( view_class, view_options, sideinputs.EmulatedIterable(iterator_fn)) - def start(self): + def setup(self): with self.scoped_start_state: - super(DoOperation, self).start() + super(DoOperation, self).setup() # See fn_data in dataflow_runner.py fn, args, kwargs, tags_and_types, window_fn = ( @@ -474,6 +488,9 @@ def start(self): if isinstance(self.dofn_runner, Receiver) else DoFnRunnerReceiver(self.dofn_runner)) + def start(self): + with self.scoped_start_state: + super(DoOperation, self).start() self.dofn_runner.start() def process(self, o): @@ -490,6 +507,13 @@ def finish(self): with self.scoped_finish_state: self.dofn_runner.finish() + def reset(self): + super(DoOperation, self).reset() + for side_input_map in self.side_input_maps: + side_input_map.reset() + if self.user_state_context: + self.user_state_context.reset() + def progress_metrics(self): metrics = super(DoOperation, self).progress_metrics() if self.tagged_receivers: diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 7ecdfee2e9b..fa811262183 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -21,6 +21,7 @@ from __future__ import print_function import abc +import collections import contextlib import logging import queue @@ -204,7 +205,8 @@ def __init__(self, state_handler_factory, data_channel_factory, fns): self.fns = fns self.state_handler_factory = state_handler_factory self.data_channel_factory = data_channel_factory - self.bundle_processors = {} + self.active_bundle_processors = {} + self.cached_bundle_processors = collections.defaultdict(list) def do_instruction(self, request): request_type = request.WhichOneof('request') @@ -223,29 +225,43 @@ def register(self, request, instruction_id): register=beam_fn_api_pb2.RegisterResponse()) def process_bundle(self, request, instruction_id): - process_bundle_desc = self.fns[request.process_bundle_descriptor_reference] - state_handler = self.state_handler_factory.create_state_handler( - process_bundle_desc.state_api_service_descriptor) - self.bundle_processors[ - instruction_id] = processor = bundle_processor.BundleProcessor( - process_bundle_desc, - state_handler, - self.data_channel_factory) + with self.get_bundle_processor( + instruction_id, + request.process_bundle_descriptor_reference) as bundle_processor: + bundle_processor.process_bundle(instruction_id) + return beam_fn_api_pb2.InstructionResponse( + instruction_id=instruction_id, + process_bundle=beam_fn_api_pb2.ProcessBundleResponse( + metrics=bundle_processor.metrics(), + monitoring_infos=bundle_processor.monitoring_infos())) + + @contextlib.contextmanager + def get_bundle_processor(self, instruction_id, bundle_descriptor_id): + try: + # pop() is threadsafe + processor = self.cached_bundle_processors[bundle_descriptor_id].pop() + state_handler = processor.state_handler + except IndexError: + process_bundle_desc = self.fns[bundle_descriptor_id] + state_handler = self.state_handler_factory.create_state_handler( + process_bundle_desc.state_api_service_descriptor) + processor = bundle_processor.BundleProcessor( + process_bundle_desc, + state_handler, + self.data_channel_factory) try: + self.active_bundle_processors[instruction_id] = processor with state_handler.process_instruction_id(instruction_id): - processor.process_bundle(instruction_id) + yield processor finally: - del self.bundle_processors[instruction_id] - - return beam_fn_api_pb2.InstructionResponse( - instruction_id=instruction_id, - process_bundle=beam_fn_api_pb2.ProcessBundleResponse( - metrics=processor.metrics(), - monitoring_infos=processor.monitoring_infos())) + del self.active_bundle_processors[instruction_id] + # Outside the finally block as we only want to re-use on success. + processor.reset() + self.cached_bundle_processors[bundle_descriptor_id].append(processor) def process_bundle_progress(self, request, instruction_id): # It is an error to get progress for a not-in-flight bundle. - processor = self.bundle_processors.get(request.instruction_reference) + processor = self.active_bundle_processors.get(request.instruction_reference) return beam_fn_api_pb2.InstructionResponse( instruction_id=instruction_id, process_bundle_progress=beam_fn_api_pb2.ProcessBundleProgressResponse( diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx index 73fc4be2e79..010810be16a 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx +++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx @@ -154,6 +154,11 @@ cdef class StateSampler(object): # pythread doesn't support conditions. self.sampling_thread.join() + def reset(self): + for state in self.scoped_states_by_index: + (<ScopedState>state)._nsecs = 0 + self.started = self.finished = False + def current_state(self): return self.scoped_states_by_index[self.current_state_index] diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py b/sdks/python/apache_beam/runners/worker/statesampler_slow.py index 196b35d6003..00918285aee 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py @@ -65,6 +65,10 @@ def start(self): def stop(self): pass + def reset(self): + for state in self._states_by_name.values(): + state.nsecs = 0 + class ScopedState(object): diff --git a/sdks/python/apache_beam/utils/counters.py b/sdks/python/apache_beam/utils/counters.py index 1df81949c77..f924853b74c 100644 --- a/sdks/python/apache_beam/utils/counters.py +++ b/sdks/python/apache_beam/utils/counters.py @@ -155,6 +155,9 @@ def __init__(self, name, combine_fn): def update(self, value): self.accumulator = self._add_input(self.accumulator, value) + def reset(self, value): + self.accumulator = self.combine_fn.create_accumulator() + def value(self): return self.combine_fn.extract_output(self.accumulator) @@ -175,11 +178,15 @@ class AccumulatorCombineFnCounter(Counter): def __init__(self, name, combine_fn): assert isinstance(combine_fn, cy_combiners.AccumulatorCombineFn) super(AccumulatorCombineFnCounter, self).__init__(name, combine_fn) - self._fast_add_input = self.accumulator.add_input + self.reset() def update(self, value): self._fast_add_input(value) + def reset(self): + self.accumulator = self.combine_fn.create_accumulator() + self._fast_add_input = self.accumulator.add_input + class CounterFactory(object): """Keeps track of unique counters.""" @@ -215,6 +222,12 @@ def get_counter(self, name, combine_fn): self.counters[name] = counter return counter + def reset(self): + # Counters are cached in state sampler states. + with self._lock: + for counter in self.counters.values(): + counter.reset() + def get_counters(self): """Returns the current set of counters. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org Issue Time Tracking ------------------- Worklog Id: (was: 155803) Time Spent: 40m (was: 0.5h) > Cache execution trees in SDK worker > ----------------------------------- > > Key: BEAM-5521 > URL: https://issues.apache.org/jira/browse/BEAM-5521 > Project: Beam > Issue Type: Improvement > Components: sdk-py-harness > Reporter: Robert Bradshaw > Assignee: Robert Bradshaw > Priority: Major > Labels: portability-flink > Time Spent: 40m > Remaining Estimate: 0h > > Currently they are re-constructed from the protos for every bundle, which is > expensive (especially for 1-element bundles in streaming flink). > Care should be taken to ensure the objects can be re-usued. -- This message was sent by Atlassian JIRA (v7.6.3#76005)