This is an automated email from the ASF dual-hosted git repository.
tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new e8fab2669a5 Pass update compat through as_deterministic_coder and use
cloudpickle for deterministic special types. (#35725)
e8fab2669a5 is described below
commit e8fab2669a53b1f4ee18b202983bb4a9f8a32fec
Author: claudevdm <[email protected]>
AuthorDate: Sat Aug 30 03:51:26 2025 -0400
Pass update compat through as_deterministic_coder and use cloudpickle for
deterministic special types. (#35725)
* Pass update compat through as_deterministic_coder.
* Coder changes.
* Pass update compat through pipeline options.
* Add tests.
* asd
* Fix test.
* Trigger tests.
* Undo disable tests.
* Fix test.
* Fix tests and lint.
* Refactor.
* Rebase and comments.
* Lint fix.
* Update the base message to include possible lower values for the flag.
* Update sdks/python/apache_beam/coders/coders.py
---------
Co-authored-by: Claude <[email protected]>
Co-authored-by: tvalentyn <[email protected]>
---
...PostCommit_Python_ValidatesRunner_Dataflow.json | 3 +-
sdks/python/apache_beam/coders/coder_impl.pxd | 2 +
sdks/python/apache_beam/coders/coder_impl.py | 61 ++++++++-
sdks/python/apache_beam/coders/coders.py | 72 +++++++++-
.../apache_beam/coders/coders_test_common.py | 146 ++++++++++++---------
sdks/python/apache_beam/coders/typecoders.py | 1 +
.../apache_beam/internal/cloudpickle_pickler.py | 2 +
sdks/python/apache_beam/pipeline.py | 5 +
.../apache_beam/transforms/ptransform_test.py | 67 +++++++++-
sdks/python/apache_beam/transforms/util.py | 20 +--
sdks/python/apache_beam/transforms/util_test.py | 62 +++++++++
11 files changed, 363 insertions(+), 78 deletions(-)
diff --git
a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json
b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json
index 19ebbfb9ad9..e3d6056a5de 100644
--- a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json
+++ b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json
@@ -1,3 +1,4 @@
{
- "https://github.com/apache/beam/pull/35951": "triggering sideinput test"
+ "comment": "Modify this file in a trivial way to cause this test suite to
run",
+ "modification": 1
}
diff --git a/sdks/python/apache_beam/coders/coder_impl.pxd
b/sdks/python/apache_beam/coders/coder_impl.pxd
index 27cffe7b62d..2db7b963d15 100644
--- a/sdks/python/apache_beam/coders/coder_impl.pxd
+++ b/sdks/python/apache_beam/coders/coder_impl.pxd
@@ -81,6 +81,7 @@ cdef class FastPrimitivesCoderImpl(StreamCoderImpl):
cdef CoderImpl iterable_coder_impl
cdef object requires_deterministic_step_label
cdef bint warn_deterministic_fallback
+ cdef bint force_use_dill
@cython.locals(dict_value=dict, int_value=libc.stdint.int64_t,
unicode_value=unicode)
@@ -88,6 +89,7 @@ cdef class FastPrimitivesCoderImpl(StreamCoderImpl):
@cython.locals(t=int)
cpdef decode_from_stream(self, InputStream stream, bint nested)
cdef encode_special_deterministic(self, value, OutputStream stream)
+ cdef encode_type_2_67_0(self, t, OutputStream stream)
cdef encode_type(self, t, OutputStream stream)
cdef decode_type(self, InputStream stream)
diff --git a/sdks/python/apache_beam/coders/coder_impl.py
b/sdks/python/apache_beam/coders/coder_impl.py
index 807d083d8a3..4f28fb3c916 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -50,7 +50,6 @@ from typing import Set
from typing import Tuple
from typing import Type
-import dill
import numpy as np
from fastavro import parse_schema
from fastavro import schemaless_reader
@@ -58,6 +57,7 @@ from fastavro import schemaless_writer
from apache_beam.coders import observable
from apache_beam.coders.avro_record import AvroRecord
+from apache_beam.internal import cloudpickle_pickler
from apache_beam.typehints.schemas import named_tuple_from_schema
from apache_beam.utils import proto_utils
from apache_beam.utils import windowed_value
@@ -71,6 +71,11 @@ try:
except ImportError:
dataclasses = None # type: ignore
+try:
+ import dill
+except ImportError:
+ dill = None
+
if TYPE_CHECKING:
import proto
from apache_beam.transforms import userstate
@@ -354,14 +359,30 @@ NESTED_STATE_TYPE = 104
_ITERABLE_LIKE_TYPES = set() # type: Set[Type]
+def _verify_dill_compat():
+ base_error = (
+ "This pipeline runs with the pipeline option "
+ "--update_compatibility_version=2.67.0 or earlier. "
+ "When running with this option on SDKs 2.68.0 or "
+ "later, you must ensure dill==0.3.1.1 is installed.")
+ if not dill:
+ raise RuntimeError(base_error + ". Dill is not installed.")
+ if dill.__version__ != "0.3.1.1":
+ raise RuntimeError(base_error + f". Found dill version
'{dill.__version__}")
+
+
class FastPrimitivesCoderImpl(StreamCoderImpl):
"""For internal use only; no backwards-compatibility guarantees."""
def __init__(
- self, fallback_coder_impl, requires_deterministic_step_label=None):
+ self,
+ fallback_coder_impl,
+ requires_deterministic_step_label=None,
+ force_use_dill=False):
self.fallback_coder_impl = fallback_coder_impl
self.iterable_coder_impl = IterableCoderImpl(self)
self.requires_deterministic_step_label = requires_deterministic_step_label
self.warn_deterministic_fallback = True
+ self.force_use_dill = force_use_dill
@staticmethod
def register_iterable_like_type(t):
@@ -525,10 +546,23 @@ class FastPrimitivesCoderImpl(StreamCoderImpl):
"please provide a type hint for the input of '%s'" %
(value, type(value), self.requires_deterministic_step_label))
- def encode_type(self, t, stream):
+ def encode_type_2_67_0(self, t, stream):
+ """
+ Encode special type with <=2.67.0 compatibility.
+ """
+ _verify_dill_compat()
stream.write(dill.dumps(t), True)
+ def encode_type(self, t, stream):
+ if self.force_use_dill:
+ return self.encode_type_2_67_0(t, stream)
+ bs = cloudpickle_pickler.dumps(
+ t, config=cloudpickle_pickler.NO_DYNAMIC_CLASS_TRACKING_CONFIG)
+ stream.write(bs, True)
+
def decode_type(self, stream):
+ if self.force_use_dill:
+ return _unpickle_type_2_67_0(stream.read_all(True))
return _unpickle_type(stream.read_all(True))
def decode_from_stream(self, stream, nested):
@@ -589,19 +623,35 @@ class FastPrimitivesCoderImpl(StreamCoderImpl):
_unpickled_types = {} # type: Dict[bytes, type]
-def _unpickle_type(bs):
+def _unpickle_type_2_67_0(bs):
+ """
+ Decode special type with <=2.67.0 compatibility.
+ """
t = _unpickled_types.get(bs, None)
if t is None:
+ _verify_dill_compat()
t = _unpickled_types[bs] = dill.loads(bs)
# Fix unpicklable anonymous named tuples for Python 3.6.
if t.__base__ is tuple and hasattr(t, '_fields'):
try:
pickle.loads(pickle.dumps(t))
except pickle.PicklingError:
- t.__reduce__ = lambda self: (_unpickle_named_tuple, (bs, tuple(self)))
+ t.__reduce__ = lambda self: (
+ _unpickle_named_tuple_2_67_0, (bs, tuple(self)))
return t
+def _unpickle_named_tuple_2_67_0(bs, items):
+ return _unpickle_type_2_67_0(bs)(*items)
+
+
+def _unpickle_type(bs):
+ if not _unpickled_types.get(bs, None):
+ _unpickled_types[bs] = cloudpickle_pickler.loads(bs)
+
+ return _unpickled_types[bs]
+
+
def _unpickle_named_tuple(bs, items):
return _unpickle_type(bs)(*items)
@@ -837,6 +887,7 @@ class IntervalWindowCoderImpl(StreamCoderImpl):
if IntervalWindow is None:
from apache_beam.transforms.window import IntervalWindow
# instantiating with None is not part of the public interface
+ # pylint: disable=too-many-function-args
typed_value = IntervalWindow(None, None) # type: ignore[arg-type]
typed_value._end_micros = (
1000 * self._to_normal_time(in_.read_bigendian_uint64()))
diff --git a/sdks/python/apache_beam/coders/coders.py
b/sdks/python/apache_beam/coders/coders.py
index 2691857bf0a..e527185bd57 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -911,6 +911,44 @@ class CloudpickleCoder(_PickleCoderBase):
cloudpickle_pickler.dumps, cloudpickle_pickler.loads)
+class DeterministicFastPrimitivesCoderV2(FastCoder):
+ """Throws runtime errors when encoding non-deterministic values."""
+ def __init__(self, coder, step_label):
+ self._underlying_coder = coder
+ self._step_label = step_label
+
+ def _create_impl(self):
+
+ return coder_impl.FastPrimitivesCoderImpl(
+ self._underlying_coder.get_impl(),
+ requires_deterministic_step_label=self._step_label,
+ force_use_dill=False)
+
+ def is_deterministic(self):
+ # type: () -> bool
+ return True
+
+ def is_kv_coder(self):
+ # type: () -> bool
+ return True
+
+ def key_coder(self):
+ return self
+
+ def value_coder(self):
+ return self
+
+ def to_type_hint(self):
+ return Any
+
+ def to_runner_api_parameter(self, context):
+ # type: (Optional[PipelineContext]) -> Tuple[str, Any, Sequence[Coder]]
+ return (
+ python_urns.PICKLED_CODER,
+ google.protobuf.wrappers_pb2.BytesValue(value=serialize_coder(self)),
+ ())
+
+
class DeterministicFastPrimitivesCoder(FastCoder):
"""Throws runtime errors when encoding non-deterministic values."""
def __init__(self, coder, step_label):
@@ -920,7 +958,8 @@ class DeterministicFastPrimitivesCoder(FastCoder):
def _create_impl(self):
return coder_impl.FastPrimitivesCoderImpl(
self._underlying_coder.get_impl(),
- requires_deterministic_step_label=self._step_label)
+ requires_deterministic_step_label=self._step_label,
+ force_use_dill=True)
def is_deterministic(self):
# type: () -> bool
@@ -940,6 +979,34 @@ class DeterministicFastPrimitivesCoder(FastCoder):
return Any
+def _should_force_use_dill():
+ from apache_beam.coders import typecoders
+ from apache_beam.transforms.util import is_v1_prior_to_v2
+ update_compat_version = typecoders.registry.update_compatibility_version
+
+ if not update_compat_version:
+ return False
+
+ if not is_v1_prior_to_v2(v1=update_compat_version, v2="2.68.0"):
+ return False
+
+ try:
+ import dill
+ assert dill.__version__ == "0.3.1.1"
+ except Exception as e:
+ raise RuntimeError("This pipeline runs with the pipeline option " \
+ "--update_compatibility_version=2.67.0 or earlier. When running with " \
+ "this option on SDKs 2.68.0 or later, you must ensure dill==0.3.1.1 " \
+ f"is installed. Error {e}")
+ return True
+
+
+def _update_compatible_deterministic_fast_primitives_coder(coder, step_label):
+ if _should_force_use_dill():
+ return DeterministicFastPrimitivesCoder(coder, step_label)
+ return DeterministicFastPrimitivesCoderV2(coder, step_label)
+
+
class FastPrimitivesCoder(FastCoder):
"""Encodes simple primitives (e.g. str, int) efficiently.
@@ -960,7 +1027,8 @@ class FastPrimitivesCoder(FastCoder):
if self.is_deterministic():
return self
else:
- return DeterministicFastPrimitivesCoder(self, step_label)
+ return _update_compatible_deterministic_fast_primitives_coder(
+ self, step_label)
def to_type_hint(self):
return Any
diff --git a/sdks/python/apache_beam/coders/coders_test_common.py
b/sdks/python/apache_beam/coders/coders_test_common.py
index dbd0a301bb0..587e5d87522 100644
--- a/sdks/python/apache_beam/coders/coders_test_common.py
+++ b/sdks/python/apache_beam/coders/coders_test_common.py
@@ -34,6 +34,8 @@ from typing import List
from typing import NamedTuple
import pytest
+from parameterized import param
+from parameterized import parameterized
from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
from apache_beam.coders import coders
@@ -58,6 +60,7 @@ except ImportError:
dataclasses = None # type: ignore
MyNamedTuple = collections.namedtuple('A', ['x', 'y']) # type:
ignore[name-match]
+AnotherNamedTuple = collections.namedtuple('AnotherNamedTuple', ['x', 'y'])
MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), ('f2', str)])
@@ -175,6 +178,9 @@ class CodersTest(unittest.TestCase):
assert not standard - cls.seen, str(standard - cls.seen)
assert not cls.seen_nested - standard, str(cls.seen_nested - standard)
+ def tearDown(self):
+ typecoders.registry.update_compatibility_version = None
+
@classmethod
def _observe(cls, coder):
cls.seen.add(type(coder))
@@ -230,9 +236,15 @@ class CodersTest(unittest.TestCase):
coder = coders._MemoizingPickleCoder()
self.check_coder(coder, *self.test_values)
- def test_deterministic_coder(self):
+ @parameterized.expand([
+ param(compat_version=None),
+ param(compat_version="2.67.0"),
+ ])
+ def test_deterministic_coder(self, compat_version):
+ typecoders.registry.update_compatibility_version = compat_version
coder = coders.FastPrimitivesCoder()
- deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder,
'step')
+ deterministic_coder = coder.as_deterministic_coder(step_label="step")
+
self.check_coder(deterministic_coder, *self.test_values_deterministic)
for v in self.test_values_deterministic:
self.check_coder(coders.TupleCoder((deterministic_coder, )), (v, ))
@@ -254,8 +266,16 @@ class CodersTest(unittest.TestCase):
self.check_coder(deterministic_coder,
test_message.MessageA(field1='value'))
+ # Skip this test during cloudpickle. Dill monkey patches the __reduce__
+ # method for anonymous named tuples (MyNamedTuple) which is not pickleable.
+ # Since the test is parameterized the type gets colbbered.
+ if compat_version:
+ self.check_coder(
+ deterministic_coder, [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')])
+
self.check_coder(
- deterministic_coder, [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')])
+ deterministic_coder,
+ [AnotherNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')])
if dataclasses is not None:
self.check_coder(deterministic_coder, FrozenDataClass(1, 2))
@@ -265,9 +285,10 @@ class CodersTest(unittest.TestCase):
with self.assertRaises(TypeError):
self.check_coder(
deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3))
- with self.assertRaises(TypeError):
- self.check_coder(
- deterministic_coder, MyNamedTuple(UnFrozenDataClass(1, 2), 3))
+ with self.assertRaises(TypeError):
+ self.check_coder(
+ deterministic_coder,
+ AnotherNamedTuple(UnFrozenDataClass(1, 2), 3))
self.check_coder(deterministic_coder, list(MyEnum))
self.check_coder(deterministic_coder, list(MyIntEnum))
@@ -286,6 +307,29 @@ class CodersTest(unittest.TestCase):
1: 'x', 'y': 2
}))
+ @parameterized.expand([
+ param(compat_version=None),
+ param(compat_version="2.67.0"),
+ ])
+ def test_deterministic_map_coder_is_update_compatible(self, compat_version):
+ typecoders.registry.update_compatibility_version = compat_version
+ values = [{
+ MyTypedNamedTuple(i, 'a'): MyTypedNamedTuple('a', i)
+ for i in range(10)
+ }]
+
+ coder = coders.MapCoder(
+ coders.FastPrimitivesCoder(), coders.FastPrimitivesCoder())
+
+ deterministic_coder = coder.as_deterministic_coder(step_label="step")
+
+ assert isinstance(
+ deterministic_coder._key_coder,
+ coders.DeterministicFastPrimitivesCoderV2
+ if not compat_version else coders.DeterministicFastPrimitivesCoder)
+
+ self.check_coder(deterministic_coder, *values)
+
def test_dill_coder(self):
cell_value = (lambda x: lambda: x)(0).__closure__[0]
self.check_coder(coders.DillCoder(), 'a', 1, cell_value)
@@ -610,15 +654,21 @@ class CodersTest(unittest.TestCase):
1, (window.IntervalWindow(11, 21), ),
PaneInfo(True, False, 1, 2, 3))))
- def test_cross_process_encoding_of_special_types_is_deterministic(self):
+ @parameterized.expand([
+ param(compat_version=None),
+ param(compat_version="2.67.0"),
+ ])
+ def test_cross_process_encoding_of_special_types_is_deterministic(
+ self, compat_version):
"""Test cross-process determinism for all special deterministic types"""
if sys.executable is None:
self.skipTest('No Python interpreter found')
+ typecoders.registry.update_compatibility_version = compat_version
# pylint: disable=line-too-long
script = textwrap.dedent(
- '''\
+ f'''\
import pickle
import sys
import collections
@@ -626,13 +676,19 @@ class CodersTest(unittest.TestCase):
import logging
from apache_beam.coders import coders
- from apache_beam.coders import proto2_coder_test_messages_pb2 as
test_message
- from typing import NamedTuple
+ from apache_beam.coders import typecoders
+ from apache_beam.coders.coders_test_common import MyNamedTuple
+ from apache_beam.coders.coders_test_common import MyTypedNamedTuple
+ from apache_beam.coders.coders_test_common import MyEnum
+ from apache_beam.coders.coders_test_common import MyIntEnum
+ from apache_beam.coders.coders_test_common import MyIntFlag
+ from apache_beam.coders.coders_test_common import MyFlag
+ from apache_beam.coders.coders_test_common import DefinesGetState
+ from apache_beam.coders.coders_test_common import DefinesGetAndSetState
+ from apache_beam.coders.coders_test_common import FrozenDataClass
- try:
- import dataclasses
- except ImportError:
- dataclasses = None
+
+ from apache_beam.coders import proto2_coder_test_messages_pb2 as
test_message
logging.basicConfig(
level=logging.INFO,
@@ -640,38 +696,6 @@ class CodersTest(unittest.TestCase):
stream=sys.stderr,
force=True
)
-
- # Define all the special types that encode_special_deterministic
handles
- MyNamedTuple = collections.namedtuple('A', ['x', 'y'])
- MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int),
('f2', str)])
-
- class MyEnum(enum.Enum):
- E1 = 5
- E2 = enum.auto()
- E3 = 'abc'
-
- MyIntEnum = enum.IntEnum('MyIntEnum', 'I1 I2 I3')
- MyIntFlag = enum.IntFlag('MyIntFlag', 'F1 F2 F3')
- MyFlag = enum.Flag('MyFlag', 'F1 F2 F3')
-
- if dataclasses is not None:
- @dataclasses.dataclass(frozen=True)
- class FrozenDataClass:
- a: int
- b: int
-
- class DefinesGetAndSetState:
- def __init__(self, value):
- self.value = value
-
- def __getstate__(self):
- return self.value
-
- def __setstate__(self, value):
- self.value = value
-
- def __eq__(self, other):
- return type(other) is type(self) and other.value == self.value
# Test cases for all special deterministic types
# NOTE: When this script run in a subprocess the module is considered
@@ -683,26 +707,28 @@ class CodersTest(unittest.TestCase):
("named_tuple_simple", MyNamedTuple(1, 2)),
("typed_named_tuple", MyTypedNamedTuple(1, 'a')),
("named_tuple_list", [MyNamedTuple(1, 2), MyTypedNamedTuple(1,
'a')]),
- # ("enum_single", MyEnum.E1),
- # ("enum_list", list(MyEnum)),
- # ("int_enum_list", list(MyIntEnum)),
- # ("int_flag_list", list(MyIntFlag)),
- # ("flag_list", list(MyFlag)),
+ ("enum_single", MyEnum.E1),
+ ("enum_list", list(MyEnum)),
+ ("int_enum_list", list(MyIntEnum)),
+ ("int_flag_list", list(MyIntFlag)),
+ ("flag_list", list(MyFlag)),
("getstate_setstate_simple", DefinesGetAndSetState(1)),
("getstate_setstate_complex", DefinesGetAndSetState((1, 2, 3))),
("getstate_setstate_list", [DefinesGetAndSetState(1),
DefinesGetAndSetState((1, 2, 3))]),
]
- if dataclasses is not None:
- test_cases.extend([
- ("frozen_dataclass", FrozenDataClass(1, 2)),
- ("frozen_dataclass_list", [FrozenDataClass(1, 2),
FrozenDataClass(3, 4)]),
- ])
+
+ test_cases.extend([
+ ("frozen_dataclass", FrozenDataClass(1, 2)),
+ ("frozen_dataclass_list", [FrozenDataClass(1, 2),
FrozenDataClass(3, 4)]),
+ ])
+ compat_version = {'"'+ compat_version +'"' if compat_version else None}
+ typecoders.registry.update_compatibility_version = compat_version
coder = coders.FastPrimitivesCoder()
- deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder,
'step')
+ deterministic_coder = coder.as_deterministic_coder("step")
- results = {}
+ results = dict()
for test_name, value in test_cases:
try:
encoded = deterministic_coder.encode(value)
@@ -730,7 +756,7 @@ class CodersTest(unittest.TestCase):
results2 = run_subprocess()
coder = coders.FastPrimitivesCoder()
- deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder,
'step')
+ deterministic_coder = coder.as_deterministic_coder("step")
for test_name in results1:
data1 = results1[test_name]
@@ -861,7 +887,7 @@ class CodersTest(unittest.TestCase):
{
i: str(i)
for i in range(5000)
- }
+ },
]
map_coder = coders.MapCoder(coders.VarIntCoder(), coders.StrUtf8Coder())
self.check_coder(map_coder, *values)
diff --git a/sdks/python/apache_beam/coders/typecoders.py
b/sdks/python/apache_beam/coders/typecoders.py
index 19300c67559..779c65dc772 100644
--- a/sdks/python/apache_beam/coders/typecoders.py
+++ b/sdks/python/apache_beam/coders/typecoders.py
@@ -84,6 +84,7 @@ class CoderRegistry(object):
self._coders: Dict[Any, Type[coders.Coder]] = {}
self.custom_types: List[Any] = []
self.register_standard_coders(fallback_coder)
+ self.update_compatibility_version = None
def register_standard_coders(self, fallback_coder):
"""Register coders for all basic and composite types."""
diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler.py
b/sdks/python/apache_beam/internal/cloudpickle_pickler.py
index 63038e770f2..e55818bfb22 100644
--- a/sdks/python/apache_beam/internal/cloudpickle_pickler.py
+++ b/sdks/python/apache_beam/internal/cloudpickle_pickler.py
@@ -39,6 +39,8 @@ from apache_beam.internal.cloudpickle import cloudpickle
DEFAULT_CONFIG = cloudpickle.CloudPickleConfig(
skip_reset_dynamic_type_state=True)
+NO_DYNAMIC_CLASS_TRACKING_CONFIG = cloudpickle.CloudPickleConfig(
+ id_generator=None, skip_reset_dynamic_type_state=True)
try:
from absl import flags
diff --git a/sdks/python/apache_beam/pipeline.py
b/sdks/python/apache_beam/pipeline.py
index 83a0bee8145..0ed5a435e78 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -76,6 +76,7 @@ from typing import Union
from google.protobuf import message
from apache_beam import pvalue
+from apache_beam.coders import typecoders
from apache_beam.internal import pickler
from apache_beam.io.filesystems import FileSystems
from apache_beam.options.pipeline_options import CrossLanguageOptions
@@ -83,6 +84,7 @@ from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.options.pipeline_options import StandardOptions
+from apache_beam.options.pipeline_options import StreamingOptions
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.options.pipeline_options_validator import
PipelineOptionsValidator
from apache_beam.portability import common_urns
@@ -229,6 +231,9 @@ class Pipeline(HasDisplayData):
raise ValueError(
'Pipeline has validations errors: \n' + '\n'.join(errors))
+ typecoders.registry.update_compatibility_version = self._options.view_as(
+ StreamingOptions).update_compatibility_version
+
# set default experiments for portable runners
# (needs to occur prior to pipeline construction)
if runner.is_fnapi_compatible():
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py
b/sdks/python/apache_beam/transforms/ptransform_test.py
index e1c84c7dc9a..39d216c4b3b 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -34,16 +34,20 @@ from unittest.mock import patch
import hamcrest as hc
import numpy as np
import pytest
+from parameterized import param
+from parameterized import parameterized
from parameterized import parameterized_class
import apache_beam as beam
import apache_beam.transforms.combiners as combine
from apache_beam import pvalue
from apache_beam import typehints
+from apache_beam.coders import coders_test_common
from apache_beam.io.iobase import Read
from apache_beam.metrics import Metrics
from apache_beam.metrics.metric import MetricsFilter
from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import StreamingOptions
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.portability import common_urns
from apache_beam.testing.test_pipeline import TestPipeline
@@ -572,7 +576,7 @@ class PTransformTest(unittest.TestCase):
def decode(self, encoded):
return MyObject(pickle.loads(encoded)[0])
- def as_deterministic_coder(self, *args):
+ def as_deterministic_coder(self, *args, **kwargs):
return MydeterministicObjectCoder()
def to_type_hint(self):
@@ -719,6 +723,67 @@ class PTransformTest(unittest.TestCase):
result = (pcoll, ) | 'Single Flatten' >> beam.Flatten()
assert_that(result, equal_to(input))
+ @parameterized.expand([
+ param(compat_version=None),
+ param(compat_version="2.66.0"),
+ ])
+ @pytest.mark.it_validatesrunner
+ def test_group_by_key_importable_special_types(self, compat_version):
+ def generate(_):
+ for _ in range(100):
+ yield (coders_test_common.MyTypedNamedTuple(1, 'a'), 1)
+
+ pipeline = TestPipeline(is_integration_test=True)
+ if compat_version:
+ pipeline.get_pipeline_options().view_as(
+ StreamingOptions).update_compatibility_version = compat_version
+ with pipeline as p:
+ result = (
+ p
+ | 'Create' >> beam.Create([i for i in range(100)])
+ | 'Generate' >> beam.ParDo(generate)
+ | 'Reshuffle' >> beam.Reshuffle()
+ | 'GBK' >> beam.GroupByKey())
+ assert_that(
+ result,
+ equal_to([(
+ coders_test_common.MyTypedNamedTuple(1, 'a'),
+ [1 for i in range(10000)])]))
+
+ @pytest.mark.it_validatesrunner
+ def test_group_by_key_dynamic_special_types(self):
+ def create_dynamic_named_tuple():
+ return collections.namedtuple('DynamicNamedTuple', ['x', 'y'])
+
+ dynamic_named_tuple = create_dynamic_named_tuple()
+
+ # Standard FastPrimitivesCoder falls back to python PickleCoder which
+ # cannot serialize dynamic types or types defined in __main__. Use
+ # CloudPickleCoder as fallback coder for non-deterministic steps.
+ class FastPrimitivesCoderV2(beam.coders.FastPrimitivesCoder):
+ def __init__(self):
+ super().__init__(fallback_coder=beam.coders.CloudpickleCoder())
+
+ beam.coders.typecoders.registry.register_coder(
+ dynamic_named_tuple, FastPrimitivesCoderV2)
+
+ def generate(_):
+ for _ in range(100):
+ yield (dynamic_named_tuple(1, 'a'), 1)
+
+ pipeline = TestPipeline(is_integration_test=True)
+
+ with pipeline as p:
+ result = (
+ p
+ | 'Create' >> beam.Create([i for i in range(100)])
+ | 'Reshuffle' >> beam.Reshuffle()
+ | 'Generate' >> beam.ParDo(generate).with_output_types(
+ tuple[dynamic_named_tuple, int])
+ | 'GBK' >> beam.GroupByKey()
+ | 'Count Elements' >> beam.Map(lambda x: len(x[1])))
+ assert_that(result, equal_to([10000]))
+
# TODO(https://github.com/apache/beam/issues/20067): Does not work in
# streaming mode on Dataflow.
@pytest.mark.no_sickbay_streaming
diff --git a/sdks/python/apache_beam/transforms/util.py
b/sdks/python/apache_beam/transforms/util.py
index c60ded52df2..2df66aadcc6 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -932,6 +932,15 @@ class _IdentityWindowFn(NonMergingWindowFn):
return self._window_coder
+def is_v1_prior_to_v2(*, v1, v2):
+ if v1 is None:
+ return False
+
+ v1_parts = (v1.split('.') + ['0', '0', '0'])[:3]
+ v2_parts = (v2.split('.') + ['0', '0', '0'])[:3]
+ return tuple(map(int, v1_parts)) < tuple(map(int, v2_parts))
+
+
def is_compat_version_prior_to(options, breaking_change_version):
# This function is used in a branch statement to determine whether we should
# keep the old behavior prior to a breaking change or use the new behavior.
@@ -940,15 +949,8 @@ def is_compat_version_prior_to(options,
breaking_change_version):
update_compatibility_version = options.view_as(
pipeline_options.StreamingOptions).update_compatibility_version
- if update_compatibility_version is None:
- return False
-
- compat_version = tuple(map(int,
update_compatibility_version.split('.')[0:3]))
- change_version = tuple(map(int, breaking_change_version.split('.')[0:3]))
- for i in range(min(len(compat_version), len(change_version))):
- if compat_version[i] < change_version[i]:
- return True
- return False
+ return is_v1_prior_to_v2(
+ v1=update_compatibility_version, v2=breaking_change_version)
def reify_metadata_default_window(
diff --git a/sdks/python/apache_beam/transforms/util_test.py
b/sdks/python/apache_beam/transforms/util_test.py
index ad185ac6a6d..b365d9b2209 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -2193,6 +2193,68 @@ class WaitOnTest(unittest.TestCase):
label='result')
+class CompatCheckTest(unittest.TestCase):
+ def test_is_v1_prior_to_v2(self):
+ test_cases = [
+ # Basic comparison cases
+ ("1.0.0", "2.0.0", True), # v1 < v2 in major
+ ("2.0.0", "1.0.0", False), # v1 > v2 in major
+ ("1.1.0", "1.2.0", True), # v1 < v2 in minor
+ ("1.2.0", "1.1.0", False), # v1 > v2 in minor
+ ("1.0.1", "1.0.2", True), # v1 < v2 in patch
+ ("1.0.2", "1.0.1", False), # v1 > v2 in patch
+
+ # Equal versions
+ ("1.0.0", "1.0.0", False), # Identical
+ ("0.0.0", "0.0.0", False), # Both zero
+
+ # Different lengths - shorter vs longer
+ ("1.0", "1.0.0", False), # Should be equal (1.0 = 1.0.0)
+ ("1.0", "1.0.1", True), # 1.0.0 < 1.0.1
+ ("1.2", "1.2.0", False), # Should be equal (1.2 = 1.2.0)
+ ("1.2", "1.2.3", True), # 1.2.0 < 1.2.3
+ ("2", "2.0.0", False), # Should be equal (2 = 2.0.0)
+ ("2", "2.0.1", True), # 2.0.0 < 2.0.1
+ ("1", "2.0", True), # 1.0.0 < 2.0.0
+
+ # Different lengths - longer vs shorter
+ ("1.0.0", "1.0", False), # Should be equal
+ ("1.0.1", "1.0", False), # 1.0.1 > 1.0.0
+ ("1.2.0", "1.2", False), # Should be equal
+ ("1.2.3", "1.2", False), # 1.2.3 > 1.2.0
+ ("2.0.0", "2", False), # Should be equal
+ ("2.0.1", "2", False), # 2.0.1 > 2.0.0
+ ("2.0", "1", False), # 2.0.0 > 1.0.0
+
+ # Mixed length comparisons
+ ("1.0", "2.0.0", True), # 1.0.0 < 2.0.0
+ ("2.0", "1.0.0", False), # 2.0.0 > 1.0.0
+ ("1", "1.0.1", True), # 1.0.0 < 1.0.1
+ ("1.1", "1.0.9", False), # 1.1.0 > 1.0.9
+
+ # Large numbers
+ ("1.9.9", "2.0.0", True), # 1.9.9 < 2.0.0
+ ("10.0.0", "9.9.9", False), # 10.0.0 > 9.9.9
+ ("1.10.0", "1.9.0", False), # 1.10.0 > 1.9.0
+ ("1.2.10", "1.2.9", False), # 1.2.10 > 1.2.9
+
+ # Sequential versions
+ ("1.0.0", "1.0.1", True),
+ ("1.0.1", "1.0.2", True),
+ ("1.0.9", "1.1.0", True),
+ ("1.9.9", "2.0.0", True),
+
+ # Null/None cases
+ (None, "1.0.0", False), # v1 is None
+ ]
+
+ for v1, v2, expected in test_cases:
+ self.assertEqual(
+ util.is_v1_prior_to_v2(v1=v1, v2=v2),
+ expected,
+ msg=f"Failed {v1} < {v2} == {expected}")
+
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()