This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new d38645a  [BEAM-6778] Enable Bundle Finalization in Python SDK harness 
over FnApi (#7937)
d38645a is described below

commit d38645ae8758d834c3e819b715a66dd82c78f6d4
Author: Boyuan Zhang <36090911+boyua...@users.noreply.github.com>
AuthorDate: Thu Mar 21 01:02:39 2019 -0700

    [BEAM-6778] Enable Bundle Finalization in Python SDK harness over FnApi 
(#7937)
---
 sdks/python/apache_beam/runners/common.pxd         |  3 +-
 sdks/python/apache_beam/runners/common.py          | 23 +++++--
 .../runners/portability/flink_runner_test.py       |  6 ++
 .../runners/portability/fn_api_runner.py           |  9 +++
 .../runners/portability/fn_api_runner_test.py      | 76 +++++++++++++++++++++
 .../apache_beam/runners/worker/bundle_processor.py | 15 +++--
 .../apache_beam/runners/worker/operations.py       | 13 ++++
 .../apache_beam/runners/worker/sdk_worker.py       | 77 ++++++++++++++++------
 sdks/python/apache_beam/transforms/core.py         | 30 ++++++++-
 sdks/python/scripts/generate_pydoc.sh              |  1 +
 10 files changed, 222 insertions(+), 31 deletions(-)

diff --git a/sdks/python/apache_beam/runners/common.pxd 
b/sdks/python/apache_beam/runners/common.pxd
index b5ab88d..b28e852 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -56,6 +56,7 @@ cdef class DoFnInvoker(object):
   cdef public DoFnSignature signature
   cdef OutputProcessor output_processor
   cdef object user_state_context
+  cdef public object bundle_finalizer_param
 
   cpdef invoke_process(self, WindowedValue windowed_value,
                        restriction_tracker=*,
@@ -92,7 +93,7 @@ cdef class DoFnRunner(Receiver):
   cdef object step_name
   cdef list side_inputs
   cdef DoFnInvoker do_fn_invoker
-
+  cdef public object bundle_finalizer_param
   cpdef process(self, WindowedValue windowed_value)
 
 
diff --git a/sdks/python/apache_beam/runners/common.py 
b/sdks/python/apache_beam/runners/common.py
index b07242c..3438790 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -292,6 +292,7 @@ class DoFnInvoker(object):
     self.output_processor = output_processor
     self.signature = signature
     self.user_state_context = None
+    self.bundle_finalizer_param = None
 
   @staticmethod
   def create_invoker(
@@ -299,7 +300,8 @@ class DoFnInvoker(object):
       output_processor=None,
       context=None, side_inputs=None, input_args=None, input_kwargs=None,
       process_invocation=True,
-      user_state_context=None):
+      user_state_context=None,
+      bundle_finalizer_param=None):
     """ Creates a new DoFnInvoker based on given arguments.
 
     Args:
@@ -321,6 +323,8 @@ class DoFnInvoker(object):
                             method efficiently.
         user_state_context: The UserStateContext instance for the current
                             Stateful DoFn.
+        bundle_finalizer_param: The param that passed to a process method, 
which
+                                allows a callback to be registered.
     """
     side_inputs = side_inputs or []
     default_arg_values = signature.process_method.defaults
@@ -333,7 +337,7 @@ class DoFnInvoker(object):
       return PerWindowInvoker(
           output_processor,
           signature, context, side_inputs, input_args, input_kwargs,
-          user_state_context)
+          user_state_context, bundle_finalizer_param)
 
   def invoke_process(self, windowed_value, restriction_tracker=None,
                      output_processor=None,
@@ -423,7 +427,8 @@ class PerWindowInvoker(DoFnInvoker):
   """An invoker that processes elements considering windowing information."""
 
   def __init__(self, output_processor, signature, context,
-               side_inputs, input_args, input_kwargs, user_state_context):
+               side_inputs, input_args, input_kwargs, user_state_context,
+               bundle_finalizer_param):
     super(PerWindowInvoker, self).__init__(output_processor, signature)
     self.side_inputs = side_inputs
     self.context = context
@@ -437,6 +442,7 @@ class PerWindowInvoker(DoFnInvoker):
     self.is_splittable = signature.is_splittable_dofn()
     self.restriction_tracker = None
     self.current_windowed_value = None
+    self.bundle_finalizer_param = bundle_finalizer_param
 
     # Try to prepare all the arguments that can just be filled in
     # without any additional work. in the process function.
@@ -487,6 +493,8 @@ class PerWindowInvoker(DoFnInvoker):
         args_with_placeholders.append(ArgPlaceholder(d))
       elif isinstance(d, core.DoFn.TimerParam):
         args_with_placeholders.append(ArgPlaceholder(d))
+      elif d == core.DoFn.BundleFinalizerParam:
+        args_with_placeholders.append(ArgPlaceholder(d))
       else:
         # If no more args are present then the value must be passed via kwarg
         try:
@@ -608,6 +616,8 @@ class PerWindowInvoker(DoFnInvoker):
       elif isinstance(p, core.DoFn.TimerParam):
         args_for_process[i] = (
             self.user_state_context.get_timer(p.timer_spec, key, window))
+      elif p == core.DoFn.BundleFinalizerParam:
+        args_for_process[i] = self.bundle_finalizer_param
 
     if additional_kwargs:
       if kwargs_for_process is None:
@@ -694,6 +704,7 @@ class DoFnRunner(Receiver):
 
     self.step_name = step_name
     self.context = DoFnContext(step_name, state=state)
+    self.bundle_finalizer_param = DoFn.BundleFinalizerParam()
 
     do_fn_signature = DoFnSignature(fn)
 
@@ -722,7 +733,8 @@ class DoFnRunner(Receiver):
 
     self.do_fn_invoker = DoFnInvoker.create_invoker(
         do_fn_signature, output_processor, self.context, side_inputs, args,
-        kwargs, user_state_context=user_state_context)
+        kwargs, user_state_context=user_state_context,
+        bundle_finalizer_param=self.bundle_finalizer_param)
 
   def receive(self, windowed_value):
     self.process(windowed_value)
@@ -733,6 +745,9 @@ class DoFnRunner(Receiver):
     except BaseException as exn:
       self._reraise_augmented(exn)
 
+  def finalize(self):
+    self.bundle_finalizer_param.finalize_bundle()
+
   def process_with_restriction(self, windowed_value):
     element, restriction = windowed_value.value
     return self.do_fn_invoker.invoke_process(
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py 
b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index cb27f55..d67b5fb 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -220,6 +220,12 @@ if __name__ == '__main__':
     def test_sdf(self):
       raise unittest.SkipTest("BEAM-2939")
 
+    def test_callbacks_with_exception(self):
+      raise unittest.SkipTest("BEAM-6868")
+
+    def test_register_finalizations(self):
+      raise unittest.SkipTest("BEAM-6868")
+
     # Inherits all other tests.
 
   # Run the tests.
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 36b1a22..bdbbda2 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -1298,6 +1298,15 @@ class BundleManager(object):
     if result.error:
       raise RuntimeError(result.error)
 
+    if result.process_bundle.requires_finalization:
+      finalize_request = beam_fn_api_pb2.InstructionRequest(
+          finalize_bundle=
+          beam_fn_api_pb2.FinalizeBundleRequest(
+              instruction_reference=process_bundle_id
+          ))
+      self._controller.control_handler.push(
+          finalize_request)
+
     return result, split_results
 
 
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 0e24bda..f298215 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -21,6 +21,7 @@ import collections
 import logging
 import os
 import random
+import shutil
 import sys
 import tempfile
 import threading
@@ -801,6 +802,50 @@ class FnApiRunnerTest(unittest.TestCase):
       print(res._monitoring_infos_by_stage)
       raise
 
+  def test_callbacks_with_exception(self):
+    elements_list = ['1', '2']
+
+    def raise_expetion():
+      raise Exception('raise exception when calling callback')
+
+    class FinalizebleDoFnWithException(beam.DoFn):
+
+      def process(
+          self,
+          element,
+          bundle_finalizer=beam.DoFn.BundleFinalizerParam):
+        bundle_finalizer.register(raise_expetion)
+        yield element
+
+    with self.create_pipeline() as p:
+      res = (p
+             | beam.Create(elements_list)
+             | beam.ParDo(FinalizebleDoFnWithException()))
+      assert_that(res, equal_to(['1', '2']))
+
+  def test_register_finalizations(self):
+    event_recorder = EventRecorder(tempfile.gettempdir())
+    elements_list = ['2', '1']
+
+    class FinalizableDoFn(beam.DoFn):
+      def process(
+          self,
+          element,
+          bundle_finalizer=beam.DoFn.BundleFinalizerParam):
+        bundle_finalizer.register(lambda: event_recorder.record(element))
+        yield element
+
+    with self.create_pipeline() as p:
+      res = (p
+             | beam.Create(elements_list)
+             | beam.ParDo(FinalizableDoFn()))
+
+      assert_that(res, equal_to(elements_list))
+
+    results = event_recorder.events()
+    event_recorder.cleanup()
+    self.assertEquals(results, sorted(elements_list))
+
 
 class FnApiRunnerTestWithGrpc(FnApiRunnerTest):
 
@@ -827,6 +872,9 @@ class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest):
     return beam.Pipeline(
         runner=fn_api_runner.FnApiRunner(bundle_repeat=3))
 
+  def test_register_finalizations(self):
+    raise unittest.SkipTest("TODO: Avoid bundle finalizations on repeat.")
+
 
 class FnApiRunnerSplitTest(unittest.TestCase):
 
@@ -1084,6 +1132,34 @@ def _unpickle_element_counter(name):
   return _pickled_element_counters[name]
 
 
+class EventRecorder(object):
+  """Used to be registered as a callback in bundle finalization.
+
+  The reason why records are written into a tmp file is, the in-memory dataset
+  cannot keep callback records when passing into one DoFn.
+  """
+  def __init__(self, tmp_dir):
+    self.tmp_dir = os.path.join(tmp_dir, uuid.uuid4().hex)
+    os.mkdir(self.tmp_dir)
+
+  def record(self, content):
+    file_path = os.path.join(self.tmp_dir, uuid.uuid4().hex + '.txt')
+    with open(file_path, 'w') as f:
+      f.write(content)
+
+  def events(self):
+    content = []
+    record_files = [f for f in os.listdir(self.tmp_dir) if os.path.isfile(
+        os.path.join(self.tmp_dir, f))]
+    for file in record_files:
+      with open(os.path.join(self.tmp_dir, file), 'r') as f:
+        content.append(f.read())
+    return sorted(content)
+
+  def cleanup(self):
+    shutil.rmtree(self.tmp_dir)
+
+
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 63aece7..06b0322 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -469,7 +469,6 @@ class BundleProcessor(object):
     self.splitting_lock = threading.Lock()
 
   def create_execution_tree(self, descriptor):
-
     transform_factory = BeamTransformFactory(
         descriptor, self.data_channel_factory, self.counter_factory,
         self.state_sampler, self.state_handler)
@@ -559,9 +558,9 @@ class BundleProcessor(object):
         logging.debug('finish %s', op)
         op.finish()
 
-      return [
-          self.delayed_bundle_application(op, residual)
-          for op, residual in execution_context.delayed_applications]
+      return ([self.delayed_bundle_application(op, residual)
+               for op, residual in execution_context.delayed_applications],
+              self.requires_finalization())
 
     finally:
       # Ensure any in-flight split attempts complete.
@@ -569,6 +568,14 @@ class BundleProcessor(object):
         pass
       self.state_sampler.stop_if_still_running()
 
+  def finalize_bundle(self):
+    for op in self.ops.values():
+      op.finalize_bundle()
+    return beam_fn_api_pb2.FinalizeBundleResponse()
+
+  def requires_finalization(self):
+    return any(op.needs_finalization() for op in self.ops.values())
+
   def try_split(self, bundle_split_request):
     split_response = beam_fn_api_pb2.ProcessBundleSplitResponse()
     with self.splitting_lock:
diff --git a/sdks/python/apache_beam/runners/worker/operations.py 
b/sdks/python/apache_beam/runners/worker/operations.py
index 1412403..657efb7 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -223,6 +223,12 @@ class Operation(object):
     """Process element in operation."""
     pass
 
+  def finalize_bundle(self):
+    pass
+
+  def needs_finalization(self):
+    return False
+
   def try_split(self, fraction_of_remainder):
     return None
 
@@ -557,6 +563,12 @@ class DoOperation(Operation):
         self.execution_context.delayed_applications.append(
             (self, delayed_application))
 
+  def finalize_bundle(self):
+    self.dofn_receiver.finalize()
+
+  def needs_finalization(self):
+    return self.dofn_receiver.bundle_finalizer_param.has_callbacks()
+
   def process_timer(self, tag, windowed_timer):
     key, timer_data = windowed_timer.value
     timer_spec = self.timer_specs[tag]
@@ -575,6 +587,7 @@ class DoOperation(Operation):
       side_input_map.reset()
     if self.user_state_context:
       self.user_state_context.reset()
+    self.dofn_receiver.bundle_finalizer_param.reset()
 
   def progress_metrics(self):
     metrics = super(DoOperation, self).progress_metrics()
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 9a10325..7be8cf5 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -73,17 +73,26 @@ class SdkHarness(object):
         credentials)
     self._state_handler_factory = GrpcStateHandlerFactory(credentials)
     self._profiler_factory = profiler_factory
+    self._fns = {}
+    # BundleProcessor cache across all workers.
+    self._bundle_processor_cache = BundleProcessorCache(
+        state_handler_factory=self._state_handler_factory,
+        data_channel_factory=self._data_channel_factory,
+        fns=self._fns)
+    # workers for process/finalize bundle.
     self.workers = queue.Queue()
+    # one worker for progress/split request.
+    self.progress_worker = SdkWorker(self._bundle_processor_cache,
+                                     profiler_factory=self._profiler_factory)
     # one thread is enough for getting the progress report.
     # Assumption:
     # Progress report generation should not do IO or wait on other resources.
     #  Without wait, having multiple threads will not improve performance and
     #  will only add complexity.
     self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1)
+    # finalize and process share one thread pool.
     self._process_thread_pool = futures.ThreadPoolExecutor(
         max_workers=self._worker_count)
-    self._instruction_id_vs_worker = {}
-    self._fns = {}
     self._responses = queue.Queue()
     self._process_bundle_queue = queue.Queue()
     self._unscheduled_process_bundle = {}
@@ -93,11 +102,7 @@ class SdkHarness(object):
     control_stub = 
beam_fn_api_pb2_grpc.BeamFnControlStub(self._control_channel)
     no_more_work = object()
 
-    # Create workers
-    bundle_processor_cache = BundleProcessorCache(
-        state_handler_factory=self._state_handler_factory,
-        data_channel_factory=self._data_channel_factory,
-        fns=self._fns)
+    # Create process workers
     for _ in range(self._worker_count):
       # SdkHarness manage function registration and share self._fns with all
       # the workers. This is needed because function registration (register)
@@ -108,7 +113,7 @@ class SdkHarness(object):
       # potentially get executed by different worker. Hence we need a
       # centralized function list shared among all the workers.
       self.workers.put(
-          SdkWorker(bundle_processor_cache,
+          SdkWorker(self._bundle_processor_cache,
                     profiler_factory=self._profiler_factory))
 
     def get_responses():
@@ -180,17 +185,12 @@ class SdkHarness(object):
       worker = self.workers.get()
       # Get the first work item in the queue
       work = self._process_bundle_queue.get()
-      # add the instuction_id vs worker map for progress reporting lookup
-      self._instruction_id_vs_worker[work.instruction_id] = worker
       self._unscheduled_process_bundle.pop(work.instruction_id, None)
       try:
         self._execute(lambda: worker.do_instruction(work), work)
       finally:
-        # Delete the instruction_id <-> worker mapping
-        self._instruction_id_vs_worker.pop(work.instruction_id, None)
         # Put the worker back in the free worker pool
         self.workers.put(worker)
-
     # Create a task for each process_bundle request and schedule it
     self._process_bundle_queue.put(request)
     self._unscheduled_process_bundle[request.instruction_id] = time.time()
@@ -209,11 +209,11 @@ class SdkHarness(object):
     def task():
       instruction_reference = getattr(
           request, request.WhichOneof('request')).instruction_reference
-      if instruction_reference in self._instruction_id_vs_worker:
+      # only process progress/split request when a bundle is in processing.
+      if (instruction_reference in
+          self._bundle_processor_cache.active_bundle_processors):
         self._execute(
-            lambda: self._instruction_id_vs_worker[
-                instruction_reference
-            ].do_instruction(request), request)
+            lambda: self.progress_worker.do_instruction(request), request)
       else:
         self._execute(lambda: beam_fn_api_pb2.InstructionResponse(
             instruction_id=request.instruction_id, error=(
@@ -224,6 +224,20 @@ class SdkHarness(object):
 
     self._progress_thread_pool.submit(task)
 
+  def _request_finalize_bundle(self, request):
+
+    def task():
+      # Get one available worker.
+      worker = self.workers.get()
+      try:
+        self._execute(
+            lambda: worker.do_instruction(request), request)
+      finally:
+        # Put the worker back in the free worker pool.
+        self.workers.put(worker)
+
+    self._process_thread_pool.submit(task)
+
   def _monitor_process_bundle(self):
     """
     Monitor the unscheduled bundles and log if a bundle is not scheduled for
@@ -313,15 +327,18 @@ class SdkWorker(object):
       with bundle_processor.state_handler.process_instruction_id(
           instruction_id):
         with self.maybe_profile(instruction_id):
-          delayed_applications = 
bundle_processor.process_bundle(instruction_id)
+          delayed_applications, requests_finalization = (
+              bundle_processor.process_bundle(instruction_id))
           response = beam_fn_api_pb2.InstructionResponse(
               instruction_id=instruction_id,
               process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
                   residual_roots=delayed_applications,
                   metrics=bundle_processor.metrics(),
-                  monitoring_infos=bundle_processor.monitoring_infos()))
-      # TODO(boyuanz): Don't release here if finalize is needed.
-      self.bundle_processor_cache.release(instruction_id)
+                  monitoring_infos=bundle_processor.monitoring_infos(),
+                  requires_finalization=requests_finalization))
+      # Don't release here if finalize is needed.
+      if not requests_finalization:
+        self.bundle_processor_cache.release(instruction_id)
       return response
     except:  # pylint: disable=broad-except
       # Don't re-use bundle processors on failure.
@@ -350,6 +367,24 @@ class SdkWorker(object):
             metrics=processor.metrics() if processor else None,
             monitoring_infos=processor.monitoring_infos() if processor else 
[]))
 
+  def finalize_bundle(self, request, instruction_id):
+    processor = self.bundle_processor_cache.lookup(
+        request.instruction_reference)
+    if processor:
+      try:
+        finalize_response = processor.finalize_bundle()
+        self.bundle_processor_cache.release(request.instruction_reference)
+        return beam_fn_api_pb2.InstructionResponse(
+            instruction_id=instruction_id,
+            finalize_bundle=finalize_response)
+      except:
+        self.bundle_processor_cache.discard(request.instruction_reference)
+        raise
+    else:
+      return beam_fn_api_pb2.InstructionResponse(
+          instruction_id=instruction_id,
+          error='Instruction not running: %s' % instruction_id)
+
   @contextlib.contextmanager
   def maybe_profile(self, instruction_id):
     if self.profiler_factory:
diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 27b67cd..1d095e8 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -326,6 +326,32 @@ class _TimerDoFnParam(_DoFnParam):
     self.param_id = 'TimerParam(%s)' % timer_spec.name
 
 
+class _BundleFinalizerParam(_DoFnParam):
+  """Bundle Finalization DoFn parameter."""
+
+  def __init__(self):
+    self._callbacks = []
+    self.param_id = "FinalizeBundle"
+
+  def register(self, callback):
+    self._callbacks.append(callback)
+
+  # Log errors when calling callback to make sure all callbacks get called
+  # though there are errors. And errors should not fail pipeline.
+  def finalize_bundle(self):
+    for callback in self._callbacks:
+      try:
+        callback()
+      except Exception as e:
+        logging.warn("Got exception from finalization call: %s", e)
+
+  def has_callbacks(self):
+    return len(self._callbacks) > 0
+
+  def reset(self):
+    del self._callbacks[:]
+
+
 class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
   """A function object used by a transform with custom processing.
 
@@ -344,9 +370,11 @@ class DoFn(WithTypeHints, HasDisplayData, 
urns.RunnerApiFn):
   TimestampParam = _DoFnParam('TimestampParam')
   WindowParam = _DoFnParam('WindowParam')
   WatermarkReporterParam = _DoFnParam('WatermarkReporterParam')
+  BundleFinalizerParam = _BundleFinalizerParam
 
   DoFnProcessParams = [ElementParam, SideInputParam, TimestampParam,
-                       WindowParam, WatermarkReporterParam]
+                       WindowParam, WatermarkReporterParam,
+                       BundleFinalizerParam]
 
   # Parameters to access state and timers.  Not restricted to use only in the
   # .process() method. Usage: DoFn.StateParam(state_spec),
diff --git a/sdks/python/scripts/generate_pydoc.sh 
b/sdks/python/scripts/generate_pydoc.sh
index 4929c5e..6184227 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -171,6 +171,7 @@ ignore_identifiers = [
   # DoFn param inner classes, due to a Sphinx misparsing of inner classes
   '_StateDoFnParam',
   '_TimerDoFnParam',
+  '_BundleFinalizerParam',
 
   # Sphinx cannot find this py:class reference target
   'typing.Generic',

Reply via email to