[ https://issues.apache.org/jira/browse/BEAM-3189?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16293311#comment-16293311 ]
ASF GitHub Bot commented on BEAM-3189: -------------------------------------- aaltay closed pull request #4134: [BEAM-3189] Sdk worker multithreading URL: https://github.com/apache/beam/pull/4134 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index e40faa5a38d..62522eea80f 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -850,7 +850,7 @@ def __init__(self): self.data_plane_handler = data_plane.InMemoryDataChannel() self.worker = sdk_worker.SdkWorker( self.state_handler, data_plane.InMemoryDataChannelFactory( - self.data_plane_handler.inverse())) + self.data_plane_handler.inverse()), {}) def push(self, request): logging.debug('CONTROL REQUEST %s', request) @@ -902,8 +902,11 @@ def __init__(self, sdk_harness_factory=None): self.data_server.start() self.control_server.start() - self.worker = (self.sdk_harness_factory or sdk_worker.SdkHarness)( - 'localhost:%s' % self.control_port) + self.worker = self.sdk_harness_factory( + 'localhost:%s' % self.control_port + ) if self.sdk_harness_factory else sdk_worker.SdkHarness( + 'localhost:%s' % self.control_port, worker_count=1) + self.worker_thread = threading.Thread( name='run_worker', target=self.worker.run) logging.info('starting worker') diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index 249eeceae1d..eb297abfe63 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import functools import logging import time import unittest @@ -22,6 +22,7 @@ import apache_beam as beam from apache_beam.runners.portability import fn_api_runner from apache_beam.runners.portability import maptask_executor_runner_test +from apache_beam.runners.worker import sdk_worker from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms import window @@ -155,6 +156,16 @@ def create_pipeline(self): runner=fn_api_runner.FnApiRunner(use_grpc=True)) +class FnApiRunnerTestWithGrpcMultiThreaded(FnApiRunnerTest): + + def create_pipeline(self): + return beam.Pipeline( + runner=fn_api_runner.FnApiRunner( + use_grpc=True, + sdk_harness_factory=functools.partial( + sdk_worker.SdkHarness, worker_count=2))) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index f2a3751dc2a..b3d4854d897 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -277,20 +277,23 @@ class GrpcClientDataChannelFactory(DataChannelFactory): def __init__(self): self._data_channel_cache = {} + self._lock = threading.Lock() def create_data_channel(self, remote_grpc_port): url = remote_grpc_port.api_service_descriptor.url if url not in self._data_channel_cache: - logging.info('Creating channel for %s', url) - grpc_channel = grpc.insecure_channel( - url, - # Options to have no limits (-1) on the size of the messages - # received or sent over the data plane. The actual buffer size is - # controlled in a layer above. - options=[("grpc.max_receive_message_length", -1), - ("grpc.max_send_message_length", -1)]) - self._data_channel_cache[url] = GrpcClientDataChannel( - beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel)) + with self._lock: + if url not in self._data_channel_cache: + logging.info('Creating channel for %s', url) + grpc_channel = grpc.insecure_channel( + url, + # Options to have no limits (-1) on the size of the messages + # received or sent over the data plane. The actual buffer size is + # controlled in a layer above. + options=[("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1)]) + self._data_channel_cache[url] = GrpcClientDataChannel( + beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel)) return self._data_channel_cache[url] def close(self): diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index fec844e6e0e..980357ee221 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -14,14 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # - """SDK harness for executing Python Fns via the Fn API.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools import logging import Queue as queue import sys @@ -38,27 +36,54 @@ class SdkHarness(object): + REQUEST_METHOD_PREFIX = '_request_' - def __init__(self, control_address): + def __init__(self, control_address, worker_count): + self._worker_count = worker_count + self._worker_index = 0 self._control_channel = grpc.insecure_channel(control_address) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory() - # TODO: Ensure thread safety to run with more than 1 thread. - self._default_work_thread_pool = futures.ThreadPoolExecutor(max_workers=1) + self.workers = queue.Queue() + # one thread is enough for getting the progress report. + # Assumption: + # Progress report generation should not do IO or wait on other resources. + # Without wait, having multiple threads will not improve performance and + # will only add complexity. self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1) + self._process_thread_pool = futures.ThreadPoolExecutor( + max_workers=self._worker_count) + self._instruction_id_vs_worker = {} + self._fns = {} + self._responses = queue.Queue() + self._process_bundle_queue = queue.Queue() + logging.info('Initializing SDKHarness with %s workers.', self._worker_count) def run(self): control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(self._control_channel) - state_stub = beam_fn_api_pb2_grpc.BeamFnStateStub(self._control_channel) - state_handler = GrpcStateHandler(state_stub) - state_handler.start() - self.worker = SdkWorker(state_handler, self._data_channel_factory) - - responses = queue.Queue() no_more_work = object() + # Create workers + for _ in range(self._worker_count): + state_handler = GrpcStateHandler( + beam_fn_api_pb2_grpc.BeamFnStateStub(self._control_channel)) + state_handler.start() + # SdkHarness manage function registration and share self._fns with all + # the workers. This is needed because function registration (register) + # and exceution(process_bundle) are send over different request and we + # do not really know which woker is going to process bundle + # for a function till we get process_bundle request. Moreover + # same function is reused by different process bundle calls and + # potentially get executed by different worker. Hence we need a + # centralized function list shared among all the workers. + self.workers.put( + SdkWorker( + state_handler=state_handler, + data_channel_factory=self._data_channel_factory, + fns=self._fns)) + def get_responses(): while True: - response = responses.get() + response = self._responses.get() if response is no_more_work: return yield response @@ -66,54 +91,86 @@ def get_responses(): for work_request in control_stub.Control(get_responses()): logging.info('Got work %s', work_request.instruction_id) request_type = work_request.WhichOneof('request') - # WhichOneOf returns the name of the set field as a single string - if request_type in ['process_bundle_progress']: - thread_pool = self._progress_thread_pool - else: - thread_pool = self._default_work_thread_pool - - # Need this wrapper to capture the original stack trace. - def do_instruction(request): - try: - return self.worker.do_instruction(request) - except Exception as e: # pylint: disable=broad-except - traceback_str = traceback.format_exc(e) - raise Exception("Error processing request. Original traceback " - "is\n%s\n" % traceback_str) - - def handle_response(request, response_future): - try: - response = response_future.result() - except Exception as e: # pylint: disable=broad-except - logging.error( - 'Error processing instruction %s', - request.instruction_id, - exc_info=True) - response = beam_fn_api_pb2.InstructionResponse( - instruction_id=request.instruction_id, - error=str(e)) - responses.put(response) - - thread_pool.submit(do_instruction, work_request).add_done_callback( - functools.partial(handle_response, work_request)) - - logging.info("No more requests from control plane") - logging.info("SDK Harness waiting for in-flight requests to complete") + # Name spacing the request method with 'request_'. The called method + # will be like self.request_register(request) + getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)( + work_request) + + logging.info('No more requests from control plane') + logging.info('SDK Harness waiting for in-flight requests to complete') # Wait until existing requests are processed. self._progress_thread_pool.shutdown() - self._default_work_thread_pool.shutdown() + self._process_thread_pool.shutdown() # get_responses may be blocked on responses.get(), but we need to return # control to its caller. - responses.put(no_more_work) + self._responses.put(no_more_work) self._data_channel_factory.close() - state_handler.done() + # Stop all the workers and clean all the associated resources + for worker in self.workers.queue: + worker.state_handler.done() logging.info('Done consuming work.') + def _execute(self, task, request): + try: + response = task() + except Exception as e: # pylint: disable=broad-except + traceback.print_exc(file=sys.stderr) + logging.error( + 'Error processing instruction %s. Original traceback is\n%s\n', + request.instruction_id, + traceback.format_exc(e), + exc_info=True) + response = beam_fn_api_pb2.InstructionResponse( + instruction_id=request.instruction_id, error=str(e)) + self._responses.put(response) + + def _request_register(self, request): + + def task(): + for process_bundle_descriptor in getattr( + request, request.WhichOneof('request')).process_bundle_descriptor: + self._fns[process_bundle_descriptor.id] = process_bundle_descriptor + + return beam_fn_api_pb2.InstructionResponse( + instruction_id=request.instruction_id, + register=beam_fn_api_pb2.RegisterResponse()) + + self._execute(task, request) + + def _request_process_bundle(self, request): + + def task(): + # Take the free worker. Wait till a worker is free. + worker = self.workers.get() + # Get the first work item in the queue + work = self._process_bundle_queue.get() + # add the instuction_id vs worker map for progress reporting lookup + self._instruction_id_vs_worker[work.instruction_id] = worker + try: + self._execute(lambda: worker.do_instruction(work), work) + finally: + # Delete the instruction_id <-> worker mapping + self._instruction_id_vs_worker.pop(work.instruction_id, None) + # Put the worker back in the free worker pool + self.workers.put(worker) + + # Create a task for each process_bundle request and schedule it + self._process_bundle_queue.put(request) + self._process_thread_pool.submit(task) + + def _request_process_bundle_progress(self, request): + worker = self._instruction_id_vs_worker[request.instruction_id] + + def task(): + self._execute(lambda: worker.do_instruction(request), request) + + self._progress_thread_pool.submit(task) + class SdkWorker(object): - def __init__(self, state_handler, data_channel_factory): - self.fns = {} + def __init__(self, state_handler, data_channel_factory, fns): + self.fns = fns self.state_handler = state_handler self.data_channel_factory = data_channel_factory self.bundle_processors = {} @@ -122,8 +179,8 @@ def do_instruction(self, request): request_type = request.WhichOneof('request') if request_type: # E.g. if register is set, this will call self.register(request.register)) - return getattr(self, request_type)( - getattr(request, request_type), request.instruction_id) + return getattr(self, request_type)(getattr(request, request_type), + request.instruction_id) else: raise NotImplementedError @@ -138,8 +195,7 @@ def process_bundle(self, request, instruction_id): self.bundle_processors[ instruction_id] = processor = bundle_processor.BundleProcessor( self.fns[request.process_bundle_descriptor_reference], - self.state_handler, - self.data_channel_factory) + self.state_handler, self.data_channel_factory) try: processor.process_bundle(instruction_id) finally: @@ -176,6 +232,7 @@ def request_iter(): if request is self._DONE or self._done: break yield request + responses = self._state_stub.State(request_iter()) def pull_responses(): @@ -187,6 +244,7 @@ def pull_responses(): except: # pylint: disable=bare-except self._exc_info = sys.exc_info() raise + reader = threading.Thread(target=pull_responses, name='read_state') reader.daemon = True reader.start() @@ -239,6 +297,7 @@ def _next_id(self): class _Future(object): """A simple future object to implement blocking requests. """ + def __init__(self): self._event = threading.Event() diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py index 1db8b29175f..8671005d4c8 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -20,6 +20,7 @@ import json import logging import os +import re import sys import threading import traceback @@ -128,7 +129,9 @@ def main(unused_argv): service_descriptor) # TODO(robertwb): Support credentials. assert not service_descriptor.oauth2_client_credentials_grant.url - SdkHarness(service_descriptor.url).run() + SdkHarness( + control_address=service_descriptor.url, + worker_count=_get_worker_count(sdk_pipeline_options)).run() logging.info('Python sdk harness exiting.') except: # pylint: disable=broad-except logging.exception('Python sdk harness failed: ') @@ -138,6 +141,39 @@ def main(unused_argv): fn_log_handler.close() +def _get_worker_count(pipeline_options): + """Extract worker count from the pipeline_options. + + This defines how many SdkWorkers will be started in this Python process. + And each SdkWorker will have its own thread to process data. Name of the + experimental parameter is 'worker_threads' + Example Usage in the Command Line: + --experimental worker_threads=1 + + Note: worker_threads is an experimental flag and might not be available in + future releases. + + Returns: + an int containing the worker_threads to use. Default is 1 + """ + pipeline_options = pipeline_options.get( + 'options') if pipeline_options.has_key('options') else {} + experiments = pipeline_options.get( + 'experiments' + ) if pipeline_options and pipeline_options.has_key('experiments') else [] + + experiments = experiments if experiments else [] + + for experiment in experiments: + # There should only be 1 match so returning from the loop + if re.match(r'worker_threads=', experiment): + return int( + re.match(r'worker_threads=(?P<worker_threads>.*)', + experiment).group('worker_threads')) + + return 1 + + def _load_main_session(semi_persistent_directory): """Loads a pickled main session from the path specified.""" if semi_persistent_directory: diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py index 9305c990b10..0e312f5be9d 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main_test.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function +import json import logging import unittest @@ -34,11 +35,39 @@ def test_status_server(self): def wrapped_method_for_test(): lines = sdk_worker_main.StatusServer.get_thread_dump() threaddump = '\n'.join(lines) - self.assertRegexpMatches(threaddump, ".*wrapped_method_for_test.*") + self.assertRegexpMatches(threaddump, '.*wrapped_method_for_test.*') wrapped_method_for_test() + def test_work_count_default_value(self): + self._check_worker_count('{}', 1) -if __name__ == "__main__": + def test_work_count_custom_value(self): + self._check_worker_count( + '{"options": {"experiments":["worker_threads=1"]}}', 1) + self._check_worker_count( + '{"options": {"experiments":["worker_threads=4"]}}', 4) + self._check_worker_count( + '{"options": {"experiments":["worker_threads=12"]}}', 12) + + def test_work_count_wrong_format(self): + self._check_worker_count( + '{"options": {"experiments":["worker_threads="]}}', exception=True) + self._check_worker_count( + '{"options": {"experiments":["worker_threads=a"]}}', exception=True) + self._check_worker_count( + '{"options": {"experiments":["worker_threads=1a"]}}', exception=True) + + def _check_worker_count(self, pipeline_options, expected=0, exception=False): + if exception: + self.assertRaises(Exception, sdk_worker_main._get_worker_count, + json.loads(pipeline_options)) + else: + self.assertEquals( + sdk_worker_main._get_worker_count(json.loads(pipeline_options)), + expected) + + +if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py index 2532341df93..c229d6450ef 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - """Tests for apache_beam.runners.worker.sdk_worker.""" from __future__ import absolute_import @@ -62,29 +61,54 @@ def Control(self, response_iterator, context): class SdkWorkerTest(unittest.TestCase): - def test_fn_registration(self): - process_bundle_descriptors = [ + def _get_process_bundles(self, prefix, size): + return [ beam_fn_api_pb2.ProcessBundleDescriptor( - id=str(100+ix), + id=str(str(prefix) + "-" + str(ix)), transforms={ - str(ix): beam_runner_api_pb2.PTransform(unique_name=str(ix))}) - for ix in range(4)] - - test_controller = BeamFnControlServicer([beam_fn_api_pb2.InstructionRequest( - register=beam_fn_api_pb2.RegisterRequest( - process_bundle_descriptor=process_bundle_descriptors))]) - - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server( - test_controller, server) - test_port = server.add_insecure_port("[::]:0") - server.start() - - harness = sdk_worker.SdkHarness("localhost:%s" % test_port) - harness.run() - self.assertEqual( - harness.worker.fns, - {item.id: item for item in process_bundle_descriptors}) + str(ix): beam_runner_api_pb2.PTransform(unique_name=str(ix)) + }) for ix in range(size) + ] + + def _check_fn_registration_multi_request(self, *args): + """Check the function registration calls to the sdk_harness. + + Args: + tuple of request_count, number of process_bundles per request and workers + counts to process the request. + """ + for (request_count, process_bundles_per_request, worker_count) in args: + requests = [] + process_bundle_descriptors = [] + + for i in range(request_count): + pbd = self._get_process_bundles(i, process_bundles_per_request) + process_bundle_descriptors.extend(pbd) + requests.append( + beam_fn_api_pb2.InstructionRequest( + instruction_id=str(i), + register=beam_fn_api_pb2.RegisterRequest( + process_bundle_descriptor=process_bundle_descriptors))) + + test_controller = BeamFnControlServicer(requests) + + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server( + test_controller, server) + test_port = server.add_insecure_port("[::]:0") + server.start() + + harness = sdk_worker.SdkHarness( + "localhost:%s" % test_port, worker_count=worker_count) + harness.run() + + for worker in harness.workers.queue: + self.assertEqual(worker.fns, + {item.id: item + for item in process_bundle_descriptors}) + + def test_fn_registration(self): + self._check_fn_registration_multi_request((1, 4, 1), (4, 4, 1), (4, 4, 2)) if __name__ == "__main__": ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > Python Fnapi - Worker speedup > ----------------------------- > > Key: BEAM-3189 > URL: https://issues.apache.org/jira/browse/BEAM-3189 > Project: Beam > Issue Type: Improvement > Components: sdk-py-harness > Affects Versions: 2.3.0 > Reporter: Ankur Goenka > Assignee: Ankur Goenka > Priority: Minor > Labels: performance, portability > > Beam Python SDK is couple of magnitude slower than Java SDK when it comes to > stream processing. > There are two related issues: > # Given a single core, currently we are not fully utilizing the core because > the single thread spends a lot of time on the IO. This is more of a > limitation of our implementation rather than a limitation of Python. > # Given a machine with multiple cores, single Python process could only > utilize one core. > In this task we will only address 1. 2 will be good for future optimization. -- This message was sent by Atlassian JIRA (v6.4.14#64029)