This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 87de10f36cc Add support for kw_only dataclasses (addresses #36978) 
(#36979)
87de10f36cc is described below

commit 87de10f36ccbbbfc186f49700d6e641d15af9e35
Author: assaf127 <[email protected]>
AuthorDate: Mon Dec 22 17:07:41 2025 +0200

    Add support for kw_only dataclasses (addresses #36978) (#36979)
    
    * add support for kw_only dataclasses (#36978)
    
    * use a different type for kw_ony dataclasses
    
    * allow passing positional parameters to create dataclasses when possible
    
    * remove wrong TODO
    
    * add function typehint for pylint
    
    * minor refactoring
---
 sdks/python/apache_beam/coders/coder_impl.py       | 41 ++++++++++++++++++----
 .../apache_beam/coders/coders_test_common.py       | 10 ++++++
 2 files changed, 45 insertions(+), 6 deletions(-)

diff --git a/sdks/python/apache_beam/coders/coder_impl.py 
b/sdks/python/apache_beam/coders/coder_impl.py
index 03514bb50db..1e3bb2ece92 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -32,6 +32,7 @@ For internal use only; no backwards-compatibility guarantees.
 
 import decimal
 import enum
+import functools
 import itertools
 import json
 import logging
@@ -354,6 +355,7 @@ DATACLASS_TYPE = 101
 NAMED_TUPLE_TYPE = 102
 ENUM_TYPE = 103
 NESTED_STATE_TYPE = 104
+DATACLASS_KW_ONLY_TYPE = 105
 
 # Types that can be encoded as iterables, but are not literally
 # lists, etc. due to being lazy.  The actual type is not preserved
@@ -374,6 +376,18 @@ def _verify_dill_compat():
     raise RuntimeError(base_error + f". Found dill version 
'{dill.__version__}")
 
 
+dataclass_uses_kw_only: Callable[[Any], bool]
+if dataclasses:
+  # Cache the result to avoid multiple checks for the same dataclass type.
+  @functools.cache
+  def dataclass_uses_kw_only(cls) -> bool:
+    return any(
+        field.init and field.kw_only for field in dataclasses.fields(cls))
+
+else:
+  dataclass_uses_kw_only = lambda cls: False
+
+
 class FastPrimitivesCoderImpl(StreamCoderImpl):
   """For internal use only; no backwards-compatibility guarantees."""
   def __init__(
@@ -497,18 +511,25 @@ class FastPrimitivesCoderImpl(StreamCoderImpl):
       self.encode_type(type(value), stream)
       stream.write(value.SerializePartialToString(deterministic=True), True)
     elif dataclasses and dataclasses.is_dataclass(value):
-      stream.write_byte(DATACLASS_TYPE)
       if not type(value).__dataclass_params__.frozen:
         raise TypeError(
             "Unable to deterministically encode non-frozen '%s' of type '%s' "
             "for the input of '%s'" %
             (value, type(value), self.requires_deterministic_step_label))
-      self.encode_type(type(value), stream)
-      values = [
-          getattr(value, field.name) for field in dataclasses.fields(value)
-      ]
+      init_fields = [field for field in dataclasses.fields(value) if 
field.init]
       try:
-        self.iterable_coder_impl.encode_to_stream(values, stream, True)
+        if dataclass_uses_kw_only(type(value)):
+          stream.write_byte(DATACLASS_KW_ONLY_TYPE)
+          self.encode_type(type(value), stream)
+          stream.write_var_int64(len(init_fields))
+          for field in init_fields:
+            stream.write(field.name.encode("utf-8"), True)
+            self.encode_to_stream(getattr(value, field.name), stream, True)
+        else:  # Not using kw_only, we can pass parameters by position.
+          stream.write_byte(DATACLASS_TYPE)
+          self.encode_type(type(value), stream)
+          values = [getattr(value, field.name) for field in init_fields]
+          self.iterable_coder_impl.encode_to_stream(values, stream, True)
       except Exception as e:
         raise TypeError(self._deterministic_encoding_error_msg(value)) from e
     elif isinstance(value, tuple) and hasattr(type(value), '_fields'):
@@ -616,6 +637,14 @@ class FastPrimitivesCoderImpl(StreamCoderImpl):
       msg = cls()
       msg.ParseFromString(stream.read_all(True))
       return msg
+    elif t == DATACLASS_KW_ONLY_TYPE:
+      cls = self.decode_type(stream)
+      vlen = stream.read_var_int64()
+      fields = {}
+      for _ in range(vlen):
+        field_name = stream.read_all(True).decode('utf-8')
+        fields[field_name] = self.decode_from_stream(stream, True)
+      return cls(**fields)
     elif t == DATACLASS_TYPE or t == NAMED_TUPLE_TYPE:
       cls = self.decode_type(stream)
       return cls(*self.iterable_coder_impl.decode_from_stream(stream, True))
diff --git a/sdks/python/apache_beam/coders/coders_test_common.py 
b/sdks/python/apache_beam/coders/coders_test_common.py
index 8a57d1e63e2..8f89ab9602c 100644
--- a/sdks/python/apache_beam/coders/coders_test_common.py
+++ b/sdks/python/apache_beam/coders/coders_test_common.py
@@ -113,6 +113,11 @@ if dataclasses is not None:
     a: Any
     b: int
 
+  @dataclasses.dataclass(frozen=True, kw_only=True)
+  class FrozenKwOnlyDataClass:
+    c: int
+    d: int
+
   @dataclasses.dataclass
   class UnFrozenDataClass:
     x: int
@@ -303,9 +308,11 @@ class CodersTest(unittest.TestCase):
 
     if dataclasses is not None:
       self.check_coder(deterministic_coder, FrozenDataClass(1, 2))
+      self.check_coder(deterministic_coder, FrozenKwOnlyDataClass(c=1, d=2))
 
       with self.assertRaises(TypeError):
         self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2))
+
       with self.assertRaises(TypeError):
         self.check_coder(
             deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3))
@@ -742,6 +749,7 @@ class CodersTest(unittest.TestCase):
         from apache_beam.coders.coders_test_common import DefinesGetState
         from apache_beam.coders.coders_test_common import DefinesGetAndSetState
         from apache_beam.coders.coders_test_common import FrozenDataClass
+        from apache_beam.coders.coders_test_common import FrozenKwOnlyDataClass
 
 
         from apache_beam.coders import proto2_coder_test_messages_pb2 as 
test_message
@@ -777,6 +785,8 @@ class CodersTest(unittest.TestCase):
         test_cases.extend([
             ("frozen_dataclass", FrozenDataClass(1, 2)),
             ("frozen_dataclass_list", [FrozenDataClass(1, 2), 
FrozenDataClass(3, 4)]),
+            ("frozen_kwonly_dataclass", FrozenKwOnlyDataClass(c=1, d=2)),
+            ("frozen_kwonly_dataclass_list", [FrozenKwOnlyDataClass(c=1, d=2), 
FrozenKwOnlyDataClass(c=3, d=4)]),
         ])
 
         compat_version = {'"'+ compat_version +'"' if compat_version else None}

Reply via email to