udim commented on a change in pull request #12009:
URL: https://github.com/apache/beam/pull/12009#discussion_r445246250



##########
File path: sdks/python/apache_beam/typehints/typed_pipeline_test_py3.py
##########
@@ -40,6 +40,14 @@ def process(self, element: int) -> typehints.Tuple[str]:
     with self.assertRaisesRegex(typehints.TypeCheckError,
                                 r'requires.*int.*got.*str'):
       _ = ['a', 'b', 'c'] | beam.ParDo(MyDoFn())
+  def test_pardo_dofn(self):

Review comment:
       Did you mean to leave this test here? It looks like a copy of the one in 
AnnotationsTest.

##########
File path: sdks/python/apache_beam/typehints/typehints_test_py3.py
##########
@@ -46,11 +51,61 @@ class MyDoFn(DoFn):
       def process(self, element: int) -> Iterable[str]:
         pass
 
-    print(MyDoFn().get_type_hints())
     th = MyDoFn().get_type_hints()
     self.assertEqual(th.input_types, ((int, ), {}))
     self.assertEqual(th.output_types, ((str, ), {}))
 
 
+class TestPTransformAnnotations(unittest.TestCase):
+  def test_pep484_annotations(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: PCollection[int]) -> PCollection[str]:
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((int, ), {}))
+    self.assertEqual(th.output_types, ((str, ), {}))
+
+  def test_annotations_without_pcollection_wrapper(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: int) -> str:
+        return pcoll | Map(lambda num: str(num))
+
+    with self.assertRaises(TypeCheckError) as error:
+      _th = MyPTransform().get_type_hints()
+
+    self.assertEqual(str(error.exception), 'An input typehint to a PTransform 
must be a single (or nested) type '
+                                           'wrapped by a PCollection.')
+
+  def test_annotations_without_internal_type(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: PCollection) -> PCollection:

Review comment:
       This is valid. The type hint should convert to `Any`.
   
   Quoting from https://docs.python.org/3/library/typing.html:
   > Using a generic class without specifying type parameters assumes Any for 
each position.

##########
File path: sdks/python/apache_beam/typehints/typed_pipeline_test_py3.py
##########
@@ -257,6 +265,65 @@ def fn2(element: int) -> int:
     result = [1, 2, 3] | beam.FlatMap(fn) | beam.Map(fn2)
     self.assertCountEqual([4, 6], result)
 
+  def test_typed_ptransform_with_no_error(self):
+    class StrToInt(beam.PTransform):
+      def expand(self, pcoll: beam.pvalue.PCollection[str]) -> 
beam.pvalue.PCollection[int]:
+        return pcoll | beam.Map(lambda x: int(x))
+
+    class IntToStr(beam.PTransform):
+      def expand(self, pcoll: beam.pvalue.PCollection[int]) -> 
beam.pvalue.PCollection[str]:
+        return pcoll | beam.Map(lambda x: str(x))
+
+    try:
+      _ = ['1', '2', '3'] | StrToInt() | IntToStr()
+    except Exception:
+      self.fail('An unexpected error was raised during a pipeline with correct 
typehints.')
+
+  def test_typed_ptransform_with_bad_typehints(self):
+    class StrToInt(beam.PTransform):
+      def expand(self, pcoll: beam.pvalue.PCollection[str]) -> 
beam.pvalue.PCollection[int]:
+        return pcoll | beam.Map(lambda x: int(x))
+
+    class IntToStr(beam.PTransform):
+      def expand(self, pcoll: beam.pvalue.PCollection[str]) -> 
beam.pvalue.PCollection[str]:
+        return pcoll | beam.Map(lambda x: str(x))
+
+    with self.assertRaises(typehints.TypeCheckError) as error:
+      # raises error because of mismatched typehints between StrToInt and 
IntToStr
+      _ = ['1', '2', '3'] | StrToInt() | IntToStr()
+
+    self.assertTrue("Input type hint violation at IntToStr: expected <class 
'str'>, got <class 'int'>" in str(error.exception))
+
+  def test_typed_ptransform_with_bad_input(self):
+    class StrToInt(beam.PTransform):
+      def expand(self, pcoll: beam.pvalue.PCollection[str]) -> 
beam.pvalue.PCollection[int]:
+        return pcoll | beam.Map(lambda x: int(x))
+
+    class IntToStr(beam.PTransform):
+      def expand(self, pcoll: beam.pvalue.PCollection[int]) -> 
beam.pvalue.PCollection[str]:
+        return pcoll | beam.Map(lambda x: str(x))
+
+    with self.assertRaises(typehints.TypeCheckError) as error:
+      # Feed integers to a PTransform that expects strings
+      _ = [1, 2, 3] | StrToInt() | IntToStr()
+
+    self.assertTrue("Input type hint violation at StrToInt: expected <class 
'str'>, got <class 'int'>" in str(error.exception))

Review comment:
       Please use `with self.assertRaisesRegex(..)` above instead of separately 
checking the exception text.

##########
File path: sdks/python/apache_beam/transforms/ptransform.py
##########
@@ -364,6 +366,15 @@ def default_label(self):
     # type: () -> str
     return self.__class__.__name__
 
+  def default_type_hints(self):
+    fn_type_hints = IOTypeHints.from_callable(self.expand)
+    if fn_type_hints is not None:
+      fn_type_hints = fn_type_hints.strip_pcoll_input()
+      fn_type_hints = fn_type_hints.strip_pcoll_output()

Review comment:
       You can chain the 2 function calls.

##########
File path: sdks/python/apache_beam/typehints/decorators.py
##########
@@ -378,6 +378,43 @@ def has_simple_output_type(self):
         self.output_types and len(self.output_types[0]) == 1 and
         not self.output_types[1])
 
+  def strip_pcoll_input(self):
+    # type: () -> IOTypeHints
+
+    input_type = self.input_types[0][0]
+    if isinstance(input_type, typehints.AnyTypeConstraint):

Review comment:
       Also verify that input_type is a PCollection or PBegin.
   PCollection or PDone for output type

##########
File path: sdks/python/apache_beam/typehints/typed_pipeline_test_py3.py
##########
@@ -40,6 +40,14 @@ def process(self, element: int) -> typehints.Tuple[str]:
     with self.assertRaisesRegex(typehints.TypeCheckError,
                                 r'requires.*int.*got.*str'):
       _ = ['a', 'b', 'c'] | beam.ParDo(MyDoFn())
+  def test_pardo_dofn(self):

Review comment:
       Or perhaps this was also a git merge result

##########
File path: sdks/python/apache_beam/typehints/typehints_test_py3.py
##########
@@ -46,11 +51,61 @@ class MyDoFn(DoFn):
       def process(self, element: int) -> Iterable[str]:
         pass
 
-    print(MyDoFn().get_type_hints())
     th = MyDoFn().get_type_hints()
     self.assertEqual(th.input_types, ((int, ), {}))
     self.assertEqual(th.output_types, ((str, ), {}))
 
 
+class TestPTransformAnnotations(unittest.TestCase):
+  def test_pep484_annotations(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: PCollection[int]) -> PCollection[str]:
+        return pcoll | Map(lambda num: str(num))
+
+    th = MyPTransform().get_type_hints()
+    self.assertEqual(th.input_types, ((int, ), {}))
+    self.assertEqual(th.output_types, ((str, ), {}))
+
+  def test_annotations_without_pcollection_wrapper(self):
+    class MyPTransform(PTransform):
+      def expand(self, pcoll: int) -> str:
+        return pcoll | Map(lambda num: str(num))
+
+    with self.assertRaises(TypeCheckError) as error:
+      _th = MyPTransform().get_type_hints()
+
+    self.assertEqual(str(error.exception), 'An input typehint to a PTransform 
must be a single (or nested) type '

Review comment:
       Also test when the output typehint is unsupported.

##########
File path: sdks/python/apache_beam/typehints/decorators.py
##########
@@ -378,6 +378,43 @@ def has_simple_output_type(self):
         self.output_types and len(self.output_types[0]) == 1 and
         not self.output_types[1])
 
+  def strip_pcoll_input(self):
+    # type: () -> IOTypeHints
+
+    input_type = self.input_types[0][0]
+    if isinstance(input_type, typehints.AnyTypeConstraint):
+      return self
+
+    try:
+      input_type = input_type.__args__[0]
+    except:

Review comment:
       As a general rule, don't catch all exceptions but only the ones you 
expect to be raised.

##########
File path: sdks/python/apache_beam/transforms/ptransform.py
##########
@@ -616,24 +627,14 @@ def register_urn(
     # type: (...) -> Callable[[Union[type, 
Callable[[beam_runner_api_pb2.PTransform, T, PipelineContext], Any]]], 
Callable[[T, PipelineContext], Any]]
     pass
 
-  @classmethod
-  @overload
-  def register_urn(
-      cls,
-      urn,  # type: str
-      parameter_type,  # type: None
-  ):
-    # type: (...) -> Callable[[Union[type, 
Callable[[beam_runner_api_pb2.PTransform, bytes, PipelineContext], Any]]], 
Callable[[bytes, PipelineContext], Any]]
-    pass
-
   @classmethod
   @overload
   def register_urn(cls,
                    urn,  # type: str
                    parameter_type,  # type: Type[T]
                    constructor  # type: 
Callable[[beam_runner_api_pb2.PTransform, T, PipelineContext], Any]
                   ):
-    # type: (...) -> None
+    # type: (...) -> Callable[[Union[type, 
Callable[[beam_runner_api_pb2.PTransform, bytes, PipelineContext], Any]]], 
Callable[[bytes, PipelineContext], Any]]

Review comment:
       Bad merge?

##########
File path: sdks/python/apache_beam/typehints/decorators.py
##########
@@ -378,6 +378,43 @@ def has_simple_output_type(self):
         self.output_types and len(self.output_types[0]) == 1 and
         not self.output_types[1])
 
+  def strip_pcoll_input(self):
+    # type: () -> IOTypeHints
+
+    input_type = self.input_types[0][0]

Review comment:
       Please handle cases where self.input_types is None or the number of 
arguments is not 1.
   

##########
File path: sdks/python/apache_beam/typehints/typed_pipeline_test_py3.py
##########
@@ -257,6 +265,65 @@ def fn2(element: int) -> int:
     result = [1, 2, 3] | beam.FlatMap(fn) | beam.Map(fn2)
     self.assertCountEqual([4, 6], result)
 
+  def test_typed_ptransform_with_no_error(self):
+    class StrToInt(beam.PTransform):
+      def expand(self, pcoll: beam.pvalue.PCollection[str]) -> 
beam.pvalue.PCollection[int]:
+        return pcoll | beam.Map(lambda x: int(x))
+
+    class IntToStr(beam.PTransform):
+      def expand(self, pcoll: beam.pvalue.PCollection[int]) -> 
beam.pvalue.PCollection[str]:
+        return pcoll | beam.Map(lambda x: str(x))
+
+    try:
+      _ = ['1', '2', '3'] | StrToInt() | IntToStr()
+    except Exception:
+      self.fail('An unexpected error was raised during a pipeline with correct 
typehints.')

Review comment:
       There is no need to assert that no exceptions are raised. The test will 
already fail if exceptions are raised.
   
   If I need to be explicit I usually put a comment above the line that 
shouldn't fail.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to