http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/logger_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/logger_test.py b/sdks/python/apache_beam/runners/worker/logger_test.py new file mode 100644 index 0000000..cf3f692 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/logger_test.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for worker logging utilities.""" + +import json +import logging +import sys +import threading +import unittest + +from apache_beam.runners.worker import logger + + +class PerThreadLoggingContextTest(unittest.TestCase): + + def thread_check_attribute(self, name): + self.assertFalse(name in logger.per_thread_worker_data.get_data()) + with logger.PerThreadLoggingContext(**{name: 'thread-value'}): + self.assertEqual( + logger.per_thread_worker_data.get_data()[name], 'thread-value') + self.assertFalse(name in logger.per_thread_worker_data.get_data()) + + def test_per_thread_attribute(self): + self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) + with logger.PerThreadLoggingContext(xyz='value'): + self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') + thread = threading.Thread( + target=self.thread_check_attribute, args=('xyz',)) + thread.start() + thread.join() + self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') + self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) + + def test_set_when_undefined(self): + self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) + with logger.PerThreadLoggingContext(xyz='value'): + self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') + self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) + + def test_set_when_already_defined(self): + self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) + with logger.PerThreadLoggingContext(xyz='value'): + self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') + with logger.PerThreadLoggingContext(xyz='value2'): + self.assertEqual( + logger.per_thread_worker_data.get_data()['xyz'], 'value2') + self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') + self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) + + +class JsonLogFormatterTest(unittest.TestCase): + + SAMPLE_RECORD = { + 'created': 123456.789, 'msecs': 789.654321, + 'msg': '%s:%d:%.2f', 'args': ('xyz', 4, 3.14), + 'levelname': 'WARNING', + 'process': 'pid', 'thread': 'tid', + 'name': 'name', 'filename': 'file', 'funcName': 'func', + 'exc_info': None} + + SAMPLE_OUTPUT = { + 'timestamp': {'seconds': 123456, 'nanos': 789654321}, + 'severity': 'WARN', 'message': 'xyz:4:3.14', 'thread': 'pid:tid', + 'job': 'jobid', 'worker': 'workerid', 'logger': 'name:file:func'} + + def create_log_record(self, **kwargs): + + class Record(object): + + def __init__(self, **kwargs): + for k, v in kwargs.iteritems(): + setattr(self, k, v) + + return Record(**kwargs) + + def test_basic_record(self): + formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') + record = self.create_log_record(**self.SAMPLE_RECORD) + self.assertEqual(json.loads(formatter.format(record)), self.SAMPLE_OUTPUT) + + def execute_multiple_cases(self, test_cases): + record = self.SAMPLE_RECORD + output = self.SAMPLE_OUTPUT + formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') + + for case in test_cases: + record['msg'] = case['msg'] + record['args'] = case['args'] + output['message'] = case['expected'] + + self.assertEqual( + json.loads(formatter.format(self.create_log_record(**record))), + output) + + def test_record_with_format_character(self): + test_cases = [ + {'msg': '%A', 'args': (), 'expected': '%A'}, + {'msg': '%s', 'args': (), 'expected': '%s'}, + {'msg': '%A%s', 'args': ('xy'), 'expected': '%A%s with args (xy)'}, + {'msg': '%s%s', 'args': (1), 'expected': '%s%s with args (1)'}, + ] + + self.execute_multiple_cases(test_cases) + + def test_record_with_arbitrary_messages(self): + test_cases = [ + {'msg': ImportError('abc'), 'args': (), 'expected': 'abc'}, + {'msg': TypeError('abc %s'), 'args': ('def'), 'expected': 'abc def'}, + ] + + self.execute_multiple_cases(test_cases) + + def test_record_with_per_thread_info(self): + with logger.PerThreadLoggingContext( + work_item_id='workitem', stage_name='stage', step_name='step'): + formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') + record = self.create_log_record(**self.SAMPLE_RECORD) + log_output = json.loads(formatter.format(record)) + expected_output = dict(self.SAMPLE_OUTPUT) + expected_output.update( + {'work': 'workitem', 'stage': 'stage', 'step': 'step'}) + self.assertEqual(log_output, expected_output) + + def test_nested_with_per_thread_info(self): + formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') + with logger.PerThreadLoggingContext( + work_item_id='workitem', stage_name='stage', step_name='step1'): + record = self.create_log_record(**self.SAMPLE_RECORD) + log_output1 = json.loads(formatter.format(record)) + + with logger.PerThreadLoggingContext(step_name='step2'): + record = self.create_log_record(**self.SAMPLE_RECORD) + log_output2 = json.loads(formatter.format(record)) + + record = self.create_log_record(**self.SAMPLE_RECORD) + log_output3 = json.loads(formatter.format(record)) + + record = self.create_log_record(**self.SAMPLE_RECORD) + log_output4 = json.loads(formatter.format(record)) + + self.assertEqual(log_output1, dict( + self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1')) + self.assertEqual(log_output2, dict( + self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step2')) + self.assertEqual(log_output3, dict( + self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1')) + self.assertEqual(log_output4, self.SAMPLE_OUTPUT) + + def test_exception_record(self): + formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') + try: + raise ValueError('Something') + except ValueError: + attribs = dict(self.SAMPLE_RECORD) + attribs.update({'exc_info': sys.exc_info()}) + record = self.create_log_record(**attribs) + log_output = json.loads(formatter.format(record)) + # Check if exception type, its message, and stack trace information are in. + exn_output = log_output.pop('exception') + self.assertNotEqual(exn_output.find('ValueError: Something'), -1) + self.assertNotEqual(exn_output.find('logger_test.py'), -1) + self.assertEqual(log_output, self.SAMPLE_OUTPUT) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/opcounters.pxd ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/opcounters.pxd b/sdks/python/apache_beam/runners/worker/opcounters.pxd new file mode 100644 index 0000000..5c1079f --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/opcounters.pxd @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cimport cython +cimport libc.stdint + +from apache_beam.utils.counters cimport Counter + + +cdef class SumAccumulator(object): + cdef libc.stdint.int64_t _value + cpdef update(self, libc.stdint.int64_t value) + cpdef libc.stdint.int64_t value(self) + + +cdef class OperationCounters(object): + cdef public _counter_factory + cdef public Counter element_counter + cdef public Counter mean_byte_counter + cdef public coder_impl + cdef public SumAccumulator active_accumulator + cdef public libc.stdint.int64_t _sample_counter + cdef public libc.stdint.int64_t _next_sample + + cpdef update_from(self, windowed_value) + cdef inline do_sample(self, windowed_value) + cpdef update_collect(self) + + cdef libc.stdint.int64_t _compute_next_sample(self, libc.stdint.int64_t i) + cdef inline bint _should_sample(self) + cpdef bint should_sample(self) http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/opcounters.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/opcounters.py b/sdks/python/apache_beam/runners/worker/opcounters.py new file mode 100644 index 0000000..56ce0db --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/opcounters.py @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# cython: profile=True + +"""Counters collect the progress of the Worker for reporting to the service.""" + +from __future__ import absolute_import +import math +import random + +from apache_beam.utils.counters import Counter + + +class SumAccumulator(object): + """Accumulator for collecting byte counts.""" + + def __init__(self): + self._value = 0 + + def update(self, value): + self._value += value + + def value(self): + return self._value + + +class OperationCounters(object): + """The set of basic counters to attach to an Operation.""" + + def __init__(self, counter_factory, step_name, coder, output_index): + self._counter_factory = counter_factory + self.element_counter = counter_factory.get_counter( + '%s-out%d-ElementCount' % (step_name, output_index), Counter.SUM) + self.mean_byte_counter = counter_factory.get_counter( + '%s-out%d-MeanByteCount' % (step_name, output_index), Counter.MEAN) + self.coder_impl = coder.get_impl() + self.active_accumulator = None + self._sample_counter = 0 + self._next_sample = 0 + + def update_from(self, windowed_value): + """Add one value to this counter.""" + self.element_counter.update(1) + if self._should_sample(): + self.do_sample(windowed_value) + + def _observable_callback(self, inner_coder_impl, accumulator): + def _observable_callback_inner(value, is_encoded=False): + # TODO(ccy): If this stream is large, sample it as well. + # To do this, we'll need to compute the average size of elements + # in this stream to add the *total* size of this stream to accumulator. + # We'll also want make sure we sample at least some of this stream + # (as self.should_sample() may be sampling very sparsely by now). + if is_encoded: + size = len(value) + accumulator.update(size) + else: + accumulator.update(inner_coder_impl.estimate_size(value)) + return _observable_callback_inner + + def do_sample(self, windowed_value): + size, observables = ( + self.coder_impl.get_estimated_size_and_observables(windowed_value)) + if not observables: + self.mean_byte_counter.update(size) + else: + self.active_accumulator = SumAccumulator() + self.active_accumulator.update(size) + for observable, inner_coder_impl in observables: + observable.register_observer( + self._observable_callback( + inner_coder_impl, self.active_accumulator)) + + def update_collect(self): + """Collects the accumulated size estimates. + + Now that the element has been processed, we ask our accumulator + for the total and store the result in a counter. + """ + if self.active_accumulator is not None: + self.mean_byte_counter.update(self.active_accumulator.value()) + self.active_accumulator = None + + def _compute_next_sample(self, i): + # https://en.wikipedia.org/wiki/Reservoir_sampling#Fast_Approximation + gap = math.log(1.0 - random.random()) / math.log(1.0 - 10.0/i) + return i + math.floor(gap) + + def _should_sample(self): + """Determines whether to sample the next element. + + Size calculation can be expensive, so we don't do it for each element. + Because we need only an estimate of average size, we sample. + + We always sample the first 10 elements, then the sampling rate + is approximately 10/N. After reading N elements, of the next N, + we will sample approximately 10*ln(2) (about 7) elements. + + This algorithm samples at the same rate as Reservoir Sampling, but + it never throws away early results. (Because we keep only a + running accumulation, storage is not a problem, so there is no + need to discard earlier calculations.) + + Because we accumulate and do not replace, our statistics are + biased toward early data. If the data are distributed uniformly, + this is not a problem. If the data change over time (i.e., the + element size tends to grow or shrink over time), our estimate will + show the bias. We could correct this by giving weight N to each + sample, since each sample is a stand-in for the N/(10*ln(2)) + samples around it, which is proportional to N. Since we do not + expect biased data, for efficiency we omit the extra multiplication. + We could reduce the early-data bias by putting a lower bound on + the sampling rate. + + Computing random.randint(1, self._sample_counter) for each element + is too slow, so when the sample size is big enough (we estimate 30 + is big enough), we estimate the size of the gap after each sample. + This estimation allows us to call random much less often. + + Returns: + True if it is time to compute another element's size. + """ + + self._sample_counter += 1 + if self._next_sample == 0: + if random.randint(1, self._sample_counter) <= 10: + if self._sample_counter > 30: + self._next_sample = self._compute_next_sample(self._sample_counter) + return True + return False + elif self._sample_counter >= self._next_sample: + self._next_sample = self._compute_next_sample(self._sample_counter) + return True + return False + + def should_sample(self): + # We create this separate method because the above "_should_sample()" method + # is marked as inline in Cython and thus can't be exposed to Python code. + return self._should_sample() + + def __str__(self): + return '<%s [%s]>' % (self.__class__.__name__, + ', '.join([str(x) for x in self.__iter__()])) + + def __repr__(self): + return '<%s %s at %s>' % (self.__class__.__name__, + [x for x in self.__iter__()], hex(id(self))) http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/opcounters_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/opcounters_test.py b/sdks/python/apache_beam/runners/worker/opcounters_test.py new file mode 100644 index 0000000..74561b8 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/opcounters_test.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import math +import random +import unittest + +from apache_beam import coders +from apache_beam.runners.worker.opcounters import OperationCounters +from apache_beam.transforms.window import GlobalWindows +from apache_beam.utils.counters import CounterFactory + + +# Classes to test that we can handle a variety of objects. +# These have to be at top level so the pickler can find them. + + +class OldClassThatDoesNotImplementLen: # pylint: disable=old-style-class + + def __init__(self): + pass + + +class ObjectThatDoesNotImplementLen(object): + + def __init__(self): + pass + + +class OperationCountersTest(unittest.TestCase): + + def verify_counters(self, opcounts, expected_elements, expected_size=None): + self.assertEqual(expected_elements, opcounts.element_counter.value()) + if expected_size is not None: + if math.isnan(expected_size): + self.assertTrue(math.isnan(opcounts.mean_byte_counter.value())) + else: + self.assertEqual(expected_size, opcounts.mean_byte_counter.value()) + + def test_update_int(self): + opcounts = OperationCounters(CounterFactory(), 'some-name', + coders.PickleCoder(), 0) + self.verify_counters(opcounts, 0) + opcounts.update_from(GlobalWindows.windowed_value(1)) + self.verify_counters(opcounts, 1) + + def test_update_str(self): + coder = coders.PickleCoder() + opcounts = OperationCounters(CounterFactory(), 'some-name', + coder, 0) + self.verify_counters(opcounts, 0, float('nan')) + value = GlobalWindows.windowed_value('abcde') + opcounts.update_from(value) + estimated_size = coder.estimate_size(value) + self.verify_counters(opcounts, 1, estimated_size) + + def test_update_old_object(self): + coder = coders.PickleCoder() + opcounts = OperationCounters(CounterFactory(), 'some-name', + coder, 0) + self.verify_counters(opcounts, 0, float('nan')) + obj = OldClassThatDoesNotImplementLen() + value = GlobalWindows.windowed_value(obj) + opcounts.update_from(value) + estimated_size = coder.estimate_size(value) + self.verify_counters(opcounts, 1, estimated_size) + + def test_update_new_object(self): + coder = coders.PickleCoder() + opcounts = OperationCounters(CounterFactory(), 'some-name', + coder, 0) + self.verify_counters(opcounts, 0, float('nan')) + + obj = ObjectThatDoesNotImplementLen() + value = GlobalWindows.windowed_value(obj) + opcounts.update_from(value) + estimated_size = coder.estimate_size(value) + self.verify_counters(opcounts, 1, estimated_size) + + def test_update_multiple(self): + coder = coders.PickleCoder() + total_size = 0 + opcounts = OperationCounters(CounterFactory(), 'some-name', + coder, 0) + self.verify_counters(opcounts, 0, float('nan')) + value = GlobalWindows.windowed_value('abcde') + opcounts.update_from(value) + total_size += coder.estimate_size(value) + value = GlobalWindows.windowed_value('defghij') + opcounts.update_from(value) + total_size += coder.estimate_size(value) + self.verify_counters(opcounts, 2, float(total_size) / 2) + value = GlobalWindows.windowed_value('klmnop') + opcounts.update_from(value) + total_size += coder.estimate_size(value) + self.verify_counters(opcounts, 3, float(total_size) / 3) + + def test_should_sample(self): + # Order of magnitude more buckets than highest constant in code under test. + buckets = [0] * 300 + # The seed is arbitrary and exists just to ensure this test is robust. + # If you don't like this seed, try your own; the test should still pass. + random.seed(1717) + # Do enough runs that the expected hits even in the last buckets + # is big enough to expect some statistical smoothing. + total_runs = 10 * len(buckets) + + # Fill the buckets. + for _ in xrange(total_runs): + opcounts = OperationCounters(CounterFactory(), 'some-name', + coders.PickleCoder(), 0) + for i in xrange(len(buckets)): + if opcounts.should_sample(): + buckets[i] += 1 + + # Look at the buckets to see if they are likely. + for i in xrange(10): + self.assertEqual(total_runs, buckets[i]) + for i in xrange(10, len(buckets)): + self.assertTrue(buckets[i] > 7 * total_runs / i, + 'i=%d, buckets[i]=%d, expected=%d, ratio=%f' % ( + i, buckets[i], + 10 * total_runs / i, + buckets[i] / (10.0 * total_runs / i))) + self.assertTrue(buckets[i] < 14 * total_runs / i, + 'i=%d, buckets[i]=%d, expected=%d, ratio=%f' % ( + i, buckets[i], + 10 * total_runs / i, + buckets[i] / (10.0 * total_runs / i))) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/operation_specs.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/operation_specs.py b/sdks/python/apache_beam/runners/worker/operation_specs.py new file mode 100644 index 0000000..977e165 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/operation_specs.py @@ -0,0 +1,368 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Worker utilities for representing MapTasks. + +Each MapTask represents a sequence of ParallelInstruction(s): read from a +source, write to a sink, parallel do, etc. +""" + +import collections + +from apache_beam import coders + + +def build_worker_instruction(*args): + """Create an object representing a ParallelInstruction protobuf. + + This will be a collections.namedtuple with a custom __str__ method. + + Alas, this wrapper is not known to pylint, which thinks it creates + constants. You may have to put a disable=invalid-name pylint + annotation on any use of this, depending on your names. + + Args: + *args: first argument is the name of the type to create. Should + start with "Worker". Second arguments is alist of the + attributes of this object. + Returns: + A new class, a subclass of tuple, that represents the protobuf. + """ + tuple_class = collections.namedtuple(*args) + tuple_class.__str__ = worker_object_to_string + tuple_class.__repr__ = worker_object_to_string + return tuple_class + + +def worker_printable_fields(workerproto): + """Returns the interesting fields of a Worker* object.""" + return ['%s=%s' % (name, value) + # _asdict is the only way and cannot subclass this generated class + # pylint: disable=protected-access + for name, value in workerproto._asdict().iteritems() + # want to output value 0 but not None nor [] + if (value or value == 0) + and name not in + ('coder', 'coders', 'output_coders', + 'elements', + 'combine_fn', 'serialized_fn', 'window_fn', + 'append_trailing_newlines', 'strip_trailing_newlines', + 'compression_type', 'context', + 'start_shuffle_position', 'end_shuffle_position', + 'shuffle_reader_config', 'shuffle_writer_config')] + + +def worker_object_to_string(worker_object): + """Returns a string compactly representing a Worker* object.""" + return '%s(%s)' % (worker_object.__class__.__name__, + ', '.join(worker_printable_fields(worker_object))) + + +# All the following Worker* definitions will have these lint problems: +# pylint: disable=invalid-name +# pylint: disable=pointless-string-statement + + +WorkerRead = build_worker_instruction( + 'WorkerRead', ['source', 'output_coders']) +"""Worker details needed to read from a source. + +Attributes: + source: a source object. + output_coders: 1-tuple of the coder for the output. +""" + + +WorkerSideInputSource = build_worker_instruction( + 'WorkerSideInputSource', ['source', 'tag']) +"""Worker details needed to read from a side input source. + +Attributes: + source: a source object. + tag: string tag for this side input. +""" + + +WorkerGroupingShuffleRead = build_worker_instruction( + 'WorkerGroupingShuffleRead', + ['start_shuffle_position', 'end_shuffle_position', + 'shuffle_reader_config', 'coder', 'output_coders']) +"""Worker details needed to read from a grouping shuffle source. + +Attributes: + start_shuffle_position: An opaque string to be passed to the shuffle + source to indicate where to start reading. + end_shuffle_position: An opaque string to be passed to the shuffle + source to indicate where to stop reading. + shuffle_reader_config: An opaque string used to initialize the shuffle + reader. Contains things like connection endpoints for the shuffle + server appliance and various options. + coder: The KV coder used to decode shuffle entries. + output_coders: 1-tuple of the coder for the output. +""" + + +WorkerUngroupedShuffleRead = build_worker_instruction( + 'WorkerUngroupedShuffleRead', + ['start_shuffle_position', 'end_shuffle_position', + 'shuffle_reader_config', 'coder', 'output_coders']) +"""Worker details needed to read from an ungrouped shuffle source. + +Attributes: + start_shuffle_position: An opaque string to be passed to the shuffle + source to indicate where to start reading. + end_shuffle_position: An opaque string to be passed to the shuffle + source to indicate where to stop reading. + shuffle_reader_config: An opaque string used to initialize the shuffle + reader. Contains things like connection endpoints for the shuffle + server appliance and various options. + coder: The value coder used to decode shuffle entries. +""" + + +WorkerWrite = build_worker_instruction( + 'WorkerWrite', ['sink', 'input', 'output_coders']) +"""Worker details needed to write to a sink. + +Attributes: + sink: a sink object. + input: A (producer index, output index) tuple representing the + ParallelInstruction operation whose output feeds into this operation. + The output index is 0 except for multi-output operations (like ParDo). + output_coders: 1-tuple, coder to use to estimate bytes written. +""" + + +WorkerInMemoryWrite = build_worker_instruction( + 'WorkerInMemoryWrite', + ['output_buffer', 'write_windowed_values', 'input', 'output_coders']) +"""Worker details needed to write to a in-memory sink. + +Used only for unit testing. It makes worker tests less cluttered with code like +"write to a file and then check file contents". + +Attributes: + output_buffer: list to which output elements will be appended + write_windowed_values: whether to record the entire WindowedValue outputs, + or just the raw (unwindowed) value + input: A (producer index, output index) tuple representing the + ParallelInstruction operation whose output feeds into this operation. + The output index is 0 except for multi-output operations (like ParDo). + output_coders: 1-tuple, coder to use to estimate bytes written. +""" + + +WorkerShuffleWrite = build_worker_instruction( + 'WorkerShuffleWrite', + ['shuffle_kind', 'shuffle_writer_config', 'input', 'output_coders']) +"""Worker details needed to write to a shuffle sink. + +Attributes: + shuffle_kind: A string describing the shuffle kind. This can control the + way the worker interacts with the shuffle sink. The possible values are: + 'ungrouped', 'group_keys', and 'group_keys_and_sort_values'. + shuffle_writer_config: An opaque string used to initialize the shuffle + write. Contains things like connection endpoints for the shuffle + server appliance and various options. + input: A (producer index, output index) tuple representing the + ParallelInstruction operation whose output feeds into this operation. + The output index is 0 except for multi-output operations (like ParDo). + output_coders: 1-tuple of the coder for input elements. If the + shuffle_kind is grouping, this is expected to be a KV coder. +""" + + +WorkerDoFn = build_worker_instruction( + 'WorkerDoFn', + ['serialized_fn', 'output_tags', 'input', 'side_inputs', 'output_coders']) +"""Worker details needed to run a DoFn. +Attributes: + serialized_fn: A serialized DoFn object to be run for each input element. + output_tags: The string tags used to identify the outputs of a ParDo + operation. The tag is present even if the ParDo has just one output + (e.g., ['out']. + output_coders: array of coders, one for each output. + input: A (producer index, output index) tuple representing the + ParallelInstruction operation whose output feeds into this operation. + The output index is 0 except for multi-output operations (like ParDo). + side_inputs: A list of Worker...Read instances describing sources to be + used for getting values. The types supported right now are + WorkerInMemoryRead and WorkerTextRead. +""" + + +WorkerReifyTimestampAndWindows = build_worker_instruction( + 'WorkerReifyTimestampAndWindows', + ['output_tags', 'input', 'output_coders']) +"""Worker details needed to run a WindowInto. +Attributes: + output_tags: The string tags used to identify the outputs of a ParDo + operation. The tag is present even if the ParDo has just one output + (e.g., ['out']. + output_coders: array of coders, one for each output. + input: A (producer index, output index) tuple representing the + ParallelInstruction operation whose output feeds into this operation. + The output index is 0 except for multi-output operations (like ParDo). +""" + + +WorkerMergeWindows = build_worker_instruction( + 'WorkerMergeWindows', + ['window_fn', 'combine_fn', 'phase', 'output_tags', 'input', 'coders', + 'context', 'output_coders']) +"""Worker details needed to run a MergeWindows (aka. GroupAlsoByWindows). +Attributes: + window_fn: A serialized Windowing object representing the windowing strategy. + combine_fn: A serialized CombineFn object to be used after executing the + GroupAlsoByWindows operation. May be None if not a combining operation. + phase: Possible values are 'all', 'add', 'merge', and 'extract'. + A runner optimizer may split the user combiner in 3 separate + phases (ADD, MERGE, and EXTRACT), on separate VMs, as it sees + fit. The phase attribute dictates which DoFn is actually running in + the worker. May be None if not a combining operation. + output_tags: The string tags used to identify the outputs of a ParDo + operation. The tag is present even if the ParDo has just one output + (e.g., ['out']. + output_coders: array of coders, one for each output. + input: A (producer index, output index) tuple representing the + ParallelInstruction operation whose output feeds into this operation. + The output index is 0 except for multi-output operations (like ParDo). + coders: A 2-tuple of coders (key, value) to encode shuffle entries. + context: The ExecutionContext object for the current work item. +""" + + +WorkerCombineFn = build_worker_instruction( + 'WorkerCombineFn', + ['serialized_fn', 'phase', 'input', 'output_coders']) +"""Worker details needed to run a CombineFn. +Attributes: + serialized_fn: A serialized CombineFn object to be used. + phase: Possible values are 'all', 'add', 'merge', and 'extract'. + A runner optimizer may split the user combiner in 3 separate + phases (ADD, MERGE, and EXTRACT), on separate VMs, as it sees + fit. The phase attribute dictates which DoFn is actually running in + the worker. + input: A (producer index, output index) tuple representing the + ParallelInstruction operation whose output feeds into this operation. + The output index is 0 except for multi-output operations (like ParDo). + output_coders: 1-tuple of the coder for the output. +""" + + +WorkerPartialGroupByKey = build_worker_instruction( + 'WorkerPartialGroupByKey', + ['combine_fn', 'input', 'output_coders']) +"""Worker details needed to run a partial group-by-key. +Attributes: + combine_fn: A serialized CombineFn object to be used. + input: A (producer index, output index) tuple representing the + ParallelInstruction operation whose output feeds into this operation. + The output index is 0 except for multi-output operations (like ParDo). + output_coders: 1-tuple of the coder for the output. +""" + + +WorkerFlatten = build_worker_instruction( + 'WorkerFlatten', + ['inputs', 'output_coders']) +"""Worker details needed to run a Flatten. +Attributes: + inputs: A list of tuples, each (producer index, output index), representing + the ParallelInstruction operations whose output feeds into this operation. + The output index is 0 unless the input is from a multi-output + operation (such as ParDo). + output_coders: 1-tuple of the coder for the output. +""" + + +def get_coder_from_spec(coder_spec): + """Return a coder instance from a coder spec. + + Args: + coder_spec: A dict where the value of the '@type' key is a pickled instance + of a Coder instance. + + Returns: + A coder instance (has encode/decode methods). + """ + assert coder_spec is not None + + # Ignore the wrappers in these encodings. + # TODO(silviuc): Make sure with all the renamings that names below are ok. + if coder_spec['@type'] in ignored_wrappers: + assert len(coder_spec['component_encodings']) == 1 + coder_spec = coder_spec['component_encodings'][0] + return get_coder_from_spec(coder_spec) + + # Handle a few well known types of coders. + if coder_spec['@type'] == 'kind:pair': + assert len(coder_spec['component_encodings']) == 2 + component_coders = [ + get_coder_from_spec(c) for c in coder_spec['component_encodings']] + return coders.TupleCoder(component_coders) + elif coder_spec['@type'] == 'kind:stream': + assert len(coder_spec['component_encodings']) == 1 + return coders.IterableCoder( + get_coder_from_spec(coder_spec['component_encodings'][0])) + elif coder_spec['@type'] == 'kind:windowed_value': + assert len(coder_spec['component_encodings']) == 2 + value_coder, window_coder = [ + get_coder_from_spec(c) for c in coder_spec['component_encodings']] + return coders.WindowedValueCoder(value_coder, window_coder=window_coder) + elif coder_spec['@type'] == 'kind:interval_window': + assert ('component_encodings' not in coder_spec + or len(coder_spec['component_encodings'] == 0)) + return coders.IntervalWindowCoder() + elif coder_spec['@type'] == 'kind:global_window': + assert ('component_encodings' not in coder_spec + or not coder_spec['component_encodings']) + return coders.GlobalWindowCoder() + elif coder_spec['@type'] == 'kind:length_prefix': + assert len(coder_spec['component_encodings']) == 1 + return coders.LengthPrefixCoder( + get_coder_from_spec(coder_spec['component_encodings'][0])) + + # We pass coders in the form "<coder_name>$<pickled_data>" to make the job + # description JSON more readable. + return coders.deserialize_coder(coder_spec['@type']) + + +class MapTask(object): + """A map task decoded into operations and ready to be executed. + + Attributes: + operations: A list of Worker* object created by parsing the instructions + within the map task. + stage_name: The name of this map task execution stage. + system_names: The system names of the step corresponding to each map task + operation in the execution graph. + step_names: The names of the step corresponding to each map task operation. + original_names: The internal name of a step in the original workflow graph. + """ + + def __init__( + self, operations, stage_name, system_names, step_names, original_names): + self.operations = operations + self.stage_name = stage_name + self.system_names = system_names + self.step_names = step_names + self.original_names = original_names + + def __str__(self): + return '<%s %s steps=%s>' % (self.__class__.__name__, self.stage_name, + '+'.join(self.step_names)) http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/operations.pxd ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd new file mode 100644 index 0000000..2b4e526 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/operations.pxd @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cimport cython + +from apache_beam.runners.common cimport Receiver +from apache_beam.runners.worker cimport opcounters +from apache_beam.utils.windowed_value cimport WindowedValue +from apache_beam.metrics.execution cimport ScopedMetricsContainer + + +cdef WindowedValue _globally_windowed_value +cdef type _global_window_type + +cdef class ConsumerSet(Receiver): + cdef list consumers + cdef opcounters.OperationCounters opcounter + cdef public step_name + cdef public output_index + cdef public coder + + cpdef receive(self, WindowedValue windowed_value) + cpdef update_counters_start(self, WindowedValue windowed_value) + cpdef update_counters_finish(self) + + +cdef class Operation(object): + cdef readonly operation_name + cdef readonly spec + cdef object consumers + cdef readonly counter_factory + cdef public metrics_container + cdef public ScopedMetricsContainer scoped_metrics_container + # Public for access by Fn harness operations. + # TODO(robertwb): Cythonize FnHarness. + cdef public list receivers + cdef readonly bint debug_logging_enabled + + cdef public step_name # initialized lazily + + cdef readonly object state_sampler + + cdef readonly object scoped_start_state + cdef readonly object scoped_process_state + cdef readonly object scoped_finish_state + + cpdef start(self) + cpdef process(self, WindowedValue windowed_value) + cpdef finish(self) + + cpdef output(self, WindowedValue windowed_value, int output_index=*) + +cdef class ReadOperation(Operation): + @cython.locals(windowed_value=WindowedValue) + cpdef start(self) + +cdef class DoOperation(Operation): + cdef object dofn_runner + cdef Receiver dofn_receiver + +cdef class CombineOperation(Operation): + cdef object phased_combine_fn + +cdef class FlattenOperation(Operation): + pass + +cdef class PGBKCVOperation(Operation): + cdef public object combine_fn + cdef public object combine_fn_add_input + cdef dict table + cdef long max_keys + cdef long key_count + + cpdef output_key(self, tuple wkey, value) + http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/operations.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py new file mode 100644 index 0000000..5dbe57e --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -0,0 +1,651 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# cython: profile=True + +"""Worker operations executor.""" + +import collections +import itertools +import logging + +from apache_beam import pvalue +from apache_beam.internal import pickler +from apache_beam.io import iobase +from apache_beam.metrics.execution import MetricsContainer +from apache_beam.metrics.execution import ScopedMetricsContainer +from apache_beam.runners import common +from apache_beam.runners.common import Receiver +from apache_beam.runners.dataflow.internal.names import PropertyNames +from apache_beam.runners.worker import logger +from apache_beam.runners.worker import opcounters +from apache_beam.runners.worker import operation_specs +from apache_beam.runners.worker import sideinputs +from apache_beam.transforms import combiners +from apache_beam.transforms import core +from apache_beam.transforms import sideinputs as apache_sideinputs +from apache_beam.transforms.combiners import curry_combine_fn +from apache_beam.transforms.combiners import PhasedCombineFnExecutor +from apache_beam.transforms.window import GlobalWindows +from apache_beam.utils.windowed_value import WindowedValue + +# Allow some "pure mode" declarations. +try: + import cython +except ImportError: + class FakeCython(object): + @staticmethod + def cast(type, value): + return value + globals()['cython'] = FakeCython() + + +_globally_windowed_value = GlobalWindows.windowed_value(None) +_global_window_type = type(_globally_windowed_value.windows[0]) + + +class ConsumerSet(Receiver): + """A ConsumerSet represents a graph edge between two Operation nodes. + + The ConsumerSet object collects information from the output of the + Operation at one end of its edge and the input of the Operation at + the other edge. + ConsumerSet are attached to the outputting Operation. + """ + + def __init__( + self, counter_factory, step_name, output_index, consumers, coder): + self.consumers = consumers + self.opcounter = opcounters.OperationCounters( + counter_factory, step_name, coder, output_index) + # Used in repr. + self.step_name = step_name + self.output_index = output_index + self.coder = coder + + def output(self, windowed_value): # For old SDKs. + self.receive(windowed_value) + + def receive(self, windowed_value): + self.update_counters_start(windowed_value) + for consumer in self.consumers: + cython.cast(Operation, consumer).process(windowed_value) + self.update_counters_finish() + + def update_counters_start(self, windowed_value): + self.opcounter.update_from(windowed_value) + + def update_counters_finish(self): + self.opcounter.update_collect() + + def __repr__(self): + return '%s[%s.out%s, coder=%s, len(consumers)=%s]' % ( + self.__class__.__name__, self.step_name, self.output_index, self.coder, + len(self.consumers)) + + +class Operation(object): + """An operation representing the live version of a work item specification. + + An operation can have one or more outputs and for each output it can have + one or more receiver operations that will take that as input. + """ + + def __init__(self, operation_name, spec, counter_factory, state_sampler): + """Initializes a worker operation instance. + + Args: + operation_name: The system name assigned by the runner for this + operation. + spec: A operation_specs.Worker* instance. + counter_factory: The CounterFactory to use for our counters. + state_sampler: The StateSampler for the current operation. + """ + self.operation_name = operation_name + self.spec = spec + self.counter_factory = counter_factory + self.consumers = collections.defaultdict(list) + + self.state_sampler = state_sampler + self.scoped_start_state = self.state_sampler.scoped_state( + self.operation_name + '-start') + self.scoped_process_state = self.state_sampler.scoped_state( + self.operation_name + '-process') + self.scoped_finish_state = self.state_sampler.scoped_state( + self.operation_name + '-finish') + # TODO(ccy): the '-abort' state can be added when the abort is supported in + # Operations. + + def start(self): + """Start operation.""" + self.debug_logging_enabled = logging.getLogger().isEnabledFor( + logging.DEBUG) + # Everything except WorkerSideInputSource, which is not a + # top-level operation, should have output_coders + if getattr(self.spec, 'output_coders', None): + self.receivers = [ConsumerSet(self.counter_factory, self.step_name, + i, self.consumers[i], coder) + for i, coder in enumerate(self.spec.output_coders)] + + def finish(self): + """Finish operation.""" + pass + + def process(self, o): + """Process element in operation.""" + pass + + def output(self, windowed_value, output_index=0): + cython.cast(Receiver, self.receivers[output_index]).receive(windowed_value) + + def add_receiver(self, operation, output_index=0): + """Adds a receiver operation for the specified output.""" + self.consumers[output_index].append(operation) + + def __str__(self): + """Generates a useful string for this object. + + Compactly displays interesting fields. In particular, pickled + fields are not displayed. Note that we collapse the fields of the + contained Worker* object into this object, since there is a 1-1 + mapping between Operation and operation_specs.Worker*. + + Returns: + Compact string representing this object. + """ + return self.str_internal() + + def str_internal(self, is_recursive=False): + """Internal helper for __str__ that supports recursion. + + When recursing on receivers, keep the output short. + Args: + is_recursive: whether to omit some details, particularly receivers. + Returns: + Compact string representing this object. + """ + printable_name = self.__class__.__name__ + if hasattr(self, 'step_name'): + printable_name += ' %s' % self.step_name + if is_recursive: + # If we have a step name, stop here, no more detail needed. + return '<%s>' % printable_name + + if self.spec is None: + printable_fields = [] + else: + printable_fields = operation_specs.worker_printable_fields(self.spec) + + if not is_recursive and getattr(self, 'receivers', []): + printable_fields.append('receivers=[%s]' % ', '.join([ + str(receiver) for receiver in self.receivers])) + + return '<%s %s>' % (printable_name, ', '.join(printable_fields)) + + +class ReadOperation(Operation): + + def start(self): + with self.scoped_start_state: + super(ReadOperation, self).start() + range_tracker = self.spec.source.source.get_range_tracker( + self.spec.source.start_position, self.spec.source.stop_position) + for value in self.spec.source.source.read(range_tracker): + if isinstance(value, WindowedValue): + windowed_value = value + else: + windowed_value = _globally_windowed_value.with_value(value) + self.output(windowed_value) + + +class InMemoryWriteOperation(Operation): + """A write operation that will write to an in-memory sink.""" + + def process(self, o): + with self.scoped_process_state: + if self.debug_logging_enabled: + logging.debug('Processing [%s] in %s', o, self) + self.spec.output_buffer.append( + o if self.spec.write_windowed_values else o.value) + + +class _TaggedReceivers(dict): + + class NullReceiver(Receiver): + + def receive(self, element): + pass + + # For old SDKs. + def output(self, element): + pass + + def __missing__(self, unused_key): + if not getattr(self, '_null_receiver', None): + self._null_receiver = _TaggedReceivers.NullReceiver() + return self._null_receiver + + +class DoOperation(Operation): + """A Do operation that will execute a custom DoFn for each input element.""" + + def _read_side_inputs(self, tags_and_types): + """Generator reading side inputs in the order prescribed by tags_and_types. + + Args: + tags_and_types: List of tuples (tag, type). Each side input has a string + tag that is specified in the worker instruction. The type is actually + a boolean which is True for singleton input (read just first value) + and False for collection input (read all values). + + Yields: + With each iteration it yields the result of reading an entire side source + either in singleton or collection mode according to the tags_and_types + argument. + """ + # We will read the side inputs in the order prescribed by the + # tags_and_types argument because this is exactly the order needed to + # replace the ArgumentPlaceholder objects in the args/kwargs of the DoFn + # getting the side inputs. + # + # Note that for each tag there could be several read operations in the + # specification. This can happen for instance if the source has been + # sharded into several files. + for side_tag, view_class, view_options in tags_and_types: + sources = [] + # Using the side_tag in the lambda below will trigger a pylint warning. + # However in this case it is fine because the lambda is used right away + # while the variable has the value assigned by the current iteration of + # the for loop. + # pylint: disable=cell-var-from-loop + for si in itertools.ifilter( + lambda o: o.tag == side_tag, self.spec.side_inputs): + if not isinstance(si, operation_specs.WorkerSideInputSource): + raise NotImplementedError('Unknown side input type: %r' % si) + sources.append(si.source) + iterator_fn = sideinputs.get_iterator_fn_for_sources(sources) + + # Backwards compatibility for pre BEAM-733 SDKs. + if isinstance(view_options, tuple): + if view_class == pvalue.SingletonPCollectionView: + has_default, default = view_options + view_options = {'default': default} if has_default else {} + else: + view_options = {} + + yield apache_sideinputs.SideInputMap( + view_class, view_options, sideinputs.EmulatedIterable(iterator_fn)) + + def start(self): + with self.scoped_start_state: + super(DoOperation, self).start() + + # See fn_data in dataflow_runner.py + fn, args, kwargs, tags_and_types, window_fn = ( + pickler.loads(self.spec.serialized_fn)) + + state = common.DoFnState(self.counter_factory) + state.step_name = self.step_name + + # TODO(silviuc): What is the proper label here? PCollection being + # processed? + context = common.DoFnContext('label', state=state) + # Tag to output index map used to dispatch the side output values emitted + # by the DoFn function to the appropriate receivers. The main output is + # tagged with None and is associated with its corresponding index. + tagged_receivers = _TaggedReceivers() + + output_tag_prefix = PropertyNames.OUT + '_' + for index, tag in enumerate(self.spec.output_tags): + if tag == PropertyNames.OUT: + original_tag = None + elif tag.startswith(output_tag_prefix): + original_tag = tag[len(output_tag_prefix):] + else: + raise ValueError('Unexpected output name for operation: %s' % tag) + tagged_receivers[original_tag] = self.receivers[index] + + self.dofn_runner = common.DoFnRunner( + fn, args, kwargs, self._read_side_inputs(tags_and_types), + window_fn, context, tagged_receivers, + logger, self.step_name, + scoped_metrics_container=self.scoped_metrics_container) + self.dofn_receiver = (self.dofn_runner + if isinstance(self.dofn_runner, Receiver) + else DoFnRunnerReceiver(self.dofn_runner)) + + self.dofn_runner.start() + + def finish(self): + with self.scoped_finish_state: + self.dofn_runner.finish() + + def process(self, o): + with self.scoped_process_state: + self.dofn_receiver.receive(o) + + +class DoFnRunnerReceiver(Receiver): + + def __init__(self, dofn_runner): + self.dofn_runner = dofn_runner + + def receive(self, windowed_value): + self.dofn_runner.process(windowed_value) + + +class CombineOperation(Operation): + """A Combine operation executing a CombineFn for each input element.""" + + def __init__(self, operation_name, spec, counter_factory, state_sampler): + super(CombineOperation, self).__init__( + operation_name, spec, counter_factory, state_sampler) + # Combiners do not accept deferred side-inputs (the ignored fourth argument) + # and therefore the code to handle the extra args/kwargs is simpler than for + # the DoFn's of ParDo. + fn, args, kwargs = pickler.loads(self.spec.serialized_fn)[:3] + self.phased_combine_fn = ( + PhasedCombineFnExecutor(self.spec.phase, fn, args, kwargs)) + + def finish(self): + logging.debug('Finishing %s', self) + + def process(self, o): + if self.debug_logging_enabled: + logging.debug('Processing [%s] in %s', o, self) + key, values = o.value + with self.scoped_metrics_container: + self.output( + o.with_value((key, self.phased_combine_fn.apply(values)))) + + +def create_pgbk_op(step_name, spec, counter_factory, state_sampler): + if spec.combine_fn: + return PGBKCVOperation(step_name, spec, counter_factory, state_sampler) + else: + return PGBKOperation(step_name, spec, counter_factory, state_sampler) + + +class PGBKOperation(Operation): + """Partial group-by-key operation. + + This takes (windowed) input (key, value) tuples and outputs + (key, [value]) tuples, performing a best effort group-by-key for + values in this bundle, memory permitting. + """ + + def __init__(self, operation_name, spec, counter_factory, state_sampler): + super(PGBKOperation, self).__init__( + operation_name, spec, counter_factory, state_sampler) + assert not self.spec.combine_fn + self.table = collections.defaultdict(list) + self.size = 0 + # TODO(robertwb) Make this configurable. + self.max_size = 10 * 1000 + + def process(self, o): + # TODO(robertwb): Structural (hashable) values. + key = o.value[0], tuple(o.windows) + self.table[key].append(o) + self.size += 1 + if self.size > self.max_size: + self.flush(9 * self.max_size // 10) + + def finish(self): + self.flush(0) + + def flush(self, target): + limit = self.size - target + for ix, (kw, vs) in enumerate(self.table.items()): + if ix >= limit: + break + del self.table[kw] + key, windows = kw + output_value = [v.value[1] for v in vs] + windowed_value = WindowedValue( + (key, output_value), + vs[0].timestamp, windows) + self.output(windowed_value) + + +class PGBKCVOperation(Operation): + + def __init__(self, operation_name, spec, counter_factory, state_sampler): + super(PGBKCVOperation, self).__init__( + operation_name, spec, counter_factory, state_sampler) + # Combiners do not accept deferred side-inputs (the ignored fourth + # argument) and therefore the code to handle the extra args/kwargs is + # simpler than for the DoFn's of ParDo. + fn, args, kwargs = pickler.loads(self.spec.combine_fn)[:3] + self.combine_fn = curry_combine_fn(fn, args, kwargs) + if (getattr(fn.add_input, 'im_func', None) + is core.CombineFn.add_input.im_func): + # Old versions of the SDK have CombineFns that don't implement add_input. + self.combine_fn_add_input = ( + lambda a, e: self.combine_fn.add_inputs(a, [e])) + else: + self.combine_fn_add_input = self.combine_fn.add_input + # Optimization for the (known tiny accumulator, often wide keyspace) + # combine functions. + # TODO(b/36567833): Bound by in-memory size rather than key count. + self.max_keys = ( + 1000 * 1000 if + isinstance(fn, (combiners.CountCombineFn, combiners.MeanCombineFn)) or + # TODO(b/36597732): Replace this 'or' part by adding the 'cy' optimized + # combiners to the short list above. + (isinstance(fn, core.CallableWrapperCombineFn) and + fn._fn in (min, max, sum)) else 100 * 1000) # pylint: disable=protected-access + self.key_count = 0 + self.table = {} + + def process(self, wkv): + key, value = wkv.value + # pylint: disable=unidiomatic-typecheck + # Optimization for the global window case. + if len(wkv.windows) == 1 and type(wkv.windows[0]) is _global_window_type: + wkey = 0, key + else: + wkey = tuple(wkv.windows), key + entry = self.table.get(wkey, None) + if entry is None: + if self.key_count >= self.max_keys: + target = self.key_count * 9 // 10 + old_wkeys = [] + # TODO(robertwb): Use an LRU cache? + for old_wkey, old_wvalue in self.table.iteritems(): + old_wkeys.append(old_wkey) # Can't mutate while iterating. + self.output_key(old_wkey, old_wvalue[0]) + self.key_count -= 1 + if self.key_count <= target: + break + for old_wkey in reversed(old_wkeys): + del self.table[old_wkey] + self.key_count += 1 + # We save the accumulator as a one element list so we can efficiently + # mutate when new values are added without searching the cache again. + entry = self.table[wkey] = [self.combine_fn.create_accumulator()] + entry[0] = self.combine_fn_add_input(entry[0], value) + + def finish(self): + for wkey, value in self.table.iteritems(): + self.output_key(wkey, value[0]) + self.table = {} + self.key_count = 0 + + def output_key(self, wkey, value): + windows, key = wkey + if windows is 0: + self.output(_globally_windowed_value.with_value((key, value))) + else: + self.output(WindowedValue((key, value), windows[0].end, windows)) + + +class FlattenOperation(Operation): + """Flatten operation. + + Receives one or more producer operations, outputs just one list + with all the items. + """ + + def process(self, o): + if self.debug_logging_enabled: + logging.debug('Processing [%s] in %s', o, self) + self.output(o) + + +def create_operation(operation_name, spec, counter_factory, step_name, + state_sampler, test_shuffle_source=None, + test_shuffle_sink=None, is_streaming=False): + """Create Operation object for given operation specification.""" + if isinstance(spec, operation_specs.WorkerRead): + if isinstance(spec.source, iobase.SourceBundle): + op = ReadOperation( + operation_name, spec, counter_factory, state_sampler) + else: + from dataflow_worker.native_operations import NativeReadOperation + op = NativeReadOperation( + operation_name, spec, counter_factory, state_sampler) + elif isinstance(spec, operation_specs.WorkerWrite): + from dataflow_worker.native_operations import NativeWriteOperation + op = NativeWriteOperation( + operation_name, spec, counter_factory, state_sampler) + elif isinstance(spec, operation_specs.WorkerCombineFn): + op = CombineOperation( + operation_name, spec, counter_factory, state_sampler) + elif isinstance(spec, operation_specs.WorkerPartialGroupByKey): + op = create_pgbk_op(operation_name, spec, counter_factory, state_sampler) + elif isinstance(spec, operation_specs.WorkerDoFn): + op = DoOperation(operation_name, spec, counter_factory, state_sampler) + elif isinstance(spec, operation_specs.WorkerGroupingShuffleRead): + from dataflow_worker.shuffle_operations import GroupedShuffleReadOperation + op = GroupedShuffleReadOperation( + operation_name, spec, counter_factory, state_sampler, + shuffle_source=test_shuffle_source) + elif isinstance(spec, operation_specs.WorkerUngroupedShuffleRead): + from dataflow_worker.shuffle_operations import UngroupedShuffleReadOperation + op = UngroupedShuffleReadOperation( + operation_name, spec, counter_factory, state_sampler, + shuffle_source=test_shuffle_source) + elif isinstance(spec, operation_specs.WorkerInMemoryWrite): + op = InMemoryWriteOperation( + operation_name, spec, counter_factory, state_sampler) + elif isinstance(spec, operation_specs.WorkerShuffleWrite): + from dataflow_worker.shuffle_operations import ShuffleWriteOperation + op = ShuffleWriteOperation( + operation_name, spec, counter_factory, state_sampler, + shuffle_sink=test_shuffle_sink) + elif isinstance(spec, operation_specs.WorkerFlatten): + op = FlattenOperation( + operation_name, spec, counter_factory, state_sampler) + elif isinstance(spec, operation_specs.WorkerMergeWindows): + from dataflow_worker.shuffle_operations import BatchGroupAlsoByWindowsOperation + from dataflow_worker.shuffle_operations import StreamingGroupAlsoByWindowsOperation + if is_streaming: + op = StreamingGroupAlsoByWindowsOperation( + operation_name, spec, counter_factory, state_sampler) + else: + op = BatchGroupAlsoByWindowsOperation( + operation_name, spec, counter_factory, state_sampler) + elif isinstance(spec, operation_specs.WorkerReifyTimestampAndWindows): + from dataflow_worker.shuffle_operations import ReifyTimestampAndWindowsOperation + op = ReifyTimestampAndWindowsOperation( + operation_name, spec, counter_factory, state_sampler) + else: + raise TypeError('Expected an instance of operation_specs.Worker* class ' + 'instead of %s' % (spec,)) + op.step_name = step_name + op.metrics_container = MetricsContainer(step_name) + op.scoped_metrics_container = ScopedMetricsContainer(op.metrics_container) + return op + + +class SimpleMapTaskExecutor(object): + """An executor for map tasks. + + Stores progress of the read operation that is the first operation of a map + task. + """ + + def __init__( + self, map_task, counter_factory, state_sampler, + test_shuffle_source=None, test_shuffle_sink=None): + """Initializes SimpleMapTaskExecutor. + + Args: + map_task: The map task we are to run. + counter_factory: The CounterFactory instance for the work item. + state_sampler: The StateSampler tracking the execution step. + test_shuffle_source: Used during tests for dependency injection into + shuffle read operation objects. + test_shuffle_sink: Used during tests for dependency injection into + shuffle write operation objects. + """ + + self._map_task = map_task + self._counter_factory = counter_factory + self._ops = [] + self._state_sampler = state_sampler + self._test_shuffle_source = test_shuffle_source + self._test_shuffle_sink = test_shuffle_sink + + def operations(self): + return self._ops[:] + + def execute(self): + """Executes all the operation_specs.Worker* instructions in a map task. + + We update the map_task with the execution status, expressed as counters. + + Raises: + RuntimeError: if we find more than on read instruction in task spec. + TypeError: if the spec parameter is not an instance of the recognized + operation_specs.Worker* classes. + """ + + # operations is a list of operation_specs.Worker* instances. + # The order of the elements is important because the inputs use + # list indexes as references. + + step_names = ( + self._map_task.step_names or [None] * len(self._map_task.operations)) + for ix, spec in enumerate(self._map_task.operations): + # This is used for logging and assigning names to counters. + operation_name = self._map_task.system_names[ix] + step_name = step_names[ix] + op = create_operation( + operation_name, spec, self._counter_factory, step_name, + self._state_sampler, + test_shuffle_source=self._test_shuffle_source, + test_shuffle_sink=self._test_shuffle_sink) + self._ops.append(op) + + # Add receiver operations to the appropriate producers. + if hasattr(op.spec, 'input'): + producer, output_index = op.spec.input + self._ops[producer].add_receiver(op, output_index) + # Flatten has 'inputs', not 'input' + if hasattr(op.spec, 'inputs'): + for producer, output_index in op.spec.inputs: + self._ops[producer].add_receiver(op, output_index) + + for ix, op in reversed(list(enumerate(self._ops))): + logging.debug('Starting op %d %s', ix, op) + with op.scoped_metrics_container: + op.start() + for op in self._ops: + with op.scoped_metrics_container: + op.finish() http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sdk_worker.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py new file mode 100644 index 0000000..6907f6e --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -0,0 +1,451 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""SDK harness for executing Python Fns via the Fn API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import json +import logging +import Queue as queue +import threading +import traceback +import zlib + +import dill +from google.protobuf import wrappers_pb2 + +from apache_beam.coders import coder_impl +from apache_beam.coders import WindowedValueCoder +from apache_beam.internal import pickler +from apache_beam.runners.dataflow.native_io import iobase +from apache_beam.utils import counters +from apache_beam.runners.api import beam_fn_api_pb2 +from apache_beam.runners.worker import operation_specs +from apache_beam.runners.worker import operations +try: + from apache_beam.runners.worker import statesampler +except ImportError: + from apache_beam.runners.worker import statesampler_fake as statesampler +from apache_beam.runners.worker.data_plane import GrpcClientDataChannelFactory + + +DATA_INPUT_URN = 'urn:org.apache.beam:source:runner:0.1' +DATA_OUTPUT_URN = 'urn:org.apache.beam:sink:runner:0.1' +IDENTITY_DOFN_URN = 'urn:org.apache.beam:dofn:identity:0.1' +PYTHON_ITERABLE_VIEWFN_URN = 'urn:org.apache.beam:viewfn:iterable:python:0.1' +PYTHON_CODER_URN = 'urn:org.apache.beam:coder:python:0.1' +# TODO(vikasrk): Fix this once runner sends appropriate python urns. +PYTHON_DOFN_URN = 'urn:org.apache.beam:dofn:java:0.1' +PYTHON_SOURCE_URN = 'urn:org.apache.beam:source:java:0.1' + + +class RunnerIOOperation(operations.Operation): + """Common baseclass for runner harness IO operations.""" + + def __init__(self, operation_name, step_name, consumers, counter_factory, + state_sampler, windowed_coder, target, data_channel): + super(RunnerIOOperation, self).__init__( + operation_name, None, counter_factory, state_sampler) + self.windowed_coder = windowed_coder + self.step_name = step_name + # target represents the consumer for the bytes in the data plane for a + # DataInputOperation or a producer of these bytes for a DataOutputOperation. + self.target = target + self.data_channel = data_channel + for _, consumer_ops in consumers.items(): + for consumer in consumer_ops: + self.add_receiver(consumer, 0) + + +class DataOutputOperation(RunnerIOOperation): + """A sink-like operation that gathers outputs to be sent back to the runner. + """ + + def set_output_stream(self, output_stream): + self.output_stream = output_stream + + def process(self, windowed_value): + self.windowed_coder.get_impl().encode_to_stream( + windowed_value, self.output_stream, True) + + def finish(self): + self.output_stream.close() + super(DataOutputOperation, self).finish() + + +class DataInputOperation(RunnerIOOperation): + """A source-like operation that gathers input from the runner. + """ + + def __init__(self, operation_name, step_name, consumers, counter_factory, + state_sampler, windowed_coder, input_target, data_channel): + super(DataInputOperation, self).__init__( + operation_name, step_name, consumers, counter_factory, state_sampler, + windowed_coder, target=input_target, data_channel=data_channel) + # We must do this manually as we don't have a spec or spec.output_coders. + self.receivers = [ + operations.ConsumerSet(self.counter_factory, self.step_name, 0, + consumers.itervalues().next(), + self.windowed_coder)] + + def process(self, windowed_value): + self.output(windowed_value) + + def process_encoded(self, encoded_windowed_values): + input_stream = coder_impl.create_InputStream(encoded_windowed_values) + while input_stream.size() > 0: + decoded_value = self.windowed_coder.get_impl().decode_from_stream( + input_stream, True) + self.output(decoded_value) + + +# TODO(robertwb): Revise side input API to not be in terms of native sources. +# This will enable lookups, but there's an open question as to how to handle +# custom sources without forcing intermediate materialization. This seems very +# related to the desire to inject key and window preserving [Splittable]DoFns +# into the view computation. +class SideInputSource(iobase.NativeSource, iobase.NativeSourceReader): + """A 'source' for reading side inputs via state API calls. + """ + + def __init__(self, state_handler, state_key, coder): + self._state_handler = state_handler + self._state_key = state_key + self._coder = coder + + def reader(self): + return self + + @property + def returns_windowed_values(self): + return True + + def __enter__(self): + return self + + def __exit__(self, *exn_info): + pass + + def __iter__(self): + # TODO(robertwb): Support pagination. + input_stream = coder_impl.create_InputStream( + self._state_handler.Get(self._state_key).data) + while input_stream.size() > 0: + yield self._coder.get_impl().decode_from_stream(input_stream, True) + + +def unpack_and_deserialize_py_fn(function_spec): + """Returns unpacked and deserialized object from function spec proto.""" + return pickler.loads(unpack_function_spec_data(function_spec)) + + +def unpack_function_spec_data(function_spec): + """Returns unpacked data from function spec proto.""" + data = wrappers_pb2.BytesValue() + function_spec.data.Unpack(data) + return data.value + + +# pylint: disable=redefined-builtin +def serialize_and_pack_py_fn(fn, urn, id=None): + """Returns serialized and packed function in a function spec proto.""" + return pack_function_spec_data(pickler.dumps(fn), urn, id) +# pylint: enable=redefined-builtin + + +# pylint: disable=redefined-builtin +def pack_function_spec_data(value, urn, id=None): + """Returns packed data in a function spec proto.""" + data = wrappers_pb2.BytesValue(value=value) + fn_proto = beam_fn_api_pb2.FunctionSpec(urn=urn) + fn_proto.data.Pack(data) + if id: + fn_proto.id = id + return fn_proto +# pylint: enable=redefined-builtin + + +# TODO(vikasrk): move this method to ``coders.py`` in the SDK. +def load_compressed(compressed_data): + """Returns a decompressed and deserialized python object.""" + # Note: SDK uses ``pickler.dumps`` to serialize certain python objects + # (like sources), which involves serialization, compression and base64 + # encoding. We cannot directly use ``pickler.loads`` for + # deserialization, as the runner would have already base64 decoded the + # data. So we only need to decompress and deserialize. + + data = zlib.decompress(compressed_data) + try: + return dill.loads(data) + except Exception: # pylint: disable=broad-except + dill.dill._trace(True) # pylint: disable=protected-access + return dill.loads(data) + finally: + dill.dill._trace(False) # pylint: disable=protected-access + + +class SdkHarness(object): + + def __init__(self, control_channel): + self._control_channel = control_channel + self._data_channel_factory = GrpcClientDataChannelFactory() + + def run(self): + contol_stub = beam_fn_api_pb2.BeamFnControlStub(self._control_channel) + # TODO(robertwb): Wire up to new state api. + state_stub = None + self.worker = SdkWorker(state_stub, self._data_channel_factory) + + responses = queue.Queue() + no_more_work = object() + + def get_responses(): + while True: + response = responses.get() + if response is no_more_work: + return + yield response + + def process_requests(): + for work_request in contol_stub.Control(get_responses()): + logging.info('Got work %s', work_request.instruction_id) + try: + response = self.worker.do_instruction(work_request) + except Exception: # pylint: disable=broad-except + response = beam_fn_api_pb2.InstructionResponse( + instruction_id=work_request.instruction_id, + error=traceback.format_exc()) + responses.put(response) + t = threading.Thread(target=process_requests) + t.start() + t.join() + # get_responses may be blocked on responses.get(), but we need to return + # control to its caller. + responses.put(no_more_work) + self._data_channel_factory.close() + logging.info('Done consuming work.') + + +class SdkWorker(object): + + def __init__(self, state_handler, data_channel_factory): + self.fns = {} + self.state_handler = state_handler + self.data_channel_factory = data_channel_factory + + def do_instruction(self, request): + request_type = request.WhichOneof('request') + if request_type: + # E.g. if register is set, this will construct + # InstructionResponse(register=self.register(request.register)) + return beam_fn_api_pb2.InstructionResponse(**{ + 'instruction_id': request.instruction_id, + request_type: getattr(self, request_type) + (getattr(request, request_type), request.instruction_id) + }) + else: + raise NotImplementedError + + def register(self, request, unused_instruction_id=None): + for process_bundle_descriptor in request.process_bundle_descriptor: + self.fns[process_bundle_descriptor.id] = process_bundle_descriptor + for p_transform in list(process_bundle_descriptor.primitive_transform): + self.fns[p_transform.function_spec.id] = p_transform.function_spec + return beam_fn_api_pb2.RegisterResponse() + + def initial_source_split(self, request, unused_instruction_id=None): + source_spec = self.fns[request.source_reference] + assert source_spec.urn == PYTHON_SOURCE_URN + source_bundle = unpack_and_deserialize_py_fn( + self.fns[request.source_reference]) + splits = source_bundle.source.split(request.desired_bundle_size_bytes, + source_bundle.start_position, + source_bundle.stop_position) + response = beam_fn_api_pb2.InitialSourceSplitResponse() + response.splits.extend([ + beam_fn_api_pb2.SourceSplit( + source=serialize_and_pack_py_fn(split, PYTHON_SOURCE_URN), + relative_size=split.weight, + ) + for split in splits + ]) + return response + + def create_execution_tree(self, descriptor): + # TODO(vikasrk): Add an id field to Coder proto and use that instead. + coders = {coder.function_spec.id: operation_specs.get_coder_from_spec( + json.loads(unpack_function_spec_data(coder.function_spec))) + for coder in descriptor.coders} + + counter_factory = counters.CounterFactory() + # TODO(robertwb): Figure out the correct prefix to use for output counters + # from StateSampler. + state_sampler = statesampler.StateSampler( + 'fnapi-step%s-' % descriptor.id, counter_factory) + consumers = collections.defaultdict(lambda: collections.defaultdict(list)) + ops_by_id = {} + reversed_ops = [] + + for transform in reversed(descriptor.primitive_transform): + # TODO(robertwb): Figure out how to plumb through the operation name (e.g. + # "s3") from the service through the FnAPI so that msec counters can be + # reported and correctly plumbed through the service and the UI. + operation_name = 'fnapis%s' % transform.id + + def only_element(iterable): + element, = iterable + return element + + if transform.function_spec.urn == DATA_OUTPUT_URN: + target = beam_fn_api_pb2.Target( + primitive_transform_reference=transform.id, + name=only_element(transform.outputs.keys())) + + op = DataOutputOperation( + operation_name, + transform.step_name, + consumers[transform.id], + counter_factory, + state_sampler, + coders[only_element(transform.outputs.values()).coder_reference], + target, + self.data_channel_factory.create_data_channel( + transform.function_spec)) + + elif transform.function_spec.urn == DATA_INPUT_URN: + target = beam_fn_api_pb2.Target( + primitive_transform_reference=transform.id, + name=only_element(transform.inputs.keys())) + op = DataInputOperation( + operation_name, + transform.step_name, + consumers[transform.id], + counter_factory, + state_sampler, + coders[only_element(transform.outputs.values()).coder_reference], + target, + self.data_channel_factory.create_data_channel( + transform.function_spec)) + + elif transform.function_spec.urn == PYTHON_DOFN_URN: + def create_side_input(tag, si): + # TODO(robertwb): Extract windows (and keys) out of element data. + return operation_specs.WorkerSideInputSource( + tag=tag, + source=SideInputSource( + self.state_handler, + beam_fn_api_pb2.StateKey( + function_spec_reference=si.view_fn.id), + coder=unpack_and_deserialize_py_fn(si.view_fn))) + output_tags = list(transform.outputs.keys()) + spec = operation_specs.WorkerDoFn( + serialized_fn=unpack_function_spec_data(transform.function_spec), + output_tags=output_tags, + input=None, + side_inputs=[create_side_input(tag, si) + for tag, si in transform.side_inputs.items()], + output_coders=[coders[transform.outputs[out].coder_reference] + for out in output_tags]) + + op = operations.DoOperation(operation_name, spec, counter_factory, + state_sampler) + # TODO(robertwb): Move these to the constructor. + op.step_name = transform.step_name + for tag, op_consumers in consumers[transform.id].items(): + for consumer in op_consumers: + op.add_receiver( + consumer, output_tags.index(tag)) + + elif transform.function_spec.urn == IDENTITY_DOFN_URN: + op = operations.FlattenOperation(operation_name, None, counter_factory, + state_sampler) + # TODO(robertwb): Move these to the constructor. + op.step_name = transform.step_name + for tag, op_consumers in consumers[transform.id].items(): + for consumer in op_consumers: + op.add_receiver(consumer, 0) + + elif transform.function_spec.urn == PYTHON_SOURCE_URN: + source = load_compressed(unpack_function_spec_data( + transform.function_spec)) + # TODO(vikasrk): Remove this once custom source is implemented with + # splittable dofn via the data plane. + spec = operation_specs.WorkerRead( + iobase.SourceBundle(1.0, source, None, None), + [WindowedValueCoder(source.default_output_coder())]) + op = operations.ReadOperation(operation_name, spec, counter_factory, + state_sampler) + op.step_name = transform.step_name + output_tags = list(transform.outputs.keys()) + for tag, op_consumers in consumers[transform.id].items(): + for consumer in op_consumers: + op.add_receiver( + consumer, output_tags.index(tag)) + + else: + raise NotImplementedError + + # Record consumers. + for _, inputs in transform.inputs.items(): + for target in inputs.target: + consumers[target.primitive_transform_reference][target.name].append( + op) + + reversed_ops.append(op) + ops_by_id[transform.id] = op + + return list(reversed(reversed_ops)), ops_by_id + + def process_bundle(self, request, instruction_id): + ops, ops_by_id = self.create_execution_tree( + self.fns[request.process_bundle_descriptor_reference]) + + expected_inputs = [] + for _, op in ops_by_id.items(): + if isinstance(op, DataOutputOperation): + # TODO(robertwb): Is there a better way to pass the instruction id to + # the operation? + op.set_output_stream(op.data_channel.output_stream( + instruction_id, op.target)) + elif isinstance(op, DataInputOperation): + # We must wait until we receive "end of stream" for each of these ops. + expected_inputs.append(op) + + # Start all operations. + for op in reversed(ops): + logging.info('start %s', op) + op.start() + + # Inject inputs from data plane. + for input_op in expected_inputs: + for data in input_op.data_channel.input_elements( + instruction_id, [input_op.target]): + # ignores input name + target_op = ops_by_id[data.target.primitive_transform_reference] + # lacks coder for non-input ops + target_op.process_encoded(data.data) + + # Finish all operations. + for op in ops: + logging.info('finish %s', op) + op.finish() + + return beam_fn_api_pb2.ProcessBundleResponse() http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sdk_worker_main.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py new file mode 100644 index 0000000..28828c3 --- /dev/null +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""SDK Fn Harness entry point.""" + +import logging +import os +import sys + +import grpc +from google.protobuf import text_format + +from apache_beam.runners.api import beam_fn_api_pb2 +from apache_beam.runners.worker.log_handler import FnApiLogRecordHandler +from apache_beam.runners.worker.sdk_worker import SdkHarness + + +def main(unused_argv): + """Main entry point for SDK Fn Harness.""" + logging_service_descriptor = beam_fn_api_pb2.ApiServiceDescriptor() + text_format.Merge(os.environ['LOGGING_API_SERVICE_DESCRIPTOR'], + logging_service_descriptor) + + # Send all logs to the runner. + fn_log_handler = FnApiLogRecordHandler(logging_service_descriptor) + # TODO(vikasrk): This should be picked up from pipeline options. + logging.getLogger().setLevel(logging.INFO) + logging.getLogger().addHandler(fn_log_handler) + + try: + logging.info('Python sdk harness started.') + service_descriptor = beam_fn_api_pb2.ApiServiceDescriptor() + text_format.Merge(os.environ['CONTROL_API_SERVICE_DESCRIPTOR'], + service_descriptor) + # TODO(robertwb): Support credentials. + assert not service_descriptor.oauth2_client_credentials_grant.url + channel = grpc.insecure_channel(service_descriptor.url) + SdkHarness(channel).run() + logging.info('Python sdk harness exiting.') + except: # pylint: disable=broad-except + logging.exception('Python sdk harness failed: ') + raise + finally: + fn_log_handler.close() + + +if __name__ == '__main__': + main(sys.argv)