This is an automated email from the ASF dual-hosted git repository.

yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c05733a3a4d Handle some cases during infer schema from dataclass 
(#37855)
c05733a3a4d is described below

commit c05733a3a4dfdec99af6ad68c1c4187f595b5b43
Author: Yi Hu <[email protected]>
AuthorDate: Fri Mar 20 10:27:09 2026 -0400

    Handle some cases during infer schema from dataclass (#37855)
    
    * Handle some cases during infer schema from dataclass
    
    * For backward compatibility, only infer schema for frozen dataclasses
      when it's registered with row coder
    
    * Make sure Beam schema ID does not inherit
    
    * Fix IndexOutofBoundError trying to infer type from custom Iterable
      without type hint
    
    * Fix #37862: fixed named tuple and effectively fails dataclass inside 
union typehint
    
    * Allow non-frozen dataclass register with other coders as a backup for 
backward compatibility; add tests
    
    * Add upgrade compatibility check for potential coder change
    
    * Update CHANGES.md
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
    
    * Test case for update_compatibility_version
    
    ---------
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
 CHANGES.md                                         |  1 +
 sdks/python/apache_beam/coders/coder_impl.py       |  2 +-
 .../typehints/native_type_compatibility.py         | 46 ++++++++++-
 .../typehints/native_type_compatibility_test.py    | 59 ++++++++++++++
 sdks/python/apache_beam/typehints/row_type.py      |  7 +-
 sdks/python/apache_beam/typehints/row_type_test.py | 89 ++++++++++++++++++++++
 sdks/python/apache_beam/typehints/schemas.py       | 33 +++++---
 7 files changed, 222 insertions(+), 15 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 4072a66e07c..e7f7283b7ef 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -70,6 +70,7 @@
 ## New Features / Improvements
 
 * Added support for large pipeline options via a file (Python) 
([#37370](https://github.com/apache/beam/issues/37370)).
+* Supported infer schema from dataclass (Python) 
([#22085](https://github.com/apache/beam/issues/22085)). Default coder for 
typehint-ed (or set with_output_type) for non-frozen dataclasses changed to 
RowCoder. To preserve the old behavior (fast primitive coder), explicitly 
register the type with FastPrimitiveCoder.
 
 ## Breaking Changes
 
diff --git a/sdks/python/apache_beam/coders/coder_impl.py 
b/sdks/python/apache_beam/coders/coder_impl.py
index b3e45bc7f35..1270b98f9bc 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -493,7 +493,7 @@ class FastPrimitivesCoderImpl(StreamCoderImpl):
       stream.write_byte(PROTO_TYPE)
       self.encode_type(type(value), stream)
       stream.write(value.SerializePartialToString(deterministic=True), True)
-    elif dataclasses and dataclasses.is_dataclass(value):
+    elif dataclasses.is_dataclass(value):
       if not type(value).__dataclass_params__.frozen:
         raise TypeError(
             "Unable to deterministically encode non-frozen '%s' of type '%s' "
diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py 
b/sdks/python/apache_beam/typehints/native_type_compatibility.py
index 886b1505ffe..a8a3ad29325 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py
@@ -176,8 +176,50 @@ def match_is_named_tuple(user_type):
       hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))
 
 
-def match_is_dataclass(user_type):
-  return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)
+def match_dataclass_for_row(user_type):
+  """Match whether the type is a dataclass handled by row coder.
+
+  For frozen dataclasses, only true when explicitly registered with row coder:
+
+    beam.coders.typecoders.registry.register_coder(
+        MyDataClass, beam.coders.RowCoder)
+
+  (for backward-compatibility reason).
+
+  For non-frozen dataclasses, default to true otherwise explicitly registered
+  with a coder other than the row coder.
+  """
+
+  if not dataclasses.is_dataclass(user_type):
+    return False
+
+  # pylint: disable=wrong-import-position
+  try:
+    from apache_beam.options.pipeline_options_context import 
get_pipeline_options  # pylint: disable=line-too-long
+  except AttributeError:
+    pass
+  else:
+    opts = get_pipeline_options()
+    if opts and opts.is_compat_version_prior_to("2.73.0"):
+      return False
+
+  is_frozen = user_type.__dataclass_params__.frozen
+  # avoid circular import
+  try:
+    from apache_beam.coders.typecoders import registry as coders_registry
+    from apache_beam.coders import RowCoder
+  except AttributeError:
+    # coder registery not yet initialized so it must be absent
+    return not is_frozen
+
+  if is_frozen:
+    return (
+        user_type in coders_registry._coders and
+        coders_registry._coders[user_type] == RowCoder)
+  else:
+    return (
+        user_type not in coders_registry._coders or
+        coders_registry._coders[user_type] == RowCoder)
 
 
 def _match_is_optional(user_type):
diff --git 
a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py 
b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py
index e9ce732d2e9..01f40a29945 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py
@@ -20,11 +20,17 @@
 # pytype: skip-file
 
 import collections.abc
+import dataclasses
 import enum
 import re
 import typing
 import unittest
 
+from parameterized import param
+from parameterized import parameterized
+
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options_context import 
scoped_pipeline_options
 from apache_beam.typehints import typehints
 from apache_beam.typehints.native_type_compatibility import 
convert_builtin_to_typing
 from apache_beam.typehints.native_type_compatibility import 
convert_to_beam_type
@@ -33,6 +39,7 @@ from apache_beam.typehints.native_type_compatibility import 
convert_to_python_ty
 from apache_beam.typehints.native_type_compatibility import 
convert_to_python_types
 from apache_beam.typehints.native_type_compatibility import 
convert_typing_to_builtin
 from apache_beam.typehints.native_type_compatibility import is_any
+from apache_beam.typehints.native_type_compatibility import 
match_dataclass_for_row
 
 _TestNamedTuple = typing.NamedTuple(
     '_TestNamedTuple', [('age', int), ('name', bytes)])
@@ -509,6 +516,58 @@ class NativeTypeCompatibilityTest(unittest.TestCase):
     self.assertEqual(
         typehints.Tuple[int, ...], convert_to_beam_type(AliasTuple))
 
+  def test_dataclass_default(self):
+    @dataclasses.dataclass(frozen=True)
+    class FrozenDC:
+      foo: int
+
+    @dataclasses.dataclass
+    class NonFrozenDC:
+      foo: int
+
+    self.assertFalse(match_dataclass_for_row(FrozenDC))
+    self.assertTrue(match_dataclass_for_row(NonFrozenDC))
+
+  def test_dataclass_registered(self):
+    @dataclasses.dataclass(frozen=True)
+    class FrozenRegisteredDC:
+      foo: int
+
+    @dataclasses.dataclass
+    class NonFrozenRegisteredDC:
+      foo: int
+
+    # pylint: disable=wrong-import-position
+    from apache_beam.coders import RowCoder
+    from apache_beam.coders import typecoders
+    from apache_beam.coders.coders import FastPrimitivesCoder
+
+    typecoders.registry.register_coder(FrozenRegisteredDC, RowCoder)
+    typecoders.registry.register_coder(
+        NonFrozenRegisteredDC, FastPrimitivesCoder)
+
+    self.assertTrue(match_dataclass_for_row(FrozenRegisteredDC))
+    self.assertFalse(match_dataclass_for_row(NonFrozenRegisteredDC))
+
+  @parameterized.expand([
+      param(compat_version="2.72.0"),
+      param(compat_version="2.73.0"),
+  ])
+  def test_dataclass_update_compatibility(self, compat_version):
+    @dataclasses.dataclass(frozen=True)
+    class FrozenDC:
+      foo: int
+
+    @dataclasses.dataclass
+    class NonFrozenDC:
+      foo: int
+
+    with scoped_pipeline_options(
+        PipelineOptions(update_compatibility_version=compat_version)):
+      self.assertFalse(match_dataclass_for_row(FrozenDC))
+      self.assertEqual(
+          compat_version == "2.73.0", match_dataclass_for_row(NonFrozenDC))
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/typehints/row_type.py 
b/sdks/python/apache_beam/typehints/row_type.py
index 6f96f6f64e3..0697581cb43 100644
--- a/sdks/python/apache_beam/typehints/row_type.py
+++ b/sdks/python/apache_beam/typehints/row_type.py
@@ -27,7 +27,7 @@ from typing import Sequence
 from typing import Tuple
 
 from apache_beam.typehints import typehints
-from apache_beam.typehints.native_type_compatibility import match_is_dataclass
+from apache_beam.typehints.native_type_compatibility import 
match_dataclass_for_row
 from apache_beam.typehints.native_type_compatibility import 
match_is_named_tuple
 from apache_beam.typehints.schema_registry import SchemaTypeRegistry
 
@@ -91,6 +91,9 @@ class RowTypeConstraint(typehints.TypeConstraint):
     # Currently registration happens when converting to schema protos, in
     # apache_beam.typehints.schemas
     self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None)
+    if self._schema_id and _BEAM_SCHEMA_ID not in self._user_type.__dict__:
+      # schema id does not inherit. Unset if schema id is from base class
+      self._schema_id = None
 
     self._schema_options = schema_options or []
     self._field_options = field_options or {}
@@ -105,7 +108,7 @@ class RowTypeConstraint(typehints.TypeConstraint):
     if match_is_named_tuple(user_type):
       fields = [(name, user_type.__annotations__[name])
                 for name in user_type._fields]
-    elif match_is_dataclass(user_type):
+    elif match_dataclass_for_row(user_type):
       fields = [(field.name, field.type)
                 for field in dataclasses.fields(user_type)]
     else:
diff --git a/sdks/python/apache_beam/typehints/row_type_test.py 
b/sdks/python/apache_beam/typehints/row_type_test.py
index 73d76fee49c..97012d9561d 100644
--- a/sdks/python/apache_beam/typehints/row_type_test.py
+++ b/sdks/python/apache_beam/typehints/row_type_test.py
@@ -26,6 +26,7 @@ from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 from apache_beam.typehints import row_type
+from apache_beam.typehints import schemas
 
 
 class RowTypeTest(unittest.TestCase):
@@ -85,6 +86,94 @@ class RowTypeTest(unittest.TestCase):
           | 'Count Elements' >> beam.Map(self._check_key_type_and_count))
       assert_that(result, equal_to([10] * 100))
 
+  def test_group_by_key_namedtuple_union(self):
+    Tuple1 = typing.NamedTuple("Tuple1", [("id", int)])
+
+    Tuple2 = typing.NamedTuple("Tuple2", [("id", int), ("name", str)])
+
+    def generate(num: int):
+      for i in range(2):
+        yield (Tuple1(i), num)
+        yield (Tuple2(i, 'a'), num)
+
+    pipeline = TestPipeline(is_integration_test=False)
+
+    with pipeline as p:
+      result = (
+          p
+          | 'Create' >> beam.Create([i for i in range(2)])
+          | 'Generate' >> beam.ParDo(generate).with_output_types(
+              tuple[(Tuple1 | Tuple2), int])
+          | 'GBK' >> beam.GroupByKey()
+          | 'Count' >> beam.Map(lambda x: len(x[1])))
+      assert_that(result, equal_to([2] * 4))
+
+  # Union of dataclasses as type hint currently result in FastPrimitiveCoder
+  # fails at GBK
+  @unittest.skip("https://github.com/apache/beam/issues/22085";)
+  def test_group_by_key_inherited_dataclass_union(self):
+    @dataclass
+    class DataClassInt:
+      id: int
+
+    @dataclass
+    class DataClassStr(DataClassInt):
+      name: str
+
+    beam.coders.typecoders.registry.register_coder(
+        DataClassInt, beam.coders.RowCoder)
+    beam.coders.typecoders.registry.register_coder(
+        DataClassStr, beam.coders.RowCoder)
+
+    def generate(num: int):
+      for i in range(10):
+        yield (DataClassInt(i), num)
+        yield (DataClassStr(i, 'a'), num)
+
+    pipeline = TestPipeline(is_integration_test=False)
+
+    with pipeline as p:
+      result = (
+          p
+          | 'Create' >> beam.Create([i for i in range(2)])
+          | 'Generate' >> beam.ParDo(generate).with_output_types(
+              tuple[(DataClassInt | DataClassStr), int])
+          | 'GBK' >> beam.GroupByKey()
+          | 'Count Elements' >> beam.Map(self._check_key_type_and_count))
+      assert_that(result, equal_to([2] * 4))
+
+  def test_derived_dataclass_schema_id(self):
+    @dataclass
+    class BaseDataClass:
+      id: int
+
+    @dataclass
+    class DerivedDataClass(BaseDataClass):
+      name: str
+
+    self.assertFalse(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
+    schema_for_base = schemas.schema_from_element_type(BaseDataClass)
+    self.assertTrue(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
+    self.assertEqual(
+        schema_for_base.id, getattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
+
+    # Getting the schema for BaseDataClass sets the _beam_schema_id
+    schemas.typing_to_runner_api(
+        BaseDataClass, schema_registry=schemas.SchemaTypeRegistry())
+
+    # We create a RowTypeConstraint from DerivedDataClass.
+    # It should not inherit the _beam_schema_id from BaseDataClass!
+    derived_row_type = row_type.RowTypeConstraint.from_user_type(
+        DerivedDataClass)
+    self.assertIsNone(derived_row_type._schema_id)
+
+    schema_for_derived = schemas.schema_from_element_type(DerivedDataClass)
+    self.assertTrue(hasattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
+    self.assertEqual(
+        schema_for_derived.id,
+        getattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
+    self.assertNotEqual(schema_for_derived.id, schema_for_base.id)
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/typehints/schemas.py 
b/sdks/python/apache_beam/typehints/schemas.py
index 5dd8ff290c4..d2c4db8cabc 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -96,7 +96,7 @@ from apache_beam.typehints.native_type_compatibility import 
_match_is_optional
 from apache_beam.typehints.native_type_compatibility import _safe_issubclass
 from apache_beam.typehints.native_type_compatibility import 
convert_to_python_type
 from apache_beam.typehints.native_type_compatibility import 
extract_optional_type
-from apache_beam.typehints.native_type_compatibility import match_is_dataclass
+from apache_beam.typehints.native_type_compatibility import 
match_dataclass_for_row
 from apache_beam.typehints.native_type_compatibility import 
match_is_named_tuple
 from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
 from apache_beam.typehints.schema_registry import SchemaTypeRegistry
@@ -335,9 +335,11 @@ class SchemaTranslation(object):
                   atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[int])))
 
     elif _safe_issubclass(type_, Sequence) and not _safe_issubclass(type_, 
str):
-      element_type = self.typing_to_runner_api(_get_args(type_)[0])
-      return schema_pb2.FieldType(
-          array_type=schema_pb2.ArrayType(element_type=element_type))
+      arg_types = _get_args(type_)
+      if len(arg_types) > 0:
+        element_type = self.typing_to_runner_api(arg_types[0])
+        return schema_pb2.FieldType(
+            array_type=schema_pb2.ArrayType(element_type=element_type))
 
     elif _safe_issubclass(type_, Mapping):
       key_type, value_type = map(self.typing_to_runner_api, _get_args(type_))
@@ -345,9 +347,11 @@ class SchemaTranslation(object):
           map_type=schema_pb2.MapType(key_type=key_type, 
value_type=value_type))
 
     elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, 
str):
-      element_type = self.typing_to_runner_api(_get_args(type_)[0])
-      return schema_pb2.FieldType(
-          array_type=schema_pb2.ArrayType(element_type=element_type))
+      arg_types = _get_args(type_)
+      if len(arg_types) > 0:
+        element_type = self.typing_to_runner_api(arg_types[0])
+        return schema_pb2.FieldType(
+            array_type=schema_pb2.ArrayType(element_type=element_type))
 
     try:
       if LogicalType.is_known_logical_type(type_):
@@ -630,8 +634,10 @@ def schema_from_element_type(element_type: type) -> 
schema_pb2.Schema:
   Returns schema as a list of (name, python_type) tuples"""
   if isinstance(element_type, row_type.RowTypeConstraint):
     return named_fields_to_schema(element_type._fields)
-  elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
-    if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
+  elif match_is_named_tuple(element_type) or match_dataclass_for_row(
+      element_type):
+    # schema id does not inherit from base classes
+    if row_type._BEAM_SCHEMA_ID in element_type.__dict__:
       # if the named tuple's schema is in registry, we just use it instead of
       # regenerating one.
       schema_id = getattr(element_type, row_type._BEAM_SCHEMA_ID)
@@ -657,8 +663,15 @@ def union_schema_type(element_types):
   element_types must be a set of schema-aware types whose fields have the
   same naming and ordering.
   """
+  named_fields_and_types = []
+  for t in element_types:
+    n = named_fields_from_element_type(t)
+    if named_fields_and_types and len(named_fields_and_types[-1]) != len(n):
+      raise TypeError("element types has different number of fields")
+    named_fields_and_types.append(n)
+
   union_fields_and_types = []
-  for field in zip(*[named_fields_from_element_type(t) for t in 
element_types]):
+  for field in zip(*named_fields_and_types):
     names, types = zip(*field)
     name_set = set(names)
     if len(name_set) != 1:

Reply via email to