claudevdm commented on code in PR #35433:
URL: https://github.com/apache/beam/pull/35433#discussion_r2167531174


##########
sdks/python/apache_beam/coders/coders_test_common.py:
##########
@@ -606,6 +608,131 @@ def test_param_windowed_value_coder(self):
                 1, (window.IntervalWindow(11, 21), ),
                 PaneInfo(True, False, 1, 2, 3))))
 
+  def test_cross_process_deterministic_special_types(self):
+    """Test cross-process determinism for all special deterministic types"""
+    # pylint: disable=line-too-long
+    script = textwrap.dedent(
+        '''\
+        import pickle
+        import sys
+        import collections
+        import enum
+        import logging
+
+        from apache_beam.coders import coders
+        from apache_beam.coders import proto2_coder_test_messages_pb2 as 
test_message
+        from typing import NamedTuple
+
+        try:
+            import dataclasses
+        except ImportError:
+            dataclasses = None
+
+        logging.basicConfig(
+            level=logging.INFO,
+            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+            stream=sys.stderr,
+            force=True
+        )
+
+        # Define all the special types that encode_special_deterministic 
handles
+        MyNamedTuple = collections.namedtuple('A', ['x', 'y'])
+        MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), 
('f2', str)])
+
+        class MyEnum(enum.Enum):
+            E1 = 5
+            E2 = enum.auto()
+            E3 = 'abc'
+
+        MyIntEnum = enum.IntEnum('MyIntEnum', 'I1 I2 I3')
+        MyIntFlag = enum.IntFlag('MyIntFlag', 'F1 F2 F3')
+        MyFlag = enum.Flag('MyFlag', 'F1 F2 F3')
+
+        if dataclasses is not None:
+            @dataclasses.dataclass(frozen=True)
+            class FrozenDataClass:
+                a: int
+                b: int
+
+        class DefinesGetAndSetState:
+            def __init__(self, value):
+                self.value = value
+
+            def __getstate__(self):
+                return self.value
+
+            def __setstate__(self, value):
+                self.value = value
+
+            def __eq__(self, other):
+                return type(other) is type(self) and other.value == self.value
+        
+        # Test cases for all special deterministic types
+        # NOTE: When this script run in a subprocess the module is considered
+        #  __main__. Dill cannot pickle enums in __main__ because it
+        # needs to define a way to create the type if it does not exist
+        # in the session, and reaches recursion depth limits.
+        test_cases = [
+            ("proto_message", test_message.MessageA(field1='value')),
+            ("named_tuple_simple", MyNamedTuple(1, 2)),
+            ("typed_named_tuple", MyTypedNamedTuple(1, 'a')),
+            ("named_tuple_list", [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 
'a')]),
+            # ("enum_single", MyEnum.E1),
+            # ("enum_list", list(MyEnum)),
+            # ("int_enum_list", list(MyIntEnum)),
+            # ("int_flag_list", list(MyIntFlag)),
+            # ("flag_list", list(MyFlag)),
+            ("getstate_setstate_simple", DefinesGetAndSetState(1)),
+            ("getstate_setstate_complex", DefinesGetAndSetState((1, 2, 3))),
+            ("getstate_setstate_list", [DefinesGetAndSetState(1), 
DefinesGetAndSetState((1, 2, 3))]),
+        ]
+
+        if dataclasses is not None:
+            test_cases.extend([
+                ("frozen_dataclass", FrozenDataClass(1, 2)),
+                ("frozen_dataclass_list", [FrozenDataClass(1, 2), 
FrozenDataClass(3, 4)]),
+            ])
+
+        coder = coders.FastPrimitivesCoder()
+        deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 
'step')
+        
+        results = {}
+        for test_name, value in test_cases:
+            try:
+                encoded = deterministic_coder.encode(value)
+                results[test_name] = encoded
+            except Exception as e:
+              logging.warning("Encoding failed with %s", e)
+              sys.exit(1)
+        
+        sys.stdout.buffer.write(pickle.dumps(results))
+                
+        
+    ''')
+
+    def run_subprocess():
+      import subprocess
+      import sys
+
+      result = subprocess.run([sys.executable, '-c', script],
+                              capture_output=True,
+                              timeout=30,
+                              check=False)
+
+      self.assertEqual(
+          0, result.returncode, f"Subprocess failed: {result.stderr}")
+      return pickle.loads(result.stdout)
+
+    results1 = run_subprocess()
+    results2 = run_subprocess()
+
+    for test_name in results1:
+      data1 = results1[test_name]
+      data2 = results2[test_name]
+
+      self.assertEqual(

Review Comment:
   I added some more assertions for equality and instance semantics for the 
decoded types. I can add some more validation but I think decoding should fail 
if it is not empty?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to