This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new f7a7bddb3ba Faster default coder for unknown windows. (#33382)
f7a7bddb3ba is described below

commit f7a7bddb3bad80275e59591d8444e68b61ece760
Author: Robert Bradshaw <[email protected]>
AuthorDate: Tue Dec 17 08:43:35 2024 -0800

    Faster default coder for unknown windows. (#33382)
    
    This will get used in a windowed reshuffle, among other places.
---
 sdks/python/apache_beam/coders/coder_impl.pxd      | 12 +++++++
 sdks/python/apache_beam/coders/coder_impl.py       | 31 ++++++++++++++++++
 sdks/python/apache_beam/coders/coders.py           | 38 +++++++++++++++++++++-
 .../apache_beam/coders/coders_test_common.py       |  8 +++++
 4 files changed, 88 insertions(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/coders/coder_impl.pxd 
b/sdks/python/apache_beam/coders/coder_impl.pxd
index 52889fa2fd9..8a28499555c 100644
--- a/sdks/python/apache_beam/coders/coder_impl.pxd
+++ b/sdks/python/apache_beam/coders/coder_impl.pxd
@@ -219,6 +219,18 @@ cdef libc.stdint.int64_t MIN_TIMESTAMP_micros
 cdef libc.stdint.int64_t MAX_TIMESTAMP_micros
 
 
+cdef class _OrderedUnionCoderImpl(StreamCoderImpl):
+  cdef tuple _types
+  cdef tuple _coder_impls
+  cdef CoderImpl _fallback_coder_impl
+
+  @cython.locals(ix=int, c=CoderImpl)
+  cpdef encode_to_stream(self, value, OutputStream stream, bint nested)
+
+  @cython.locals(ix=int, c=CoderImpl)
+  cpdef decode_from_stream(self, InputStream stream, bint nested)
+
+
 cdef class WindowedValueCoderImpl(StreamCoderImpl):
   """A coder for windowed values."""
   cdef CoderImpl _value_coder
diff --git a/sdks/python/apache_beam/coders/coder_impl.py 
b/sdks/python/apache_beam/coders/coder_impl.py
index 5262e6adf8a..5dff3505290 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -1421,6 +1421,37 @@ class PaneInfoCoderImpl(StreamCoderImpl):
     return size
 
 
+class _OrderedUnionCoderImpl(StreamCoderImpl):
+  def __init__(self, coder_impl_types, fallback_coder_impl):
+    assert len(coder_impl_types) < 128
+    self._types, self._coder_impls = zip(*coder_impl_types)
+    self._fallback_coder_impl = fallback_coder_impl
+
+  def encode_to_stream(self, value, out, nested):
+    value_t = type(value)
+    for (ix, t) in enumerate(self._types):
+      if value_t is t:
+        out.write_byte(ix)
+        c = self._coder_impls[ix]  # for typing
+        c.encode_to_stream(value, out, nested)
+        break
+    else:
+      if self._fallback_coder_impl is None:
+        raise ValueError("No fallback.")
+      out.write_byte(0xFF)
+      self._fallback_coder_impl.encode_to_stream(value, out, nested)
+
+  def decode_from_stream(self, in_stream, nested):
+    ix = in_stream.read_byte()
+    if ix == 0xFF:
+      if self._fallback_coder_impl is None:
+        raise ValueError("No fallback.")
+      return self._fallback_coder_impl.decode_from_stream(in_stream, nested)
+    else:
+      c = self._coder_impls[ix]  # for typing
+      return c.decode_from_stream(in_stream, nested)
+
+
 class WindowedValueCoderImpl(StreamCoderImpl):
   """For internal use only; no backwards-compatibility guarantees.
 
diff --git a/sdks/python/apache_beam/coders/coders.py 
b/sdks/python/apache_beam/coders/coders.py
index 724f268a831..e52c6048a15 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -1350,12 +1350,48 @@ Coder.register_structured_urn(
     common_urns.coders.INTERVAL_WINDOW.urn, IntervalWindowCoder)
 
 
+class _OrderedUnionCoder(FastCoder):
+  def __init__(
+      self, *coder_types: Tuple[type, Coder], fallback_coder: Optional[Coder]):
+    self._coder_types = coder_types
+    self._fallback_coder = fallback_coder
+
+  def _create_impl(self):
+    return coder_impl._OrderedUnionCoderImpl(
+        [(t, c.get_impl()) for t, c in self._coder_types],
+        fallback_coder_impl=self._fallback_coder.get_impl()
+        if self._fallback_coder else None)
+
+  def is_deterministic(self) -> bool:
+    return (
+        all(c.is_deterministic for _, c in self._coder_types) and (
+            self._fallback_coder is None or
+            self._fallback_coder.is_deterministic()))
+
+  def to_type_hint(self):
+    return Any
+
+  def __eq__(self, other):
+    return (
+        type(self) == type(other) and
+        self._coder_types == other._coder_types and
+        self._fallback_coder == other._fallback_coder)
+
+  def __hash__(self):
+    return hash((type(self), tuple(self._coder_types), self._fallback_coder))
+
+
 class WindowedValueCoder(FastCoder):
   """Coder for windowed values."""
   def __init__(self, wrapped_value_coder, window_coder=None):
     # type: (Coder, Optional[Coder]) -> None
     if not window_coder:
-      window_coder = PickleCoder()
+      # Avoid circular imports.
+      from apache_beam.transforms import window
+      window_coder = _OrderedUnionCoder(
+          (window.GlobalWindow, GlobalWindowCoder()),
+          (window.IntervalWindow, IntervalWindowCoder()),
+          fallback_coder=PickleCoder())
     self.wrapped_value_coder = wrapped_value_coder
     self.timestamp_coder = TimestampCoder()
     self.window_coder = window_coder
diff --git a/sdks/python/apache_beam/coders/coders_test_common.py 
b/sdks/python/apache_beam/coders/coders_test_common.py
index 4bd9698dd57..f3381cdb1d6 100644
--- a/sdks/python/apache_beam/coders/coders_test_common.py
+++ b/sdks/python/apache_beam/coders/coders_test_common.py
@@ -769,6 +769,14 @@ class CodersTest(unittest.TestCase):
           test_encodings[idx],
           base64.b64encode(test_coder.encode(value)).decode().rstrip("="))
 
+  def test_OrderedUnionCoder(self):
+    test_coder = coders._OrderedUnionCoder((str, coders.StrUtf8Coder()),
+                                           (int, coders.VarIntCoder()),
+                                           fallback_coder=coders.FloatCoder())
+    self.check_coder(test_coder, 's')
+    self.check_coder(test_coder, 123)
+    self.check_coder(test_coder, 1.5)
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)

Reply via email to