[ 
https://issues.apache.org/jira/browse/BEAM-2732?focusedWorklogId=94470&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-94470
 ]

ASF GitHub Bot logged work on BEAM-2732:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 24/Apr/18 06:10
            Start Date: 24/Apr/18 06:10
    Worklog Time Spent: 10m 
      Work Description: robertwb closed pull request #4387: [BEAM-2732] Metrics 
rely on statesampler state
URL: https://github.com/apache/beam/pull/4387
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/python/apache_beam/metrics/execution.py 
b/sdks/python/apache_beam/metrics/execution.py
index f6c790de5d4..310faf6c9c8 100644
--- a/sdks/python/apache_beam/metrics/execution.py
+++ b/sdks/python/apache_beam/metrics/execution.py
@@ -127,25 +127,34 @@ def set_metrics_supported(self, supported):
     with self._METRICS_SUPPORTED_LOCK:
       self.METRICS_SUPPORTED = supported
 
-  def current_container(self):
+  def _old_style_container(self):
+    """Gets the current MetricsContainer based on the container stack.
+
+    The container stack is the old method, and will be deprecated. Should
+    rely on StateSampler instead."""
     self.set_container_stack()
     index = len(self.PER_THREAD.container) - 1
     if index < 0:
       return None
     return self.PER_THREAD.container[index]
 
-  def set_current_container(self, container):
-    self.set_container_stack()
-    self.PER_THREAD.container.append(container)
-
-  def unset_current_container(self):
-    self.set_container_stack()
-    self.PER_THREAD.container.pop()
+  def current_container(self):
+    """Returns the current MetricsContainer."""
+    sampler = statesampler.get_current_tracker()
+    if sampler is None:
+      return self._old_style_container()
+    return sampler.current_state().metrics_container
 
 
 MetricsEnvironment = _MetricsEnvironment()
 
 
+def metrics_startup():
+  """Initialize metrics context to run."""
+  global statesampler  # pylint: disable=global-variable-not-assigned
+  from apache_beam.runners.worker import statesampler
+
+
 class MetricsContainer(object):
   """Holds the metrics of a single step and a single bundle."""
   def __init__(self, step_name):
@@ -227,10 +236,12 @@ def __init__(self, container=None):
     self._container = container
 
   def enter(self):
-    self._stack.append(self._container)
+    if self._container:
+      self._stack.append(self._container)
 
   def exit(self):
-    self._stack.pop()
+    if self._container:
+      self._stack.pop()
 
   def __enter__(self):
     self.enter()
diff --git a/sdks/python/apache_beam/metrics/execution_test.py 
b/sdks/python/apache_beam/metrics/execution_test.py
index 2367e35df4d..37d24f3407b 100644
--- a/sdks/python/apache_beam/metrics/execution_test.py
+++ b/sdks/python/apache_beam/metrics/execution_test.py
@@ -18,11 +18,7 @@
 import unittest
 
 from apache_beam.metrics.cells import CellCommitState
-from apache_beam.metrics.execution import MetricKey
 from apache_beam.metrics.execution import MetricsContainer
-from apache_beam.metrics.execution import MetricsEnvironment
-from apache_beam.metrics.execution import ScopedMetricsContainer
-from apache_beam.metrics.metric import Metrics
 from apache_beam.metrics.metricbase import MetricName
 
 
@@ -33,29 +29,6 @@ def test_create_new_counter(self):
     mc.get_counter(MetricName('namespace', 'name'))
     self.assertTrue(MetricName('namespace', 'name') in mc.counters)
 
-  def test_scoped_container(self):
-    c1 = MetricsContainer('mystep')
-    c2 = MetricsContainer('myinternalstep')
-    with ScopedMetricsContainer(c1):
-      self.assertEqual(c1, MetricsEnvironment.current_container())
-      counter = Metrics.counter('ns', 'name')
-      counter.inc(2)
-
-      with ScopedMetricsContainer(c2):
-        self.assertEqual(c2, MetricsEnvironment.current_container())
-        counter = Metrics.counter('ns', 'name')
-        counter.inc(3)
-        self.assertEqual(
-            list(c2.get_cumulative().counters.items()),
-            [(MetricKey('myinternalstep', MetricName('ns', 'name')), 3)])
-
-      self.assertEqual(c1, MetricsEnvironment.current_container())
-      counter = Metrics.counter('ns', 'name')
-      counter.inc(4)
-      self.assertEqual(
-          list(c1.get_cumulative().counters.items()),
-          [(MetricKey('mystep', MetricName('ns', 'name')), 6)])
-
   def test_add_to_counter(self):
     mc = MetricsContainer('astep')
     counter = mc.get_counter(MetricName('namespace', 'name'))
@@ -118,29 +91,5 @@ def test_get_cumulative_or_updates(self):
                      set([v.value for _, v in cumulative.gauges.items()]))
 
 
-class TestMetricsEnvironment(unittest.TestCase):
-  def test_uses_right_container(self):
-    c1 = MetricsContainer('step1')
-    c2 = MetricsContainer('step2')
-    counter = Metrics.counter('ns', 'name')
-    MetricsEnvironment.set_current_container(c1)
-    counter.inc()
-    MetricsEnvironment.set_current_container(c2)
-    counter.inc(3)
-    MetricsEnvironment.unset_current_container()
-
-    self.assertEqual(
-        list(c1.get_cumulative().counters.items()),
-        [(MetricKey('step1', MetricName('ns', 'name')), 1)])
-
-    self.assertEqual(
-        list(c2.get_cumulative().counters.items()),
-        [(MetricKey('step2', MetricName('ns', 'name')), 3)])
-
-  def test_no_container(self):
-    self.assertEqual(MetricsEnvironment.current_container(),
-                     None)
-
-
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/metrics/metric_test.py 
b/sdks/python/apache_beam/metrics/metric_test.py
index eaad1574c73..385d2708996 100644
--- a/sdks/python/apache_beam/metrics/metric_test.py
+++ b/sdks/python/apache_beam/metrics/metric_test.py
@@ -25,6 +25,8 @@
 from apache_beam.metrics.metric import Metrics
 from apache_beam.metrics.metric import MetricsFilter
 from apache_beam.metrics.metricbase import MetricName
+from apache_beam.runners.worker import statesampler
+from apache_beam.utils import counters
 
 
 class NameTest(unittest.TestCase):
@@ -115,37 +117,33 @@ def test_distribution_empty_namespace(self):
       Metrics.distribution("", "names")
 
   def test_create_counter_distribution(self):
-    MetricsEnvironment.set_current_container(MetricsContainer('mystep'))
-    counter_ns = 'aCounterNamespace'
-    distro_ns = 'aDistributionNamespace'
-    gauge_ns = 'aGaugeNamespace'
-    name = 'a_name'
-    counter = Metrics.counter(counter_ns, name)
-    distro = Metrics.distribution(distro_ns, name)
-    gauge = Metrics.gauge(gauge_ns, name)
-    counter.inc(10)
-    counter.dec(3)
-    distro.update(10)
-    distro.update(2)
-    gauge.set(10)
-    self.assertTrue(isinstance(counter, Metrics.DelegatingCounter))
-    self.assertTrue(isinstance(distro, Metrics.DelegatingDistribution))
-    self.assertTrue(isinstance(gauge, Metrics.DelegatingGauge))
-
-    del distro
-    del counter
-    del gauge
-
-    container = MetricsEnvironment.current_container()
-    self.assertEqual(
-        container.counters[MetricName(counter_ns, name)].get_cumulative(),
-        7)
-    self.assertEqual(
-        container.distributions[MetricName(distro_ns, name)].get_cumulative(),
-        DistributionData(12, 2, 2, 10))
-    self.assertEqual(
-        container.gauges[MetricName(gauge_ns, name)].get_cumulative().value,
-        10)
+    sampler = statesampler.StateSampler('', counters.CounterFactory())
+    statesampler.set_current_tracker(sampler)
+    state1 = sampler.scoped_state('mystep', 'myState',
+                                  metrics_container=MetricsContainer('mystep'))
+    with state1:
+      counter_ns = 'aCounterNamespace'
+      distro_ns = 'aDistributionNamespace'
+      name = 'a_name'
+      counter = Metrics.counter(counter_ns, name)
+      distro = Metrics.distribution(distro_ns, name)
+      counter.inc(10)
+      counter.dec(3)
+      distro.update(10)
+      distro.update(2)
+      self.assertTrue(isinstance(counter, Metrics.DelegatingCounter))
+      self.assertTrue(isinstance(distro, Metrics.DelegatingDistribution))
+
+      del distro
+      del counter
+
+      container = MetricsEnvironment.current_container()
+      self.assertEqual(
+          container.counters[MetricName(counter_ns, name)].get_cumulative(),
+          7)
+      self.assertEqual(
+          container.distributions[MetricName(distro_ns, 
name)].get_cumulative(),
+          DistributionData(12, 2, 2, 10))
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/common.py 
b/sdks/python/apache_beam/runners/common.py
index 44f9083a52b..fbc137cec4f 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -28,7 +28,6 @@
 import six
 
 from apache_beam.internal import util
-from apache_beam.metrics.execution import ScopedMetricsContainer
 from apache_beam.pvalue import TaggedOutput
 from apache_beam.transforms import DoFn
 from apache_beam.transforms import core
@@ -537,6 +536,8 @@ def __init__(self,
     # Need to support multiple iterations.
     side_inputs = list(side_inputs)
 
+    from apache_beam.metrics.execution import ScopedMetricsContainer
+
     self.scoped_metrics_container = (
         scoped_metrics_container or ScopedMetricsContainer())
     self.step_name = step_name
diff --git a/sdks/python/apache_beam/runners/direct/executor.py 
b/sdks/python/apache_beam/runners/direct/executor.py
index 8ab30b4046c..ef6469644c2 100644
--- a/sdks/python/apache_beam/runners/direct/executor.py
+++ b/sdks/python/apache_beam/runners/direct/executor.py
@@ -31,8 +31,9 @@
 import six
 
 from apache_beam.metrics.execution import MetricsContainer
-from apache_beam.metrics.execution import ScopedMetricsContainer
+from apache_beam.runners.worker import statesampler
 from apache_beam.transforms import sideinputs
+from apache_beam.utils import counters
 
 
 class _ExecutorService(object):
@@ -40,7 +41,7 @@ class _ExecutorService(object):
 
   class CallableTask(object):
 
-    def call(self):
+    def call(self, state_sampler):
       pass
 
     @property
@@ -83,13 +84,15 @@ def _get_task_or_none(self):
         return None
 
     def run(self):
+      state_sampler = statesampler.StateSampler('', counters.CounterFactory())
+      statesampler.set_current_tracker(state_sampler)
       while not self.shutdown_requested:
         task = self._get_task_or_none()
         if task:
           try:
             if not self.shutdown_requested:
               self._update_name(task)
-              task.call()
+              task.call(state_sampler)
               self._update_name()
           finally:
             self.queue.task_done()
@@ -290,35 +293,52 @@ def __init__(self, transform_evaluator_registry, 
evaluation_context,
     self._retry_count = 0
     self._max_retries_per_bundle = TransformExecutor._MAX_RETRY_PER_BUNDLE
 
-  def call(self):
+  def call(self, state_sampler):
     self._call_count += 1
     assert self._call_count <= (1 + len(self._applied_ptransform.side_inputs))
     metrics_container = MetricsContainer(self._applied_ptransform.full_label)
-    scoped_metrics_container = ScopedMetricsContainer(metrics_container)
-
-    for side_input in self._applied_ptransform.side_inputs:
-      # Find the projection of main's window onto the side input's window.
-      window_mapping_fn = side_input._view_options().get(
-          'window_mapping_fn', sideinputs._global_window_mapping_fn)
-      main_onto_side_window = window_mapping_fn(self._latest_main_input_window)
-      block_until = main_onto_side_window.end
-
-      if side_input not in self._side_input_values:
-        value = self._evaluation_context.get_value_or_block_until_ready(
-            side_input, self, block_until)
-        if not value:
-          # Monitor task will reschedule this executor once the side input is
-          # available.
-          return
-        self._side_input_values[side_input] = value
-    side_input_values = [self._side_input_values[side_input]
-                         for side_input in 
self._applied_ptransform.side_inputs]
+    start_state = state_sampler.scoped_state(
+        self._applied_ptransform.full_label,
+        'start',
+        metrics_container=metrics_container)
+    process_state = state_sampler.scoped_state(
+        self._applied_ptransform.full_label,
+        'process',
+        metrics_container=metrics_container)
+    finish_state = state_sampler.scoped_state(
+        self._applied_ptransform.full_label,
+        'finish',
+        metrics_container=metrics_container)
+
+    with start_state:
+      # Side input initialization should be accounted for in start_state.
+      for side_input in self._applied_ptransform.side_inputs:
+        # Find the projection of main's window onto the side input's window.
+        window_mapping_fn = side_input._view_options().get(
+            'window_mapping_fn', sideinputs._global_window_mapping_fn)
+        main_onto_side_window = window_mapping_fn(
+            self._latest_main_input_window)
+        block_until = main_onto_side_window.end
+
+        if side_input not in self._side_input_values:
+          value = self._evaluation_context.get_value_or_block_until_ready(
+              side_input, self, block_until)
+          if not value:
+            # Monitor task will reschedule this executor once the side input is
+            # available.
+            return
+          self._side_input_values[side_input] = value
+      side_input_values = [
+          self._side_input_values[side_input]
+          for side_input in self._applied_ptransform.side_inputs]
 
     while self._retry_count < self._max_retries_per_bundle:
       try:
         self.attempt_call(metrics_container,
-                          scoped_metrics_container,
-                          side_input_values)
+                          side_input_values,
+                          start_state,
+                          process_state,
+                          finish_state)
         break
       except Exception as e:
         self._retry_count += 1
@@ -336,24 +356,28 @@ def call(self):
     self._transform_evaluation_state.complete(self)
 
   def attempt_call(self, metrics_container,
-                   scoped_metrics_container,
-                   side_input_values):
+                   side_input_values,
+                   start_state,
+                   process_state,
+                   finish_state):
+    """Attempts to run a bundle."""
     evaluator = self._transform_evaluator_registry.get_evaluator(
         self._applied_ptransform, self._input_bundle,
-        side_input_values, scoped_metrics_container)
+        side_input_values)
 
-    with scoped_metrics_container:
+    with start_state:
       evaluator.start_bundle()
 
-    if self._fired_timers:
-      for timer_firing in self._fired_timers:
-        evaluator.process_timer_wrapper(timer_firing)
+    with process_state:
+      if self._fired_timers:
+        for timer_firing in self._fired_timers:
+          evaluator.process_timer_wrapper(timer_firing)
 
-    if self._input_bundle:
-      for value in self._input_bundle.get_elements_iterable():
-        evaluator.process_element(value)
+      if self._input_bundle:
+        for value in self._input_bundle.get_elements_iterable():
+          evaluator.process_element(value)
 
-    with scoped_metrics_container:
+    with finish_state:
       result = evaluator.finish_bundle()
       result.logical_metric_updates = metrics_container.get_cumulative()
 
@@ -525,7 +549,7 @@ def __init__(self, executor):
     def name(self):
       return 'monitor'
 
-    def call(self):
+    def call(self, state_sampler):
       try:
         update = self._executor.all_updates.poll()
         while update:
diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py 
b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
index eb1ccd5c36c..984bacad1c5 100644
--- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
@@ -89,7 +89,7 @@ def __init__(self, evaluation_context):
 
   def get_evaluator(
       self, applied_ptransform, input_committed_bundle,
-      side_inputs, scoped_metrics_container):
+      side_inputs):
     """Returns a TransformEvaluator suitable for processing given inputs."""
     assert applied_ptransform
     assert bool(applied_ptransform.side_inputs) == bool(side_inputs)
@@ -106,8 +106,7 @@ def get_evaluator(
           'Execution of [%s] not implemented in runner %s.' % (
               type(applied_ptransform.transform), self))
     return evaluator(self._evaluation_context, applied_ptransform,
-                     input_committed_bundle, side_inputs,
-                     scoped_metrics_container)
+                     input_committed_bundle, side_inputs)
 
   def get_root_bundle_provider(self, applied_ptransform):
     provider_cls = None
@@ -189,7 +188,7 @@ class _TransformEvaluator(object):
   """An evaluator of a specific application of a transform."""
 
   def __init__(self, evaluation_context, applied_ptransform,
-               input_committed_bundle, side_inputs, scoped_metrics_container):
+               input_committed_bundle, side_inputs):
     self._evaluation_context = evaluation_context
     self._applied_ptransform = applied_ptransform
     self._input_committed_bundle = input_committed_bundle
@@ -197,7 +196,6 @@ def __init__(self, evaluation_context, applied_ptransform,
     self._expand_outputs()
     self._execution_context = evaluation_context.get_execution_context(
         applied_ptransform)
-    self.scoped_metrics_container = scoped_metrics_container
 
   def _expand_outputs(self):
     outputs = set()
@@ -279,13 +277,13 @@ class _BoundedReadEvaluator(_TransformEvaluator):
   MAX_ELEMENT_PER_BUNDLE = 1000
 
   def __init__(self, evaluation_context, applied_ptransform,
-               input_committed_bundle, side_inputs, scoped_metrics_container):
+               input_committed_bundle, side_inputs):
     assert not side_inputs
     self._source = applied_ptransform.transform.source
     self._source.pipeline_options = evaluation_context.pipeline_options
     super(_BoundedReadEvaluator, self).__init__(
         evaluation_context, applied_ptransform, input_committed_bundle,
-        side_inputs, scoped_metrics_container)
+        side_inputs)
 
   def finish_bundle(self):
     assert len(self._outputs) == 1
@@ -314,12 +312,12 @@ class _TestStreamEvaluator(_TransformEvaluator):
   """TransformEvaluator for the TestStream transform."""
 
   def __init__(self, evaluation_context, applied_ptransform,
-               input_committed_bundle, side_inputs, scoped_metrics_container):
+               input_committed_bundle, side_inputs):
     assert not side_inputs
     self.test_stream = applied_ptransform.transform
     super(_TestStreamEvaluator, self).__init__(
         evaluation_context, applied_ptransform, input_committed_bundle,
-        side_inputs, scoped_metrics_container)
+        side_inputs)
 
   def start_bundle(self):
     self.current_index = -1
@@ -383,11 +381,11 @@ class _PubSubReadEvaluator(_TransformEvaluator):
   _subscription_cache = {}
 
   def __init__(self, evaluation_context, applied_ptransform,
-               input_committed_bundle, side_inputs, scoped_metrics_container):
+               input_committed_bundle, side_inputs):
     assert not side_inputs
     super(_PubSubReadEvaluator, self).__init__(
         evaluation_context, applied_ptransform, input_committed_bundle,
-        side_inputs, scoped_metrics_container)
+        side_inputs)
 
     self.source = self._applied_ptransform.transform._source
     self._subscription = _PubSubReadEvaluator.get_subscription(
@@ -481,11 +479,11 @@ class _FlattenEvaluator(_TransformEvaluator):
   """TransformEvaluator for Flatten transform."""
 
   def __init__(self, evaluation_context, applied_ptransform,
-               input_committed_bundle, side_inputs, scoped_metrics_container):
+               input_committed_bundle, side_inputs):
     assert not side_inputs
     super(_FlattenEvaluator, self).__init__(
         evaluation_context, applied_ptransform, input_committed_bundle,
-        side_inputs, scoped_metrics_container)
+        side_inputs)
 
   def start_bundle(self):
     assert len(self._outputs) == 1
@@ -534,11 +532,11 @@ class _ParDoEvaluator(_TransformEvaluator):
   """TransformEvaluator for ParDo transform."""
 
   def __init__(self, evaluation_context, applied_ptransform,
-               input_committed_bundle, side_inputs, scoped_metrics_container,
+               input_committed_bundle, side_inputs,
                perform_dofn_pickle_test=True):
     super(_ParDoEvaluator, self).__init__(
         evaluation_context, applied_ptransform, input_committed_bundle,
-        side_inputs, scoped_metrics_container)
+        side_inputs)
     # This is a workaround for SDF implementation. SDF implementation adds 
state
     # to the SDF that is not picklable.
     self._perform_dofn_pickle_test = perform_dofn_pickle_test
@@ -569,8 +567,7 @@ def start_bundle(self):
         self._applied_ptransform.inputs[0].windowing,
         tagged_receivers=self._tagged_receivers,
         step_name=self._applied_ptransform.full_label,
-        state=DoFnState(self._counter_factory),
-        scoped_metrics_container=self.scoped_metrics_container)
+        state=DoFnState(self._counter_factory))
     self.runner.start()
 
   def process_element(self, element):
@@ -592,11 +589,11 @@ class _GroupByKeyOnlyEvaluator(_TransformEvaluator):
   COMPLETION_TAG = _CombiningValueStateTag('completed', any)
 
   def __init__(self, evaluation_context, applied_ptransform,
-               input_committed_bundle, side_inputs, scoped_metrics_container):
+               input_committed_bundle, side_inputs):
     assert not side_inputs
     super(_GroupByKeyOnlyEvaluator, self).__init__(
         evaluation_context, applied_ptransform, input_committed_bundle,
-        side_inputs, scoped_metrics_container)
+        side_inputs)
 
   def _is_final_bundle(self):
     return (self._execution_context.watermarks.input_watermark
@@ -687,11 +684,11 @@ class 
_StreamingGroupByKeyOnlyEvaluator(_TransformEvaluator):
   MAX_ELEMENT_PER_BUNDLE = None
 
   def __init__(self, evaluation_context, applied_ptransform,
-               input_committed_bundle, side_inputs, scoped_metrics_container):
+               input_committed_bundle, side_inputs):
     assert not side_inputs
     super(_StreamingGroupByKeyOnlyEvaluator, self).__init__(
         evaluation_context, applied_ptransform, input_committed_bundle,
-        side_inputs, scoped_metrics_container)
+        side_inputs)
 
   def start_bundle(self):
     self.gbk_items = collections.defaultdict(list)
@@ -739,11 +736,11 @@ class 
_StreamingGroupAlsoByWindowEvaluator(_TransformEvaluator):
   """
 
   def __init__(self, evaluation_context, applied_ptransform,
-               input_committed_bundle, side_inputs, scoped_metrics_container):
+               input_committed_bundle, side_inputs):
     assert not side_inputs
     super(_StreamingGroupAlsoByWindowEvaluator, self).__init__(
         evaluation_context, applied_ptransform, input_committed_bundle,
-        side_inputs, scoped_metrics_container)
+        side_inputs)
 
   def start_bundle(self):
     assert len(self._outputs) == 1
@@ -798,11 +795,11 @@ class _NativeWriteEvaluator(_TransformEvaluator):
   ELEMENTS_TAG = _ListStateTag('elements')
 
   def __init__(self, evaluation_context, applied_ptransform,
-               input_committed_bundle, side_inputs, scoped_metrics_container):
+               input_committed_bundle, side_inputs):
     assert not side_inputs
     super(_NativeWriteEvaluator, self).__init__(
         evaluation_context, applied_ptransform, input_committed_bundle,
-        side_inputs, scoped_metrics_container)
+        side_inputs)
 
     assert applied_ptransform.transform.sink
     self._sink = applied_ptransform.transform.sink
@@ -867,10 +864,10 @@ class _ProcessElementsEvaluator(_TransformEvaluator):
   DEFAULT_MAX_DURATION = 1
 
   def __init__(self, evaluation_context, applied_ptransform,
-               input_committed_bundle, side_inputs, scoped_metrics_container):
+               input_committed_bundle, side_inputs):
     super(_ProcessElementsEvaluator, self).__init__(
         evaluation_context, applied_ptransform, input_committed_bundle,
-        side_inputs, scoped_metrics_container)
+        side_inputs)
 
     process_elements_transform = applied_ptransform.transform
     assert isinstance(process_elements_transform, ProcessElements)
@@ -895,7 +892,7 @@ def __init__(self, evaluation_context, applied_ptransform,
 
     self._par_do_evaluator = _ParDoEvaluator(
         evaluation_context, applied_ptransform, input_committed_bundle,
-        side_inputs, scoped_metrics_container, perform_dofn_pickle_test=False)
+        side_inputs, perform_dofn_pickle_test=False)
     self.keyed_holds = {}
 
   def start_bundle(self):
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 14b25a6035e..51ffac2728f 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -211,8 +211,7 @@ def only_element(iterable):
 
 
 class BundleProcessor(object):
-  """A class for processing bundles of elements.
-  """
+  """A class for processing bundles of elements."""
   def __init__(
       self, process_bundle_descriptor, state_handler, data_channel_factory):
     self.process_bundle_descriptor = process_bundle_descriptor
diff --git a/sdks/python/apache_beam/runners/worker/operations.py 
b/sdks/python/apache_beam/runners/worker/operations.py
index 977d4bba095..aac56402db8 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -133,24 +133,27 @@ def __init__(self, name_context, spec, counter_factory, 
state_sampler):
 
     # These are overwritten in the legacy harness.
     self.metrics_container = MetricsContainer(self.name_context.metrics_name())
-    self.scoped_metrics_container = ScopedMetricsContainer(
-        self.metrics_container)
+    # TODO(BEAM-4094): Remove ScopedMetricsContainer after Dataflow no longer
+    # depends on it.
+    self.scoped_metrics_container = ScopedMetricsContainer()
 
     self.state_sampler = state_sampler
     self.scoped_start_state = self.state_sampler.scoped_state(
-        self.name_context.metrics_name(), 'start')
+        self.name_context.metrics_name(), 'start',
+        metrics_container=self.metrics_container)
     self.scoped_process_state = self.state_sampler.scoped_state(
-        self.name_context.metrics_name(), 'process')
+        self.name_context.metrics_name(), 'process',
+        metrics_container=self.metrics_container)
     self.scoped_finish_state = self.state_sampler.scoped_state(
-        self.name_context.metrics_name(), 'finish')
+        self.name_context.metrics_name(), 'finish',
+        metrics_container=self.metrics_container)
     # TODO(ccy): the '-abort' state can be added when the abort is supported in
     # Operations.
     self.receivers = []
 
   def start(self):
     """Start operation."""
-    self.debug_logging_enabled = logging.getLogger().isEnabledFor(
-        logging.DEBUG)
+    self.debug_logging_enabled = 
logging.getLogger().isEnabledFor(logging.DEBUG)
     # Everything except WorkerSideInputSource, which is not a
     # top-level operation, should have output_coders
     #TODO(pabloem): Define better what step name is used here.
@@ -240,16 +243,15 @@ class ReadOperation(Operation):
 
   def start(self):
     with self.scoped_start_state:
-      with self.scoped_metrics_container:
-        super(ReadOperation, self).start()
-        range_tracker = self.spec.source.source.get_range_tracker(
-            self.spec.source.start_position, self.spec.source.stop_position)
-        for value in self.spec.source.source.read(range_tracker):
-          if isinstance(value, WindowedValue):
-            windowed_value = value
-          else:
-            windowed_value = _globally_windowed_value.with_value(value)
-          self.output(windowed_value)
+      super(ReadOperation, self).start()
+      range_tracker = self.spec.source.source.get_range_tracker(
+          self.spec.source.start_position, self.spec.source.stop_position)
+      for value in self.spec.source.source.read(range_tracker):
+        if isinstance(value, WindowedValue):
+          windowed_value = value
+        else:
+          windowed_value = _globally_windowed_value.with_value(value)
+        self.output(windowed_value)
 
 
 class InMemoryWriteOperation(Operation):
@@ -390,7 +392,7 @@ def start(self):
           logging_context=logger.PerThreadLoggingContext(
               step_name=self.name_context.logging_name()),
           state=state,
-          scoped_metrics_container=self.scoped_metrics_container)
+          scoped_metrics_container=None)
       self.dofn_receiver = (self.dofn_runner
                             if isinstance(self.dofn_runner, Receiver)
                             else DoFnRunnerReceiver(self.dofn_runner))
@@ -444,9 +446,8 @@ def process(self, o):
     if self.debug_logging_enabled:
       logging.debug('Processing [%s] in %s', o, self)
     key, values = o.value
-    with self.scoped_metrics_container:
-      self.output(
-          o.with_value((key, self.phased_combine_fn.apply(values))))
+    self.output(
+        o.with_value((key, self.phased_combine_fn.apply(values))))
 
 
 def create_pgbk_op(step_name, spec, counter_factory, state_sampler):
@@ -725,8 +726,6 @@ def execute(self):
 
     for ix, op in reversed(list(enumerate(self._ops))):
       logging.debug('Starting op %d %s', ix, op)
-      with op.scoped_metrics_container:
-        op.start()
+      op.start()
     for op in self._ops:
-      with op.scoped_metrics_container:
-        op.finish()
+      op.finish()
diff --git a/sdks/python/apache_beam/runners/worker/statesampler.py 
b/sdks/python/apache_beam/runners/worker/statesampler.py
index 03af644846d..d3980928ac7 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler.py
+++ b/sdks/python/apache_beam/runners/worker/statesampler.py
@@ -16,8 +16,10 @@
 #
 
 # This module is experimental. No backwards-compatibility guarantees.
+import threading
 from collections import namedtuple
 
+from apache_beam.metrics import execution
 from apache_beam.utils.counters import Counter
 from apache_beam.utils.counters import CounterName
 
@@ -29,6 +31,20 @@
   FAST_SAMPLER = False
 
 
+_STATE_SAMPLERS = threading.local()
+
+
+def set_current_tracker(tracker):
+  _STATE_SAMPLERS.tracker = tracker
+
+
+def get_current_tracker():
+  try:
+    return _STATE_SAMPLERS.tracker
+  except AttributeError:
+    return None
+
+
 StateSamplerInfo = namedtuple(
     'StateSamplerInfo',
     ['state_name', 'transition_count', 'time_since_transition'])
@@ -53,6 +69,12 @@ def stop_if_still_running(self):
     if self.started and not self.finished:
       self.stop()
 
+  def start(self):
+    set_current_tracker(self)
+    execution.metrics_startup()
+    super(StateSampler, self).start()
+    self.started = True
+
   def get_info(self):
     """Returns StateSamplerInfo with transition statistics."""
     return StateSamplerInfo(
@@ -60,7 +82,11 @@ def get_info(self):
         self.state_transition_count,
         self.time_since_transition)
 
-  def scoped_state(self, step_name, state_name, io_target=None):
+  def scoped_state(self,
+                   step_name,
+                   state_name,
+                   io_target=None,
+                   metrics_container=None):
     counter_name = CounterName(state_name + '-msecs',
                                stage_name=self._prefix,
                                step_name=step_name,
@@ -71,7 +97,9 @@ def scoped_state(self, step_name, state_name, io_target=None):
       output_counter = self._counter_factory.get_counter(counter_name,
                                                          Counter.SUM)
       self._states_by_name[counter_name] = super(
-          StateSampler, self)._scoped_state(counter_name, output_counter)
+          StateSampler, self)._scoped_state(counter_name,
+                                            output_counter,
+                                            metrics_container)
       return self._states_by_name[counter_name]
 
   def commit_counters(self):
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx 
b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
index 37d7c09da46..0fc58445f3b 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
@@ -37,6 +37,7 @@ runtime profile to be produced.
 import threading
 
 from apache_beam.utils.counters import CounterName
+from apache_beam.metrics.execution cimport MetricsContainer
 
 cimport cython
 from cpython cimport pythread
@@ -140,7 +141,6 @@ cdef class StateSampler(object):
 
   def start(self):
     assert not self.started
-    self.started = True
     self.sampling_thread = threading.Thread(target=self.run)
     self.sampling_thread.start()
 
@@ -156,17 +156,24 @@ cdef class StateSampler(object):
   def current_state(self):
     return self.scoped_states_by_index[self.current_state_index]
 
-  cpdef _scoped_state(self, counter_name, output_counter):
+  cpdef _scoped_state(self, counter_name, output_counter,
+                      metrics_container=None):
     """Returns a context manager managing transitions for a given state.
     Args:
-     TODO(pabloem)
+     counter_name: A CounterName object with information about the execution
+       state.
+     output_counter: A Beam Counter to which msecs are committed for reporting.
+     metrics_container: A MetricsContainer for the current step.
 
     Returns:
       A ScopedState for the set of step-state-io_target.
     """
     new_state_index = len(self.scoped_states_by_index)
-    scoped_state = ScopedState(self, counter_name,
-                               new_state_index, output_counter)
+    scoped_state = ScopedState(self,
+                               counter_name,
+                               new_state_index,
+                               output_counter,
+                               metrics_container)
     # Both scoped_states_by_index and scoped_state.nsecs are accessed
     # by the sampling thread; initialize them under the lock.
     pythread.PyThread_acquire_lock(self.lock, pythread.WAIT_LOCK)
@@ -185,12 +192,15 @@ cdef class ScopedState(object):
   cdef readonly object name
   cdef readonly int64_t _nsecs
   cdef int32_t old_state_index
+  cdef readonly MetricsContainer _metrics_container
 
-  def __init__(self, sampler, name, state_index, counter=None):
+  def __init__(
+      self, sampler, name, state_index, counter=None, metrics_container=None):
     self.sampler = sampler
     self.name = name
     self.state_index = state_index
     self.counter = counter
+    self._metrics_container = metrics_container
 
   @property
   def nsecs(self):
@@ -214,3 +224,7 @@ cdef class ScopedState(object):
     self.sampler.current_state_index = self.old_state_index
     self.sampler.state_transition_count += 1
     pythread.PyThread_release_lock(self.sampler.lock)
+
+  @property
+  def metrics_container(self):
+    return self._metrics_container
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_slow.py 
b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
index dafe3b46887..59f84f7891f 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_slow.py
+++ b/sdks/python/apache_beam/runners/worker/statesampler_slow.py
@@ -28,11 +28,17 @@ def __init__(self, sampling_period_ms):
     self.finished = False
 
   def current_state(self):
-    """Returns the current execution state."""
+    """Returns the current execution state.
+
+    This operation is not thread safe, and should only be called from the
+    execution thread."""
     return self._state_stack[-1]
 
-  def _scoped_state(self, counter_name, output_counter):
-    return ScopedState(self, counter_name, output_counter)
+  def _scoped_state(self,
+                    counter_name,
+                    output_counter,
+                    metrics_container=None):
+    return ScopedState(self, counter_name, output_counter, metrics_container)
 
   def _enter_state(self, state):
     self.state_transition_count += 1
@@ -44,24 +50,20 @@ def _exit_state(self):
 
   def start(self):
     # Sampling not yet supported. Only state tracking at the moment.
-    self.started = True
+    pass
 
   def stop(self):
     self.finished = True
 
-  def get_info(self):
-    """Returns StateSamplerInfo with transition statistics."""
-    return StateSamplerInfo(
-        self.current_state().name, self.transition_count, 0)
-
 
 class ScopedState(object):
 
-  def __init__(self, sampler, name, counter=None):
+  def __init__(self, sampler, name, counter=None, metrics_container=None):
     self.state_sampler = sampler
     self.name = name
     self.counter = counter
     self.nsecs = 0
+    self.metrics_container = metrics_container
 
   def sampled_seconds(self):
     return 1e-9 * self.nsecs
diff --git a/sdks/python/apache_beam/transforms/util.py 
b/sdks/python/apache_beam/transforms/util.py
index 61c2eaffd8f..1aeb5974df5 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -221,6 +221,7 @@ def __init__(self,
     self._clock = clock
     self._data = []
     self._ignore_next_timing = False
+
     self._size_distribution = Metrics.distribution(
         'BatchElements', 'batch_size')
     self._time_distribution = Metrics.distribution(


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Issue Time Tracking
-------------------

    Worklog Id:     (was: 94470)
    Time Spent: 13h 10m  (was: 13h)

> State tracking in Python is inefficient and has duplicated code
> ---------------------------------------------------------------
>
>                 Key: BEAM-2732
>                 URL: https://issues.apache.org/jira/browse/BEAM-2732
>             Project: Beam
>          Issue Type: Bug
>          Components: sdk-py-core
>            Reporter: Pablo Estrada
>            Assignee: Pablo Estrada
>            Priority: Major
>          Time Spent: 13h 10m
>  Remaining Estimate: 0h
>
> e.g logging and metrics keep state separately. State tracking should be 
> unified.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to