This is an automated email from the ASF dual-hosted git repository.
yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c05733a3a4d Handle some cases during infer schema from dataclass
(#37855)
c05733a3a4d is described below
commit c05733a3a4dfdec99af6ad68c1c4187f595b5b43
Author: Yi Hu <[email protected]>
AuthorDate: Fri Mar 20 10:27:09 2026 -0400
Handle some cases during infer schema from dataclass (#37855)
* Handle some cases during infer schema from dataclass
* For backward compatibility, only infer schema for frozen dataclasses
when it's registered with row coder
* Make sure Beam schema ID does not inherit
* Fix IndexOutofBoundError trying to infer type from custom Iterable
without type hint
* Fix #37862: fixed named tuple and effectively fails dataclass inside
union typehint
* Allow non-frozen dataclass register with other coders as a backup for
backward compatibility; add tests
* Add upgrade compatibility check for potential coder change
* Update CHANGES.md
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
* Test case for update_compatibility_version
---------
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
CHANGES.md | 1 +
sdks/python/apache_beam/coders/coder_impl.py | 2 +-
.../typehints/native_type_compatibility.py | 46 ++++++++++-
.../typehints/native_type_compatibility_test.py | 59 ++++++++++++++
sdks/python/apache_beam/typehints/row_type.py | 7 +-
sdks/python/apache_beam/typehints/row_type_test.py | 89 ++++++++++++++++++++++
sdks/python/apache_beam/typehints/schemas.py | 33 +++++---
7 files changed, 222 insertions(+), 15 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 4072a66e07c..e7f7283b7ef 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -70,6 +70,7 @@
## New Features / Improvements
* Added support for large pipeline options via a file (Python)
([#37370](https://github.com/apache/beam/issues/37370)).
+* Supported infer schema from dataclass (Python)
([#22085](https://github.com/apache/beam/issues/22085)). Default coder for
typehint-ed (or set with_output_type) for non-frozen dataclasses changed to
RowCoder. To preserve the old behavior (fast primitive coder), explicitly
register the type with FastPrimitiveCoder.
## Breaking Changes
diff --git a/sdks/python/apache_beam/coders/coder_impl.py
b/sdks/python/apache_beam/coders/coder_impl.py
index b3e45bc7f35..1270b98f9bc 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -493,7 +493,7 @@ class FastPrimitivesCoderImpl(StreamCoderImpl):
stream.write_byte(PROTO_TYPE)
self.encode_type(type(value), stream)
stream.write(value.SerializePartialToString(deterministic=True), True)
- elif dataclasses and dataclasses.is_dataclass(value):
+ elif dataclasses.is_dataclass(value):
if not type(value).__dataclass_params__.frozen:
raise TypeError(
"Unable to deterministically encode non-frozen '%s' of type '%s' "
diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py
b/sdks/python/apache_beam/typehints/native_type_compatibility.py
index 886b1505ffe..a8a3ad29325 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py
@@ -176,8 +176,50 @@ def match_is_named_tuple(user_type):
hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))
-def match_is_dataclass(user_type):
- return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)
+def match_dataclass_for_row(user_type):
+ """Match whether the type is a dataclass handled by row coder.
+
+ For frozen dataclasses, only true when explicitly registered with row coder:
+
+ beam.coders.typecoders.registry.register_coder(
+ MyDataClass, beam.coders.RowCoder)
+
+ (for backward-compatibility reason).
+
+ For non-frozen dataclasses, default to true otherwise explicitly registered
+ with a coder other than the row coder.
+ """
+
+ if not dataclasses.is_dataclass(user_type):
+ return False
+
+ # pylint: disable=wrong-import-position
+ try:
+ from apache_beam.options.pipeline_options_context import
get_pipeline_options # pylint: disable=line-too-long
+ except AttributeError:
+ pass
+ else:
+ opts = get_pipeline_options()
+ if opts and opts.is_compat_version_prior_to("2.73.0"):
+ return False
+
+ is_frozen = user_type.__dataclass_params__.frozen
+ # avoid circular import
+ try:
+ from apache_beam.coders.typecoders import registry as coders_registry
+ from apache_beam.coders import RowCoder
+ except AttributeError:
+ # coder registery not yet initialized so it must be absent
+ return not is_frozen
+
+ if is_frozen:
+ return (
+ user_type in coders_registry._coders and
+ coders_registry._coders[user_type] == RowCoder)
+ else:
+ return (
+ user_type not in coders_registry._coders or
+ coders_registry._coders[user_type] == RowCoder)
def _match_is_optional(user_type):
diff --git
a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py
b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py
index e9ce732d2e9..01f40a29945 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py
@@ -20,11 +20,17 @@
# pytype: skip-file
import collections.abc
+import dataclasses
import enum
import re
import typing
import unittest
+from parameterized import param
+from parameterized import parameterized
+
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options_context import
scoped_pipeline_options
from apache_beam.typehints import typehints
from apache_beam.typehints.native_type_compatibility import
convert_builtin_to_typing
from apache_beam.typehints.native_type_compatibility import
convert_to_beam_type
@@ -33,6 +39,7 @@ from apache_beam.typehints.native_type_compatibility import
convert_to_python_ty
from apache_beam.typehints.native_type_compatibility import
convert_to_python_types
from apache_beam.typehints.native_type_compatibility import
convert_typing_to_builtin
from apache_beam.typehints.native_type_compatibility import is_any
+from apache_beam.typehints.native_type_compatibility import
match_dataclass_for_row
_TestNamedTuple = typing.NamedTuple(
'_TestNamedTuple', [('age', int), ('name', bytes)])
@@ -509,6 +516,58 @@ class NativeTypeCompatibilityTest(unittest.TestCase):
self.assertEqual(
typehints.Tuple[int, ...], convert_to_beam_type(AliasTuple))
+ def test_dataclass_default(self):
+ @dataclasses.dataclass(frozen=True)
+ class FrozenDC:
+ foo: int
+
+ @dataclasses.dataclass
+ class NonFrozenDC:
+ foo: int
+
+ self.assertFalse(match_dataclass_for_row(FrozenDC))
+ self.assertTrue(match_dataclass_for_row(NonFrozenDC))
+
+ def test_dataclass_registered(self):
+ @dataclasses.dataclass(frozen=True)
+ class FrozenRegisteredDC:
+ foo: int
+
+ @dataclasses.dataclass
+ class NonFrozenRegisteredDC:
+ foo: int
+
+ # pylint: disable=wrong-import-position
+ from apache_beam.coders import RowCoder
+ from apache_beam.coders import typecoders
+ from apache_beam.coders.coders import FastPrimitivesCoder
+
+ typecoders.registry.register_coder(FrozenRegisteredDC, RowCoder)
+ typecoders.registry.register_coder(
+ NonFrozenRegisteredDC, FastPrimitivesCoder)
+
+ self.assertTrue(match_dataclass_for_row(FrozenRegisteredDC))
+ self.assertFalse(match_dataclass_for_row(NonFrozenRegisteredDC))
+
+ @parameterized.expand([
+ param(compat_version="2.72.0"),
+ param(compat_version="2.73.0"),
+ ])
+ def test_dataclass_update_compatibility(self, compat_version):
+ @dataclasses.dataclass(frozen=True)
+ class FrozenDC:
+ foo: int
+
+ @dataclasses.dataclass
+ class NonFrozenDC:
+ foo: int
+
+ with scoped_pipeline_options(
+ PipelineOptions(update_compatibility_version=compat_version)):
+ self.assertFalse(match_dataclass_for_row(FrozenDC))
+ self.assertEqual(
+ compat_version == "2.73.0", match_dataclass_for_row(NonFrozenDC))
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/typehints/row_type.py
b/sdks/python/apache_beam/typehints/row_type.py
index 6f96f6f64e3..0697581cb43 100644
--- a/sdks/python/apache_beam/typehints/row_type.py
+++ b/sdks/python/apache_beam/typehints/row_type.py
@@ -27,7 +27,7 @@ from typing import Sequence
from typing import Tuple
from apache_beam.typehints import typehints
-from apache_beam.typehints.native_type_compatibility import match_is_dataclass
+from apache_beam.typehints.native_type_compatibility import
match_dataclass_for_row
from apache_beam.typehints.native_type_compatibility import
match_is_named_tuple
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
@@ -91,6 +91,9 @@ class RowTypeConstraint(typehints.TypeConstraint):
# Currently registration happens when converting to schema protos, in
# apache_beam.typehints.schemas
self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None)
+ if self._schema_id and _BEAM_SCHEMA_ID not in self._user_type.__dict__:
+ # schema id does not inherit. Unset if schema id is from base class
+ self._schema_id = None
self._schema_options = schema_options or []
self._field_options = field_options or {}
@@ -105,7 +108,7 @@ class RowTypeConstraint(typehints.TypeConstraint):
if match_is_named_tuple(user_type):
fields = [(name, user_type.__annotations__[name])
for name in user_type._fields]
- elif match_is_dataclass(user_type):
+ elif match_dataclass_for_row(user_type):
fields = [(field.name, field.type)
for field in dataclasses.fields(user_type)]
else:
diff --git a/sdks/python/apache_beam/typehints/row_type_test.py
b/sdks/python/apache_beam/typehints/row_type_test.py
index 73d76fee49c..97012d9561d 100644
--- a/sdks/python/apache_beam/typehints/row_type_test.py
+++ b/sdks/python/apache_beam/typehints/row_type_test.py
@@ -26,6 +26,7 @@ from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.typehints import row_type
+from apache_beam.typehints import schemas
class RowTypeTest(unittest.TestCase):
@@ -85,6 +86,94 @@ class RowTypeTest(unittest.TestCase):
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
assert_that(result, equal_to([10] * 100))
+ def test_group_by_key_namedtuple_union(self):
+ Tuple1 = typing.NamedTuple("Tuple1", [("id", int)])
+
+ Tuple2 = typing.NamedTuple("Tuple2", [("id", int), ("name", str)])
+
+ def generate(num: int):
+ for i in range(2):
+ yield (Tuple1(i), num)
+ yield (Tuple2(i, 'a'), num)
+
+ pipeline = TestPipeline(is_integration_test=False)
+
+ with pipeline as p:
+ result = (
+ p
+ | 'Create' >> beam.Create([i for i in range(2)])
+ | 'Generate' >> beam.ParDo(generate).with_output_types(
+ tuple[(Tuple1 | Tuple2), int])
+ | 'GBK' >> beam.GroupByKey()
+ | 'Count' >> beam.Map(lambda x: len(x[1])))
+ assert_that(result, equal_to([2] * 4))
+
+ # Union of dataclasses as type hint currently result in FastPrimitiveCoder
+ # fails at GBK
+ @unittest.skip("https://github.com/apache/beam/issues/22085")
+ def test_group_by_key_inherited_dataclass_union(self):
+ @dataclass
+ class DataClassInt:
+ id: int
+
+ @dataclass
+ class DataClassStr(DataClassInt):
+ name: str
+
+ beam.coders.typecoders.registry.register_coder(
+ DataClassInt, beam.coders.RowCoder)
+ beam.coders.typecoders.registry.register_coder(
+ DataClassStr, beam.coders.RowCoder)
+
+ def generate(num: int):
+ for i in range(10):
+ yield (DataClassInt(i), num)
+ yield (DataClassStr(i, 'a'), num)
+
+ pipeline = TestPipeline(is_integration_test=False)
+
+ with pipeline as p:
+ result = (
+ p
+ | 'Create' >> beam.Create([i for i in range(2)])
+ | 'Generate' >> beam.ParDo(generate).with_output_types(
+ tuple[(DataClassInt | DataClassStr), int])
+ | 'GBK' >> beam.GroupByKey()
+ | 'Count Elements' >> beam.Map(self._check_key_type_and_count))
+ assert_that(result, equal_to([2] * 4))
+
+ def test_derived_dataclass_schema_id(self):
+ @dataclass
+ class BaseDataClass:
+ id: int
+
+ @dataclass
+ class DerivedDataClass(BaseDataClass):
+ name: str
+
+ self.assertFalse(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
+ schema_for_base = schemas.schema_from_element_type(BaseDataClass)
+ self.assertTrue(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
+ self.assertEqual(
+ schema_for_base.id, getattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
+
+ # Getting the schema for BaseDataClass sets the _beam_schema_id
+ schemas.typing_to_runner_api(
+ BaseDataClass, schema_registry=schemas.SchemaTypeRegistry())
+
+ # We create a RowTypeConstraint from DerivedDataClass.
+ # It should not inherit the _beam_schema_id from BaseDataClass!
+ derived_row_type = row_type.RowTypeConstraint.from_user_type(
+ DerivedDataClass)
+ self.assertIsNone(derived_row_type._schema_id)
+
+ schema_for_derived = schemas.schema_from_element_type(DerivedDataClass)
+ self.assertTrue(hasattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
+ self.assertEqual(
+ schema_for_derived.id,
+ getattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
+ self.assertNotEqual(schema_for_derived.id, schema_for_base.id)
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/typehints/schemas.py
b/sdks/python/apache_beam/typehints/schemas.py
index 5dd8ff290c4..d2c4db8cabc 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -96,7 +96,7 @@ from apache_beam.typehints.native_type_compatibility import
_match_is_optional
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
from apache_beam.typehints.native_type_compatibility import
convert_to_python_type
from apache_beam.typehints.native_type_compatibility import
extract_optional_type
-from apache_beam.typehints.native_type_compatibility import match_is_dataclass
+from apache_beam.typehints.native_type_compatibility import
match_dataclass_for_row
from apache_beam.typehints.native_type_compatibility import
match_is_named_tuple
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
@@ -335,9 +335,11 @@ class SchemaTranslation(object):
atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[int])))
elif _safe_issubclass(type_, Sequence) and not _safe_issubclass(type_,
str):
- element_type = self.typing_to_runner_api(_get_args(type_)[0])
- return schema_pb2.FieldType(
- array_type=schema_pb2.ArrayType(element_type=element_type))
+ arg_types = _get_args(type_)
+ if len(arg_types) > 0:
+ element_type = self.typing_to_runner_api(arg_types[0])
+ return schema_pb2.FieldType(
+ array_type=schema_pb2.ArrayType(element_type=element_type))
elif _safe_issubclass(type_, Mapping):
key_type, value_type = map(self.typing_to_runner_api, _get_args(type_))
@@ -345,9 +347,11 @@ class SchemaTranslation(object):
map_type=schema_pb2.MapType(key_type=key_type,
value_type=value_type))
elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_,
str):
- element_type = self.typing_to_runner_api(_get_args(type_)[0])
- return schema_pb2.FieldType(
- array_type=schema_pb2.ArrayType(element_type=element_type))
+ arg_types = _get_args(type_)
+ if len(arg_types) > 0:
+ element_type = self.typing_to_runner_api(arg_types[0])
+ return schema_pb2.FieldType(
+ array_type=schema_pb2.ArrayType(element_type=element_type))
try:
if LogicalType.is_known_logical_type(type_):
@@ -630,8 +634,10 @@ def schema_from_element_type(element_type: type) ->
schema_pb2.Schema:
Returns schema as a list of (name, python_type) tuples"""
if isinstance(element_type, row_type.RowTypeConstraint):
return named_fields_to_schema(element_type._fields)
- elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
- if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
+ elif match_is_named_tuple(element_type) or match_dataclass_for_row(
+ element_type):
+ # schema id does not inherit from base classes
+ if row_type._BEAM_SCHEMA_ID in element_type.__dict__:
# if the named tuple's schema is in registry, we just use it instead of
# regenerating one.
schema_id = getattr(element_type, row_type._BEAM_SCHEMA_ID)
@@ -657,8 +663,15 @@ def union_schema_type(element_types):
element_types must be a set of schema-aware types whose fields have the
same naming and ordering.
"""
+ named_fields_and_types = []
+ for t in element_types:
+ n = named_fields_from_element_type(t)
+ if named_fields_and_types and len(named_fields_and_types[-1]) != len(n):
+ raise TypeError("element types has different number of fields")
+ named_fields_and_types.append(n)
+
union_fields_and_types = []
- for field in zip(*[named_fields_from_element_type(t) for t in
element_types]):
+ for field in zip(*named_fields_and_types):
names, types = zip(*field)
name_set = set(names)
if len(name_set) != 1: