This is an automated email from the ASF dual-hosted git repository.

yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 9f5904b2b59 Support inferring schemas from Python dataclasses (#37728)
9f5904b2b59 is described below

commit 9f5904b2b59c6b75d9c161ee7bc1828cbd08e8f4
Author: Yi Hu <[email protected]>
AuthorDate: Wed Mar 11 12:16:49 2026 -0400

    Support inferring schemas from Python dataclasses (#37728)
    
    * Support inferring schemas from Python dataclasses
    
    * Address comments; Revert native_type_compatibility _TypeMapEntry change
    
    * Add unit test for named tuple and dataclasses encoded by RowCoder and 
passing through GBK
    
    * Fix lint
---
 sdks/python/apache_beam/coders/coder_impl.py       |  6 +-
 .../typehints/native_type_compatibility.py         |  5 ++
 sdks/python/apache_beam/typehints/row_type.py      | 47 +++++------
 sdks/python/apache_beam/typehints/row_type_test.py | 90 ++++++++++++++++++++++
 sdks/python/apache_beam/typehints/schemas.py       |  3 +-
 sdks/python/apache_beam/typehints/schemas_test.py  | 61 +++++++++++++++
 6 files changed, 183 insertions(+), 29 deletions(-)

diff --git a/sdks/python/apache_beam/coders/coder_impl.py 
b/sdks/python/apache_beam/coders/coder_impl.py
index 3e0b5218b16..b3e45bc7f35 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -30,6 +30,7 @@ For internal use only; no backwards-compatibility guarantees.
 """
 # pytype: skip-file
 
+import dataclasses
 import decimal
 import enum
 import itertools
@@ -67,11 +68,6 @@ from apache_beam.utils.timestamp import MAX_TIMESTAMP
 from apache_beam.utils.timestamp import MIN_TIMESTAMP
 from apache_beam.utils.timestamp import Timestamp
 
-try:
-  import dataclasses
-except ImportError:
-  dataclasses = None  # type: ignore
-
 try:
   import dill
 except ImportError:
diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py 
b/sdks/python/apache_beam/typehints/native_type_compatibility.py
index 345c04706d6..886b1505ffe 100644
--- a/sdks/python/apache_beam/typehints/native_type_compatibility.py
+++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py
@@ -21,6 +21,7 @@
 
 import collections
 import collections.abc
+import dataclasses
 import logging
 import sys
 import types
@@ -175,6 +176,10 @@ def match_is_named_tuple(user_type):
       hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))
 
 
+def match_is_dataclass(user_type):
+  return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)
+
+
 def _match_is_optional(user_type):
   return _match_is_union(user_type) and sum(
       tp is type(None) for tp in _get_args(user_type)) == 1
diff --git a/sdks/python/apache_beam/typehints/row_type.py 
b/sdks/python/apache_beam/typehints/row_type.py
index 08838c84a05..6f96f6f64e3 100644
--- a/sdks/python/apache_beam/typehints/row_type.py
+++ b/sdks/python/apache_beam/typehints/row_type.py
@@ -19,6 +19,7 @@
 
 from __future__ import annotations
 
+import dataclasses
 from typing import Any
 from typing import Dict
 from typing import Optional
@@ -26,6 +27,7 @@ from typing import Sequence
 from typing import Tuple
 
 from apache_beam.typehints import typehints
+from apache_beam.typehints.native_type_compatibility import match_is_dataclass
 from apache_beam.typehints.native_type_compatibility import 
match_is_named_tuple
 from apache_beam.typehints.schema_registry import SchemaTypeRegistry
 
@@ -56,18 +58,14 @@ class RowTypeConstraint(typehints.TypeConstraint):
     for guidance on creating PCollections with inferred schemas.
 
     Note RowTypeConstraint does not currently store arbitrary functions for
-    converting to/from the user type. Instead, we only support ``NamedTuple``
-    user types and make the follow assumptions:
+    converting to/from the user type. Instead, we support ``NamedTuple`` and
+    ``dataclasses`` user types and make the follow assumptions:
 
     - The user type can be constructed with field values as arguments in order
       (i.e. ``constructor(*field_values)``).
     - Field values can be accessed from instances of the user type by attribute
       (i.e. with ``getattr(obj, field_name)``).
 
-    In the future we will add support for dataclasses
-    ([#22085](https://github.com/apache/beam/issues/22085)) which also satisfy
-    these assumptions.
-
     The RowTypeConstraint constructor should not be called directly (even
     internally to Beam). Prefer static methods ``from_user_type`` or
     ``from_fields``.
@@ -107,27 +105,30 @@ class RowTypeConstraint(typehints.TypeConstraint):
     if match_is_named_tuple(user_type):
       fields = [(name, user_type.__annotations__[name])
                 for name in user_type._fields]
-
-      field_descriptions = getattr(user_type, '_field_descriptions', None)
-
-      if _user_type_is_generated(user_type):
-        return RowTypeConstraint.from_fields(
-            fields,
-            schema_id=getattr(user_type, _BEAM_SCHEMA_ID),
-            schema_options=schema_options,
-            field_options=field_options,
-            field_descriptions=field_descriptions)
-
-      # TODO(https://github.com/apache/beam/issues/22125): Add user API for
-      # specifying schema/field options
-      return RowTypeConstraint(
-          fields=fields,
-          user_type=user_type,
+    elif match_is_dataclass(user_type):
+      fields = [(field.name, field.type)
+                for field in dataclasses.fields(user_type)]
+    else:
+      return None
+
+    field_descriptions = getattr(user_type, '_field_descriptions', None)
+
+    if _user_type_is_generated(user_type):
+      return RowTypeConstraint.from_fields(
+          fields,
+          schema_id=getattr(user_type, _BEAM_SCHEMA_ID),
           schema_options=schema_options,
           field_options=field_options,
           field_descriptions=field_descriptions)
 
-    return None
+    # TODO(https://github.com/apache/beam/issues/22125): Add user API for
+    # specifying schema/field options
+    return RowTypeConstraint(
+        fields=fields,
+        user_type=user_type,
+        schema_options=schema_options,
+        field_options=field_options,
+        field_descriptions=field_descriptions)
 
   @staticmethod
   def from_fields(
diff --git a/sdks/python/apache_beam/typehints/row_type_test.py 
b/sdks/python/apache_beam/typehints/row_type_test.py
new file mode 100644
index 00000000000..73d76fee49c
--- /dev/null
+++ b/sdks/python/apache_beam/typehints/row_type_test.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for the Beam Row typing functionality."""
+
+import typing
+import unittest
+from dataclasses import dataclass
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.typehints import row_type
+
+
+class RowTypeTest(unittest.TestCase):
+  @staticmethod
+  def _check_key_type_and_count(x) -> int:
+    key_type = type(x[0])
+    if not row_type._user_type_is_generated(key_type):
+      raise RuntimeError("Expect type after GBK to be generated user type")
+
+    return len(x[1])
+
+  def test_group_by_key_namedtuple(self):
+    MyNamedTuple = typing.NamedTuple(
+        "MyNamedTuple", [("id", int), ("name", str)])
+
+    beam.coders.typecoders.registry.register_coder(
+        MyNamedTuple, beam.coders.RowCoder)
+
+    def generate(num: int):
+      for i in range(100):
+        yield (MyNamedTuple(i, 'a'), num)
+
+    pipeline = TestPipeline(is_integration_test=False)
+
+    with pipeline as p:
+      result = (
+          p
+          | 'Create' >> beam.Create([i for i in range(10)])
+          | 'Generate' >> beam.ParDo(generate).with_output_types(
+              tuple[MyNamedTuple, int])
+          | 'GBK' >> beam.GroupByKey()
+          | 'Count Elements' >> beam.Map(self._check_key_type_and_count))
+      assert_that(result, equal_to([10] * 100))
+
+  def test_group_by_key_dataclass(self):
+    @dataclass
+    class MyDataClass:
+      id: int
+      name: str
+
+    beam.coders.typecoders.registry.register_coder(
+        MyDataClass, beam.coders.RowCoder)
+
+    def generate(num: int):
+      for i in range(100):
+        yield (MyDataClass(i, 'a'), num)
+
+    pipeline = TestPipeline(is_integration_test=False)
+
+    with pipeline as p:
+      result = (
+          p
+          | 'Create' >> beam.Create([i for i in range(10)])
+          | 'Generate' >> beam.ParDo(generate).with_output_types(
+              tuple[MyDataClass, int])
+          | 'GBK' >> beam.GroupByKey()
+          | 'Count Elements' >> beam.Map(self._check_key_type_and_count))
+      assert_that(result, equal_to([10] * 100))
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/typehints/schemas.py 
b/sdks/python/apache_beam/typehints/schemas.py
index e9674fa5bc2..5dd8ff290c4 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -96,6 +96,7 @@ from apache_beam.typehints.native_type_compatibility import 
_match_is_optional
 from apache_beam.typehints.native_type_compatibility import _safe_issubclass
 from apache_beam.typehints.native_type_compatibility import 
convert_to_python_type
 from apache_beam.typehints.native_type_compatibility import 
extract_optional_type
+from apache_beam.typehints.native_type_compatibility import match_is_dataclass
 from apache_beam.typehints.native_type_compatibility import 
match_is_named_tuple
 from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
 from apache_beam.typehints.schema_registry import SchemaTypeRegistry
@@ -629,7 +630,7 @@ def schema_from_element_type(element_type: type) -> 
schema_pb2.Schema:
   Returns schema as a list of (name, python_type) tuples"""
   if isinstance(element_type, row_type.RowTypeConstraint):
     return named_fields_to_schema(element_type._fields)
-  elif match_is_named_tuple(element_type):
+  elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
     if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
       # if the named tuple's schema is in registry, we just use it instead of
       # regenerating one.
diff --git a/sdks/python/apache_beam/typehints/schemas_test.py 
b/sdks/python/apache_beam/typehints/schemas_test.py
index 73db06b9a8d..5a5d7396ab3 100644
--- a/sdks/python/apache_beam/typehints/schemas_test.py
+++ b/sdks/python/apache_beam/typehints/schemas_test.py
@@ -19,6 +19,7 @@
 
 # pytype: skip-file
 
+import dataclasses
 import itertools
 import pickle
 import unittest
@@ -388,6 +389,24 @@ class SchemaTest(unittest.TestCase):
     self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
     self.assert_namedtuple_equivalent(roundtripped.user_type, user_type)
 
+  def test_dataclass_roundtrip(self):
+    @dataclasses.dataclass
+    class SimpleDataclass:
+      id: np.int64
+      name: str
+
+    roundtripped = typing_from_runner_api(
+        typing_to_runner_api(
+            SimpleDataclass, schema_registry=SchemaTypeRegistry()),
+        schema_registry=SchemaTypeRegistry())
+
+    self.assertIsInstance(roundtripped, row_type.RowTypeConstraint)
+    # The roundtripped user_type is generated as a NamedTuple, so we can't test
+    # equivalence directly with the dataclass.
+    # Instead, let's verify annotations.
+    self.assertEqual(
+        roundtripped.user_type.__annotations__, 
SimpleDataclass.__annotations__)
+
   def test_row_type_constraint_to_schema(self):
     result_type = typing_to_runner_api(
         row_type.RowTypeConstraint.from_fields([
@@ -646,6 +665,48 @@ class SchemaTest(unittest.TestCase):
         expected.row_type.schema.fields,
         typing_to_runner_api(MyCuteClass).row_type.schema.fields)
 
+  def test_trivial_example_dataclass(self):
+    @dataclasses.dataclass
+    class MyCuteDataclass:
+      name: str
+      age: Optional[int]
+      interests: List[str]
+      height: float
+      blob: ByteString
+
+    expected = schema_pb2.FieldType(
+        row_type=schema_pb2.RowType(
+            schema=schema_pb2.Schema(
+                fields=[
+                    schema_pb2.Field(
+                        name='name',
+                        type=schema_pb2.FieldType(
+                            atomic_type=schema_pb2.STRING),
+                    ),
+                    schema_pb2.Field(
+                        name='age',
+                        type=schema_pb2.FieldType(
+                            nullable=True, atomic_type=schema_pb2.INT64)),
+                    schema_pb2.Field(
+                        name='interests',
+                        type=schema_pb2.FieldType(
+                            array_type=schema_pb2.ArrayType(
+                                element_type=schema_pb2.FieldType(
+                                    atomic_type=schema_pb2.STRING)))),
+                    schema_pb2.Field(
+                        name='height',
+                        type=schema_pb2.FieldType(
+                            atomic_type=schema_pb2.DOUBLE)),
+                    schema_pb2.Field(
+                        name='blob',
+                        type=schema_pb2.FieldType(
+                            atomic_type=schema_pb2.BYTES)),
+                ])))
+
+    self.assertEqual(
+        expected.row_type.schema.fields,
+        typing_to_runner_api(MyCuteDataclass).row_type.schema.fields)
+
   def test_user_type_annotated_with_id_after_conversion(self):
     MyCuteClass = NamedTuple('MyCuteClass', [
         ('name', str),

Reply via email to