This is an automated email from the ASF dual-hosted git repository. damondouglas pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 92d82dbf7e2 Python data sampling optimization (#27157) 92d82dbf7e2 is described below commit 92d82dbf7e27e856cfc3b389b5abad7b67657550 Author: Sam Rohde <rohde.sam...@gmail.com> AuthorDate: Fri Jun 23 14:23:57 2023 -0700 Python data sampling optimization (#27157) * Python optimization work * Exception Sampling perf tests * add better element sampling microbenchmark * slowly move towards plumbing the samplers to the bundle processors * cleaned up * starting clean up and more testing * finish tests * fix unused data_sampler args and comments * yapf, comments, and simplifications * linter * lint and mypy * linter * run tests * address review comments * run tests --------- Co-authored-by: Sam Rohde <sro...@google.com> --- .../apache_beam/runners/worker/bundle_processor.py | 90 +---- .../runners/worker/bundle_processor_test.py | 89 +++-- .../apache_beam/runners/worker/data_sampler.py | 193 ++++++--- .../runners/worker/data_sampler_test.py | 435 ++++++++++++++++----- .../apache_beam/runners/worker/operations.pxd | 1 + .../apache_beam/runners/worker/operations.py | 70 +++- .../apache_beam/runners/worker/sdk_worker.py | 4 +- .../apache_beam/runners/worker/sdk_worker_test.py | 29 +- 8 files changed, 600 insertions(+), 311 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 6d814331fd6..3adb26552d0 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -66,7 +66,6 @@ from apache_beam.runners.worker import data_sampler from apache_beam.runners.worker import operation_specs from apache_beam.runners.worker import operations from apache_beam.runners.worker import statesampler -from apache_beam.runners.worker.data_sampler import OutputSampler from apache_beam.transforms import TimeDomain from apache_beam.transforms import core from apache_beam.transforms import environments @@ -194,8 +193,8 @@ class DataInputOperation(RunnerIOOperation): self.stop = float('inf') self.started = False - def setup(self): - super().setup() + def setup(self, data_sampler=None): + super().setup(data_sampler) # We must do this manually as we don't have a spec or spec.output_coders. self.receivers = [ operations.ConsumerSet.create( @@ -897,39 +896,11 @@ class BundleProcessor(object): 'fnapi-step-%s' % self.process_bundle_descriptor.id, self.counter_factory) - if self.data_sampler: - self.add_data_sampling_operations(process_bundle_descriptor) - self.ops = self.create_execution_tree(self.process_bundle_descriptor) for op in reversed(self.ops.values()): - op.setup() + op.setup(self.data_sampler) self.splitting_lock = threading.Lock() - def add_data_sampling_operations(self, pbd): - # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None - - """Adds a DataSamplingOperation to every PCollection. - - Implementation note: the alternative to this, is to add modify each - Operation and forward a DataSampler to manually sample when an element is - processed. This gets messy very quickly and is not future-proof as new - operation types will need to be updated. This is the cleanest way of adding - new operations to the final execution tree. - """ - coder = coders.FastPrimitivesCoder() - - for pcoll_id in pbd.pcollections: - transform_id = 'synthetic-data-sampling-transform-{}'.format(pcoll_id) - transform_proto: beam_runner_api_pb2.PTransform = pbd.transforms[ - transform_id] - transform_proto.unique_name = transform_id - transform_proto.spec.urn = SYNTHETIC_DATA_SAMPLING_URN - - coder_id = pbd.pcollections[pcoll_id].coder_id - transform_proto.spec.payload = coder.encode((pcoll_id, coder_id)) - - transform_proto.inputs['None'] = pcoll_id - def create_execution_tree( self, descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor @@ -966,6 +937,12 @@ class BundleProcessor(object): for tag, pcoll_id in descriptor.transforms[transform_id].outputs.items() } + + # Initialize transform-specific state in the Data Sampler. + if self.data_sampler: + self.data_sampler.initialize_samplers( + transform_id, descriptor, transform_factory.get_coder) + return transform_factory.create_operation( transform_id, transform_consumers) @@ -1987,52 +1964,3 @@ def create_to_string_fn( return _create_simple_pardo_operation( factory, transform_id, transform_proto, consumers, ToString()) - - -class DataSamplingOperation(operations.Operation): - """Operation that samples incoming elements.""" - - def __init__( - self, - name_context, # type: common.NameContext - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - pcoll_id, # type: str - sample_coder, # type: coders.Coder - data_sampler, # type: data_sampler.DataSampler - ): - # type: (...) -> None - super().__init__(name_context, None, counter_factory, state_sampler) - self._coder = sample_coder # type: coders.Coder - self._pcoll_id = pcoll_id # type: str - - self._sampler: OutputSampler = data_sampler.sample_output( - self._pcoll_id, sample_coder) - - def process(self, windowed_value): - # type: (windowed_value.WindowedValue) -> None - self._sampler.sample(windowed_value) - - -@BeamTransformFactory.register_urn(SYNTHETIC_DATA_SAMPLING_URN, (bytes)) -def create_data_sampling_op( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - pcoll_and_coder_id, # type: bytes - consumers, # type: Dict[str, List[operations.Operation]] -): - # Creating this operation should only occur when data sampling is enabled. - data_sampler = factory.data_sampler - assert data_sampler is not None - - coder = coders.FastPrimitivesCoder() - pcoll_id, coder_id = coder.decode(pcoll_and_coder_id) - return DataSamplingOperation( - common.NameContext(transform_proto.unique_name, transform_id), - factory.counter_factory, - factory.state_sampler, - pcoll_id, - factory.get_coder(coder_id), - data_sampler, - ) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py index 6b21071a875..db9e35a0baf 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py @@ -18,14 +18,16 @@ """Unit tests for bundle processing.""" # pytype: skip-file +import time import unittest +from typing import Dict +from typing import List from apache_beam.coders.coders import FastPrimitivesCoder from apache_beam.portability import common_urns from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.runners import common from apache_beam.runners.worker import operations -from apache_beam.runners.worker.bundle_processor import SYNTHETIC_DATA_SAMPLING_URN from apache_beam.runners.worker.bundle_processor import BeamTransformFactory from apache_beam.runners.worker.bundle_processor import BundleProcessor from apache_beam.runners.worker.bundle_processor import DataInputOperation @@ -190,18 +192,25 @@ def element_split(frac, index): class TestOperation(operations.Operation): """Test operation that forwards its payload to consumers.""" class Spec: - def __init__(self): - self.output_coders = [FastPrimitivesCoder()] + def __init__(self, transform_proto): + self.output_coders = [ + FastPrimitivesCoder() for _ in transform_proto.outputs + ] def __init__( self, + transform_proto, name_context, counter_factory, state_sampler, consumers, payload, ): - super().__init__(name_context, self.Spec(), counter_factory, state_sampler) + super().__init__( + name_context, + self.Spec(transform_proto), + counter_factory, + state_sampler) self.payload = payload for _, consumer_ops in consumers.items(): @@ -212,8 +221,9 @@ class TestOperation(operations.Operation): super().start() # Not using windowing logic, so just using simple defaults here. - self.process( - WindowedValue(self.payload, timestamp=0, windows=[GlobalWindow()])) + if self.payload: + self.process( + WindowedValue(self.payload, timestamp=0, windows=[GlobalWindow()])) def process(self, windowed_value): self.output(windowed_value) @@ -222,6 +232,7 @@ class TestOperation(operations.Operation): @BeamTransformFactory.register_urn('beam:internal:testop:v1', bytes) def create_test_op(factory, transform_id, transform_proto, payload, consumers): return TestOperation( + transform_proto, common.NameContext(transform_proto.unique_name, transform_id), factory.counter_factory, factory.state_sampler, @@ -241,37 +252,24 @@ class DataSamplingTest(unittest.TestCase): _ = BundleProcessor(descriptor, None, None) self.assertEqual(len(descriptor.transforms), 0) - def test_adds_data_sampling_operations(self): - """Test that providing the sampler creates sampling PTransforms. + def wait_for_samples(self, data_sampler: DataSampler, + pcollection_id: str) -> Dict[str, List[bytes]]: + """Waits for samples from the given PCollection to exist.""" + now = time.time() + end = now + 30 - Data sampling is implemented by modifying the ProcessBundleDescriptor with - additional sampling PTransforms reading from each PCllection. - """ - data_sampler = DataSampler() + samples = {} + while now < end: + time.sleep(0.1) + now = time.time() + samples.update(data_sampler.samples([pcollection_id])) - # Data sampling samples the PCollections, which adds a PTransform to read - # from each PCollection. So add a simple PCollection here to create the - # DataSamplingOperation. - PCOLLECTION_ID = 'pc' - CODER_ID = 'c' - descriptor = beam_fn_api_pb2.ProcessBundleDescriptor() - descriptor.pcollections[PCOLLECTION_ID].unique_name = PCOLLECTION_ID - descriptor.pcollections[PCOLLECTION_ID].coder_id = CODER_ID - descriptor.coders[ - CODER_ID].spec.urn = common_urns.StandardCoders.Enum.BYTES.urn - - _ = BundleProcessor(descriptor, None, None, data_sampler=data_sampler) - - # Assert that the data sampling transform was created. - self.assertEqual(len(descriptor.transforms), 1) - sampling_transform = list(descriptor.transforms.values())[0] + if samples: + return samples - # Ensure that the data sampling transform has the correct spec and that it's - # sampling the correct PCollection. - self.assertEqual( - sampling_transform.unique_name, 'synthetic-data-sampling-transform-pc') - self.assertEqual(sampling_transform.spec.urn, SYNTHETIC_DATA_SAMPLING_URN) - self.assertEqual(sampling_transform.inputs, {'None': PCOLLECTION_ID}) + self.assertLess( + now, end, 'Timed out waiting for samples for {}'.format(pcollection_id)) + return {} def test_can_sample(self): """Test that elements are sampled. @@ -281,7 +279,7 @@ class DataSamplingTest(unittest.TestCase): DataSamplingOperations and samples are taken from in-flight elements. These elements are then finally queried. """ - data_sampler = DataSampler() + data_sampler = DataSampler(sample_every_sec=0.1) descriptor = beam_fn_api_pb2.ProcessBundleDescriptor() # Create the PCollection to sample from. @@ -295,18 +293,23 @@ class DataSamplingTest(unittest.TestCase): # Add a simple transform to inject an element into the data sampler. This # doesn't use the FnApi, so this uses a simple operation to forward its # payload to consumers. - test_transform = descriptor.transforms['test_transform'] + TRANSFORM_ID = 'test_transform' + test_transform = descriptor.transforms[TRANSFORM_ID] test_transform.outputs['None'] = PCOLLECTION_ID test_transform.spec.urn = 'beam:internal:testop:v1' test_transform.spec.payload = b'hello, world!' - # Create and process a fake bundle. The instruction id doesn't matter here. - processor = BundleProcessor( - descriptor, None, None, data_sampler=data_sampler) - processor.process_bundle('instruction_id') - - self.assertEqual( - data_sampler.samples(), {PCOLLECTION_ID: [b'\rhello, world!']}) + try: + # Create and process a fake bundle. The instruction id doesn't matter + # here. + processor = BundleProcessor( + descriptor, None, None, data_sampler=data_sampler) + processor.process_bundle('instruction_id') + + samples = self.wait_for_samples(data_sampler, PCOLLECTION_ID) + self.assertEqual(samples, {PCOLLECTION_ID: [b'\rhello, world!']}) + finally: + data_sampler.stop() if __name__ == '__main__': diff --git a/sdks/python/apache_beam/runners/worker/data_sampler.py b/sdks/python/apache_beam/runners/worker/data_sampler.py index 7cc8152693d..9c74188a699 100644 --- a/sdks/python/apache_beam/runners/worker/data_sampler.py +++ b/sdks/python/apache_beam/runners/worker/data_sampler.py @@ -19,9 +19,12 @@ # pytype: skip-file +from __future__ import annotations + import collections +import logging import threading -import time +from threading import Timer from typing import Any from typing import DefaultDict from typing import Deque @@ -34,28 +37,73 @@ from typing import Union from apache_beam.coders.coder_impl import CoderImpl from apache_beam.coders.coder_impl import WindowedValueCoderImpl from apache_beam.coders.coders import Coder +from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.utils.windowed_value import WindowedValue +_LOGGER = logging.getLogger(__name__) + + +class SampleTimer: + """Periodic timer for sampling elements.""" + def __init__(self, timeout_secs: float, sampler: OutputSampler) -> None: + self._timeout_secs = timeout_secs + self._timer = Timer(self._timeout_secs, self.sample) + self._sampler = sampler + + def reset(self): + self._timer.cancel() + self._timer = Timer(self._timeout_secs, self.sample) + self._timer.start() + + def stop(self): + self._timer.cancel() + + def sample(self): + self._sampler.sample() + self.reset() + + +class ElementSampler: + """Record class to hold sampled elements. + + This class is used as an optimization to quickly sample elements. This is a + shared reference between the Operation and the OutputSampler. + """ + + # Is true iff the `el` has been set with a sample. + has_element: bool + + # The sampled element. Note that `None` is a valid element and cannot be uesd + # as a sentintel to check if there is a sample. Use the `has_element` flag to + # check for this case. + el: Any + class OutputSampler: """Represents a way to sample an output of a PTransform. - This is configurable to only keep max_samples (see constructor) sampled - elements in memory. The first 10 elements are always sampled, then after each - sample_every_sec (see constructor). + This is configurable to only keep `max_samples` (see constructor) sampled + elements in memory. Samples are taken every `sample_every_sec`. """ def __init__( self, coder: Coder, max_samples: int = 10, - sample_every_sec: float = 30, - clock=None) -> None: + sample_every_sec: float = 5) -> None: self._samples: Deque[Any] = collections.deque(maxlen=max_samples) + self._samples_lock: threading.Lock = threading.Lock() self._coder_impl: CoderImpl = coder.get_impl() - self._sample_count: int = 0 - self._sample_every_sec: float = sample_every_sec - self._clock = clock - self._last_sample_sec: float = self.time() + self._sample_timer = SampleTimer(sample_every_sec, self) + self.element_sampler = ElementSampler() + self.element_sampler.has_element = False + + # For testing, it's easier to disable the Timer and manually sample. + if sample_every_sec > 0: + self._sample_timer.reset() + + def stop(self) -> None: + """Stops sampling.""" + self._sample_timer.stop() def remove_windowed_value(self, el: Union[WindowedValue, Any]) -> Any: """Retrieves the value from the WindowedValue. @@ -63,39 +111,30 @@ class OutputSampler: The Python SDK passes elements as WindowedValues, which may not match the coder for that particular PCollection. """ - if isinstance(el, WindowedValue): - return self.remove_windowed_value(el.value) + while isinstance(el, WindowedValue): + el = el.value return el - def time(self) -> float: - """Returns the current time. Used for mocking out the clock for testing.""" - return self._clock.time() if self._clock else time.time() - - def flush(self) -> List[bytes]: - """Returns all samples and clears buffer.""" - if isinstance(self._coder_impl, WindowedValueCoderImpl): - samples = [s for s in self._samples] - else: - samples = [self.remove_windowed_value(s) for s in self._samples] - - # Encode in the nested context b/c this ensures that the SDK can decode the - # bytes with the ToStringFn. - self._samples.clear() - return [self._coder_impl.encode_nested(s) for s in samples] + def flush(self, clear: bool = True) -> List[bytes]: + """Returns all samples and optionally clears buffer if clear is True.""" + with self._samples_lock: + if isinstance(self._coder_impl, WindowedValueCoderImpl): + samples = [s for s in self._samples] + else: + samples = [self.remove_windowed_value(s) for s in self._samples] - def sample(self, element: Any) -> None: - """Samples the given element to an internal buffer. + # Encode in the nested context b/c this ensures that the SDK can decode + # the bytes with the ToStringFn. + if clear: + self._samples.clear() + return [self._coder_impl.encode_nested(s) for s in samples] - Samples are only taken for the first 10 elements then every - `self._sample_every_sec` second after. - """ - self._sample_count += 1 - now = self.time() - sample_diff = now - self._last_sample_sec - - if self._sample_count <= 10 or sample_diff >= self._sample_every_sec: - self._samples.append(element) - self._last_sample_sec = now + def sample(self) -> None: + """Samples the given element to an internal buffer.""" + with self._samples_lock: + if self.element_sampler.has_element: + self.element_sampler.has_element = False + self._samples.append(self.element_sampler.el) class DataSampler: @@ -103,14 +142,17 @@ class DataSampler: This class is meant to be a singleton with regard to a particular `sdk_worker.SdkHarness`. When creating the operators, individual - `OutputSampler`s are created from `DataSampler.sample_output`. This allows for - multi-threaded sampling of a PCollection across the SdkHarness. + `OutputSampler`s are created from `DataSampler.initialize_samplers`. This + allows for multi-threaded sampling of a PCollection across the SdkHarness. Samples generated during execution can then be sampled with the `samples` method. This filters samples from the given pcollection ids. """ def __init__( - self, max_samples: int = 10, sample_every_sec: float = 30) -> None: + self, + max_samples: int = 10, + sample_every_sec: float = 30, + clock=None) -> None: # Key is PCollection id. Is guarded by the _samplers_lock. self._samplers: Dict[str, OutputSampler] = {} # Bundles are processed in parallel, so new samplers may be added when the @@ -118,17 +160,72 @@ class DataSampler: self._samplers_lock: threading.Lock = threading.Lock() self._max_samples = max_samples self._sample_every_sec = sample_every_sec + self._element_samplers: Dict[str, List[ElementSampler]] = {} + self._clock = clock - def sample_output(self, pcoll_id: str, coder: Coder) -> OutputSampler: - """Create or get an OutputSampler for a pcoll_id.""" + def stop(self) -> None: + """Stops all sampling, does not clear samplers in case there are outstanding + samples. + """ with self._samplers_lock: - if pcoll_id in self._samplers: - sampler = self._samplers[pcoll_id] - else: + for sampler in self._samplers.values(): + sampler.stop() + + def sampler_for_output( + self, transform_id: str, output_index: int) -> ElementSampler: + """Returns the ElementSampler for the given output.""" + try: + return self._element_samplers[transform_id][output_index] + except KeyError: + _LOGGER.warning( + f'Out-of-bounds access for transform "{transform_id}" ' + + 'and output "{output_index}" ElementSampler. This may ' + + 'indicate that the transform was improperly ' + + 'initialized with the DataSampler.') + return ElementSampler() + + def initialize_samplers( + self, + transform_id: str, + descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, + coder_factory) -> List[ElementSampler]: + """Creates the OutputSamplers for the given PTransform. + + This initializes the samplers only once per PCollection Id. Note that an + OutputSampler is created per PCollection and an ElementSampler is created + per OutputSampler. This means that multiple ProcessBundles can and will + share the same ElementSampler for a given PCollection. + """ + transform_proto = descriptor.transforms[transform_id] + with self._samplers_lock: + # Initialize the samplers. + for pcoll_id in transform_proto.outputs.values(): + # Only initialize new PCollections. + if pcoll_id in self._samplers: + continue + + # Create the sampler with the corresponding coder. + coder_id = descriptor.pcollections[pcoll_id].coder_id + coder = coder_factory(coder_id) sampler = OutputSampler( coder, self._max_samples, self._sample_every_sec) self._samplers[pcoll_id] = sampler - return sampler + + # Next update the lookup table for ElementSamplers for a given PTransform. + # Operations look up the ElementSampler for an output based on the index + # of the tag in the PTransform's outputs. The following code intializes + # the array with ElementSamplers in the correct indices. + if transform_id in self._element_samplers: + return self._element_samplers[transform_id] + + outputs = transform_proto.outputs + samplers = [ + self._samplers[pcoll_id].element_sampler + for pcoll_id in outputs.values() + ] + self._element_samplers[transform_id] = samplers + + return samplers def samples( self, diff --git a/sdks/python/apache_beam/runners/worker/data_sampler_test.py b/sdks/python/apache_beam/runners/worker/data_sampler_test.py index 6a163c83b28..346251ee216 100644 --- a/sdks/python/apache_beam/runners/worker/data_sampler_test.py +++ b/sdks/python/apache_beam/runners/worker/data_sampler_test.py @@ -17,159 +17,385 @@ # pytype: skip-file +import time import unittest +from typing import Any +from typing import Dict +from typing import List +from typing import Optional from apache_beam.coders import FastPrimitivesCoder from apache_beam.coders import WindowedValueCoder -from apache_beam.coders.coders import Coder +from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.runners.worker.data_sampler import DataSampler from apache_beam.runners.worker.data_sampler import OutputSampler from apache_beam.transforms.window import GlobalWindow from apache_beam.utils.windowed_value import WindowedValue +MAIN_TRANSFORM_ID = 'transform' +MAIN_PCOLLECTION_ID = 'pcoll' +PRIMITIVES_CODER = FastPrimitivesCoder() + + +class FakeClock: + def __init__(self): + self.clock = 0 + + def time(self): + return self.clock + class DataSamplerTest(unittest.TestCase): + def make_test_descriptor( + self, + outputs: Optional[List[str]] = None, + transforms: Optional[List[str]] = None + ) -> beam_fn_api_pb2.ProcessBundleDescriptor: + outputs = outputs or [MAIN_PCOLLECTION_ID] + transforms = transforms or [MAIN_TRANSFORM_ID] + + descriptor = beam_fn_api_pb2.ProcessBundleDescriptor() + for transform_id in transforms: + transform = descriptor.transforms[transform_id] + for output in outputs: + transform.outputs[output] = output + + return descriptor + + def setUp(self): + self.data_sampler = DataSampler(sample_every_sec=0.1) + + def tearDown(self): + self.data_sampler.stop() + + def wait_for_samples( + self, data_sampler: DataSampler, + pcollection_ids: List[str]) -> Dict[str, List[bytes]]: + """Waits for samples to exist for the given PCollections.""" + now = time.time() + end = now + 30 + + samples = {} + while now < end: + time.sleep(0.1) + now = time.time() + samples.update(data_sampler.samples(pcollection_ids)) + + if not samples: + continue + + has_all = all(pcoll_id in samples for pcoll_id in pcollection_ids) + if has_all: + return samples + + self.assertLess( + now, + end, + 'Timed out waiting for samples for {}'.format(pcollection_ids)) + return {} + + def primitives_coder_factory(self, _): + return PRIMITIVES_CODER + + def gen_sample( + self, + data_sampler: DataSampler, + element: Any, + output_index: int, + transform_id: str = MAIN_TRANSFORM_ID): + """Generates a sample for the given transform's output.""" + element_sampler = self.data_sampler.sampler_for_output( + transform_id, output_index) + element_sampler.el = element + element_sampler.has_element = True + def test_single_output(self): """Simple test for a single sample.""" - data_sampler = DataSampler() - coder = FastPrimitivesCoder() + descriptor = self.make_test_descriptor() + self.data_sampler.initialize_samplers( + MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory) + + self.gen_sample(self.data_sampler, 'a', output_index=0) + + expected_sample = { + MAIN_PCOLLECTION_ID: [PRIMITIVES_CODER.encode_nested('a')] + } + samples = self.wait_for_samples(self.data_sampler, [MAIN_PCOLLECTION_ID]) + self.assertEqual(samples, expected_sample) + + def test_not_initialized(self): + """Tests that transforms fail gracefully if not properly initialized.""" + with self.assertLogs() as cm: + self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, 0) + self.assertRegex(cm.output[0], 'Out-of-bounds access.*') + + def map_outputs_to_indices( + self, outputs, descriptor, transform_id=MAIN_TRANSFORM_ID): + tag_list = list(descriptor.transforms[transform_id].outputs) + return {output: tag_list.index(output) for output in outputs} + + def test_sampler_mapping(self): + """Tests that the ElementSamplers are created for the correct output.""" + # Initialize the DataSampler with the following outputs. The order here may + # get shuffled when inserting into the descriptor. + pcollection_ids = ['o0', 'o1', 'o2'] + descriptor = self.make_test_descriptor(outputs=pcollection_ids) + samplers = self.data_sampler.initialize_samplers( + MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory) + + # Create a map from the PCollection id to the index into the transform + # output. This mirrors what happens when operators are created. The index of + # an output is where in the PTransform.outputs it is located (when the map + # is converted to a list). + outputs = self.map_outputs_to_indices(pcollection_ids, descriptor) + + # Assert that the mapping is correct, i.e. that we can go from the + # PCollection id -> output index and that this is the same as the created + # samplers. + index = outputs['o0'] + self.assertEqual( + self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, index), + samplers[index]) - output_sampler = data_sampler.sample_output('1', coder) - output_sampler.sample('a') + index = outputs['o1'] + self.assertEqual( + self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, index), + samplers[index]) - self.assertEqual(data_sampler.samples(), {'1': [coder.encode_nested('a')]}) + index = outputs['o2'] + self.assertEqual( + self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, index), + samplers[index]) def test_multiple_outputs(self): """Tests that multiple PCollections have their own sampler.""" - data_sampler = DataSampler() - coder = FastPrimitivesCoder() - - data_sampler.sample_output('1', coder).sample('a') - data_sampler.sample_output('2', coder).sample('a') - - self.assertEqual( - data_sampler.samples(), { - '1': [coder.encode_nested('a')], '2': [coder.encode_nested('a')] - }) - - def gen_samples(self, data_sampler: DataSampler, coder: Coder): - data_sampler.sample_output('a', coder).sample('1') - data_sampler.sample_output('a', coder).sample('2') - data_sampler.sample_output('b', coder).sample('3') - data_sampler.sample_output('b', coder).sample('4') - data_sampler.sample_output('c', coder).sample('5') - data_sampler.sample_output('c', coder).sample('6') + pcollection_ids = ['o0', 'o1', 'o2'] + descriptor = self.make_test_descriptor(outputs=pcollection_ids) + outputs = self.map_outputs_to_indices(pcollection_ids, descriptor) + + self.data_sampler.initialize_samplers( + MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory) + + self.gen_sample(self.data_sampler, 'a', output_index=outputs['o0']) + self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1']) + self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2']) + + samples = self.wait_for_samples(self.data_sampler, ['o0', 'o1', 'o2']) + expected_samples = { + 'o0': [PRIMITIVES_CODER.encode_nested('a')], + 'o1': [PRIMITIVES_CODER.encode_nested('b')], + 'o2': [PRIMITIVES_CODER.encode_nested('c')], + } + self.assertEqual(samples, expected_samples) + + def test_multiple_transforms(self): + """Test that multiple transforms with the same PCollections can be sampled. + """ + # Initialize two transform both with the same two outputs. + pcollection_ids = ['o0', 'o1'] + descriptor = self.make_test_descriptor( + outputs=pcollection_ids, transforms=['t0', 't1']) + t0_outputs = self.map_outputs_to_indices( + pcollection_ids, descriptor, transform_id='t0') + t1_outputs = self.map_outputs_to_indices( + pcollection_ids, descriptor, transform_id='t1') + + self.data_sampler.initialize_samplers( + 't0', descriptor, self.primitives_coder_factory) + + self.data_sampler.initialize_samplers( + 't1', descriptor, self.primitives_coder_factory) + + # The OutputSampler is on a different thread so we don't test the same + # PCollections to ensure that no data race occurs. + self.gen_sample( + self.data_sampler, + 'a', + output_index=t0_outputs['o0'], + transform_id='t0') + self.gen_sample( + self.data_sampler, + 'd', + output_index=t1_outputs['o1'], + transform_id='t1') + expected_samples = { + 'o0': [PRIMITIVES_CODER.encode_nested('a')], + 'o1': [PRIMITIVES_CODER.encode_nested('d')], + } + samples = self.wait_for_samples(self.data_sampler, ['o0', 'o1']) + self.assertEqual(samples, expected_samples) + + self.gen_sample( + self.data_sampler, + 'b', + output_index=t0_outputs['o1'], + transform_id='t0') + self.gen_sample( + self.data_sampler, + 'c', + output_index=t1_outputs['o0'], + transform_id='t1') + expected_samples = { + 'o0': [PRIMITIVES_CODER.encode_nested('c')], + 'o1': [PRIMITIVES_CODER.encode_nested('b')], + } + samples = self.wait_for_samples(self.data_sampler, ['o0', 'o1']) + self.assertEqual(samples, expected_samples) def test_sample_filters_single_pcollection_ids(self): """Tests the samples can be filtered based on a single pcollection id.""" - data_sampler = DataSampler() - coder = FastPrimitivesCoder() - - self.gen_samples(data_sampler, coder) - self.assertEqual( - data_sampler.samples(pcollection_ids=['a']), - {'a': [coder.encode_nested('1'), coder.encode_nested('2')]}) - - self.assertEqual( - data_sampler.samples(pcollection_ids=['b']), - {'b': [coder.encode_nested('3'), coder.encode_nested('4')]}) + pcollection_ids = ['o0', 'o1', 'o2'] + descriptor = self.make_test_descriptor(outputs=pcollection_ids) + outputs = self.map_outputs_to_indices(pcollection_ids, descriptor) + + self.data_sampler.initialize_samplers( + MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory) + + self.gen_sample(self.data_sampler, 'a', output_index=outputs['o0']) + self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1']) + self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2']) + + samples = self.wait_for_samples(self.data_sampler, ['o0']) + expected_samples = { + 'o0': [PRIMITIVES_CODER.encode_nested('a')], + } + self.assertEqual(samples, expected_samples) + + samples = self.wait_for_samples(self.data_sampler, ['o1']) + expected_samples = { + 'o1': [PRIMITIVES_CODER.encode_nested('b')], + } + self.assertEqual(samples, expected_samples) + + samples = self.wait_for_samples(self.data_sampler, ['o2']) + expected_samples = { + 'o2': [PRIMITIVES_CODER.encode_nested('c')], + } + self.assertEqual(samples, expected_samples) def test_sample_filters_multiple_pcollection_ids(self): """Tests the samples can be filtered based on a multiple pcollection ids.""" - data_sampler = DataSampler() - coder = FastPrimitivesCoder() + pcollection_ids = ['o0', 'o1', 'o2'] + descriptor = self.make_test_descriptor(outputs=pcollection_ids) + outputs = self.map_outputs_to_indices(pcollection_ids, descriptor) - self.gen_samples(data_sampler, coder) - self.assertEqual( - data_sampler.samples(pcollection_ids=['a', 'c']), - { - 'a': [coder.encode_nested('1'), coder.encode_nested('2')], - 'c': [coder.encode_nested('5'), coder.encode_nested('6')] - }) + self.data_sampler.initialize_samplers( + MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory) + self.gen_sample(self.data_sampler, 'a', output_index=outputs['o0']) + self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1']) + self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2']) -class FakeClock: - def __init__(self): - self.clock = 0 - - def time(self): - return self.clock + samples = self.wait_for_samples(self.data_sampler, ['o0', 'o2']) + expected_samples = { + 'o0': [PRIMITIVES_CODER.encode_nested('a')], + 'o2': [PRIMITIVES_CODER.encode_nested('c')], + } + self.assertEqual(samples, expected_samples) class OutputSamplerTest(unittest.TestCase): def setUp(self): self.fake_clock = FakeClock() + def tearDown(self): + self.sampler.stop() + def control_time(self, new_time): self.fake_clock.clock = new_time - def test_samples_first_n(self): - """Tests that the first elements are always sampled.""" - coder = FastPrimitivesCoder() - sampler = OutputSampler(coder) + def wait_for_samples(self, output_sampler: OutputSampler, expected_num: int): + """Waits for the expected number of samples for the given sampler.""" + now = time.time() + end = now + 30 - for i in range(15): - sampler.sample(i) + while now < end: + time.sleep(0.1) + now = time.time() + samples = output_sampler.flush(clear=False) - self.assertEqual( - sampler.flush(), [coder.encode_nested(i) for i in range(10)]) + if not samples: + continue + + if len(samples) == expected_num: + return samples + + self.assertLess(now, end, 'Timed out waiting for samples') + + def ensure_sample( + self, output_sampler: OutputSampler, sample: Any, expected_num: int): + """Generates a sample and waits for it to be available.""" + + element_sampler = output_sampler.element_sampler + + now = time.time() + end = now + 30 + + while now < end: + element_sampler.el = sample + element_sampler.has_element = True + time.sleep(0.1) + now = time.time() + samples = output_sampler.flush(clear=False) + + if not samples: + continue + + if len(samples) == expected_num: + return samples + + self.assertLess( + now, end, 'Timed out waiting for sample "{sample}" to be generated.') + + def test_can_sample(self): + """Tests that the underlying timer can sample.""" + self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0.05) + element_sampler = self.sampler.element_sampler + element_sampler.el = 'a' + element_sampler.has_element = True + + samples = self.wait_for_samples(self.sampler, expected_num=1) + self.assertEqual(samples, [PRIMITIVES_CODER.encode_nested('a')]) def test_acts_like_circular_buffer(self): """Tests that the buffer overwrites old samples.""" - coder = FastPrimitivesCoder() - sampler = OutputSampler(coder, max_samples=2) + self.sampler = OutputSampler( + PRIMITIVES_CODER, max_samples=2, sample_every_sec=0) + element_sampler = self.sampler.element_sampler for i in range(10): - sampler.sample(i) + element_sampler.el = i + element_sampler.has_element = True + self.sampler.sample() - self.assertEqual(sampler.flush(), [coder.encode_nested(i) for i in (8, 9)]) + self.assertEqual( + self.sampler.flush(), + [PRIMITIVES_CODER.encode_nested(i) for i in (8, 9)]) - def test_samples_every_n_secs(self): + def test_samples_multiple_times(self): """Tests that the buffer overwrites old samples.""" - coder = FastPrimitivesCoder() - sampler = OutputSampler( - coder, max_samples=1, sample_every_sec=10, clock=self.fake_clock) + self.sampler = OutputSampler( + PRIMITIVES_CODER, max_samples=10, sample_every_sec=0.05) # Always samples the first ten. for i in range(10): - sampler.sample(i) - self.assertEqual(sampler.flush(), [coder.encode_nested(9)]) - - # Start at t=0 - sampler.sample(10) - self.assertEqual(len(sampler.flush()), 0) - - # Still not over threshold yet. - self.control_time(9) - for i in range(100): - sampler.sample(i) - self.assertEqual(len(sampler.flush()), 0) - - # First sample after 10s. - self.control_time(10) - sampler.sample(10) - self.assertEqual(sampler.flush(), [coder.encode_nested(10)]) - - # No samples between tresholds. - self.control_time(15) - for i in range(100): - sampler.sample(i) - self.assertEqual(len(sampler.flush()), 0) - - # Second sample after 20s. - self.control_time(20) - sampler.sample(11) - self.assertEqual(sampler.flush(), [coder.encode_nested(11)]) + self.ensure_sample(self.sampler, i, i + 1) + self.assertEqual( + self.sampler.flush(), + [PRIMITIVES_CODER.encode_nested(i) for i in range(10)]) def test_can_sample_windowed_value(self): """Tests that values with WindowedValueCoders are sampled wholesale.""" - data_sampler = DataSampler() coder = WindowedValueCoder(FastPrimitivesCoder()) value = WindowedValue('Hello, World!', 0, [GlobalWindow()]) - data_sampler.sample_output('1', coder).sample(value) - self.assertEqual( - data_sampler.samples(), {'1': [coder.encode_nested(value)]}) + self.sampler = OutputSampler(coder, sample_every_sec=0) + element_sampler = self.sampler.element_sampler + element_sampler.el = value + element_sampler.has_element = True + self.sampler.sample() + + self.assertEqual(self.sampler.flush(), [coder.encode_nested(value)]) def test_can_sample_non_windowed_value(self): """Tests that windowed values with WindowedValueCoders sample only the @@ -179,13 +405,16 @@ class OutputSamplerTest(unittest.TestCase): even if the coder is not a WindowedValueCoder. In this case, the value must be retrieved from the WindowedValue to match the correct coder. """ - data_sampler = DataSampler() - coder = FastPrimitivesCoder() - data_sampler.sample_output('1', coder).sample( - WindowedValue('Hello, World!', 0, [GlobalWindow()])) + value = WindowedValue('Hello, World!', 0, [GlobalWindow()]) + + self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0) + element_sampler = self.sampler.element_sampler + element_sampler.el = value + element_sampler.has_element = True + self.sampler.sample() self.assertEqual( - data_sampler.samples(), {'1': [coder.encode_nested('Hello, World!')]}) + self.sampler.flush(), [PRIMITIVES_CODER.encode_nested('Hello, World!')]) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd index cf1f1b3fb51..725c5d2346a 100644 --- a/sdks/python/apache_beam/runners/worker/operations.pxd +++ b/sdks/python/apache_beam/runners/worker/operations.pxd @@ -34,6 +34,7 @@ cdef class ConsumerSet(Receiver): cdef public step_name cdef public output_index cdef public coder + cdef public object element_sampler cpdef update_counters_start(self, WindowedValue windowed_value) cpdef update_counters_finish(self) diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index fdde55042ec..ca5e6552973 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -53,6 +53,7 @@ from apache_beam.runners.common import Receiver from apache_beam.runners.worker import opcounters from apache_beam.runners.worker import operation_specs from apache_beam.runners.worker import sideinputs +from apache_beam.runners.worker.data_sampler import ElementSampler from apache_beam.transforms import sideinputs as apache_sideinputs from apache_beam.transforms import combiners from apache_beam.transforms import core @@ -69,6 +70,8 @@ if TYPE_CHECKING: from apache_beam.runners.sdf_utils import SplitResultPrimary from apache_beam.runners.sdf_utils import SplitResultResidual from apache_beam.runners.worker.bundle_processor import ExecutionContext + from apache_beam.runners.worker.data_sampler import DataSampler + from apache_beam.runners.worker.data_sampler import OutputSampler from apache_beam.runners.worker.statesampler import StateSampler from apache_beam.transforms.userstate import TimerSpec @@ -123,6 +126,7 @@ class ConsumerSet(Receiver): coder, producer_type_hints, producer_batch_converter, # type: Optional[BatchConverter] + element_sampler=None, # type: Optional[ElementSampler] ): # type: (...) -> ConsumerSet if len(consumers) == 1: @@ -139,7 +143,8 @@ class ConsumerSet(Receiver): output_index, consumer, coder, - producer_type_hints) + producer_type_hints, + element_sampler) return GeneralPurposeConsumerSet( counter_factory, @@ -148,7 +153,8 @@ class ConsumerSet(Receiver): coder, producer_type_hints, consumers, - producer_batch_converter) + producer_batch_converter, + element_sampler) def __init__(self, counter_factory, @@ -157,7 +163,8 @@ class ConsumerSet(Receiver): consumers, coder, producer_type_hints, - producer_batch_converter + producer_batch_converter, + element_sampler ): self.opcounter = opcounters.OperationCounters( counter_factory, @@ -171,6 +178,7 @@ class ConsumerSet(Receiver): self.output_index = output_index self.coder = coder self.consumers = consumers + self.element_sampler = element_sampler def try_split(self, fraction_of_remainder): # type: (...) -> Optional[Any] @@ -197,6 +205,16 @@ class ConsumerSet(Receiver): # type: (WindowedValue) -> None self.opcounter.update_from(windowed_value) + # The following code is optimized by inlining a function call. Because this + # is called for every element, a function call is too expensive (order of + # 100s of nanoseconds). Furthermore, a lock was purposefully not used + # between here and the DataSampler as an additional operation. The tradeoff + # is that some samples might be dropped, but it is better than the + # alternative which is double sampling the same element. + if self.element_sampler is not None: + self.element_sampler.el = windowed_value + self.element_sampler.has_element = True + def update_counters_finish(self): # type: () -> None self.opcounter.update_collect() @@ -223,7 +241,8 @@ class SingletonElementConsumerSet(ConsumerSet): output_index, consumer, # type: Operation coder, - producer_type_hints + producer_type_hints, + element_sampler ): super().__init__( counter_factory, @@ -231,7 +250,8 @@ class SingletonElementConsumerSet(ConsumerSet): output_index, [consumer], coder, producer_type_hints, - None) + None, + element_sampler) self.consumer = consumer def receive(self, windowed_value): @@ -268,7 +288,8 @@ class GeneralPurposeConsumerSet(ConsumerSet): coder, producer_type_hints, consumers, # type: List[Operation] - producer_batch_converter): + producer_batch_converter, + element_sampler): super().__init__( counter_factory, step_name, @@ -276,7 +297,8 @@ class GeneralPurposeConsumerSet(ConsumerSet): consumers, coder, producer_type_hints, - producer_batch_converter) + producer_batch_converter, + element_sampler) self.producer_batch_converter = producer_batch_converter @@ -432,8 +454,8 @@ class Operation(object): self.setup_done = False self.step_name = None # type: Optional[str] - def setup(self): - # type: () -> None + def setup(self, data_sampler=None): + # type: (Optional[DataSampler]) -> None """Set up operation. @@ -441,10 +463,18 @@ class Operation(object): with self.scoped_start_state: self.debug_logging_enabled = logging.getLogger().isEnabledFor( logging.DEBUG) + transform_id = self.name_context.transform_id + # Everything except WorkerSideInputSource, which is not a # top-level operation, should have output_coders #TODO(pabloem): Define better what step name is used here. if getattr(self.spec, 'output_coders', None): + + def get_element_sampler(output_num): + if data_sampler is None: + return None + return data_sampler.sampler_for_output(transform_id, output_num) + self.receivers = [ ConsumerSet.create( self.counter_factory, @@ -454,7 +484,7 @@ class Operation(object): coder, self._get_runtime_performance_hints(), self.get_output_batch_converter(), - ) for i, + get_element_sampler(i)) for i, coder in enumerate(self.spec.output_coders) ] self.setup_done = True @@ -759,7 +789,7 @@ class DoOperation(Operation): counter_factory, sampler, side_input_maps=None, - user_state_context=None + user_state_context=None, ): super(DoOperation, self).__init__(name, spec, counter_factory, sampler) self.side_input_maps = side_input_maps @@ -828,10 +858,10 @@ class DoOperation(Operation): yield apache_sideinputs.SideInputMap( view_class, view_options, sideinputs.EmulatedIterable(iterator_fn)) - def setup(self): - # type: () -> None + def setup(self, data_sampler=None): + # type: (Optional[DataSampler]) -> None with self.scoped_start_state: - super(DoOperation, self).setup() + super(DoOperation, self).setup(data_sampler) # See fn_data in dataflow_runner.py fn, args, kwargs, tags_and_types, window_fn = ( @@ -1107,11 +1137,11 @@ class CombineOperation(Operation): self.phased_combine_fn = ( PhasedCombineFnExecutor(self.spec.phase, fn, args, kwargs)) - def setup(self): - # type: () -> None + def setup(self, data_sampler=None): + # type: (Optional[DataSampler]) -> None with self.scoped_start_state: _LOGGER.debug('Setup called for %s', self) - super(CombineOperation, self).setup() + super(CombineOperation, self).setup(data_sampler) self.phased_combine_fn.combine_fn.setup() def process(self, o): @@ -1226,11 +1256,11 @@ class PGBKCVOperation(Operation): self.key_count = 0 self.table = {} - def setup(self): - # type: () -> None + def setup(self, data_sampler=None): + # type: (Optional[DataSampler]) -> None with self.scoped_start_state: _LOGGER.debug('Setup called for %s', self) - super(PGBKCVOperation, self).setup() + super(PGBKCVOperation, self).setup(data_sampler) self.combine_fn.setup() def process(self, wkv): diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index f7da86234d5..f5b1456c251 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -268,6 +268,8 @@ class SdkHarness(object): work_request) finally: self._alive = False + if self.data_sampler: + self.data_sampler.stop() _LOGGER.info('No more requests from control plane') _LOGGER.info('SDK Harness waiting for in-flight requests to complete') @@ -379,7 +381,7 @@ class SdkHarness(object): def get_samples(request): # type: (beam_fn_api_pb2.InstructionRequest) -> beam_fn_api_pb2.InstructionResponse samples: Dict[str, List[bytes]] = {} - if self.data_sampler: + if self.data_sampler is not None: samples = self.data_sampler.samples(request.sample_data.pcollection_ids) sample_response = beam_fn_api_pb2.SampleDataResponse() diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py index b3ad349be75..d12dc48ed29 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py @@ -23,6 +23,7 @@ import contextlib import logging import unittest from collections import namedtuple +from typing import Any import grpc import hamcrest as hc @@ -39,7 +40,6 @@ from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.portability.api import metrics_pb2 from apache_beam.runners.worker import sdk_worker from apache_beam.runners.worker import statecache -from apache_beam.runners.worker.data_sampler import DataSampler from apache_beam.runners.worker.sdk_worker import BundleProcessorCache from apache_beam.runners.worker.sdk_worker import GlobalCachingStateHandler from apache_beam.runners.worker.sdk_worker import SdkWorker @@ -279,18 +279,19 @@ class SdkWorkerTest(unittest.TestCase): def test_data_sampling_response(self): # Create a data sampler with some fake sampled data. This data will be seen # in the sample response. - data_sampler = DataSampler() coder = FastPrimitivesCoder() - # Sample from two fake PCollections to test that all sampled PCollections - # are present in the response. Also adds an extra sample to test that - # filtering is forwarded to the DataSampler. - data_sampler.sample_output('pcoll_id_1', - coder).sample('hello, world from pcoll_id_1!') - data_sampler.sample_output('pcoll_id_2', - coder).sample('hello, world from pcoll_id_2!') - data_sampler.sample_output('bad_pcoll_id', - coder).sample('if present bug in filter') + class FakeDataSampler: + def samples(self, pcollection_ids): + return { + 'pcoll_id_1': [coder.encode_nested('a')], + 'pcoll_id_2': [coder.encode_nested('b')], + } + + def stop(self): + pass + + data_sampler = FakeDataSampler() # Create and send the fake reponse. The SdkHarness should query the # DataSampler and fill out the sample response. @@ -310,14 +311,12 @@ class SdkWorkerTest(unittest.TestCase): 'pcoll_id_1': beam_fn_api_pb2.SampleDataResponse.ElementList( elements=[ beam_fn_api_pb2.SampledElement( - element=coder.encode_nested( - 'hello, world from pcoll_id_1!')) + element=coder.encode_nested('a')) ]), 'pcoll_id_2': beam_fn_api_pb2.SampleDataResponse.ElementList( elements=[ beam_fn_api_pb2.SampledElement( - element=coder.encode_nested( - 'hello, world from pcoll_id_2!')) + element=coder.encode_nested('b')) ]) }))