Add special work to handle indexable return types. Introduce a base IndexableTypeConstraint.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/f364248b Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/f364248b Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/f364248b Branch: refs/heads/master Commit: f364248b984e8ff884a34b37d49a17fa1a50cc26 Parents: 5cc6b5f Author: Holden Karau <hol...@us.ibm.com> Authored: Fri Sep 1 20:51:18 2017 -0700 Committer: Robert Bradshaw <rober...@gmail.com> Committed: Thu Oct 12 15:50:08 2017 -0700 ---------------------------------------------------------------------- .../apache_beam/typehints/trivial_inference.py | 20 +++++++++++++- .../typehints/trivial_inference_test.py | 5 ++++ sdks/python/apache_beam/typehints/typehints.py | 28 +++++++++++++++++--- 3 files changed, 49 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/f364248b/sdks/python/apache_beam/typehints/trivial_inference.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/typehints/trivial_inference.py b/sdks/python/apache_beam/typehints/trivial_inference.py index 51d3db2..a68bd18 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference.py +++ b/sdks/python/apache_beam/typehints/trivial_inference.py @@ -107,7 +107,10 @@ class FrameState(object): def __init__(self, f, local_vars=None, stack=()): self.f = f - self.co = f.__code__ + if sys.version_info[0] >= 3: + self.co = f.__code__ + else: + self.co = f.func_code self.vars = list(local_vars) self.stack = list(stack) @@ -362,7 +365,22 @@ def infer_return_type_func(f, input_types, debug=False, depth=0): else: return_type = Any state.stack[-pop_count:] = [return_type] + elif (opname == 'BINARY_SUBSCR' + and isinstance(state.stack[1], Const) + and isinstance(state.stack[0], typehints.IndexableTypeConstraint)): + if debug: + print("Executing special case binary subscript") + idx = state.stack.pop() + src = state.stack.pop() + try: + state.stack.append(src._constraint_for_index(idx.value)) + except Exception as e: + if debug: + print("Exception {0} during special case indexing".format(e)) + state.stack.append(Any) elif opname in simple_ops: + if debug: + print("Executing simple op " + opname) simple_ops[opname](state, arg) elif opname == 'RETURN_VALUE': returns.add(state.stack[-1]) http://git-wip-us.apache.org/repos/asf/beam/blob/f364248b/sdks/python/apache_beam/typehints/trivial_inference_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/typehints/trivial_inference_test.py b/sdks/python/apache_beam/typehints/trivial_inference_test.py index 8af9dd6..7b7b6a8 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference_test.py +++ b/sdks/python/apache_beam/typehints/trivial_inference_test.py @@ -32,6 +32,11 @@ class TrivialInferenceTest(unittest.TestCase): def testIdentity(self): self.assertReturnType(int, lambda x: x, [int]) + def testIndexing(self): + self.assertReturnType(int, lambda x: x[0], [typehints.Tuple[int, str]]) + self.assertReturnType(str, lambda x: x[1], [typehints.Tuple[int, str]]) + self.assertReturnType(str, lambda x: x[1], [typehints.List[str]]) + def testTuples(self): self.assertReturnType( typehints.Tuple[typehints.Tuple[()], int], lambda x: ((), x), [int]) http://git-wip-us.apache.org/repos/asf/beam/blob/f364248b/sdks/python/apache_beam/typehints/typehints.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index a27dd7e..b78ead2 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -65,6 +65,7 @@ In addition, type-hints can be used to implement run-time type-checking via the import collections import copy +import sys import types __all__ = [ @@ -184,7 +185,17 @@ def bind_type_variables(type_constraint, bindings): return type_constraint -class SequenceTypeConstraint(TypeConstraint): +class IndexableTypeConstraint(TypeConstraint): + """An internal common base-class for all type constraints with indexing. + E.G. SequenceTypeConstraint + Tuple's of fixed size. + """ + + def _constraint_for_index(self, idx): + """Returns the type at the given index.""" + raise NotImplementedError + + +class SequenceTypeConstraint(IndexableTypeConstraint): """A common base-class for all sequence related type-constraint classes. A sequence is defined as an arbitrary length homogeneous container type. Type @@ -214,6 +225,10 @@ class SequenceTypeConstraint(TypeConstraint): def _inner_types(self): yield self.inner_type + def _constraint_for_index(self, idx): + """Returns the type at the given index.""" + return self.inner_type + def _consistent_with_check_(self, sub): return (isinstance(sub, self.__class__) and is_consistent_with(sub.inner_type, self.inner_type)) @@ -314,8 +329,11 @@ def validate_composite_type_param(type_param, error_msg_prefix): parameter for a :class:`CompositeTypeHint`. """ # Must either be a TypeConstraint instance or a basic Python type. + possible_classes = [type, TypeConstraint] + if sys.version_info[0] == 2: + possible_classes.append(types.ClassType) is_not_type_constraint = ( - not isinstance(type_param, (type, TypeConstraint)) + not isinstance(type_param, tuple(possible_classes)) and type_param is not None) is_forbidden_type = (isinstance(type_param, type) and type_param in DISALLOWED_PRIMITIVE_TYPES) @@ -546,7 +564,7 @@ class TupleHint(CompositeTypeHint): for elem in sub.tuple_types) return super(TupleSequenceConstraint, self)._consistent_with_check_(sub) - class TupleConstraint(TypeConstraint): + class TupleConstraint(IndexableTypeConstraint): def __init__(self, type_params): self.tuple_types = tuple(type_params) @@ -566,6 +584,10 @@ class TupleHint(CompositeTypeHint): for t in self.tuple_types: yield t + def _constraint_for_index(self, idx): + """Returns the type at the given index.""" + return self.tuple_types[idx] + def _consistent_with_check_(self, sub): return (isinstance(sub, self.__class__) and len(sub.tuple_types) == len(self.tuple_types)