This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 87de10f36cc Add support for kw_only dataclasses (addresses #36978)
(#36979)
87de10f36cc is described below
commit 87de10f36ccbbbfc186f49700d6e641d15af9e35
Author: assaf127 <[email protected]>
AuthorDate: Mon Dec 22 17:07:41 2025 +0200
Add support for kw_only dataclasses (addresses #36978) (#36979)
* add support for kw_only dataclasses (#36978)
* use a different type for kw_ony dataclasses
* allow passing positional parameters to create dataclasses when possible
* remove wrong TODO
* add function typehint for pylint
* minor refactoring
---
sdks/python/apache_beam/coders/coder_impl.py | 41 ++++++++++++++++++----
.../apache_beam/coders/coders_test_common.py | 10 ++++++
2 files changed, 45 insertions(+), 6 deletions(-)
diff --git a/sdks/python/apache_beam/coders/coder_impl.py
b/sdks/python/apache_beam/coders/coder_impl.py
index 03514bb50db..1e3bb2ece92 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -32,6 +32,7 @@ For internal use only; no backwards-compatibility guarantees.
import decimal
import enum
+import functools
import itertools
import json
import logging
@@ -354,6 +355,7 @@ DATACLASS_TYPE = 101
NAMED_TUPLE_TYPE = 102
ENUM_TYPE = 103
NESTED_STATE_TYPE = 104
+DATACLASS_KW_ONLY_TYPE = 105
# Types that can be encoded as iterables, but are not literally
# lists, etc. due to being lazy. The actual type is not preserved
@@ -374,6 +376,18 @@ def _verify_dill_compat():
raise RuntimeError(base_error + f". Found dill version
'{dill.__version__}")
+dataclass_uses_kw_only: Callable[[Any], bool]
+if dataclasses:
+ # Cache the result to avoid multiple checks for the same dataclass type.
+ @functools.cache
+ def dataclass_uses_kw_only(cls) -> bool:
+ return any(
+ field.init and field.kw_only for field in dataclasses.fields(cls))
+
+else:
+ dataclass_uses_kw_only = lambda cls: False
+
+
class FastPrimitivesCoderImpl(StreamCoderImpl):
"""For internal use only; no backwards-compatibility guarantees."""
def __init__(
@@ -497,18 +511,25 @@ class FastPrimitivesCoderImpl(StreamCoderImpl):
self.encode_type(type(value), stream)
stream.write(value.SerializePartialToString(deterministic=True), True)
elif dataclasses and dataclasses.is_dataclass(value):
- stream.write_byte(DATACLASS_TYPE)
if not type(value).__dataclass_params__.frozen:
raise TypeError(
"Unable to deterministically encode non-frozen '%s' of type '%s' "
"for the input of '%s'" %
(value, type(value), self.requires_deterministic_step_label))
- self.encode_type(type(value), stream)
- values = [
- getattr(value, field.name) for field in dataclasses.fields(value)
- ]
+ init_fields = [field for field in dataclasses.fields(value) if
field.init]
try:
- self.iterable_coder_impl.encode_to_stream(values, stream, True)
+ if dataclass_uses_kw_only(type(value)):
+ stream.write_byte(DATACLASS_KW_ONLY_TYPE)
+ self.encode_type(type(value), stream)
+ stream.write_var_int64(len(init_fields))
+ for field in init_fields:
+ stream.write(field.name.encode("utf-8"), True)
+ self.encode_to_stream(getattr(value, field.name), stream, True)
+ else: # Not using kw_only, we can pass parameters by position.
+ stream.write_byte(DATACLASS_TYPE)
+ self.encode_type(type(value), stream)
+ values = [getattr(value, field.name) for field in init_fields]
+ self.iterable_coder_impl.encode_to_stream(values, stream, True)
except Exception as e:
raise TypeError(self._deterministic_encoding_error_msg(value)) from e
elif isinstance(value, tuple) and hasattr(type(value), '_fields'):
@@ -616,6 +637,14 @@ class FastPrimitivesCoderImpl(StreamCoderImpl):
msg = cls()
msg.ParseFromString(stream.read_all(True))
return msg
+ elif t == DATACLASS_KW_ONLY_TYPE:
+ cls = self.decode_type(stream)
+ vlen = stream.read_var_int64()
+ fields = {}
+ for _ in range(vlen):
+ field_name = stream.read_all(True).decode('utf-8')
+ fields[field_name] = self.decode_from_stream(stream, True)
+ return cls(**fields)
elif t == DATACLASS_TYPE or t == NAMED_TUPLE_TYPE:
cls = self.decode_type(stream)
return cls(*self.iterable_coder_impl.decode_from_stream(stream, True))
diff --git a/sdks/python/apache_beam/coders/coders_test_common.py
b/sdks/python/apache_beam/coders/coders_test_common.py
index 8a57d1e63e2..8f89ab9602c 100644
--- a/sdks/python/apache_beam/coders/coders_test_common.py
+++ b/sdks/python/apache_beam/coders/coders_test_common.py
@@ -113,6 +113,11 @@ if dataclasses is not None:
a: Any
b: int
+ @dataclasses.dataclass(frozen=True, kw_only=True)
+ class FrozenKwOnlyDataClass:
+ c: int
+ d: int
+
@dataclasses.dataclass
class UnFrozenDataClass:
x: int
@@ -303,9 +308,11 @@ class CodersTest(unittest.TestCase):
if dataclasses is not None:
self.check_coder(deterministic_coder, FrozenDataClass(1, 2))
+ self.check_coder(deterministic_coder, FrozenKwOnlyDataClass(c=1, d=2))
with self.assertRaises(TypeError):
self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2))
+
with self.assertRaises(TypeError):
self.check_coder(
deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3))
@@ -742,6 +749,7 @@ class CodersTest(unittest.TestCase):
from apache_beam.coders.coders_test_common import DefinesGetState
from apache_beam.coders.coders_test_common import DefinesGetAndSetState
from apache_beam.coders.coders_test_common import FrozenDataClass
+ from apache_beam.coders.coders_test_common import FrozenKwOnlyDataClass
from apache_beam.coders import proto2_coder_test_messages_pb2 as
test_message
@@ -777,6 +785,8 @@ class CodersTest(unittest.TestCase):
test_cases.extend([
("frozen_dataclass", FrozenDataClass(1, 2)),
("frozen_dataclass_list", [FrozenDataClass(1, 2),
FrozenDataClass(3, 4)]),
+ ("frozen_kwonly_dataclass", FrozenKwOnlyDataClass(c=1, d=2)),
+ ("frozen_kwonly_dataclass_list", [FrozenKwOnlyDataClass(c=1, d=2),
FrozenKwOnlyDataClass(c=3, d=4)]),
])
compat_version = {'"'+ compat_version +'"' if compat_version else None}