[
https://issues.apache.org/jira/browse/BEAM-5521?focusedWorklogId=155803&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-155803
]
ASF GitHub Bot logged work on BEAM-5521:
----------------------------------------
Author: ASF GitHub Bot
Created on: 18/Oct/18 09:45
Start Date: 18/Oct/18 09:45
Worklog Time Spent: 10m
Work Description: robertwb closed pull request #6717: [BEAM-5521] Re-use
bundle processors across bundles.
URL: https://github.com/apache/beam/pull/6717
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/sdks/python/apache_beam/metrics/cells.py
b/sdks/python/apache_beam/metrics/cells.py
index b177d2014d7..1836a82eb0c 100644
--- a/sdks/python/apache_beam/metrics/cells.py
+++ b/sdks/python/apache_beam/metrics/cells.py
@@ -147,6 +147,10 @@ def __init__(self, *args):
super(CounterCell, self).__init__(*args)
self.value = CounterAggregator.identity_element()
+ def reset(self):
+ self.commit = CellCommitState()
+ self.value = CounterAggregator.identity_element()
+
def combine(self, other):
result = CounterCell()
result.inc(self.value + other.value)
@@ -189,6 +193,10 @@ def __init__(self, *args):
super(DistributionCell, self).__init__(*args)
self.data = DistributionAggregator.identity_element()
+ def reset(self):
+ self.commit = CellCommitState()
+ self.data = DistributionAggregator.identity_element()
+
def combine(self, other):
result = DistributionCell()
result.data = self.data.combine(other.data)
@@ -230,6 +238,10 @@ def __init__(self, *args):
super(GaugeCell, self).__init__(*args)
self.data = GaugeAggregator.identity_element()
+ def reset(self):
+ self.commit = CellCommitState()
+ self.data = GaugeAggregator.identity_element()
+
def combine(self, other):
result = GaugeCell()
result.data = self.data.combine(other.data)
diff --git a/sdks/python/apache_beam/metrics/execution.py
b/sdks/python/apache_beam/metrics/execution.py
index 2d771394c23..5818ac2efa8 100644
--- a/sdks/python/apache_beam/metrics/execution.py
+++ b/sdks/python/apache_beam/metrics/execution.py
@@ -238,6 +238,14 @@ def to_runner_api_monitoring_infos(self, transform_id):
))
return {monitoring_infos.to_key(mi) : mi for mi in all_user_metrics}
+ def reset(self):
+ for counter in self.counters.values():
+ counter.reset()
+ for distribution in self.distributions.values():
+ distribution.reset()
+ for gauge in self.gauges.values():
+ gauge.reset()
+
class MetricUpdates(object):
"""Contains updates for several metrics.
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index e4737b4ad09..4fd0aace539 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -151,7 +151,6 @@ def __init__(self, state_handler, transform_id, tag,
side_input_data, coder):
self._element_coder = coder.wrapped_value_coder
self._target_window_coder = coder.window_coder
# TODO(robertwb): Limit the cache size.
- # TODO(robertwb): Cross-bundle caching respecting cache tokens.
self._cache = {}
def __getitem__(self, window):
@@ -205,6 +204,10 @@ def is_globally_windowed(self):
return (self._side_input_data.window_mapping_fn
== sideinputs._global_window_mapping_fn)
+ def reset(self):
+ # TODO(BEAM-5428): Cross-bundle caching respecting cache tokens.
+ self._cache = {}
+
class CombiningValueRuntimeState(userstate.RuntimeState):
def __init__(self, underlying_bag_state, combinefn):
@@ -310,6 +313,10 @@ def get_state(self, state_spec, key, window):
else:
raise NotImplementedError(state_spec)
+ def reset(self):
+ # TODO(BEAM-5428): Implement cross-bundle state caching.
+ pass
+
def memoize(func):
cache = {}
@@ -342,6 +349,8 @@ def __init__(
'fnapi-step-%s' % self.process_bundle_descriptor.id,
self.counter_factory)
self.ops = self.create_execution_tree(self.process_bundle_descriptor)
+ for op in self.ops.values():
+ op.setup()
def create_execution_tree(self, descriptor):
@@ -385,6 +394,13 @@ def topological_height(transform_id):
for transform_id in sorted(
descriptor.transforms, key=topological_height, reverse=True)])
+ def reset(self):
+ self.counter_factory.reset()
+ self.state_sampler.reset()
+ # Side input caches.
+ for op in self.ops.values():
+ op.reset()
+
def process_bundle(self, instruction_id):
expected_inputs = []
for op in self.ops.values():
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd
b/sdks/python/apache_beam/runners/worker/operations.pxd
index 318ea516810..22134cd69a0 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -49,6 +49,8 @@ cdef class Operation(object):
# TODO(robertwb): Cythonize FnHarness.
cdef public list receivers
cdef readonly bint debug_logging_enabled
+ # For legacy workers.
+ cdef bint setup_done
cdef public step_name # initialized lazily
diff --git a/sdks/python/apache_beam/runners/worker/operations.py
b/sdks/python/apache_beam/runners/worker/operations.py
index 61dd632aafe..20dec1d69e9 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -143,19 +143,30 @@ def __init__(self, name_context, spec, counter_factory,
state_sampler):
# TODO(ccy): the '-abort' state can be added when the abort is supported in
# Operations.
self.receivers = []
+ # Legacy workers cannot call setup() until after setting additional state
+ # on the operation.
+ self.setup_done = False
+
+ def setup(self):
+ with self.scoped_start_state:
+ self.debug_logging_enabled = logging.getLogger().isEnabledFor(
+ logging.DEBUG)
+ # Everything except WorkerSideInputSource, which is not a
+ # top-level operation, should have output_coders
+ #TODO(pabloem): Define better what step name is used here.
+ if getattr(self.spec, 'output_coders', None):
+ self.receivers = [ConsumerSet(self.counter_factory,
+ self.name_context.logging_name(),
+ i,
+ self.consumers[i], coder)
+ for i, coder in enumerate(self.spec.output_coders)]
+ self.setup_done = True
def start(self):
"""Start operation."""
- self.debug_logging_enabled =
logging.getLogger().isEnabledFor(logging.DEBUG)
- # Everything except WorkerSideInputSource, which is not a
- # top-level operation, should have output_coders
- #TODO(pabloem): Define better what step name is used here.
- if getattr(self.spec, 'output_coders', None):
- self.receivers = [ConsumerSet(self.counter_factory,
- self.name_context.logging_name(),
- i,
- self.consumers[i], coder)
- for i, coder in enumerate(self.spec.output_coders)]
+ if not self.setup_done:
+ # For legacy workers.
+ self.setup()
def process(self, o):
"""Process element in operation."""
@@ -165,6 +176,9 @@ def finish(self):
"""Finish operation."""
pass
+ def reset(self):
+ self.metrics_container.reset()
+
def output(self, windowed_value, output_index=0):
cython.cast(Receiver, self.receivers[output_index]).receive(windowed_value)
@@ -422,9 +436,9 @@ def _read_side_inputs(self, tags_and_types):
yield apache_sideinputs.SideInputMap(
view_class, view_options, sideinputs.EmulatedIterable(iterator_fn))
- def start(self):
+ def setup(self):
with self.scoped_start_state:
- super(DoOperation, self).start()
+ super(DoOperation, self).setup()
# See fn_data in dataflow_runner.py
fn, args, kwargs, tags_and_types, window_fn = (
@@ -474,6 +488,9 @@ def start(self):
if isinstance(self.dofn_runner, Receiver)
else DoFnRunnerReceiver(self.dofn_runner))
+ def start(self):
+ with self.scoped_start_state:
+ super(DoOperation, self).start()
self.dofn_runner.start()
def process(self, o):
@@ -490,6 +507,13 @@ def finish(self):
with self.scoped_finish_state:
self.dofn_runner.finish()
+ def reset(self):
+ super(DoOperation, self).reset()
+ for side_input_map in self.side_input_maps:
+ side_input_map.reset()
+ if self.user_state_context:
+ self.user_state_context.reset()
+
def progress_metrics(self):
metrics = super(DoOperation, self).progress_metrics()
if self.tagged_receivers:
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 7ecdfee2e9b..fa811262183 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -21,6 +21,7 @@
from __future__ import print_function
import abc
+import collections
import contextlib
import logging
import queue
@@ -204,7 +205,8 @@ def __init__(self, state_handler_factory,
data_channel_factory, fns):
self.fns = fns
self.state_handler_factory = state_handler_factory
self.data_channel_factory = data_channel_factory
- self.bundle_processors = {}
+ self.active_bundle_processors = {}
+ self.cached_bundle_processors = collections.defaultdict(list)
def do_instruction(self, request):
request_type = request.WhichOneof('request')
@@ -223,29 +225,43 @@ def register(self, request, instruction_id):
register=beam_fn_api_pb2.RegisterResponse())
def process_bundle(self, request, instruction_id):
- process_bundle_desc = self.fns[request.process_bundle_descriptor_reference]
- state_handler = self.state_handler_factory.create_state_handler(
- process_bundle_desc.state_api_service_descriptor)
- self.bundle_processors[
- instruction_id] = processor = bundle_processor.BundleProcessor(
- process_bundle_desc,
- state_handler,
- self.data_channel_factory)
+ with self.get_bundle_processor(
+ instruction_id,
+ request.process_bundle_descriptor_reference) as bundle_processor:
+ bundle_processor.process_bundle(instruction_id)
+ return beam_fn_api_pb2.InstructionResponse(
+ instruction_id=instruction_id,
+ process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
+ metrics=bundle_processor.metrics(),
+ monitoring_infos=bundle_processor.monitoring_infos()))
+
+ @contextlib.contextmanager
+ def get_bundle_processor(self, instruction_id, bundle_descriptor_id):
+ try:
+ # pop() is threadsafe
+ processor = self.cached_bundle_processors[bundle_descriptor_id].pop()
+ state_handler = processor.state_handler
+ except IndexError:
+ process_bundle_desc = self.fns[bundle_descriptor_id]
+ state_handler = self.state_handler_factory.create_state_handler(
+ process_bundle_desc.state_api_service_descriptor)
+ processor = bundle_processor.BundleProcessor(
+ process_bundle_desc,
+ state_handler,
+ self.data_channel_factory)
try:
+ self.active_bundle_processors[instruction_id] = processor
with state_handler.process_instruction_id(instruction_id):
- processor.process_bundle(instruction_id)
+ yield processor
finally:
- del self.bundle_processors[instruction_id]
-
- return beam_fn_api_pb2.InstructionResponse(
- instruction_id=instruction_id,
- process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
- metrics=processor.metrics(),
- monitoring_infos=processor.monitoring_infos()))
+ del self.active_bundle_processors[instruction_id]
+ # Outside the finally block as we only want to re-use on success.
+ processor.reset()
+ self.cached_bundle_processors[bundle_descriptor_id].append(processor)
def process_bundle_progress(self, request, instruction_id):
# It is an error to get progress for a not-in-flight bundle.
- processor = self.bundle_processors.get(request.instruction_reference)
+ processor =
self.active_bundle_processors.get(request.instruction_reference)
return beam_fn_api_pb2.InstructionResponse(
instruction_id=instruction_id,
process_bundle_progress=beam_fn_api_pb2.ProcessBundleProgressResponse(
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
index 73fc4be2e79..010810be16a 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
@@ -154,6 +154,11 @@ cdef class StateSampler(object):
# pythread doesn't support conditions.
self.sampling_thread.join()
+ def reset(self):
+ for state in self.scoped_states_by_index:
+ (<ScopedState>state)._nsecs = 0
+ self.started = self.finished = False
+
def current_state(self):
return self.scoped_states_by_index[self.current_state_index]
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py
b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
index 196b35d6003..00918285aee 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py
+++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
@@ -65,6 +65,10 @@ def start(self):
def stop(self):
pass
+ def reset(self):
+ for state in self._states_by_name.values():
+ state.nsecs = 0
+
class ScopedState(object):
diff --git a/sdks/python/apache_beam/utils/counters.py
b/sdks/python/apache_beam/utils/counters.py
index 1df81949c77..f924853b74c 100644
--- a/sdks/python/apache_beam/utils/counters.py
+++ b/sdks/python/apache_beam/utils/counters.py
@@ -155,6 +155,9 @@ def __init__(self, name, combine_fn):
def update(self, value):
self.accumulator = self._add_input(self.accumulator, value)
+ def reset(self, value):
+ self.accumulator = self.combine_fn.create_accumulator()
+
def value(self):
return self.combine_fn.extract_output(self.accumulator)
@@ -175,11 +178,15 @@ class AccumulatorCombineFnCounter(Counter):
def __init__(self, name, combine_fn):
assert isinstance(combine_fn, cy_combiners.AccumulatorCombineFn)
super(AccumulatorCombineFnCounter, self).__init__(name, combine_fn)
- self._fast_add_input = self.accumulator.add_input
+ self.reset()
def update(self, value):
self._fast_add_input(value)
+ def reset(self):
+ self.accumulator = self.combine_fn.create_accumulator()
+ self._fast_add_input = self.accumulator.add_input
+
class CounterFactory(object):
"""Keeps track of unique counters."""
@@ -215,6 +222,12 @@ def get_counter(self, name, combine_fn):
self.counters[name] = counter
return counter
+ def reset(self):
+ # Counters are cached in state sampler states.
+ with self._lock:
+ for counter in self.counters.values():
+ counter.reset()
+
def get_counters(self):
"""Returns the current set of counters.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 155803)
Time Spent: 40m (was: 0.5h)
> Cache execution trees in SDK worker
> -----------------------------------
>
> Key: BEAM-5521
> URL: https://issues.apache.org/jira/browse/BEAM-5521
> Project: Beam
> Issue Type: Improvement
> Components: sdk-py-harness
> Reporter: Robert Bradshaw
> Assignee: Robert Bradshaw
> Priority: Major
> Labels: portability-flink
> Time Spent: 40m
> Remaining Estimate: 0h
>
> Currently they are re-constructed from the protos for every bundle, which is
> expensive (especially for 1-element bundles in streaming flink).
> Care should be taken to ensure the objects can be re-usued.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)