This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e8fab2669a5 Pass update compat through as_deterministic_coder and use 
cloudpickle for deterministic special types. (#35725)
e8fab2669a5 is described below

commit e8fab2669a53b1f4ee18b202983bb4a9f8a32fec
Author: claudevdm <[email protected]>
AuthorDate: Sat Aug 30 03:51:26 2025 -0400

    Pass update compat through as_deterministic_coder and use cloudpickle for 
deterministic special types. (#35725)
    
    * Pass update compat through as_deterministic_coder.
    
    * Coder changes.
    
    * Pass update compat through pipeline options.
    
    * Add tests.
    
    * asd
    
    * Fix test.
    
    * Trigger tests.
    
    * Undo disable tests.
    
    * Fix test.
    
    * Fix tests and lint.
    
    * Refactor.
    
    * Rebase and comments.
    
    * Lint fix.
    
    * Update the base message to include possible lower values for the flag.
    
    * Update sdks/python/apache_beam/coders/coders.py
    
    ---------
    
    Co-authored-by: Claude <[email protected]>
    Co-authored-by: tvalentyn <[email protected]>
---
 ...PostCommit_Python_ValidatesRunner_Dataflow.json |   3 +-
 sdks/python/apache_beam/coders/coder_impl.pxd      |   2 +
 sdks/python/apache_beam/coders/coder_impl.py       |  61 ++++++++-
 sdks/python/apache_beam/coders/coders.py           |  72 +++++++++-
 .../apache_beam/coders/coders_test_common.py       | 146 ++++++++++++---------
 sdks/python/apache_beam/coders/typecoders.py       |   1 +
 .../apache_beam/internal/cloudpickle_pickler.py    |   2 +
 sdks/python/apache_beam/pipeline.py                |   5 +
 .../apache_beam/transforms/ptransform_test.py      |  67 +++++++++-
 sdks/python/apache_beam/transforms/util.py         |  20 +--
 sdks/python/apache_beam/transforms/util_test.py    |  62 +++++++++
 11 files changed, 363 insertions(+), 78 deletions(-)

diff --git 
a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json 
b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json
index 19ebbfb9ad9..e3d6056a5de 100644
--- a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json
+++ b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json
@@ -1,3 +1,4 @@
 {
-  "https://github.com/apache/beam/pull/35951": "triggering sideinput test"
+  "comment": "Modify this file in a trivial way to cause this test suite to 
run",
+  "modification": 1
 }
diff --git a/sdks/python/apache_beam/coders/coder_impl.pxd 
b/sdks/python/apache_beam/coders/coder_impl.pxd
index 27cffe7b62d..2db7b963d15 100644
--- a/sdks/python/apache_beam/coders/coder_impl.pxd
+++ b/sdks/python/apache_beam/coders/coder_impl.pxd
@@ -81,6 +81,7 @@ cdef class FastPrimitivesCoderImpl(StreamCoderImpl):
   cdef CoderImpl iterable_coder_impl
   cdef object requires_deterministic_step_label
   cdef bint warn_deterministic_fallback
+  cdef bint force_use_dill
 
   @cython.locals(dict_value=dict, int_value=libc.stdint.int64_t,
                  unicode_value=unicode)
@@ -88,6 +89,7 @@ cdef class FastPrimitivesCoderImpl(StreamCoderImpl):
   @cython.locals(t=int)
   cpdef decode_from_stream(self, InputStream stream, bint nested)
   cdef encode_special_deterministic(self, value, OutputStream stream)
+  cdef encode_type_2_67_0(self, t, OutputStream stream)
   cdef encode_type(self, t, OutputStream stream)
   cdef decode_type(self, InputStream stream)
 
diff --git a/sdks/python/apache_beam/coders/coder_impl.py 
b/sdks/python/apache_beam/coders/coder_impl.py
index 807d083d8a3..4f28fb3c916 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -50,7 +50,6 @@ from typing import Set
 from typing import Tuple
 from typing import Type
 
-import dill
 import numpy as np
 from fastavro import parse_schema
 from fastavro import schemaless_reader
@@ -58,6 +57,7 @@ from fastavro import schemaless_writer
 
 from apache_beam.coders import observable
 from apache_beam.coders.avro_record import AvroRecord
+from apache_beam.internal import cloudpickle_pickler
 from apache_beam.typehints.schemas import named_tuple_from_schema
 from apache_beam.utils import proto_utils
 from apache_beam.utils import windowed_value
@@ -71,6 +71,11 @@ try:
 except ImportError:
   dataclasses = None  # type: ignore
 
+try:
+  import dill
+except ImportError:
+  dill = None
+
 if TYPE_CHECKING:
   import proto
   from apache_beam.transforms import userstate
@@ -354,14 +359,30 @@ NESTED_STATE_TYPE = 104
 _ITERABLE_LIKE_TYPES = set()  # type: Set[Type]
 
 
+def _verify_dill_compat():
+  base_error = (
+      "This pipeline runs with the pipeline option "
+      "--update_compatibility_version=2.67.0 or earlier. "
+      "When running with this option on SDKs 2.68.0 or "
+      "later, you must ensure dill==0.3.1.1 is installed.")
+  if not dill:
+    raise RuntimeError(base_error + ". Dill is not installed.")
+  if dill.__version__ != "0.3.1.1":
+    raise RuntimeError(base_error + f". Found dill version 
'{dill.__version__}")
+
+
 class FastPrimitivesCoderImpl(StreamCoderImpl):
   """For internal use only; no backwards-compatibility guarantees."""
   def __init__(
-      self, fallback_coder_impl, requires_deterministic_step_label=None):
+      self,
+      fallback_coder_impl,
+      requires_deterministic_step_label=None,
+      force_use_dill=False):
     self.fallback_coder_impl = fallback_coder_impl
     self.iterable_coder_impl = IterableCoderImpl(self)
     self.requires_deterministic_step_label = requires_deterministic_step_label
     self.warn_deterministic_fallback = True
+    self.force_use_dill = force_use_dill
 
   @staticmethod
   def register_iterable_like_type(t):
@@ -525,10 +546,23 @@ class FastPrimitivesCoderImpl(StreamCoderImpl):
         "please provide a type hint for the input of '%s'" %
         (value, type(value), self.requires_deterministic_step_label))
 
-  def encode_type(self, t, stream):
+  def encode_type_2_67_0(self, t, stream):
+    """
+    Encode special type with <=2.67.0 compatibility.
+    """
+    _verify_dill_compat()
     stream.write(dill.dumps(t), True)
 
+  def encode_type(self, t, stream):
+    if self.force_use_dill:
+      return self.encode_type_2_67_0(t, stream)
+    bs = cloudpickle_pickler.dumps(
+        t, config=cloudpickle_pickler.NO_DYNAMIC_CLASS_TRACKING_CONFIG)
+    stream.write(bs, True)
+
   def decode_type(self, stream):
+    if self.force_use_dill:
+      return _unpickle_type_2_67_0(stream.read_all(True))
     return _unpickle_type(stream.read_all(True))
 
   def decode_from_stream(self, stream, nested):
@@ -589,19 +623,35 @@ class FastPrimitivesCoderImpl(StreamCoderImpl):
 _unpickled_types = {}  # type: Dict[bytes, type]
 
 
-def _unpickle_type(bs):
+def _unpickle_type_2_67_0(bs):
+  """
+  Decode special type with <=2.67.0 compatibility.
+  """
   t = _unpickled_types.get(bs, None)
   if t is None:
+    _verify_dill_compat()
     t = _unpickled_types[bs] = dill.loads(bs)
     # Fix unpicklable anonymous named tuples for Python 3.6.
     if t.__base__ is tuple and hasattr(t, '_fields'):
       try:
         pickle.loads(pickle.dumps(t))
       except pickle.PicklingError:
-        t.__reduce__ = lambda self: (_unpickle_named_tuple, (bs, tuple(self)))
+        t.__reduce__ = lambda self: (
+            _unpickle_named_tuple_2_67_0, (bs, tuple(self)))
   return t
 
 
+def _unpickle_named_tuple_2_67_0(bs, items):
+  return _unpickle_type_2_67_0(bs)(*items)
+
+
+def _unpickle_type(bs):
+  if not _unpickled_types.get(bs, None):
+    _unpickled_types[bs] = cloudpickle_pickler.loads(bs)
+
+  return _unpickled_types[bs]
+
+
 def _unpickle_named_tuple(bs, items):
   return _unpickle_type(bs)(*items)
 
@@ -837,6 +887,7 @@ class IntervalWindowCoderImpl(StreamCoderImpl):
       if IntervalWindow is None:
         from apache_beam.transforms.window import IntervalWindow
     # instantiating with None is not part of the public interface
+    # pylint: disable=too-many-function-args
     typed_value = IntervalWindow(None, None)  # type: ignore[arg-type]
     typed_value._end_micros = (
         1000 * self._to_normal_time(in_.read_bigendian_uint64()))
diff --git a/sdks/python/apache_beam/coders/coders.py 
b/sdks/python/apache_beam/coders/coders.py
index 2691857bf0a..e527185bd57 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -911,6 +911,44 @@ class CloudpickleCoder(_PickleCoderBase):
         cloudpickle_pickler.dumps, cloudpickle_pickler.loads)
 
 
+class DeterministicFastPrimitivesCoderV2(FastCoder):
+  """Throws runtime errors when encoding non-deterministic values."""
+  def __init__(self, coder, step_label):
+    self._underlying_coder = coder
+    self._step_label = step_label
+
+  def _create_impl(self):
+
+    return coder_impl.FastPrimitivesCoderImpl(
+        self._underlying_coder.get_impl(),
+        requires_deterministic_step_label=self._step_label,
+        force_use_dill=False)
+
+  def is_deterministic(self):
+    # type: () -> bool
+    return True
+
+  def is_kv_coder(self):
+    # type: () -> bool
+    return True
+
+  def key_coder(self):
+    return self
+
+  def value_coder(self):
+    return self
+
+  def to_type_hint(self):
+    return Any
+
+  def to_runner_api_parameter(self, context):
+    # type: (Optional[PipelineContext]) -> Tuple[str, Any, Sequence[Coder]]
+    return (
+        python_urns.PICKLED_CODER,
+        google.protobuf.wrappers_pb2.BytesValue(value=serialize_coder(self)),
+        ())
+
+
 class DeterministicFastPrimitivesCoder(FastCoder):
   """Throws runtime errors when encoding non-deterministic values."""
   def __init__(self, coder, step_label):
@@ -920,7 +958,8 @@ class DeterministicFastPrimitivesCoder(FastCoder):
   def _create_impl(self):
     return coder_impl.FastPrimitivesCoderImpl(
         self._underlying_coder.get_impl(),
-        requires_deterministic_step_label=self._step_label)
+        requires_deterministic_step_label=self._step_label,
+        force_use_dill=True)
 
   def is_deterministic(self):
     # type: () -> bool
@@ -940,6 +979,34 @@ class DeterministicFastPrimitivesCoder(FastCoder):
     return Any
 
 
+def _should_force_use_dill():
+  from apache_beam.coders import typecoders
+  from apache_beam.transforms.util import is_v1_prior_to_v2
+  update_compat_version = typecoders.registry.update_compatibility_version
+
+  if not update_compat_version:
+    return False
+
+  if not is_v1_prior_to_v2(v1=update_compat_version, v2="2.68.0"):
+    return False
+
+  try:
+    import dill
+    assert dill.__version__ == "0.3.1.1"
+  except Exception as e:
+    raise RuntimeError("This pipeline runs with the pipeline option " \
+    "--update_compatibility_version=2.67.0 or earlier. When running with " \
+    "this option on SDKs 2.68.0 or later, you must ensure dill==0.3.1.1 " \
+    f"is installed. Error {e}")
+  return True
+
+
+def _update_compatible_deterministic_fast_primitives_coder(coder, step_label):
+  if _should_force_use_dill():
+    return DeterministicFastPrimitivesCoder(coder, step_label)
+  return DeterministicFastPrimitivesCoderV2(coder, step_label)
+
+
 class FastPrimitivesCoder(FastCoder):
   """Encodes simple primitives (e.g. str, int) efficiently.
 
@@ -960,7 +1027,8 @@ class FastPrimitivesCoder(FastCoder):
     if self.is_deterministic():
       return self
     else:
-      return DeterministicFastPrimitivesCoder(self, step_label)
+      return _update_compatible_deterministic_fast_primitives_coder(
+          self, step_label)
 
   def to_type_hint(self):
     return Any
diff --git a/sdks/python/apache_beam/coders/coders_test_common.py 
b/sdks/python/apache_beam/coders/coders_test_common.py
index dbd0a301bb0..587e5d87522 100644
--- a/sdks/python/apache_beam/coders/coders_test_common.py
+++ b/sdks/python/apache_beam/coders/coders_test_common.py
@@ -34,6 +34,8 @@ from typing import List
 from typing import NamedTuple
 
 import pytest
+from parameterized import param
+from parameterized import parameterized
 
 from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
 from apache_beam.coders import coders
@@ -58,6 +60,7 @@ except ImportError:
   dataclasses = None  # type: ignore
 
 MyNamedTuple = collections.namedtuple('A', ['x', 'y'])  # type: 
ignore[name-match]
+AnotherNamedTuple = collections.namedtuple('AnotherNamedTuple', ['x', 'y'])
 MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), ('f2', str)])
 
 
@@ -175,6 +178,9 @@ class CodersTest(unittest.TestCase):
     assert not standard - cls.seen, str(standard - cls.seen)
     assert not cls.seen_nested - standard, str(cls.seen_nested - standard)
 
+  def tearDown(self):
+    typecoders.registry.update_compatibility_version = None
+
   @classmethod
   def _observe(cls, coder):
     cls.seen.add(type(coder))
@@ -230,9 +236,15 @@ class CodersTest(unittest.TestCase):
     coder = coders._MemoizingPickleCoder()
     self.check_coder(coder, *self.test_values)
 
-  def test_deterministic_coder(self):
+  @parameterized.expand([
+      param(compat_version=None),
+      param(compat_version="2.67.0"),
+  ])
+  def test_deterministic_coder(self, compat_version):
+    typecoders.registry.update_compatibility_version = compat_version
     coder = coders.FastPrimitivesCoder()
-    deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 
'step')
+    deterministic_coder = coder.as_deterministic_coder(step_label="step")
+
     self.check_coder(deterministic_coder, *self.test_values_deterministic)
     for v in self.test_values_deterministic:
       self.check_coder(coders.TupleCoder((deterministic_coder, )), (v, ))
@@ -254,8 +266,16 @@ class CodersTest(unittest.TestCase):
 
     self.check_coder(deterministic_coder, 
test_message.MessageA(field1='value'))
 
+    # Skip this test during cloudpickle. Dill monkey patches the __reduce__
+    # method for anonymous named tuples (MyNamedTuple) which is not pickleable.
+    # Since the test is parameterized the type gets colbbered.
+    if compat_version:
+      self.check_coder(
+          deterministic_coder, [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')])
+
     self.check_coder(
-        deterministic_coder, [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')])
+        deterministic_coder,
+        [AnotherNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')])
 
     if dataclasses is not None:
       self.check_coder(deterministic_coder, FrozenDataClass(1, 2))
@@ -265,9 +285,10 @@ class CodersTest(unittest.TestCase):
       with self.assertRaises(TypeError):
         self.check_coder(
             deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3))
-      with self.assertRaises(TypeError):
-        self.check_coder(
-            deterministic_coder, MyNamedTuple(UnFrozenDataClass(1, 2), 3))
+        with self.assertRaises(TypeError):
+          self.check_coder(
+              deterministic_coder,
+              AnotherNamedTuple(UnFrozenDataClass(1, 2), 3))
 
     self.check_coder(deterministic_coder, list(MyEnum))
     self.check_coder(deterministic_coder, list(MyIntEnum))
@@ -286,6 +307,29 @@ class CodersTest(unittest.TestCase):
               1: 'x', 'y': 2
           }))
 
+  @parameterized.expand([
+      param(compat_version=None),
+      param(compat_version="2.67.0"),
+  ])
+  def test_deterministic_map_coder_is_update_compatible(self, compat_version):
+    typecoders.registry.update_compatibility_version = compat_version
+    values = [{
+        MyTypedNamedTuple(i, 'a'): MyTypedNamedTuple('a', i)
+        for i in range(10)
+    }]
+
+    coder = coders.MapCoder(
+        coders.FastPrimitivesCoder(), coders.FastPrimitivesCoder())
+
+    deterministic_coder = coder.as_deterministic_coder(step_label="step")
+
+    assert isinstance(
+        deterministic_coder._key_coder,
+        coders.DeterministicFastPrimitivesCoderV2
+        if not compat_version else coders.DeterministicFastPrimitivesCoder)
+
+    self.check_coder(deterministic_coder, *values)
+
   def test_dill_coder(self):
     cell_value = (lambda x: lambda: x)(0).__closure__[0]
     self.check_coder(coders.DillCoder(), 'a', 1, cell_value)
@@ -610,15 +654,21 @@ class CodersTest(unittest.TestCase):
                 1, (window.IntervalWindow(11, 21), ),
                 PaneInfo(True, False, 1, 2, 3))))
 
-  def test_cross_process_encoding_of_special_types_is_deterministic(self):
+  @parameterized.expand([
+      param(compat_version=None),
+      param(compat_version="2.67.0"),
+  ])
+  def test_cross_process_encoding_of_special_types_is_deterministic(
+      self, compat_version):
     """Test cross-process determinism for all special deterministic types"""
 
     if sys.executable is None:
       self.skipTest('No Python interpreter found')
+    typecoders.registry.update_compatibility_version = compat_version
 
     # pylint: disable=line-too-long
     script = textwrap.dedent(
-        '''\
+        f'''\
         import pickle
         import sys
         import collections
@@ -626,13 +676,19 @@ class CodersTest(unittest.TestCase):
         import logging
 
         from apache_beam.coders import coders
-        from apache_beam.coders import proto2_coder_test_messages_pb2 as 
test_message
-        from typing import NamedTuple
+        from apache_beam.coders import typecoders
+        from apache_beam.coders.coders_test_common import MyNamedTuple
+        from apache_beam.coders.coders_test_common import MyTypedNamedTuple
+        from apache_beam.coders.coders_test_common import MyEnum
+        from apache_beam.coders.coders_test_common import MyIntEnum
+        from apache_beam.coders.coders_test_common import MyIntFlag
+        from apache_beam.coders.coders_test_common import MyFlag
+        from apache_beam.coders.coders_test_common import DefinesGetState
+        from apache_beam.coders.coders_test_common import DefinesGetAndSetState
+        from apache_beam.coders.coders_test_common import FrozenDataClass
 
-        try:
-            import dataclasses
-        except ImportError:
-            dataclasses = None
+
+        from apache_beam.coders import proto2_coder_test_messages_pb2 as 
test_message
 
         logging.basicConfig(
             level=logging.INFO,
@@ -640,38 +696,6 @@ class CodersTest(unittest.TestCase):
             stream=sys.stderr,
             force=True
         )
-
-        # Define all the special types that encode_special_deterministic 
handles
-        MyNamedTuple = collections.namedtuple('A', ['x', 'y'])
-        MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), 
('f2', str)])
-
-        class MyEnum(enum.Enum):
-            E1 = 5
-            E2 = enum.auto()
-            E3 = 'abc'
-
-        MyIntEnum = enum.IntEnum('MyIntEnum', 'I1 I2 I3')
-        MyIntFlag = enum.IntFlag('MyIntFlag', 'F1 F2 F3')
-        MyFlag = enum.Flag('MyFlag', 'F1 F2 F3')
-
-        if dataclasses is not None:
-            @dataclasses.dataclass(frozen=True)
-            class FrozenDataClass:
-                a: int
-                b: int
-
-        class DefinesGetAndSetState:
-            def __init__(self, value):
-                self.value = value
-
-            def __getstate__(self):
-                return self.value
-
-            def __setstate__(self, value):
-                self.value = value
-
-            def __eq__(self, other):
-                return type(other) is type(self) and other.value == self.value
         
         # Test cases for all special deterministic types
         # NOTE: When this script run in a subprocess the module is considered
@@ -683,26 +707,28 @@ class CodersTest(unittest.TestCase):
             ("named_tuple_simple", MyNamedTuple(1, 2)),
             ("typed_named_tuple", MyTypedNamedTuple(1, 'a')),
             ("named_tuple_list", [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 
'a')]),
-            # ("enum_single", MyEnum.E1),
-            # ("enum_list", list(MyEnum)),
-            # ("int_enum_list", list(MyIntEnum)),
-            # ("int_flag_list", list(MyIntFlag)),
-            # ("flag_list", list(MyFlag)),
+            ("enum_single", MyEnum.E1),
+            ("enum_list", list(MyEnum)),
+            ("int_enum_list", list(MyIntEnum)),
+            ("int_flag_list", list(MyIntFlag)),
+            ("flag_list", list(MyFlag)),
             ("getstate_setstate_simple", DefinesGetAndSetState(1)),
             ("getstate_setstate_complex", DefinesGetAndSetState((1, 2, 3))),
             ("getstate_setstate_list", [DefinesGetAndSetState(1), 
DefinesGetAndSetState((1, 2, 3))]),
         ]
 
-        if dataclasses is not None:
-            test_cases.extend([
-                ("frozen_dataclass", FrozenDataClass(1, 2)),
-                ("frozen_dataclass_list", [FrozenDataClass(1, 2), 
FrozenDataClass(3, 4)]),
-            ])
+        
+        test_cases.extend([
+            ("frozen_dataclass", FrozenDataClass(1, 2)),
+            ("frozen_dataclass_list", [FrozenDataClass(1, 2), 
FrozenDataClass(3, 4)]),
+        ])
 
+        compat_version = {'"'+ compat_version +'"' if compat_version else None}
+        typecoders.registry.update_compatibility_version = compat_version
         coder = coders.FastPrimitivesCoder()
-        deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 
'step')
+        deterministic_coder = coder.as_deterministic_coder("step")
         
-        results = {}
+        results = dict()
         for test_name, value in test_cases:
             try:
                 encoded = deterministic_coder.encode(value)
@@ -730,7 +756,7 @@ class CodersTest(unittest.TestCase):
     results2 = run_subprocess()
 
     coder = coders.FastPrimitivesCoder()
-    deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 
'step')
+    deterministic_coder = coder.as_deterministic_coder("step")
 
     for test_name in results1:
       data1 = results1[test_name]
@@ -861,7 +887,7 @@ class CodersTest(unittest.TestCase):
         {
             i: str(i)
             for i in range(5000)
-        }
+        },
     ]
     map_coder = coders.MapCoder(coders.VarIntCoder(), coders.StrUtf8Coder())
     self.check_coder(map_coder, *values)
diff --git a/sdks/python/apache_beam/coders/typecoders.py 
b/sdks/python/apache_beam/coders/typecoders.py
index 19300c67559..779c65dc772 100644
--- a/sdks/python/apache_beam/coders/typecoders.py
+++ b/sdks/python/apache_beam/coders/typecoders.py
@@ -84,6 +84,7 @@ class CoderRegistry(object):
     self._coders: Dict[Any, Type[coders.Coder]] = {}
     self.custom_types: List[Any] = []
     self.register_standard_coders(fallback_coder)
+    self.update_compatibility_version = None
 
   def register_standard_coders(self, fallback_coder):
     """Register coders for all basic and composite types."""
diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler.py 
b/sdks/python/apache_beam/internal/cloudpickle_pickler.py
index 63038e770f2..e55818bfb22 100644
--- a/sdks/python/apache_beam/internal/cloudpickle_pickler.py
+++ b/sdks/python/apache_beam/internal/cloudpickle_pickler.py
@@ -39,6 +39,8 @@ from apache_beam.internal.cloudpickle import cloudpickle
 
 DEFAULT_CONFIG = cloudpickle.CloudPickleConfig(
     skip_reset_dynamic_type_state=True)
+NO_DYNAMIC_CLASS_TRACKING_CONFIG = cloudpickle.CloudPickleConfig(
+    id_generator=None, skip_reset_dynamic_type_state=True)
 
 try:
   from absl import flags
diff --git a/sdks/python/apache_beam/pipeline.py 
b/sdks/python/apache_beam/pipeline.py
index 83a0bee8145..0ed5a435e78 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -76,6 +76,7 @@ from typing import Union
 from google.protobuf import message
 
 from apache_beam import pvalue
+from apache_beam.coders import typecoders
 from apache_beam.internal import pickler
 from apache_beam.io.filesystems import FileSystems
 from apache_beam.options.pipeline_options import CrossLanguageOptions
@@ -83,6 +84,7 @@ from apache_beam.options.pipeline_options import DebugOptions
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import SetupOptions
 from apache_beam.options.pipeline_options import StandardOptions
+from apache_beam.options.pipeline_options import StreamingOptions
 from apache_beam.options.pipeline_options import TypeOptions
 from apache_beam.options.pipeline_options_validator import 
PipelineOptionsValidator
 from apache_beam.portability import common_urns
@@ -229,6 +231,9 @@ class Pipeline(HasDisplayData):
       raise ValueError(
           'Pipeline has validations errors: \n' + '\n'.join(errors))
 
+    typecoders.registry.update_compatibility_version = self._options.view_as(
+        StreamingOptions).update_compatibility_version
+
     # set default experiments for portable runners
     # (needs to occur prior to pipeline construction)
     if runner.is_fnapi_compatible():
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py 
b/sdks/python/apache_beam/transforms/ptransform_test.py
index e1c84c7dc9a..39d216c4b3b 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -34,16 +34,20 @@ from unittest.mock import patch
 import hamcrest as hc
 import numpy as np
 import pytest
+from parameterized import param
+from parameterized import parameterized
 from parameterized import parameterized_class
 
 import apache_beam as beam
 import apache_beam.transforms.combiners as combine
 from apache_beam import pvalue
 from apache_beam import typehints
+from apache_beam.coders import coders_test_common
 from apache_beam.io.iobase import Read
 from apache_beam.metrics import Metrics
 from apache_beam.metrics.metric import MetricsFilter
 from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import StreamingOptions
 from apache_beam.options.pipeline_options import TypeOptions
 from apache_beam.portability import common_urns
 from apache_beam.testing.test_pipeline import TestPipeline
@@ -572,7 +576,7 @@ class PTransformTest(unittest.TestCase):
       def decode(self, encoded):
         return MyObject(pickle.loads(encoded)[0])
 
-      def as_deterministic_coder(self, *args):
+      def as_deterministic_coder(self, *args, **kwargs):
         return MydeterministicObjectCoder()
 
       def to_type_hint(self):
@@ -719,6 +723,67 @@ class PTransformTest(unittest.TestCase):
       result = (pcoll, ) | 'Single Flatten' >> beam.Flatten()
       assert_that(result, equal_to(input))
 
+  @parameterized.expand([
+      param(compat_version=None),
+      param(compat_version="2.66.0"),
+  ])
+  @pytest.mark.it_validatesrunner
+  def test_group_by_key_importable_special_types(self, compat_version):
+    def generate(_):
+      for _ in range(100):
+        yield (coders_test_common.MyTypedNamedTuple(1, 'a'), 1)
+
+    pipeline = TestPipeline(is_integration_test=True)
+    if compat_version:
+      pipeline.get_pipeline_options().view_as(
+          StreamingOptions).update_compatibility_version = compat_version
+    with pipeline as p:
+      result = (
+          p
+          | 'Create' >> beam.Create([i for i in range(100)])
+          | 'Generate' >> beam.ParDo(generate)
+          | 'Reshuffle' >> beam.Reshuffle()
+          | 'GBK' >> beam.GroupByKey())
+      assert_that(
+          result,
+          equal_to([(
+              coders_test_common.MyTypedNamedTuple(1, 'a'),
+              [1 for i in range(10000)])]))
+
+  @pytest.mark.it_validatesrunner
+  def test_group_by_key_dynamic_special_types(self):
+    def create_dynamic_named_tuple():
+      return collections.namedtuple('DynamicNamedTuple', ['x', 'y'])
+
+    dynamic_named_tuple = create_dynamic_named_tuple()
+
+    # Standard FastPrimitivesCoder falls back to python PickleCoder which
+    # cannot serialize dynamic types or types defined in __main__. Use
+    # CloudPickleCoder as fallback coder for non-deterministic steps.
+    class FastPrimitivesCoderV2(beam.coders.FastPrimitivesCoder):
+      def __init__(self):
+        super().__init__(fallback_coder=beam.coders.CloudpickleCoder())
+
+    beam.coders.typecoders.registry.register_coder(
+        dynamic_named_tuple, FastPrimitivesCoderV2)
+
+    def generate(_):
+      for _ in range(100):
+        yield (dynamic_named_tuple(1, 'a'), 1)
+
+    pipeline = TestPipeline(is_integration_test=True)
+
+    with pipeline as p:
+      result = (
+          p
+          | 'Create' >> beam.Create([i for i in range(100)])
+          | 'Reshuffle' >> beam.Reshuffle()
+          | 'Generate' >> beam.ParDo(generate).with_output_types(
+              tuple[dynamic_named_tuple, int])
+          | 'GBK' >> beam.GroupByKey()
+          | 'Count Elements' >> beam.Map(lambda x: len(x[1])))
+      assert_that(result, equal_to([10000]))
+
   # TODO(https://github.com/apache/beam/issues/20067): Does not work in
   # streaming mode on Dataflow.
   @pytest.mark.no_sickbay_streaming
diff --git a/sdks/python/apache_beam/transforms/util.py 
b/sdks/python/apache_beam/transforms/util.py
index c60ded52df2..2df66aadcc6 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -932,6 +932,15 @@ class _IdentityWindowFn(NonMergingWindowFn):
     return self._window_coder
 
 
+def is_v1_prior_to_v2(*, v1, v2):
+  if v1 is None:
+    return False
+
+  v1_parts = (v1.split('.') + ['0', '0', '0'])[:3]
+  v2_parts = (v2.split('.') + ['0', '0', '0'])[:3]
+  return tuple(map(int, v1_parts)) < tuple(map(int, v2_parts))
+
+
 def is_compat_version_prior_to(options, breaking_change_version):
   # This function is used in a branch statement to determine whether we should
   # keep the old behavior prior to a breaking change or use the new behavior.
@@ -940,15 +949,8 @@ def is_compat_version_prior_to(options, 
breaking_change_version):
   update_compatibility_version = options.view_as(
       pipeline_options.StreamingOptions).update_compatibility_version
 
-  if update_compatibility_version is None:
-    return False
-
-  compat_version = tuple(map(int, 
update_compatibility_version.split('.')[0:3]))
-  change_version = tuple(map(int, breaking_change_version.split('.')[0:3]))
-  for i in range(min(len(compat_version), len(change_version))):
-    if compat_version[i] < change_version[i]:
-      return True
-  return False
+  return is_v1_prior_to_v2(
+      v1=update_compatibility_version, v2=breaking_change_version)
 
 
 def reify_metadata_default_window(
diff --git a/sdks/python/apache_beam/transforms/util_test.py 
b/sdks/python/apache_beam/transforms/util_test.py
index ad185ac6a6d..b365d9b2209 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -2193,6 +2193,68 @@ class WaitOnTest(unittest.TestCase):
           label='result')
 
 
+class CompatCheckTest(unittest.TestCase):
+  def test_is_v1_prior_to_v2(self):
+    test_cases = [
+        # Basic comparison cases
+        ("1.0.0", "2.0.0", True),  # v1 < v2 in major
+        ("2.0.0", "1.0.0", False),  # v1 > v2 in major
+        ("1.1.0", "1.2.0", True),  # v1 < v2 in minor
+        ("1.2.0", "1.1.0", False),  # v1 > v2 in minor
+        ("1.0.1", "1.0.2", True),  # v1 < v2 in patch
+        ("1.0.2", "1.0.1", False),  # v1 > v2 in patch
+
+        # Equal versions
+        ("1.0.0", "1.0.0", False),  # Identical
+        ("0.0.0", "0.0.0", False),  # Both zero
+
+        # Different lengths - shorter vs longer
+        ("1.0", "1.0.0", False),  # Should be equal (1.0 = 1.0.0)
+        ("1.0", "1.0.1", True),  # 1.0.0 < 1.0.1
+        ("1.2", "1.2.0", False),  # Should be equal (1.2 = 1.2.0)
+        ("1.2", "1.2.3", True),  # 1.2.0 < 1.2.3
+        ("2", "2.0.0", False),  # Should be equal (2 = 2.0.0)
+        ("2", "2.0.1", True),  # 2.0.0 < 2.0.1
+        ("1", "2.0", True),  # 1.0.0 < 2.0.0
+
+        # Different lengths - longer vs shorter
+        ("1.0.0", "1.0", False),  # Should be equal
+        ("1.0.1", "1.0", False),  # 1.0.1 > 1.0.0
+        ("1.2.0", "1.2", False),  # Should be equal
+        ("1.2.3", "1.2", False),  # 1.2.3 > 1.2.0
+        ("2.0.0", "2", False),  # Should be equal
+        ("2.0.1", "2", False),  # 2.0.1 > 2.0.0
+        ("2.0", "1", False),  # 2.0.0 > 1.0.0
+
+        # Mixed length comparisons
+        ("1.0", "2.0.0", True),  # 1.0.0 < 2.0.0
+        ("2.0", "1.0.0", False),  # 2.0.0 > 1.0.0
+        ("1", "1.0.1", True),  # 1.0.0 < 1.0.1
+        ("1.1", "1.0.9", False),  # 1.1.0 > 1.0.9
+
+        # Large numbers
+        ("1.9.9", "2.0.0", True),  # 1.9.9 < 2.0.0
+        ("10.0.0", "9.9.9", False),  # 10.0.0 > 9.9.9
+        ("1.10.0", "1.9.0", False),  # 1.10.0 > 1.9.0
+        ("1.2.10", "1.2.9", False),  # 1.2.10 > 1.2.9
+
+        # Sequential versions
+        ("1.0.0", "1.0.1", True),
+        ("1.0.1", "1.0.2", True),
+        ("1.0.9", "1.1.0", True),
+        ("1.9.9", "2.0.0", True),
+
+        # Null/None cases
+        (None, "1.0.0", False),  # v1 is None
+    ]
+
+    for v1, v2, expected in test_cases:
+      self.assertEqual(
+          util.is_v1_prior_to_v2(v1=v1, v2=v2),
+          expected,
+          msg=f"Failed {v1} < {v2} == {expected}")
+
+
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()

Reply via email to