This is an automated email from the ASF dual-hosted git repository.

damondouglas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 92d82dbf7e2 Python data sampling optimization  (#27157)
92d82dbf7e2 is described below

commit 92d82dbf7e27e856cfc3b389b5abad7b67657550
Author: Sam Rohde <rohde.sam...@gmail.com>
AuthorDate: Fri Jun 23 14:23:57 2023 -0700

    Python data sampling optimization  (#27157)
    
    * Python optimization work
    
    * Exception Sampling perf tests
    
    * add better element sampling microbenchmark
    
    * slowly move towards plumbing the samplers to the bundle processors
    
    * cleaned up
    
    * starting clean up and more testing
    
    * finish tests
    
    * fix unused data_sampler args and comments
    
    * yapf, comments, and simplifications
    
    * linter
    
    * lint and mypy
    
    * linter
    
    * run tests
    
    * address review comments
    
    * run tests
    
    ---------
    
    Co-authored-by: Sam Rohde <sro...@google.com>
---
 .../apache_beam/runners/worker/bundle_processor.py |  90 +----
 .../runners/worker/bundle_processor_test.py        |  89 +++--
 .../apache_beam/runners/worker/data_sampler.py     | 193 ++++++---
 .../runners/worker/data_sampler_test.py            | 435 ++++++++++++++++-----
 .../apache_beam/runners/worker/operations.pxd      |   1 +
 .../apache_beam/runners/worker/operations.py       |  70 +++-
 .../apache_beam/runners/worker/sdk_worker.py       |   4 +-
 .../apache_beam/runners/worker/sdk_worker_test.py  |  29 +-
 8 files changed, 600 insertions(+), 311 deletions(-)

diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 6d814331fd6..3adb26552d0 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -66,7 +66,6 @@ from apache_beam.runners.worker import data_sampler
 from apache_beam.runners.worker import operation_specs
 from apache_beam.runners.worker import operations
 from apache_beam.runners.worker import statesampler
-from apache_beam.runners.worker.data_sampler import OutputSampler
 from apache_beam.transforms import TimeDomain
 from apache_beam.transforms import core
 from apache_beam.transforms import environments
@@ -194,8 +193,8 @@ class DataInputOperation(RunnerIOOperation):
     self.stop = float('inf')
     self.started = False
 
-  def setup(self):
-    super().setup()
+  def setup(self, data_sampler=None):
+    super().setup(data_sampler)
     # We must do this manually as we don't have a spec or spec.output_coders.
     self.receivers = [
         operations.ConsumerSet.create(
@@ -897,39 +896,11 @@ class BundleProcessor(object):
         'fnapi-step-%s' % self.process_bundle_descriptor.id,
         self.counter_factory)
 
-    if self.data_sampler:
-      self.add_data_sampling_operations(process_bundle_descriptor)
-
     self.ops = self.create_execution_tree(self.process_bundle_descriptor)
     for op in reversed(self.ops.values()):
-      op.setup()
+      op.setup(self.data_sampler)
     self.splitting_lock = threading.Lock()
 
-  def add_data_sampling_operations(self, pbd):
-    # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None
-
-    """Adds a DataSamplingOperation to every PCollection.
-
-    Implementation note: the alternative to this, is to add modify each
-    Operation and forward a DataSampler to manually sample when an element is
-    processed. This gets messy very quickly and is not future-proof as new
-    operation types will need to be updated. This is the cleanest way of adding
-    new operations to the final execution tree.
-    """
-    coder = coders.FastPrimitivesCoder()
-
-    for pcoll_id in pbd.pcollections:
-      transform_id = 'synthetic-data-sampling-transform-{}'.format(pcoll_id)
-      transform_proto: beam_runner_api_pb2.PTransform = pbd.transforms[
-          transform_id]
-      transform_proto.unique_name = transform_id
-      transform_proto.spec.urn = SYNTHETIC_DATA_SAMPLING_URN
-
-      coder_id = pbd.pcollections[pcoll_id].coder_id
-      transform_proto.spec.payload = coder.encode((pcoll_id, coder_id))
-
-      transform_proto.inputs['None'] = pcoll_id
-
   def create_execution_tree(
       self,
       descriptor  # type: beam_fn_api_pb2.ProcessBundleDescriptor
@@ -966,6 +937,12 @@ class BundleProcessor(object):
           for tag,
           pcoll_id in descriptor.transforms[transform_id].outputs.items()
       }
+
+      # Initialize transform-specific state in the Data Sampler.
+      if self.data_sampler:
+        self.data_sampler.initialize_samplers(
+            transform_id, descriptor, transform_factory.get_coder)
+
       return transform_factory.create_operation(
           transform_id, transform_consumers)
 
@@ -1987,52 +1964,3 @@ def create_to_string_fn(
 
   return _create_simple_pardo_operation(
       factory, transform_id, transform_proto, consumers, ToString())
-
-
-class DataSamplingOperation(operations.Operation):
-  """Operation that samples incoming elements."""
-
-  def __init__(
-      self,
-      name_context,  # type: common.NameContext
-      counter_factory,  # type: counters.CounterFactory
-      state_sampler,  # type: statesampler.StateSampler
-      pcoll_id,  # type: str
-      sample_coder,  # type: coders.Coder
-      data_sampler,  # type: data_sampler.DataSampler
-  ):
-    # type: (...) -> None
-    super().__init__(name_context, None, counter_factory, state_sampler)
-    self._coder = sample_coder  # type: coders.Coder
-    self._pcoll_id = pcoll_id  # type: str
-
-    self._sampler: OutputSampler = data_sampler.sample_output(
-        self._pcoll_id, sample_coder)
-
-  def process(self, windowed_value):
-    # type: (windowed_value.WindowedValue) -> None
-    self._sampler.sample(windowed_value)
-
-
-@BeamTransformFactory.register_urn(SYNTHETIC_DATA_SAMPLING_URN, (bytes))
-def create_data_sampling_op(
-    factory,  # type: BeamTransformFactory
-    transform_id,  # type: str
-    transform_proto,  # type: beam_runner_api_pb2.PTransform
-    pcoll_and_coder_id,  # type: bytes
-    consumers,  # type: Dict[str, List[operations.Operation]]
-):
-  # Creating this operation should only occur when data sampling is enabled.
-  data_sampler = factory.data_sampler
-  assert data_sampler is not None
-
-  coder = coders.FastPrimitivesCoder()
-  pcoll_id, coder_id = coder.decode(pcoll_and_coder_id)
-  return DataSamplingOperation(
-      common.NameContext(transform_proto.unique_name, transform_id),
-      factory.counter_factory,
-      factory.state_sampler,
-      pcoll_id,
-      factory.get_coder(coder_id),
-      data_sampler,
-  )
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
index 6b21071a875..db9e35a0baf 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
@@ -18,14 +18,16 @@
 """Unit tests for bundle processing."""
 # pytype: skip-file
 
+import time
 import unittest
+from typing import Dict
+from typing import List
 
 from apache_beam.coders.coders import FastPrimitivesCoder
 from apache_beam.portability import common_urns
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.runners import common
 from apache_beam.runners.worker import operations
-from apache_beam.runners.worker.bundle_processor import 
SYNTHETIC_DATA_SAMPLING_URN
 from apache_beam.runners.worker.bundle_processor import BeamTransformFactory
 from apache_beam.runners.worker.bundle_processor import BundleProcessor
 from apache_beam.runners.worker.bundle_processor import DataInputOperation
@@ -190,18 +192,25 @@ def element_split(frac, index):
 class TestOperation(operations.Operation):
   """Test operation that forwards its payload to consumers."""
   class Spec:
-    def __init__(self):
-      self.output_coders = [FastPrimitivesCoder()]
+    def __init__(self, transform_proto):
+      self.output_coders = [
+          FastPrimitivesCoder() for _ in transform_proto.outputs
+      ]
 
   def __init__(
       self,
+      transform_proto,
       name_context,
       counter_factory,
       state_sampler,
       consumers,
       payload,
   ):
-    super().__init__(name_context, self.Spec(), counter_factory, state_sampler)
+    super().__init__(
+        name_context,
+        self.Spec(transform_proto),
+        counter_factory,
+        state_sampler)
     self.payload = payload
 
     for _, consumer_ops in consumers.items():
@@ -212,8 +221,9 @@ class TestOperation(operations.Operation):
     super().start()
 
     # Not using windowing logic, so just using simple defaults here.
-    self.process(
-        WindowedValue(self.payload, timestamp=0, windows=[GlobalWindow()]))
+    if self.payload:
+      self.process(
+          WindowedValue(self.payload, timestamp=0, windows=[GlobalWindow()]))
 
   def process(self, windowed_value):
     self.output(windowed_value)
@@ -222,6 +232,7 @@ class TestOperation(operations.Operation):
 @BeamTransformFactory.register_urn('beam:internal:testop:v1', bytes)
 def create_test_op(factory, transform_id, transform_proto, payload, consumers):
   return TestOperation(
+      transform_proto,
       common.NameContext(transform_proto.unique_name, transform_id),
       factory.counter_factory,
       factory.state_sampler,
@@ -241,37 +252,24 @@ class DataSamplingTest(unittest.TestCase):
     _ = BundleProcessor(descriptor, None, None)
     self.assertEqual(len(descriptor.transforms), 0)
 
-  def test_adds_data_sampling_operations(self):
-    """Test that providing the sampler creates sampling PTransforms.
+  def wait_for_samples(self, data_sampler: DataSampler,
+                       pcollection_id: str) -> Dict[str, List[bytes]]:
+    """Waits for samples from the given PCollection to exist."""
+    now = time.time()
+    end = now + 30
 
-    Data sampling is implemented by modifying the ProcessBundleDescriptor with
-    additional sampling PTransforms reading from each PCllection.
-    """
-    data_sampler = DataSampler()
+    samples = {}
+    while now < end:
+      time.sleep(0.1)
+      now = time.time()
+      samples.update(data_sampler.samples([pcollection_id]))
 
-    # Data sampling samples the PCollections, which adds a PTransform to read
-    # from each PCollection. So add a simple PCollection here to create the
-    # DataSamplingOperation.
-    PCOLLECTION_ID = 'pc'
-    CODER_ID = 'c'
-    descriptor = beam_fn_api_pb2.ProcessBundleDescriptor()
-    descriptor.pcollections[PCOLLECTION_ID].unique_name = PCOLLECTION_ID
-    descriptor.pcollections[PCOLLECTION_ID].coder_id = CODER_ID
-    descriptor.coders[
-        CODER_ID].spec.urn = common_urns.StandardCoders.Enum.BYTES.urn
-
-    _ = BundleProcessor(descriptor, None, None, data_sampler=data_sampler)
-
-    # Assert that the data sampling transform was created.
-    self.assertEqual(len(descriptor.transforms), 1)
-    sampling_transform = list(descriptor.transforms.values())[0]
+      if samples:
+        return samples
 
-    # Ensure that the data sampling transform has the correct spec and that 
it's
-    # sampling the correct PCollection.
-    self.assertEqual(
-        sampling_transform.unique_name, 'synthetic-data-sampling-transform-pc')
-    self.assertEqual(sampling_transform.spec.urn, SYNTHETIC_DATA_SAMPLING_URN)
-    self.assertEqual(sampling_transform.inputs, {'None': PCOLLECTION_ID})
+    self.assertLess(
+        now, end, 'Timed out waiting for samples for 
{}'.format(pcollection_id))
+    return {}
 
   def test_can_sample(self):
     """Test that elements are sampled.
@@ -281,7 +279,7 @@ class DataSamplingTest(unittest.TestCase):
     DataSamplingOperations and samples are taken from in-flight elements. These
     elements are then finally queried.
     """
-    data_sampler = DataSampler()
+    data_sampler = DataSampler(sample_every_sec=0.1)
     descriptor = beam_fn_api_pb2.ProcessBundleDescriptor()
 
     # Create the PCollection to sample from.
@@ -295,18 +293,23 @@ class DataSamplingTest(unittest.TestCase):
     # Add a simple transform to inject an element into the data sampler. This
     # doesn't use the FnApi, so this uses a simple operation to forward its
     # payload to consumers.
-    test_transform = descriptor.transforms['test_transform']
+    TRANSFORM_ID = 'test_transform'
+    test_transform = descriptor.transforms[TRANSFORM_ID]
     test_transform.outputs['None'] = PCOLLECTION_ID
     test_transform.spec.urn = 'beam:internal:testop:v1'
     test_transform.spec.payload = b'hello, world!'
 
-    # Create and process a fake bundle. The instruction id doesn't matter here.
-    processor = BundleProcessor(
-        descriptor, None, None, data_sampler=data_sampler)
-    processor.process_bundle('instruction_id')
-
-    self.assertEqual(
-        data_sampler.samples(), {PCOLLECTION_ID: [b'\rhello, world!']})
+    try:
+      # Create and process a fake bundle. The instruction id doesn't matter
+      # here.
+      processor = BundleProcessor(
+          descriptor, None, None, data_sampler=data_sampler)
+      processor.process_bundle('instruction_id')
+
+      samples = self.wait_for_samples(data_sampler, PCOLLECTION_ID)
+      self.assertEqual(samples, {PCOLLECTION_ID: [b'\rhello, world!']})
+    finally:
+      data_sampler.stop()
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/worker/data_sampler.py 
b/sdks/python/apache_beam/runners/worker/data_sampler.py
index 7cc8152693d..9c74188a699 100644
--- a/sdks/python/apache_beam/runners/worker/data_sampler.py
+++ b/sdks/python/apache_beam/runners/worker/data_sampler.py
@@ -19,9 +19,12 @@
 
 # pytype: skip-file
 
+from __future__ import annotations
+
 import collections
+import logging
 import threading
-import time
+from threading import Timer
 from typing import Any
 from typing import DefaultDict
 from typing import Deque
@@ -34,28 +37,73 @@ from typing import Union
 from apache_beam.coders.coder_impl import CoderImpl
 from apache_beam.coders.coder_impl import WindowedValueCoderImpl
 from apache_beam.coders.coders import Coder
+from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.utils.windowed_value import WindowedValue
 
+_LOGGER = logging.getLogger(__name__)
+
+
+class SampleTimer:
+  """Periodic timer for sampling elements."""
+  def __init__(self, timeout_secs: float, sampler: OutputSampler) -> None:
+    self._timeout_secs = timeout_secs
+    self._timer = Timer(self._timeout_secs, self.sample)
+    self._sampler = sampler
+
+  def reset(self):
+    self._timer.cancel()
+    self._timer = Timer(self._timeout_secs, self.sample)
+    self._timer.start()
+
+  def stop(self):
+    self._timer.cancel()
+
+  def sample(self):
+    self._sampler.sample()
+    self.reset()
+
+
+class ElementSampler:
+  """Record class to hold sampled elements.
+
+  This class is used as an optimization to quickly sample elements. This is a
+  shared reference between the Operation and the OutputSampler.
+  """
+
+  # Is true iff the `el` has been set with a sample.
+  has_element: bool
+
+  # The sampled element. Note that `None` is a valid element and cannot be uesd
+  # as a sentintel to check if there is a sample. Use the `has_element` flag to
+  # check for this case.
+  el: Any
+
 
 class OutputSampler:
   """Represents a way to sample an output of a PTransform.
 
-  This is configurable to only keep max_samples (see constructor) sampled
-  elements in memory. The first 10 elements are always sampled, then after each
-  sample_every_sec (see constructor).
+  This is configurable to only keep `max_samples` (see constructor) sampled
+  elements in memory. Samples are taken every `sample_every_sec`.
   """
   def __init__(
       self,
       coder: Coder,
       max_samples: int = 10,
-      sample_every_sec: float = 30,
-      clock=None) -> None:
+      sample_every_sec: float = 5) -> None:
     self._samples: Deque[Any] = collections.deque(maxlen=max_samples)
+    self._samples_lock: threading.Lock = threading.Lock()
     self._coder_impl: CoderImpl = coder.get_impl()
-    self._sample_count: int = 0
-    self._sample_every_sec: float = sample_every_sec
-    self._clock = clock
-    self._last_sample_sec: float = self.time()
+    self._sample_timer = SampleTimer(sample_every_sec, self)
+    self.element_sampler = ElementSampler()
+    self.element_sampler.has_element = False
+
+    # For testing, it's easier to disable the Timer and manually sample.
+    if sample_every_sec > 0:
+      self._sample_timer.reset()
+
+  def stop(self) -> None:
+    """Stops sampling."""
+    self._sample_timer.stop()
 
   def remove_windowed_value(self, el: Union[WindowedValue, Any]) -> Any:
     """Retrieves the value from the WindowedValue.
@@ -63,39 +111,30 @@ class OutputSampler:
     The Python SDK passes elements as WindowedValues, which may not match the
     coder for that particular PCollection.
     """
-    if isinstance(el, WindowedValue):
-      return self.remove_windowed_value(el.value)
+    while isinstance(el, WindowedValue):
+      el = el.value
     return el
 
-  def time(self) -> float:
-    """Returns the current time. Used for mocking out the clock for testing."""
-    return self._clock.time() if self._clock else time.time()
-
-  def flush(self) -> List[bytes]:
-    """Returns all samples and clears buffer."""
-    if isinstance(self._coder_impl, WindowedValueCoderImpl):
-      samples = [s for s in self._samples]
-    else:
-      samples = [self.remove_windowed_value(s) for s in self._samples]
-
-    # Encode in the nested context b/c this ensures that the SDK can decode the
-    # bytes with the ToStringFn.
-    self._samples.clear()
-    return [self._coder_impl.encode_nested(s) for s in samples]
+  def flush(self, clear: bool = True) -> List[bytes]:
+    """Returns all samples and optionally clears buffer if clear is True."""
+    with self._samples_lock:
+      if isinstance(self._coder_impl, WindowedValueCoderImpl):
+        samples = [s for s in self._samples]
+      else:
+        samples = [self.remove_windowed_value(s) for s in self._samples]
 
-  def sample(self, element: Any) -> None:
-    """Samples the given element to an internal buffer.
+      # Encode in the nested context b/c this ensures that the SDK can decode
+      # the bytes with the ToStringFn.
+      if clear:
+        self._samples.clear()
+      return [self._coder_impl.encode_nested(s) for s in samples]
 
-    Samples are only taken for the first 10 elements then every
-    `self._sample_every_sec` second after.
-    """
-    self._sample_count += 1
-    now = self.time()
-    sample_diff = now - self._last_sample_sec
-
-    if self._sample_count <= 10 or sample_diff >= self._sample_every_sec:
-      self._samples.append(element)
-      self._last_sample_sec = now
+  def sample(self) -> None:
+    """Samples the given element to an internal buffer."""
+    with self._samples_lock:
+      if self.element_sampler.has_element:
+        self.element_sampler.has_element = False
+        self._samples.append(self.element_sampler.el)
 
 
 class DataSampler:
@@ -103,14 +142,17 @@ class DataSampler:
 
   This class is meant to be a singleton with regard to a particular
   `sdk_worker.SdkHarness`. When creating the operators, individual
-  `OutputSampler`s are created from `DataSampler.sample_output`. This allows 
for
-  multi-threaded sampling of a PCollection across the SdkHarness.
+  `OutputSampler`s are created from `DataSampler.initialize_samplers`. This
+  allows for multi-threaded sampling of a PCollection across the SdkHarness.
 
   Samples generated during execution can then be sampled with the `samples`
   method. This filters samples from the given pcollection ids.
   """
   def __init__(
-      self, max_samples: int = 10, sample_every_sec: float = 30) -> None:
+      self,
+      max_samples: int = 10,
+      sample_every_sec: float = 30,
+      clock=None) -> None:
     # Key is PCollection id. Is guarded by the _samplers_lock.
     self._samplers: Dict[str, OutputSampler] = {}
     # Bundles are processed in parallel, so new samplers may be added when the
@@ -118,17 +160,72 @@ class DataSampler:
     self._samplers_lock: threading.Lock = threading.Lock()
     self._max_samples = max_samples
     self._sample_every_sec = sample_every_sec
+    self._element_samplers: Dict[str, List[ElementSampler]] = {}
+    self._clock = clock
 
-  def sample_output(self, pcoll_id: str, coder: Coder) -> OutputSampler:
-    """Create or get an OutputSampler for a pcoll_id."""
+  def stop(self) -> None:
+    """Stops all sampling, does not clear samplers in case there are 
outstanding
+    samples.
+    """
     with self._samplers_lock:
-      if pcoll_id in self._samplers:
-        sampler = self._samplers[pcoll_id]
-      else:
+      for sampler in self._samplers.values():
+        sampler.stop()
+
+  def sampler_for_output(
+      self, transform_id: str, output_index: int) -> ElementSampler:
+    """Returns the ElementSampler for the given output."""
+    try:
+      return self._element_samplers[transform_id][output_index]
+    except KeyError:
+      _LOGGER.warning(
+          f'Out-of-bounds access for transform "{transform_id}" ' +
+          'and output "{output_index}" ElementSampler. This may ' +
+          'indicate that the transform was improperly ' +
+          'initialized with the DataSampler.')
+      return ElementSampler()
+
+  def initialize_samplers(
+      self,
+      transform_id: str,
+      descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
+      coder_factory) -> List[ElementSampler]:
+    """Creates the OutputSamplers for the given PTransform.
+
+    This initializes the samplers only once per PCollection Id. Note that an
+    OutputSampler is created per PCollection and an ElementSampler is created
+    per OutputSampler. This means that multiple ProcessBundles can and will
+    share the same ElementSampler for a given PCollection.
+    """
+    transform_proto = descriptor.transforms[transform_id]
+    with self._samplers_lock:
+      # Initialize the samplers.
+      for pcoll_id in transform_proto.outputs.values():
+        # Only initialize new PCollections.
+        if pcoll_id in self._samplers:
+          continue
+
+        # Create the sampler with the corresponding coder.
+        coder_id = descriptor.pcollections[pcoll_id].coder_id
+        coder = coder_factory(coder_id)
         sampler = OutputSampler(
             coder, self._max_samples, self._sample_every_sec)
         self._samplers[pcoll_id] = sampler
-      return sampler
+
+      # Next update the lookup table for ElementSamplers for a given 
PTransform.
+      # Operations look up the ElementSampler for an output based on the index
+      # of the tag in the PTransform's outputs. The following code intializes
+      # the array with ElementSamplers in the correct indices.
+      if transform_id in self._element_samplers:
+        return self._element_samplers[transform_id]
+
+      outputs = transform_proto.outputs
+      samplers = [
+          self._samplers[pcoll_id].element_sampler
+          for pcoll_id in outputs.values()
+      ]
+      self._element_samplers[transform_id] = samplers
+
+      return samplers
 
   def samples(
       self,
diff --git a/sdks/python/apache_beam/runners/worker/data_sampler_test.py 
b/sdks/python/apache_beam/runners/worker/data_sampler_test.py
index 6a163c83b28..346251ee216 100644
--- a/sdks/python/apache_beam/runners/worker/data_sampler_test.py
+++ b/sdks/python/apache_beam/runners/worker/data_sampler_test.py
@@ -17,159 +17,385 @@
 
 # pytype: skip-file
 
+import time
 import unittest
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
 
 from apache_beam.coders import FastPrimitivesCoder
 from apache_beam.coders import WindowedValueCoder
-from apache_beam.coders.coders import Coder
+from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.runners.worker.data_sampler import DataSampler
 from apache_beam.runners.worker.data_sampler import OutputSampler
 from apache_beam.transforms.window import GlobalWindow
 from apache_beam.utils.windowed_value import WindowedValue
 
+MAIN_TRANSFORM_ID = 'transform'
+MAIN_PCOLLECTION_ID = 'pcoll'
+PRIMITIVES_CODER = FastPrimitivesCoder()
+
+
+class FakeClock:
+  def __init__(self):
+    self.clock = 0
+
+  def time(self):
+    return self.clock
+
 
 class DataSamplerTest(unittest.TestCase):
+  def make_test_descriptor(
+      self,
+      outputs: Optional[List[str]] = None,
+      transforms: Optional[List[str]] = None
+  ) -> beam_fn_api_pb2.ProcessBundleDescriptor:
+    outputs = outputs or [MAIN_PCOLLECTION_ID]
+    transforms = transforms or [MAIN_TRANSFORM_ID]
+
+    descriptor = beam_fn_api_pb2.ProcessBundleDescriptor()
+    for transform_id in transforms:
+      transform = descriptor.transforms[transform_id]
+      for output in outputs:
+        transform.outputs[output] = output
+
+    return descriptor
+
+  def setUp(self):
+    self.data_sampler = DataSampler(sample_every_sec=0.1)
+
+  def tearDown(self):
+    self.data_sampler.stop()
+
+  def wait_for_samples(
+      self, data_sampler: DataSampler,
+      pcollection_ids: List[str]) -> Dict[str, List[bytes]]:
+    """Waits for samples to exist for the given PCollections."""
+    now = time.time()
+    end = now + 30
+
+    samples = {}
+    while now < end:
+      time.sleep(0.1)
+      now = time.time()
+      samples.update(data_sampler.samples(pcollection_ids))
+
+      if not samples:
+        continue
+
+      has_all = all(pcoll_id in samples for pcoll_id in pcollection_ids)
+      if has_all:
+        return samples
+
+    self.assertLess(
+        now,
+        end,
+        'Timed out waiting for samples for {}'.format(pcollection_ids))
+    return {}
+
+  def primitives_coder_factory(self, _):
+    return PRIMITIVES_CODER
+
+  def gen_sample(
+      self,
+      data_sampler: DataSampler,
+      element: Any,
+      output_index: int,
+      transform_id: str = MAIN_TRANSFORM_ID):
+    """Generates a sample for the given transform's output."""
+    element_sampler = self.data_sampler.sampler_for_output(
+        transform_id, output_index)
+    element_sampler.el = element
+    element_sampler.has_element = True
+
   def test_single_output(self):
     """Simple test for a single sample."""
-    data_sampler = DataSampler()
-    coder = FastPrimitivesCoder()
+    descriptor = self.make_test_descriptor()
+    self.data_sampler.initialize_samplers(
+        MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
+
+    self.gen_sample(self.data_sampler, 'a', output_index=0)
+
+    expected_sample = {
+        MAIN_PCOLLECTION_ID: [PRIMITIVES_CODER.encode_nested('a')]
+    }
+    samples = self.wait_for_samples(self.data_sampler, [MAIN_PCOLLECTION_ID])
+    self.assertEqual(samples, expected_sample)
+
+  def test_not_initialized(self):
+    """Tests that transforms fail gracefully if not properly initialized."""
+    with self.assertLogs() as cm:
+      self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, 0)
+    self.assertRegex(cm.output[0], 'Out-of-bounds access.*')
+
+  def map_outputs_to_indices(
+      self, outputs, descriptor, transform_id=MAIN_TRANSFORM_ID):
+    tag_list = list(descriptor.transforms[transform_id].outputs)
+    return {output: tag_list.index(output) for output in outputs}
+
+  def test_sampler_mapping(self):
+    """Tests that the ElementSamplers are created for the correct output."""
+    # Initialize the DataSampler with the following outputs. The order here may
+    # get shuffled when inserting into the descriptor.
+    pcollection_ids = ['o0', 'o1', 'o2']
+    descriptor = self.make_test_descriptor(outputs=pcollection_ids)
+    samplers = self.data_sampler.initialize_samplers(
+        MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
+
+    # Create a map from the PCollection id to the index into the transform
+    # output. This mirrors what happens when operators are created. The index 
of
+    # an output is where in the PTransform.outputs it is located (when the map
+    # is converted to a list).
+    outputs = self.map_outputs_to_indices(pcollection_ids, descriptor)
+
+    # Assert that the mapping is correct, i.e. that we can go from the
+    # PCollection id -> output index and that this is the same as the created
+    # samplers.
+    index = outputs['o0']
+    self.assertEqual(
+        self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, index),
+        samplers[index])
 
-    output_sampler = data_sampler.sample_output('1', coder)
-    output_sampler.sample('a')
+    index = outputs['o1']
+    self.assertEqual(
+        self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, index),
+        samplers[index])
 
-    self.assertEqual(data_sampler.samples(), {'1': [coder.encode_nested('a')]})
+    index = outputs['o2']
+    self.assertEqual(
+        self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, index),
+        samplers[index])
 
   def test_multiple_outputs(self):
     """Tests that multiple PCollections have their own sampler."""
-    data_sampler = DataSampler()
-    coder = FastPrimitivesCoder()
-
-    data_sampler.sample_output('1', coder).sample('a')
-    data_sampler.sample_output('2', coder).sample('a')
-
-    self.assertEqual(
-        data_sampler.samples(), {
-            '1': [coder.encode_nested('a')], '2': [coder.encode_nested('a')]
-        })
-
-  def gen_samples(self, data_sampler: DataSampler, coder: Coder):
-    data_sampler.sample_output('a', coder).sample('1')
-    data_sampler.sample_output('a', coder).sample('2')
-    data_sampler.sample_output('b', coder).sample('3')
-    data_sampler.sample_output('b', coder).sample('4')
-    data_sampler.sample_output('c', coder).sample('5')
-    data_sampler.sample_output('c', coder).sample('6')
+    pcollection_ids = ['o0', 'o1', 'o2']
+    descriptor = self.make_test_descriptor(outputs=pcollection_ids)
+    outputs = self.map_outputs_to_indices(pcollection_ids, descriptor)
+
+    self.data_sampler.initialize_samplers(
+        MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
+
+    self.gen_sample(self.data_sampler, 'a', output_index=outputs['o0'])
+    self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1'])
+    self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2'])
+
+    samples = self.wait_for_samples(self.data_sampler, ['o0', 'o1', 'o2'])
+    expected_samples = {
+        'o0': [PRIMITIVES_CODER.encode_nested('a')],
+        'o1': [PRIMITIVES_CODER.encode_nested('b')],
+        'o2': [PRIMITIVES_CODER.encode_nested('c')],
+    }
+    self.assertEqual(samples, expected_samples)
+
+  def test_multiple_transforms(self):
+    """Test that multiple transforms with the same PCollections can be sampled.
+    """
+    # Initialize two transform both with the same two outputs.
+    pcollection_ids = ['o0', 'o1']
+    descriptor = self.make_test_descriptor(
+        outputs=pcollection_ids, transforms=['t0', 't1'])
+    t0_outputs = self.map_outputs_to_indices(
+        pcollection_ids, descriptor, transform_id='t0')
+    t1_outputs = self.map_outputs_to_indices(
+        pcollection_ids, descriptor, transform_id='t1')
+
+    self.data_sampler.initialize_samplers(
+        't0', descriptor, self.primitives_coder_factory)
+
+    self.data_sampler.initialize_samplers(
+        't1', descriptor, self.primitives_coder_factory)
+
+    # The OutputSampler is on a different thread so we don't test the same
+    # PCollections to ensure that no data race occurs.
+    self.gen_sample(
+        self.data_sampler,
+        'a',
+        output_index=t0_outputs['o0'],
+        transform_id='t0')
+    self.gen_sample(
+        self.data_sampler,
+        'd',
+        output_index=t1_outputs['o1'],
+        transform_id='t1')
+    expected_samples = {
+        'o0': [PRIMITIVES_CODER.encode_nested('a')],
+        'o1': [PRIMITIVES_CODER.encode_nested('d')],
+    }
+    samples = self.wait_for_samples(self.data_sampler, ['o0', 'o1'])
+    self.assertEqual(samples, expected_samples)
+
+    self.gen_sample(
+        self.data_sampler,
+        'b',
+        output_index=t0_outputs['o1'],
+        transform_id='t0')
+    self.gen_sample(
+        self.data_sampler,
+        'c',
+        output_index=t1_outputs['o0'],
+        transform_id='t1')
+    expected_samples = {
+        'o0': [PRIMITIVES_CODER.encode_nested('c')],
+        'o1': [PRIMITIVES_CODER.encode_nested('b')],
+    }
+    samples = self.wait_for_samples(self.data_sampler, ['o0', 'o1'])
+    self.assertEqual(samples, expected_samples)
 
   def test_sample_filters_single_pcollection_ids(self):
     """Tests the samples can be filtered based on a single pcollection id."""
-    data_sampler = DataSampler()
-    coder = FastPrimitivesCoder()
-
-    self.gen_samples(data_sampler, coder)
-    self.assertEqual(
-        data_sampler.samples(pcollection_ids=['a']),
-        {'a': [coder.encode_nested('1'), coder.encode_nested('2')]})
-
-    self.assertEqual(
-        data_sampler.samples(pcollection_ids=['b']),
-        {'b': [coder.encode_nested('3'), coder.encode_nested('4')]})
+    pcollection_ids = ['o0', 'o1', 'o2']
+    descriptor = self.make_test_descriptor(outputs=pcollection_ids)
+    outputs = self.map_outputs_to_indices(pcollection_ids, descriptor)
+
+    self.data_sampler.initialize_samplers(
+        MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
+
+    self.gen_sample(self.data_sampler, 'a', output_index=outputs['o0'])
+    self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1'])
+    self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2'])
+
+    samples = self.wait_for_samples(self.data_sampler, ['o0'])
+    expected_samples = {
+        'o0': [PRIMITIVES_CODER.encode_nested('a')],
+    }
+    self.assertEqual(samples, expected_samples)
+
+    samples = self.wait_for_samples(self.data_sampler, ['o1'])
+    expected_samples = {
+        'o1': [PRIMITIVES_CODER.encode_nested('b')],
+    }
+    self.assertEqual(samples, expected_samples)
+
+    samples = self.wait_for_samples(self.data_sampler, ['o2'])
+    expected_samples = {
+        'o2': [PRIMITIVES_CODER.encode_nested('c')],
+    }
+    self.assertEqual(samples, expected_samples)
 
   def test_sample_filters_multiple_pcollection_ids(self):
     """Tests the samples can be filtered based on a multiple pcollection 
ids."""
-    data_sampler = DataSampler()
-    coder = FastPrimitivesCoder()
+    pcollection_ids = ['o0', 'o1', 'o2']
+    descriptor = self.make_test_descriptor(outputs=pcollection_ids)
+    outputs = self.map_outputs_to_indices(pcollection_ids, descriptor)
 
-    self.gen_samples(data_sampler, coder)
-    self.assertEqual(
-        data_sampler.samples(pcollection_ids=['a', 'c']),
-        {
-            'a': [coder.encode_nested('1'), coder.encode_nested('2')],
-            'c': [coder.encode_nested('5'), coder.encode_nested('6')]
-        })
+    self.data_sampler.initialize_samplers(
+        MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
 
+    self.gen_sample(self.data_sampler, 'a', output_index=outputs['o0'])
+    self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1'])
+    self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2'])
 
-class FakeClock:
-  def __init__(self):
-    self.clock = 0
-
-  def time(self):
-    return self.clock
+    samples = self.wait_for_samples(self.data_sampler, ['o0', 'o2'])
+    expected_samples = {
+        'o0': [PRIMITIVES_CODER.encode_nested('a')],
+        'o2': [PRIMITIVES_CODER.encode_nested('c')],
+    }
+    self.assertEqual(samples, expected_samples)
 
 
 class OutputSamplerTest(unittest.TestCase):
   def setUp(self):
     self.fake_clock = FakeClock()
 
+  def tearDown(self):
+    self.sampler.stop()
+
   def control_time(self, new_time):
     self.fake_clock.clock = new_time
 
-  def test_samples_first_n(self):
-    """Tests that the first elements are always sampled."""
-    coder = FastPrimitivesCoder()
-    sampler = OutputSampler(coder)
+  def wait_for_samples(self, output_sampler: OutputSampler, expected_num: int):
+    """Waits for the expected number of samples for the given sampler."""
+    now = time.time()
+    end = now + 30
 
-    for i in range(15):
-      sampler.sample(i)
+    while now < end:
+      time.sleep(0.1)
+      now = time.time()
+      samples = output_sampler.flush(clear=False)
 
-    self.assertEqual(
-        sampler.flush(), [coder.encode_nested(i) for i in range(10)])
+      if not samples:
+        continue
+
+      if len(samples) == expected_num:
+        return samples
+
+    self.assertLess(now, end, 'Timed out waiting for samples')
+
+  def ensure_sample(
+      self, output_sampler: OutputSampler, sample: Any, expected_num: int):
+    """Generates a sample and waits for it to be available."""
+
+    element_sampler = output_sampler.element_sampler
+
+    now = time.time()
+    end = now + 30
+
+    while now < end:
+      element_sampler.el = sample
+      element_sampler.has_element = True
+      time.sleep(0.1)
+      now = time.time()
+      samples = output_sampler.flush(clear=False)
+
+      if not samples:
+        continue
+
+      if len(samples) == expected_num:
+        return samples
+
+    self.assertLess(
+        now, end, 'Timed out waiting for sample "{sample}" to be generated.')
+
+  def test_can_sample(self):
+    """Tests that the underlying timer can sample."""
+    self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0.05)
+    element_sampler = self.sampler.element_sampler
+    element_sampler.el = 'a'
+    element_sampler.has_element = True
+
+    samples = self.wait_for_samples(self.sampler, expected_num=1)
+    self.assertEqual(samples, [PRIMITIVES_CODER.encode_nested('a')])
 
   def test_acts_like_circular_buffer(self):
     """Tests that the buffer overwrites old samples."""
-    coder = FastPrimitivesCoder()
-    sampler = OutputSampler(coder, max_samples=2)
+    self.sampler = OutputSampler(
+        PRIMITIVES_CODER, max_samples=2, sample_every_sec=0)
+    element_sampler = self.sampler.element_sampler
 
     for i in range(10):
-      sampler.sample(i)
+      element_sampler.el = i
+      element_sampler.has_element = True
+      self.sampler.sample()
 
-    self.assertEqual(sampler.flush(), [coder.encode_nested(i) for i in (8, 9)])
+    self.assertEqual(
+        self.sampler.flush(),
+        [PRIMITIVES_CODER.encode_nested(i) for i in (8, 9)])
 
-  def test_samples_every_n_secs(self):
+  def test_samples_multiple_times(self):
     """Tests that the buffer overwrites old samples."""
-    coder = FastPrimitivesCoder()
-    sampler = OutputSampler(
-        coder, max_samples=1, sample_every_sec=10, clock=self.fake_clock)
+    self.sampler = OutputSampler(
+        PRIMITIVES_CODER, max_samples=10, sample_every_sec=0.05)
 
     # Always samples the first ten.
     for i in range(10):
-      sampler.sample(i)
-    self.assertEqual(sampler.flush(), [coder.encode_nested(9)])
-
-    # Start at t=0
-    sampler.sample(10)
-    self.assertEqual(len(sampler.flush()), 0)
-
-    # Still not over threshold yet.
-    self.control_time(9)
-    for i in range(100):
-      sampler.sample(i)
-    self.assertEqual(len(sampler.flush()), 0)
-
-    # First sample after 10s.
-    self.control_time(10)
-    sampler.sample(10)
-    self.assertEqual(sampler.flush(), [coder.encode_nested(10)])
-
-    # No samples between tresholds.
-    self.control_time(15)
-    for i in range(100):
-      sampler.sample(i)
-    self.assertEqual(len(sampler.flush()), 0)
-
-    # Second sample after 20s.
-    self.control_time(20)
-    sampler.sample(11)
-    self.assertEqual(sampler.flush(), [coder.encode_nested(11)])
+      self.ensure_sample(self.sampler, i, i + 1)
+    self.assertEqual(
+        self.sampler.flush(),
+        [PRIMITIVES_CODER.encode_nested(i) for i in range(10)])
 
   def test_can_sample_windowed_value(self):
     """Tests that values with WindowedValueCoders are sampled wholesale."""
-    data_sampler = DataSampler()
     coder = WindowedValueCoder(FastPrimitivesCoder())
     value = WindowedValue('Hello, World!', 0, [GlobalWindow()])
-    data_sampler.sample_output('1', coder).sample(value)
 
-    self.assertEqual(
-        data_sampler.samples(), {'1': [coder.encode_nested(value)]})
+    self.sampler = OutputSampler(coder, sample_every_sec=0)
+    element_sampler = self.sampler.element_sampler
+    element_sampler.el = value
+    element_sampler.has_element = True
+    self.sampler.sample()
+
+    self.assertEqual(self.sampler.flush(), [coder.encode_nested(value)])
 
   def test_can_sample_non_windowed_value(self):
     """Tests that windowed values with WindowedValueCoders sample only the
@@ -179,13 +405,16 @@ class OutputSamplerTest(unittest.TestCase):
     even if the coder is not a WindowedValueCoder. In this case, the value must
     be retrieved from the WindowedValue to match the correct coder.
     """
-    data_sampler = DataSampler()
-    coder = FastPrimitivesCoder()
-    data_sampler.sample_output('1', coder).sample(
-        WindowedValue('Hello, World!', 0, [GlobalWindow()]))
+    value = WindowedValue('Hello, World!', 0, [GlobalWindow()])
+
+    self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0)
+    element_sampler = self.sampler.element_sampler
+    element_sampler.el = value
+    element_sampler.has_element = True
+    self.sampler.sample()
 
     self.assertEqual(
-        data_sampler.samples(), {'1': [coder.encode_nested('Hello, World!')]})
+        self.sampler.flush(), [PRIMITIVES_CODER.encode_nested('Hello, 
World!')])
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd 
b/sdks/python/apache_beam/runners/worker/operations.pxd
index cf1f1b3fb51..725c5d2346a 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -34,6 +34,7 @@ cdef class ConsumerSet(Receiver):
   cdef public step_name
   cdef public output_index
   cdef public coder
+  cdef public object element_sampler
 
   cpdef update_counters_start(self, WindowedValue windowed_value)
   cpdef update_counters_finish(self)
diff --git a/sdks/python/apache_beam/runners/worker/operations.py 
b/sdks/python/apache_beam/runners/worker/operations.py
index fdde55042ec..ca5e6552973 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -53,6 +53,7 @@ from apache_beam.runners.common import Receiver
 from apache_beam.runners.worker import opcounters
 from apache_beam.runners.worker import operation_specs
 from apache_beam.runners.worker import sideinputs
+from apache_beam.runners.worker.data_sampler import ElementSampler
 from apache_beam.transforms import sideinputs as apache_sideinputs
 from apache_beam.transforms import combiners
 from apache_beam.transforms import core
@@ -69,6 +70,8 @@ if TYPE_CHECKING:
   from apache_beam.runners.sdf_utils import SplitResultPrimary
   from apache_beam.runners.sdf_utils import SplitResultResidual
   from apache_beam.runners.worker.bundle_processor import ExecutionContext
+  from apache_beam.runners.worker.data_sampler import DataSampler
+  from apache_beam.runners.worker.data_sampler import OutputSampler
   from apache_beam.runners.worker.statesampler import StateSampler
   from apache_beam.transforms.userstate import TimerSpec
 
@@ -123,6 +126,7 @@ class ConsumerSet(Receiver):
              coder,
              producer_type_hints,
              producer_batch_converter, # type: Optional[BatchConverter]
+             element_sampler=None,  # type: Optional[ElementSampler]
              ):
     # type: (...) -> ConsumerSet
     if len(consumers) == 1:
@@ -139,7 +143,8 @@ class ConsumerSet(Receiver):
             output_index,
             consumer,
             coder,
-            producer_type_hints)
+            producer_type_hints,
+            element_sampler)
 
     return GeneralPurposeConsumerSet(
         counter_factory,
@@ -148,7 +153,8 @@ class ConsumerSet(Receiver):
         coder,
         producer_type_hints,
         consumers,
-        producer_batch_converter)
+        producer_batch_converter,
+        element_sampler)
 
   def __init__(self,
                counter_factory,
@@ -157,7 +163,8 @@ class ConsumerSet(Receiver):
                consumers,
                coder,
                producer_type_hints,
-               producer_batch_converter
+               producer_batch_converter,
+               element_sampler
                ):
     self.opcounter = opcounters.OperationCounters(
         counter_factory,
@@ -171,6 +178,7 @@ class ConsumerSet(Receiver):
     self.output_index = output_index
     self.coder = coder
     self.consumers = consumers
+    self.element_sampler = element_sampler
 
   def try_split(self, fraction_of_remainder):
     # type: (...) -> Optional[Any]
@@ -197,6 +205,16 @@ class ConsumerSet(Receiver):
     # type: (WindowedValue) -> None
     self.opcounter.update_from(windowed_value)
 
+    # The following code is optimized by inlining a function call. Because this
+    # is called for every element, a function call is too expensive (order of
+    # 100s of nanoseconds). Furthermore, a lock was purposefully not used
+    # between here and the DataSampler as an additional operation. The tradeoff
+    # is that some samples might be dropped, but it is better than the
+    # alternative which is double sampling the same element.
+    if self.element_sampler is not None:
+      self.element_sampler.el = windowed_value
+      self.element_sampler.has_element = True
+
   def update_counters_finish(self):
     # type: () -> None
     self.opcounter.update_collect()
@@ -223,7 +241,8 @@ class SingletonElementConsumerSet(ConsumerSet):
                output_index,
                consumer,  # type: Operation
                coder,
-               producer_type_hints
+               producer_type_hints,
+               element_sampler
                ):
     super().__init__(
         counter_factory,
@@ -231,7 +250,8 @@ class SingletonElementConsumerSet(ConsumerSet):
         output_index, [consumer],
         coder,
         producer_type_hints,
-        None)
+        None,
+        element_sampler)
     self.consumer = consumer
 
   def receive(self, windowed_value):
@@ -268,7 +288,8 @@ class GeneralPurposeConsumerSet(ConsumerSet):
                coder,
                producer_type_hints,
                consumers,  # type: List[Operation]
-               producer_batch_converter):
+               producer_batch_converter,
+               element_sampler):
     super().__init__(
         counter_factory,
         step_name,
@@ -276,7 +297,8 @@ class GeneralPurposeConsumerSet(ConsumerSet):
         consumers,
         coder,
         producer_type_hints,
-        producer_batch_converter)
+        producer_batch_converter,
+        element_sampler)
 
     self.producer_batch_converter = producer_batch_converter
 
@@ -432,8 +454,8 @@ class Operation(object):
     self.setup_done = False
     self.step_name = None  # type: Optional[str]
 
-  def setup(self):
-    # type: () -> None
+  def setup(self, data_sampler=None):
+    # type: (Optional[DataSampler]) -> None
 
     """Set up operation.
 
@@ -441,10 +463,18 @@ class Operation(object):
     with self.scoped_start_state:
       self.debug_logging_enabled = logging.getLogger().isEnabledFor(
           logging.DEBUG)
+      transform_id = self.name_context.transform_id
+
       # Everything except WorkerSideInputSource, which is not a
       # top-level operation, should have output_coders
       #TODO(pabloem): Define better what step name is used here.
       if getattr(self.spec, 'output_coders', None):
+
+        def get_element_sampler(output_num):
+          if data_sampler is None:
+            return None
+          return data_sampler.sampler_for_output(transform_id, output_num)
+
         self.receivers = [
             ConsumerSet.create(
                 self.counter_factory,
@@ -454,7 +484,7 @@ class Operation(object):
                 coder,
                 self._get_runtime_performance_hints(),
                 self.get_output_batch_converter(),
-            ) for i,
+                get_element_sampler(i)) for i,
             coder in enumerate(self.spec.output_coders)
         ]
     self.setup_done = True
@@ -759,7 +789,7 @@ class DoOperation(Operation):
                counter_factory,
                sampler,
                side_input_maps=None,
-               user_state_context=None
+               user_state_context=None,
               ):
     super(DoOperation, self).__init__(name, spec, counter_factory, sampler)
     self.side_input_maps = side_input_maps
@@ -828,10 +858,10 @@ class DoOperation(Operation):
       yield apache_sideinputs.SideInputMap(
           view_class, view_options, sideinputs.EmulatedIterable(iterator_fn))
 
-  def setup(self):
-    # type: () -> None
+  def setup(self, data_sampler=None):
+    # type: (Optional[DataSampler]) -> None
     with self.scoped_start_state:
-      super(DoOperation, self).setup()
+      super(DoOperation, self).setup(data_sampler)
 
       # See fn_data in dataflow_runner.py
       fn, args, kwargs, tags_and_types, window_fn = (
@@ -1107,11 +1137,11 @@ class CombineOperation(Operation):
     self.phased_combine_fn = (
         PhasedCombineFnExecutor(self.spec.phase, fn, args, kwargs))
 
-  def setup(self):
-    # type: () -> None
+  def setup(self, data_sampler=None):
+    # type: (Optional[DataSampler]) -> None
     with self.scoped_start_state:
       _LOGGER.debug('Setup called for %s', self)
-      super(CombineOperation, self).setup()
+      super(CombineOperation, self).setup(data_sampler)
       self.phased_combine_fn.combine_fn.setup()
 
   def process(self, o):
@@ -1226,11 +1256,11 @@ class PGBKCVOperation(Operation):
     self.key_count = 0
     self.table = {}
 
-  def setup(self):
-    # type: () -> None
+  def setup(self, data_sampler=None):
+    # type: (Optional[DataSampler]) -> None
     with self.scoped_start_state:
       _LOGGER.debug('Setup called for %s', self)
-      super(PGBKCVOperation, self).setup()
+      super(PGBKCVOperation, self).setup(data_sampler)
       self.combine_fn.setup()
 
   def process(self, wkv):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index f7da86234d5..f5b1456c251 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -268,6 +268,8 @@ class SdkHarness(object):
             work_request)
     finally:
       self._alive = False
+      if self.data_sampler:
+        self.data_sampler.stop()
 
     _LOGGER.info('No more requests from control plane')
     _LOGGER.info('SDK Harness waiting for in-flight requests to complete')
@@ -379,7 +381,7 @@ class SdkHarness(object):
     def get_samples(request):
       # type: (beam_fn_api_pb2.InstructionRequest) -> 
beam_fn_api_pb2.InstructionResponse
       samples: Dict[str, List[bytes]] = {}
-      if self.data_sampler:
+      if self.data_sampler is not None:
         samples = 
self.data_sampler.samples(request.sample_data.pcollection_ids)
 
       sample_response = beam_fn_api_pb2.SampleDataResponse()
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
index b3ad349be75..d12dc48ed29 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -23,6 +23,7 @@ import contextlib
 import logging
 import unittest
 from collections import namedtuple
+from typing import Any
 
 import grpc
 import hamcrest as hc
@@ -39,7 +40,6 @@ from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.portability.api import metrics_pb2
 from apache_beam.runners.worker import sdk_worker
 from apache_beam.runners.worker import statecache
-from apache_beam.runners.worker.data_sampler import DataSampler
 from apache_beam.runners.worker.sdk_worker import BundleProcessorCache
 from apache_beam.runners.worker.sdk_worker import GlobalCachingStateHandler
 from apache_beam.runners.worker.sdk_worker import SdkWorker
@@ -279,18 +279,19 @@ class SdkWorkerTest(unittest.TestCase):
   def test_data_sampling_response(self):
     # Create a data sampler with some fake sampled data. This data will be seen
     # in the sample response.
-    data_sampler = DataSampler()
     coder = FastPrimitivesCoder()
 
-    # Sample from two fake PCollections to test that all sampled PCollections
-    # are present in the response. Also adds an extra sample to test that
-    # filtering is forwarded to the DataSampler.
-    data_sampler.sample_output('pcoll_id_1',
-                               coder).sample('hello, world from pcoll_id_1!')
-    data_sampler.sample_output('pcoll_id_2',
-                               coder).sample('hello, world from pcoll_id_2!')
-    data_sampler.sample_output('bad_pcoll_id',
-                               coder).sample('if present bug in filter')
+    class FakeDataSampler:
+      def samples(self, pcollection_ids):
+        return {
+            'pcoll_id_1': [coder.encode_nested('a')],
+            'pcoll_id_2': [coder.encode_nested('b')],
+        }
+
+      def stop(self):
+        pass
+
+    data_sampler = FakeDataSampler()
 
     # Create and send the fake reponse. The SdkHarness should query the
     # DataSampler and fill out the sample response.
@@ -310,14 +311,12 @@ class SdkWorkerTest(unittest.TestCase):
                 'pcoll_id_1': beam_fn_api_pb2.SampleDataResponse.ElementList(
                     elements=[
                         beam_fn_api_pb2.SampledElement(
-                            element=coder.encode_nested(
-                                'hello, world from pcoll_id_1!'))
+                            element=coder.encode_nested('a'))
                     ]),
                 'pcoll_id_2': beam_fn_api_pb2.SampleDataResponse.ElementList(
                     elements=[
                         beam_fn_api_pb2.SampledElement(
-                            element=coder.encode_nested(
-                                'hello, world from pcoll_id_2!'))
+                            element=coder.encode_nested('b'))
                     ])
             }))
 

Reply via email to