Repository: beam
Updated Branches:
  refs/heads/master fe395900a -> ddc0a7d83


Enable progress request handling in python SDK harness


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a241eda6
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a241eda6
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a241eda6

Branch: refs/heads/master
Commit: a241eda684376fffa80a4f530446ae6cae800a06
Parents: fe39590
Author: Vikas Kedigehalli <vika...@google.com>
Authored: Tue Aug 29 12:32:38 2017 -0700
Committer: Robert Bradshaw <rober...@gmail.com>
Committed: Wed Sep 20 16:10:37 2017 -0700

----------------------------------------------------------------------
 .../apache_beam/runners/worker/sdk_worker.py    | 74 ++++++++++++++------
 1 file changed, 52 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/a241eda6/sdks/python/apache_beam/runners/worker/sdk_worker.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 1481797..97f1f59 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -21,10 +21,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import functools
 import logging
 import Queue as queue
-import threading
 import traceback
+from concurrent import futures
 
 import grpc
 
@@ -38,6 +39,9 @@ class SdkHarness(object):
   def __init__(self, control_address):
     self._control_channel = grpc.insecure_channel(control_address)
     self._data_channel_factory = data_plane.GrpcClientDataChannelFactory()
+    # TODO: Ensure thread safety to run with more than 1 thread.
+    self._default_work_thread_pool = futures.ThreadPoolExecutor(max_workers=1)
+    self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1)
 
   def run(self):
     contol_stub = beam_fn_api_pb2.BeamFnControlStub(self._control_channel)
@@ -55,23 +59,44 @@ class SdkHarness(object):
           return
         yield response
 
-    def process_requests():
-      for work_request in contol_stub.Control(get_responses()):
-        logging.info('Got work %s', work_request.instruction_id)
+    for work_request in contol_stub.Control(get_responses()):
+      logging.info('Got work %s', work_request.instruction_id)
+      request_type = work_request.WhichOneof('request')
+      if request_type == ['process_bundle_progress']:
+        thread_pool = self._progress_thread_pool
+      else:
+        thread_pool = self._default_work_thread_pool
+
+      # Need this wrapper to capture the original stack trace.
+      def do_instruction(request):
+        try:
+          return self.worker.do_instruction(request)
+        except Exception as e:  # pylint: disable=broad-except
+          traceback_str = traceback.format_exc(e)
+          raise StandardError("Error processing request. Original traceback "
+                              "is\n%s\n" % traceback_str)
+
+      def handle_response(request, response_future):
         try:
-          response = self.worker.do_instruction(work_request)
-        except Exception:  # pylint: disable=broad-except
+          response = response_future.result()
+        except Exception as e:  # pylint: disable=broad-except
           logging.error(
               'Error processing instruction %s',
-              work_request.instruction_id,
+              request.instruction_id,
               exc_info=True)
           response = beam_fn_api_pb2.InstructionResponse(
-              instruction_id=work_request.instruction_id,
-              error=traceback.format_exc())
+              instruction_id=request.instruction_id,
+              error=str(e))
         responses.put(response)
-    t = threading.Thread(target=process_requests)
-    t.start()
-    t.join()
+
+      thread_pool.submit(do_instruction, work_request).add_done_callback(
+          functools.partial(handle_response, work_request))
+
+    logging.info("No more requests from control plane")
+    logging.info("SDK Harness waiting for in-flight requests to complete")
+    # Wait until existing requests are processed.
+    self._progress_thread_pool.shutdown()
+    self._default_work_thread_pool.shutdown()
     # get_responses may be blocked on responses.get(), but we need to return
     # control to its caller.
     responses.put(no_more_work)
@@ -89,20 +114,18 @@ class SdkWorker(object):
   def do_instruction(self, request):
     request_type = request.WhichOneof('request')
     if request_type:
-      # E.g. if register is set, this will construct
-      # InstructionResponse(register=self.register(request.register))
-      return beam_fn_api_pb2.InstructionResponse(**{
-          'instruction_id': request.instruction_id,
-          request_type: getattr(self, request_type)
-                        (getattr(request, request_type), 
request.instruction_id)
-      })
+      # E.g. if register is set, this will call 
self.register(request.register))
+      return getattr(self, request_type)(
+          getattr(request, request_type), request.instruction_id)
     else:
       raise NotImplementedError
 
-  def register(self, request, unused_instruction_id=None):
+  def register(self, request, instruction_id):
     for process_bundle_descriptor in request.process_bundle_descriptor:
       self.fns[process_bundle_descriptor.id] = process_bundle_descriptor
-    return beam_fn_api_pb2.RegisterResponse()
+    return beam_fn_api_pb2.InstructionResponse(**{
+        'instruction_id': instruction_id,
+        'register': beam_fn_api_pb2.RegisterResponse()})
 
   def process_bundle(self, request, instruction_id):
     bundle_processor.BundleProcessor(
@@ -110,4 +133,11 @@ class SdkWorker(object):
         self.state_handler,
         self.data_channel_factory).process_bundle(instruction_id)
 
-    return beam_fn_api_pb2.ProcessBundleResponse()
+    return beam_fn_api_pb2.InstructionResponse(**{
+        'instruction_id': instruction_id,
+        'process_bundle': beam_fn_api_pb2.ProcessBundleResponse()})
+
+  def process_bundle_progress(self, request, instruction_id):
+    return beam_fn_api_pb2.InstructionResponse(**{
+        'instruction_id': instruction_id,
+        'error': 'Not Supported'})

Reply via email to