This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 81517a7fffd Add support for common collections.abc types (#25798)
81517a7fffd is described below

commit 81517a7fffd7c3dbc51accf3c56db8eed4e73339
Author: Jack McCluskey <34928439+jrmcclus...@users.noreply.github.com>
AuthorDate: Tue Mar 14 11:37:17 2023 -0400

    Add support for common collections.abc types (#25798)
    
    * Add support for common collections.abc types
    
    * Change routing to support 3.9 and 3.10
    
    * linting
    
    * more linting
    
    * Remove carryover print
---
 .../typehints/native_type_compatibility.py         | 23 +++++++++++++++++
 .../typehints/native_type_compatibility_test.py    | 29 ++++++++++++++++++++++
 2 files changed, 52 insertions(+)

diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py 
b/sdks/python/apache_beam/typehints/native_type_compatibility.py
index 0fa31bd70ee..f0ef1b2e2c3 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py
@@ -187,6 +187,25 @@ def convert_builtin_to_typing(typ):
   return typ
 
 
+def convert_collections_to_typing(typ):
+  """Converts a given collections.abc type to a typing object.
+
+  Args:
+    typ: an object inheriting from a collections.abc object
+
+  Returns:
+    type: The corresponding typing object.
+  """
+  if hasattr(typ, '__iter__'):
+    if hasattr(typ, '__next__'):
+      typ = typing.Iterator[typ.__args__]
+    elif hasattr(typ, 'send') and hasattr(typ, 'throw'):
+      typ = typing.Generator[typ.__args__]
+    else:
+      typ = typing.Iterable[typ.__args__]
+  return typ
+
+
 def convert_to_beam_type(typ):
   """Convert a given typing type to a Beam type.
 
@@ -212,6 +231,10 @@ def convert_to_beam_type(typ):
   if sys.version_info >= (3, 9) and isinstance(typ, types.GenericAlias):
     typ = convert_builtin_to_typing(typ)
 
+  if sys.version_info >= (3, 9) and getattr(typ, '__module__',
+                                            None) == 'collections.abc':
+    typ = convert_collections_to_typing(typ)
+
   if isinstance(typ, typing.TypeVar):
     # This is a special case, as it's not parameterized by types.
     # Also, identity must be preserved through conversion (i.e. the same
diff --git 
a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py 
b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py
index e7b69932a08..013fa354b82 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py
@@ -19,6 +19,7 @@
 
 # pytype: skip-file
 
+import collections.abc
 import sys
 import typing
 import unittest
@@ -136,6 +137,34 @@ class NativeTypeCompatibilityTest(unittest.TestCase):
         converted_beam_type = convert_to_beam_type(builtins_type)
         self.assertEqual(converted_beam_type, expected_beam_type, description)
 
+  def test_convert_to_beam_type_with_collections_types(self):
+    if sys.version_info >= (3, 9):
+      test_cases = [
+          (
+              'collection iterable',
+              collections.abc.Iterable[int],
+              typehints.Iterable[int]),
+          (
+              'collection generator',
+              collections.abc.Generator[int],
+              typehints.Generator[int]),
+          (
+              'collection iterator',
+              collections.abc.Iterator[int],
+              typehints.Iterator[int]),
+          (
+              'nested iterable',
+              tuple[bytes, collections.abc.Iterable[int]],
+              typehints.Tuple[bytes, typehints.Iterable[int]]),
+      ]
+
+      for test_case in test_cases:
+        description = test_case[0]
+        builtins_type = test_case[1]
+        expected_beam_type = test_case[2]
+        converted_beam_type = convert_to_beam_type(builtins_type)
+        self.assertEqual(converted_beam_type, expected_beam_type, description)
+
   def test_convert_builtin_to_typing(self):
     if sys.version_info >= (3, 9):
       test_cases = [

Reply via email to