Repository: beam Updated Branches: refs/heads/master fe395900a -> ddc0a7d83
Enable progress request handling in python SDK harness Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a241eda6 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a241eda6 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a241eda6 Branch: refs/heads/master Commit: a241eda684376fffa80a4f530446ae6cae800a06 Parents: fe39590 Author: Vikas Kedigehalli <vika...@google.com> Authored: Tue Aug 29 12:32:38 2017 -0700 Committer: Robert Bradshaw <rober...@gmail.com> Committed: Wed Sep 20 16:10:37 2017 -0700 ---------------------------------------------------------------------- .../apache_beam/runners/worker/sdk_worker.py | 74 ++++++++++++++------ 1 file changed, 52 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/a241eda6/sdks/python/apache_beam/runners/worker/sdk_worker.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 1481797..97f1f59 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -21,10 +21,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import logging import Queue as queue -import threading import traceback +from concurrent import futures import grpc @@ -38,6 +39,9 @@ class SdkHarness(object): def __init__(self, control_address): self._control_channel = grpc.insecure_channel(control_address) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory() + # TODO: Ensure thread safety to run with more than 1 thread. + self._default_work_thread_pool = futures.ThreadPoolExecutor(max_workers=1) + self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1) def run(self): contol_stub = beam_fn_api_pb2.BeamFnControlStub(self._control_channel) @@ -55,23 +59,44 @@ class SdkHarness(object): return yield response - def process_requests(): - for work_request in contol_stub.Control(get_responses()): - logging.info('Got work %s', work_request.instruction_id) + for work_request in contol_stub.Control(get_responses()): + logging.info('Got work %s', work_request.instruction_id) + request_type = work_request.WhichOneof('request') + if request_type == ['process_bundle_progress']: + thread_pool = self._progress_thread_pool + else: + thread_pool = self._default_work_thread_pool + + # Need this wrapper to capture the original stack trace. + def do_instruction(request): + try: + return self.worker.do_instruction(request) + except Exception as e: # pylint: disable=broad-except + traceback_str = traceback.format_exc(e) + raise StandardError("Error processing request. Original traceback " + "is\n%s\n" % traceback_str) + + def handle_response(request, response_future): try: - response = self.worker.do_instruction(work_request) - except Exception: # pylint: disable=broad-except + response = response_future.result() + except Exception as e: # pylint: disable=broad-except logging.error( 'Error processing instruction %s', - work_request.instruction_id, + request.instruction_id, exc_info=True) response = beam_fn_api_pb2.InstructionResponse( - instruction_id=work_request.instruction_id, - error=traceback.format_exc()) + instruction_id=request.instruction_id, + error=str(e)) responses.put(response) - t = threading.Thread(target=process_requests) - t.start() - t.join() + + thread_pool.submit(do_instruction, work_request).add_done_callback( + functools.partial(handle_response, work_request)) + + logging.info("No more requests from control plane") + logging.info("SDK Harness waiting for in-flight requests to complete") + # Wait until existing requests are processed. + self._progress_thread_pool.shutdown() + self._default_work_thread_pool.shutdown() # get_responses may be blocked on responses.get(), but we need to return # control to its caller. responses.put(no_more_work) @@ -89,20 +114,18 @@ class SdkWorker(object): def do_instruction(self, request): request_type = request.WhichOneof('request') if request_type: - # E.g. if register is set, this will construct - # InstructionResponse(register=self.register(request.register)) - return beam_fn_api_pb2.InstructionResponse(**{ - 'instruction_id': request.instruction_id, - request_type: getattr(self, request_type) - (getattr(request, request_type), request.instruction_id) - }) + # E.g. if register is set, this will call self.register(request.register)) + return getattr(self, request_type)( + getattr(request, request_type), request.instruction_id) else: raise NotImplementedError - def register(self, request, unused_instruction_id=None): + def register(self, request, instruction_id): for process_bundle_descriptor in request.process_bundle_descriptor: self.fns[process_bundle_descriptor.id] = process_bundle_descriptor - return beam_fn_api_pb2.RegisterResponse() + return beam_fn_api_pb2.InstructionResponse(**{ + 'instruction_id': instruction_id, + 'register': beam_fn_api_pb2.RegisterResponse()}) def process_bundle(self, request, instruction_id): bundle_processor.BundleProcessor( @@ -110,4 +133,11 @@ class SdkWorker(object): self.state_handler, self.data_channel_factory).process_bundle(instruction_id) - return beam_fn_api_pb2.ProcessBundleResponse() + return beam_fn_api_pb2.InstructionResponse(**{ + 'instruction_id': instruction_id, + 'process_bundle': beam_fn_api_pb2.ProcessBundleResponse()}) + + def process_bundle_progress(self, request, instruction_id): + return beam_fn_api_pb2.InstructionResponse(**{ + 'instruction_id': instruction_id, + 'error': 'Not Supported'})