Repository: beam
Updated Branches:
  refs/heads/master 1761d1cab -> 7fd9c6516


Fix GroupByKeyInputVisitor for Direct Runner


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/9e453fab
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/9e453fab
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/9e453fab

Branch: refs/heads/master
Commit: 9e453fabe2bf448552ab5706130495e5ea4cf1c2
Parents: 1761d1c
Author: Vikas Kedigehalli <vika...@google.com>
Authored: Fri Apr 7 19:22:27 2017 -0700
Committer: Ahmet Altay <al...@google.com>
Committed: Mon Apr 10 15:28:00 2017 -0700

----------------------------------------------------------------------
 .../apache_beam/runners/direct/direct_runner.py | 18 ++---
 sdks/python/apache_beam/runners/runner.py       | 71 ++++++++++++--------
 sdks/python/apache_beam/runners/runner_test.py  | 41 +++++++++++
 3 files changed, 93 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/9e453fab/sdks/python/apache_beam/runners/direct/direct_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py 
b/sdks/python/apache_beam/runners/direct/direct_runner.py
index 1a5775f..9b4e1ac 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -32,6 +32,7 @@ from apache_beam.runners.runner import PipelineResult
 from apache_beam.runners.runner import PipelineRunner
 from apache_beam.runners.runner import PipelineState
 from apache_beam.runners.runner import PValueCache
+from apache_beam.runners.runner import group_by_key_input_visitor
 from apache_beam.utils.pipeline_options import DirectOptions
 from apache_beam.utils.value_provider import RuntimeValueProvider
 
@@ -68,21 +69,22 @@ class DirectRunner(PipelineRunner):
 
     MetricsEnvironment.set_metrics_supported(True)
     logging.info('Running pipeline with DirectRunner.')
-    self.visitor = ConsumerTrackingPipelineVisitor()
-    pipeline.visit(self.visitor)
+    self.consumer_tracking_visitor = ConsumerTrackingPipelineVisitor()
+    pipeline.visit(group_by_key_input_visitor())
+    pipeline.visit(self.consumer_tracking_visitor)
 
     evaluation_context = EvaluationContext(
         pipeline.options,
         BundleFactory(stacked=pipeline.options.view_as(DirectOptions)
                       .direct_runner_use_stacked_bundle),
-        self.visitor.root_transforms,
-        self.visitor.value_to_consumers,
-        self.visitor.step_names,
-        self.visitor.views)
+        self.consumer_tracking_visitor.root_transforms,
+        self.consumer_tracking_visitor.value_to_consumers,
+        self.consumer_tracking_visitor.step_names,
+        self.consumer_tracking_visitor.views)
 
     evaluation_context.use_pvalue_cache(self._cache)
 
-    executor = Executor(self.visitor.value_to_consumers,
+    executor = Executor(self.consumer_tracking_visitor.value_to_consumers,
                         TransformEvaluatorRegistry(evaluation_context),
                         evaluation_context)
     # Start the executor. This is a non-blocking call, it will start the
@@ -90,7 +92,7 @@ class DirectRunner(PipelineRunner):
 
     if pipeline.options:
       RuntimeValueProvider.set_runtime_options(pipeline.options._options_id, 
{})
-    executor.start(self.visitor.root_transforms)
+    executor.start(self.consumer_tracking_visitor.root_transforms)
     result = DirectPipelineResult(executor, evaluation_context)
 
     if self._cache:

http://git-wip-us.apache.org/repos/asf/beam/blob/9e453fab/sdks/python/apache_beam/runners/runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/runner.py 
b/sdks/python/apache_beam/runners/runner.py
index 528b03f..de9c892 100644
--- a/sdks/python/apache_beam/runners/runner.py
+++ b/sdks/python/apache_beam/runners/runner.py
@@ -86,6 +86,47 @@ def create_runner(runner_name):
             runner_name, ', '.join(_ALL_KNOWN_RUNNERS)))
 
 
+def group_by_key_input_visitor():
+  # Imported here to avoid circular dependencies.
+  from apache_beam.pipeline import PipelineVisitor
+
+  class GroupByKeyInputVisitor(PipelineVisitor):
+    """A visitor that replaces `Any` element type for input `PCollection` of
+    a `GroupByKey` or `GroupByKeyOnly` with a `KV` type.
+
+    TODO(BEAM-115): Once Python SDk is compatible with the new Runner API,
+    we could directly replace the coder instead of mutating the element type.
+    """
+
+    def visit_transform(self, transform_node):
+      # Imported here to avoid circular dependencies.
+      # pylint: disable=wrong-import-order, wrong-import-position
+      from apache_beam import GroupByKey, GroupByKeyOnly
+      from apache_beam import typehints
+      if (isinstance(transform_node.transform, GroupByKey) or
+          isinstance(transform_node.transform, GroupByKeyOnly)):
+        pcoll = transform_node.inputs[0]
+        input_type = pcoll.element_type
+        # If input_type is not specified, then treat it as `Any`.
+        if not input_type:
+          input_type = typehints.Any
+
+        if not isinstance(input_type, typehints.TupleHint.TupleConstraint):
+          if isinstance(input_type, typehints.AnyTypeConstraint):
+            # `Any` type needs to be replaced with a KV[Any, Any] to
+            # force a KV coder as the main output coder for the pcollection
+            # preceding a GroupByKey.
+            pcoll.element_type = typehints.KV[typehints.Any, typehints.Any]
+          else:
+            # TODO: Handle other valid types,
+            # e.g. Union[KV[str, int], KV[str, float]]
+            raise ValueError(
+                "Input to GroupByKey must be of Tuple or Any type. "
+                "Found %s for %s" % (input_type, pcoll))
+
+  return GroupByKeyInputVisitor()
+
+
 class PipelineRunner(object):
   """A runner of a pipeline object.
 
@@ -119,35 +160,7 @@ class PipelineRunner(object):
           logging.error('Error while visiting %s', transform_node.full_label)
           raise
 
-    class GroupByKeyInputVisitor(PipelineVisitor):
-      """A visitor that replaces `Any` element type for input `PCollection` of
-      a `GroupByKey` with a `KV` type.
-
-      TODO(BEAM-115): Once Python SDk is compatible with the new Runner API,
-      we could directly replace the coder instead of mutating the element type.
-      """
-      def visit_transform(self, transform_node):
-        # Imported here to avoid circular dependencies.
-        # pylint: disable=wrong-import-order, wrong-import-position
-        from apache_beam import GroupByKey
-        from apache_beam import typehints
-        if isinstance(transform_node.transform, GroupByKey):
-          pcoll = transform_node.inputs[0]
-          input_type = pcoll.element_type
-          if not isinstance(input_type, typehints.TupleHint.TupleConstraint):
-            if isinstance(input_type, typehints.AnyTypeConstraint):
-              # `Any` type needs to be replaced with a KV[Any, Any] to
-              # force a KV coder as the main output coder for the pcollection
-              # preceding a GroupByKey.
-              pcoll.element_type = typehints.KV[typehints.Any, typehints.Any]
-            else:
-              # TODO: Handle other valid types,
-              # e.g. Union[KV[str, int], KV[str, float]]
-              raise ValueError(
-                  "Input to GroupByKey must be of Tuple or Any type. "
-                  "Found %s for %s" % (input_type, pcoll))
-
-    pipeline.visit(GroupByKeyInputVisitor())
+    pipeline.visit(group_by_key_input_visitor())
     pipeline.visit(RunVisitor(self))
 
   def clear(self, pipeline, node=None):

http://git-wip-us.apache.org/repos/asf/beam/blob/9e453fab/sdks/python/apache_beam/runners/runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/runner_test.py 
b/sdks/python/apache_beam/runners/runner_test.py
index b161cbb..0bebd66 100644
--- a/sdks/python/apache_beam/runners/runner_test.py
+++ b/sdks/python/apache_beam/runners/runner_test.py
@@ -28,14 +28,19 @@ import hamcrest as hc
 
 import apache_beam as beam
 import apache_beam.transforms as ptransform
+from apache_beam import typehints
 from apache_beam.metrics.cells import DistributionData
 from apache_beam.metrics.cells import DistributionResult
 from apache_beam.metrics.execution import MetricKey
 from apache_beam.metrics.execution import MetricResult
 from apache_beam.metrics.metricbase import MetricName
+from apache_beam.pipeline import AppliedPTransform
 from apache_beam.pipeline import Pipeline
+from apache_beam.pvalue import PCollection
 from apache_beam.runners import DirectRunner
+from apache_beam.runners import runner
 from apache_beam.runners import create_runner
+from apache_beam.test_pipeline import TestPipeline
 from apache_beam.transforms.util import assert_that
 from apache_beam.transforms.util import equal_to
 from apache_beam.utils.pipeline_options import PipelineOptions
@@ -118,6 +123,42 @@ class RunnerTest(unittest.TestCase):
                 DistributionResult(DistributionData(15, 5, 1, 5)),
                 DistributionResult(DistributionData(15, 5, 1, 5)))))
 
+  def test_group_by_key_input_visitor_with_valid_inputs(self):
+    p = TestPipeline()
+    pcoll1 = PCollection(p)
+    pcoll2 = PCollection(p)
+    pcoll3 = PCollection(p)
+    for transform in [beam.GroupByKeyOnly(), beam.GroupByKey()]:
+      pcoll1.element_type = None
+      pcoll2.element_type = typehints.Any
+      pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any]
+      for pcoll in [pcoll1, pcoll2, pcoll3]:
+        runner.group_by_key_input_visitor().visit_transform(
+            AppliedPTransform(None, transform, "label", [pcoll]))
+        self.assertEqual(pcoll.element_type,
+                         typehints.KV[typehints.Any, typehints.Any])
+
+  def test_group_by_key_input_visitor_with_invalid_inputs(self):
+    p = TestPipeline()
+    pcoll1 = PCollection(p)
+    pcoll2 = PCollection(p)
+    for transform in [beam.GroupByKeyOnly(), beam.GroupByKey()]:
+      pcoll1.element_type = typehints.TupleSequenceConstraint
+      pcoll2.element_type = typehints.Set
+      err_msg = "Input to GroupByKey must be of Tuple or Any type"
+      for pcoll in [pcoll1, pcoll2]:
+        with self.assertRaisesRegexp(ValueError, err_msg):
+          runner.group_by_key_input_visitor().visit_transform(
+              AppliedPTransform(None, transform, "label", [pcoll]))
+
+  def test_group_by_key_input_visitor_for_non_gbk_transforms(self):
+    p = TestPipeline()
+    pcoll = PCollection(p)
+    for transform in [beam.Flatten(), beam.Map(lambda x: x)]:
+      pcoll.element_type = typehints.Any
+      runner.group_by_key_input_visitor().visit_transform(
+          AppliedPTransform(None, transform, "label", [pcoll]))
+      self.assertEqual(pcoll.element_type, typehints.Any)
 
 if __name__ == '__main__':
   unittest.main()

Reply via email to