Repository: beam Updated Branches: refs/heads/master 697b19fe5 -> cc5f78dd0
Replace Any type with a KV type for inputs of a GroupByKey step Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/7fad7391 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/7fad7391 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/7fad7391 Branch: refs/heads/master Commit: 7fad7391548b6399066c0d9e79949707d7a5914e Parents: 697b19f Author: Vikas Kedigehalli <vika...@google.com> Authored: Sun Apr 2 20:46:55 2017 -0700 Committer: Ahmet Altay <al...@google.com> Committed: Wed Apr 5 17:21:14 2017 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/runners/runner.py | 29 ++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/7fad7391/sdks/python/apache_beam/runners/runner.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index b203c8b..528b03f 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -119,6 +119,35 @@ class PipelineRunner(object): logging.error('Error while visiting %s', transform_node.full_label) raise + class GroupByKeyInputVisitor(PipelineVisitor): + """A visitor that replaces `Any` element type for input `PCollection` of + a `GroupByKey` with a `KV` type. + + TODO(BEAM-115): Once Python SDk is compatible with the new Runner API, + we could directly replace the coder instead of mutating the element type. + """ + def visit_transform(self, transform_node): + # Imported here to avoid circular dependencies. + # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam import GroupByKey + from apache_beam import typehints + if isinstance(transform_node.transform, GroupByKey): + pcoll = transform_node.inputs[0] + input_type = pcoll.element_type + if not isinstance(input_type, typehints.TupleHint.TupleConstraint): + if isinstance(input_type, typehints.AnyTypeConstraint): + # `Any` type needs to be replaced with a KV[Any, Any] to + # force a KV coder as the main output coder for the pcollection + # preceding a GroupByKey. + pcoll.element_type = typehints.KV[typehints.Any, typehints.Any] + else: + # TODO: Handle other valid types, + # e.g. Union[KV[str, int], KV[str, float]] + raise ValueError( + "Input to GroupByKey must be of Tuple or Any type. " + "Found %s for %s" % (input_type, pcoll)) + + pipeline.visit(GroupByKeyInputVisitor()) pipeline.visit(RunVisitor(self)) def clear(self, pipeline, node=None):