[ 
https://issues.apache.org/jira/browse/BEAM-5521?focusedWorklogId=155803&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-155803
 ]

ASF GitHub Bot logged work on BEAM-5521:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 18/Oct/18 09:45
            Start Date: 18/Oct/18 09:45
    Worklog Time Spent: 10m 
      Work Description: robertwb closed pull request #6717: [BEAM-5521] Re-use 
bundle processors across bundles.
URL: https://github.com/apache/beam/pull/6717
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/python/apache_beam/metrics/cells.py 
b/sdks/python/apache_beam/metrics/cells.py
index b177d2014d7..1836a82eb0c 100644
--- a/sdks/python/apache_beam/metrics/cells.py
+++ b/sdks/python/apache_beam/metrics/cells.py
@@ -147,6 +147,10 @@ def __init__(self, *args):
     super(CounterCell, self).__init__(*args)
     self.value = CounterAggregator.identity_element()
 
+  def reset(self):
+    self.commit = CellCommitState()
+    self.value = CounterAggregator.identity_element()
+
   def combine(self, other):
     result = CounterCell()
     result.inc(self.value + other.value)
@@ -189,6 +193,10 @@ def __init__(self, *args):
     super(DistributionCell, self).__init__(*args)
     self.data = DistributionAggregator.identity_element()
 
+  def reset(self):
+    self.commit = CellCommitState()
+    self.data = DistributionAggregator.identity_element()
+
   def combine(self, other):
     result = DistributionCell()
     result.data = self.data.combine(other.data)
@@ -230,6 +238,10 @@ def __init__(self, *args):
     super(GaugeCell, self).__init__(*args)
     self.data = GaugeAggregator.identity_element()
 
+  def reset(self):
+    self.commit = CellCommitState()
+    self.data = GaugeAggregator.identity_element()
+
   def combine(self, other):
     result = GaugeCell()
     result.data = self.data.combine(other.data)
diff --git a/sdks/python/apache_beam/metrics/execution.py 
b/sdks/python/apache_beam/metrics/execution.py
index 2d771394c23..5818ac2efa8 100644
--- a/sdks/python/apache_beam/metrics/execution.py
+++ b/sdks/python/apache_beam/metrics/execution.py
@@ -238,6 +238,14 @@ def to_runner_api_monitoring_infos(self, transform_id):
       ))
     return {monitoring_infos.to_key(mi) : mi for mi in all_user_metrics}
 
+  def reset(self):
+    for counter in self.counters.values():
+      counter.reset()
+    for distribution in self.distributions.values():
+      distribution.reset()
+    for gauge in self.gauges.values():
+      gauge.reset()
+
 
 class MetricUpdates(object):
   """Contains updates for several metrics.
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index e4737b4ad09..4fd0aace539 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -151,7 +151,6 @@ def __init__(self, state_handler, transform_id, tag, 
side_input_data, coder):
     self._element_coder = coder.wrapped_value_coder
     self._target_window_coder = coder.window_coder
     # TODO(robertwb): Limit the cache size.
-    # TODO(robertwb): Cross-bundle caching respecting cache tokens.
     self._cache = {}
 
   def __getitem__(self, window):
@@ -205,6 +204,10 @@ def is_globally_windowed(self):
     return (self._side_input_data.window_mapping_fn
             == sideinputs._global_window_mapping_fn)
 
+  def reset(self):
+    # TODO(BEAM-5428): Cross-bundle caching respecting cache tokens.
+    self._cache = {}
+
 
 class CombiningValueRuntimeState(userstate.RuntimeState):
   def __init__(self, underlying_bag_state, combinefn):
@@ -310,6 +313,10 @@ def get_state(self, state_spec, key, window):
     else:
       raise NotImplementedError(state_spec)
 
+  def reset(self):
+    # TODO(BEAM-5428): Implement cross-bundle state caching.
+    pass
+
 
 def memoize(func):
   cache = {}
@@ -342,6 +349,8 @@ def __init__(
         'fnapi-step-%s' % self.process_bundle_descriptor.id,
         self.counter_factory)
     self.ops = self.create_execution_tree(self.process_bundle_descriptor)
+    for op in self.ops.values():
+      op.setup()
 
   def create_execution_tree(self, descriptor):
 
@@ -385,6 +394,13 @@ def topological_height(transform_id):
         for transform_id in sorted(
             descriptor.transforms, key=topological_height, reverse=True)])
 
+  def reset(self):
+    self.counter_factory.reset()
+    self.state_sampler.reset()
+    # Side input caches.
+    for op in self.ops.values():
+      op.reset()
+
   def process_bundle(self, instruction_id):
     expected_inputs = []
     for op in self.ops.values():
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd 
b/sdks/python/apache_beam/runners/worker/operations.pxd
index 318ea516810..22134cd69a0 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -49,6 +49,8 @@ cdef class Operation(object):
   # TODO(robertwb): Cythonize FnHarness.
   cdef public list receivers
   cdef readonly bint debug_logging_enabled
+  # For legacy workers.
+  cdef bint setup_done
 
   cdef public step_name  # initialized lazily
 
diff --git a/sdks/python/apache_beam/runners/worker/operations.py 
b/sdks/python/apache_beam/runners/worker/operations.py
index 61dd632aafe..20dec1d69e9 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -143,19 +143,30 @@ def __init__(self, name_context, spec, counter_factory, 
state_sampler):
     # TODO(ccy): the '-abort' state can be added when the abort is supported in
     # Operations.
     self.receivers = []
+    # Legacy workers cannot call setup() until after setting additional state
+    # on the operation.
+    self.setup_done = False
+
+  def setup(self):
+    with self.scoped_start_state:
+      self.debug_logging_enabled = logging.getLogger().isEnabledFor(
+          logging.DEBUG)
+      # Everything except WorkerSideInputSource, which is not a
+      # top-level operation, should have output_coders
+      #TODO(pabloem): Define better what step name is used here.
+      if getattr(self.spec, 'output_coders', None):
+        self.receivers = [ConsumerSet(self.counter_factory,
+                                      self.name_context.logging_name(),
+                                      i,
+                                      self.consumers[i], coder)
+                          for i, coder in enumerate(self.spec.output_coders)]
+    self.setup_done = True
 
   def start(self):
     """Start operation."""
-    self.debug_logging_enabled = 
logging.getLogger().isEnabledFor(logging.DEBUG)
-    # Everything except WorkerSideInputSource, which is not a
-    # top-level operation, should have output_coders
-    #TODO(pabloem): Define better what step name is used here.
-    if getattr(self.spec, 'output_coders', None):
-      self.receivers = [ConsumerSet(self.counter_factory,
-                                    self.name_context.logging_name(),
-                                    i,
-                                    self.consumers[i], coder)
-                        for i, coder in enumerate(self.spec.output_coders)]
+    if not self.setup_done:
+      # For legacy workers.
+      self.setup()
 
   def process(self, o):
     """Process element in operation."""
@@ -165,6 +176,9 @@ def finish(self):
     """Finish operation."""
     pass
 
+  def reset(self):
+    self.metrics_container.reset()
+
   def output(self, windowed_value, output_index=0):
     cython.cast(Receiver, self.receivers[output_index]).receive(windowed_value)
 
@@ -422,9 +436,9 @@ def _read_side_inputs(self, tags_and_types):
       yield apache_sideinputs.SideInputMap(
           view_class, view_options, sideinputs.EmulatedIterable(iterator_fn))
 
-  def start(self):
+  def setup(self):
     with self.scoped_start_state:
-      super(DoOperation, self).start()
+      super(DoOperation, self).setup()
 
       # See fn_data in dataflow_runner.py
       fn, args, kwargs, tags_and_types, window_fn = (
@@ -474,6 +488,9 @@ def start(self):
                             if isinstance(self.dofn_runner, Receiver)
                             else DoFnRunnerReceiver(self.dofn_runner))
 
+  def start(self):
+    with self.scoped_start_state:
+      super(DoOperation, self).start()
       self.dofn_runner.start()
 
   def process(self, o):
@@ -490,6 +507,13 @@ def finish(self):
     with self.scoped_finish_state:
       self.dofn_runner.finish()
 
+  def reset(self):
+    super(DoOperation, self).reset()
+    for side_input_map in self.side_input_maps:
+      side_input_map.reset()
+    if self.user_state_context:
+      self.user_state_context.reset()
+
   def progress_metrics(self):
     metrics = super(DoOperation, self).progress_metrics()
     if self.tagged_receivers:
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 7ecdfee2e9b..fa811262183 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -21,6 +21,7 @@
 from __future__ import print_function
 
 import abc
+import collections
 import contextlib
 import logging
 import queue
@@ -204,7 +205,8 @@ def __init__(self, state_handler_factory, 
data_channel_factory, fns):
     self.fns = fns
     self.state_handler_factory = state_handler_factory
     self.data_channel_factory = data_channel_factory
-    self.bundle_processors = {}
+    self.active_bundle_processors = {}
+    self.cached_bundle_processors = collections.defaultdict(list)
 
   def do_instruction(self, request):
     request_type = request.WhichOneof('request')
@@ -223,29 +225,43 @@ def register(self, request, instruction_id):
         register=beam_fn_api_pb2.RegisterResponse())
 
   def process_bundle(self, request, instruction_id):
-    process_bundle_desc = self.fns[request.process_bundle_descriptor_reference]
-    state_handler = self.state_handler_factory.create_state_handler(
-        process_bundle_desc.state_api_service_descriptor)
-    self.bundle_processors[
-        instruction_id] = processor = bundle_processor.BundleProcessor(
-            process_bundle_desc,
-            state_handler,
-            self.data_channel_factory)
+    with self.get_bundle_processor(
+        instruction_id,
+        request.process_bundle_descriptor_reference) as bundle_processor:
+      bundle_processor.process_bundle(instruction_id)
+      return beam_fn_api_pb2.InstructionResponse(
+          instruction_id=instruction_id,
+          process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
+              metrics=bundle_processor.metrics(),
+              monitoring_infos=bundle_processor.monitoring_infos()))
+
+  @contextlib.contextmanager
+  def get_bundle_processor(self, instruction_id, bundle_descriptor_id):
+    try:
+      # pop() is threadsafe
+      processor = self.cached_bundle_processors[bundle_descriptor_id].pop()
+      state_handler = processor.state_handler
+    except IndexError:
+      process_bundle_desc = self.fns[bundle_descriptor_id]
+      state_handler = self.state_handler_factory.create_state_handler(
+          process_bundle_desc.state_api_service_descriptor)
+      processor = bundle_processor.BundleProcessor(
+          process_bundle_desc,
+          state_handler,
+          self.data_channel_factory)
     try:
+      self.active_bundle_processors[instruction_id] = processor
       with state_handler.process_instruction_id(instruction_id):
-        processor.process_bundle(instruction_id)
+        yield processor
     finally:
-      del self.bundle_processors[instruction_id]
-
-    return beam_fn_api_pb2.InstructionResponse(
-        instruction_id=instruction_id,
-        process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
-            metrics=processor.metrics(),
-            monitoring_infos=processor.monitoring_infos()))
+      del self.active_bundle_processors[instruction_id]
+    # Outside the finally block as we only want to re-use on success.
+    processor.reset()
+    self.cached_bundle_processors[bundle_descriptor_id].append(processor)
 
   def process_bundle_progress(self, request, instruction_id):
     # It is an error to get progress for a not-in-flight bundle.
-    processor = self.bundle_processors.get(request.instruction_reference)
+    processor = 
self.active_bundle_processors.get(request.instruction_reference)
     return beam_fn_api_pb2.InstructionResponse(
         instruction_id=instruction_id,
         process_bundle_progress=beam_fn_api_pb2.ProcessBundleProgressResponse(
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx 
b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
index 73fc4be2e79..010810be16a 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
@@ -154,6 +154,11 @@ cdef class StateSampler(object):
     # pythread doesn't support conditions.
     self.sampling_thread.join()
 
+  def reset(self):
+    for state in self.scoped_states_by_index:
+      (<ScopedState>state)._nsecs = 0
+    self.started = self.finished = False
+
   def current_state(self):
     return self.scoped_states_by_index[self.current_state_index]
 
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py 
b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
index 196b35d6003..00918285aee 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py
+++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
@@ -65,6 +65,10 @@ def start(self):
   def stop(self):
     pass
 
+  def reset(self):
+    for state in self._states_by_name.values():
+      state.nsecs = 0
+
 
 class ScopedState(object):
 
diff --git a/sdks/python/apache_beam/utils/counters.py 
b/sdks/python/apache_beam/utils/counters.py
index 1df81949c77..f924853b74c 100644
--- a/sdks/python/apache_beam/utils/counters.py
+++ b/sdks/python/apache_beam/utils/counters.py
@@ -155,6 +155,9 @@ def __init__(self, name, combine_fn):
   def update(self, value):
     self.accumulator = self._add_input(self.accumulator, value)
 
+  def reset(self, value):
+    self.accumulator = self.combine_fn.create_accumulator()
+
   def value(self):
     return self.combine_fn.extract_output(self.accumulator)
 
@@ -175,11 +178,15 @@ class AccumulatorCombineFnCounter(Counter):
   def __init__(self, name, combine_fn):
     assert isinstance(combine_fn, cy_combiners.AccumulatorCombineFn)
     super(AccumulatorCombineFnCounter, self).__init__(name, combine_fn)
-    self._fast_add_input = self.accumulator.add_input
+    self.reset()
 
   def update(self, value):
     self._fast_add_input(value)
 
+  def reset(self):
+    self.accumulator = self.combine_fn.create_accumulator()
+    self._fast_add_input = self.accumulator.add_input
+
 
 class CounterFactory(object):
   """Keeps track of unique counters."""
@@ -215,6 +222,12 @@ def get_counter(self, name, combine_fn):
         self.counters[name] = counter
       return counter
 
+  def reset(self):
+    # Counters are cached in state sampler states.
+    with self._lock:
+      for counter in self.counters.values():
+        counter.reset()
+
   def get_counters(self):
     """Returns the current set of counters.
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Issue Time Tracking
-------------------

    Worklog Id:     (was: 155803)
    Time Spent: 40m  (was: 0.5h)

> Cache execution trees in SDK worker
> -----------------------------------
>
>                 Key: BEAM-5521
>                 URL: https://issues.apache.org/jira/browse/BEAM-5521
>             Project: Beam
>          Issue Type: Improvement
>          Components: sdk-py-harness
>            Reporter: Robert Bradshaw
>            Assignee: Robert Bradshaw
>            Priority: Major
>              Labels: portability-flink
>          Time Spent: 40m
>  Remaining Estimate: 0h
>
> Currently they are re-constructed from the protos for every bundle, which is 
> expensive (especially for 1-element bundles in streaming flink). 
> Care should be taken to ensure the objects can be re-usued. 



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to