This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4289ea20a8d Allow stateful exception handling (#35965)
4289ea20a8d is described below

commit 4289ea20a8dc962dab0ead17ec8216a63655c48d
Author: Danny McCormick <[email protected]>
AuthorDate: Tue Sep 2 15:41:50 2025 -0400

    Allow stateful exception handling (#35965)
    
    * [WIP] Allow stateful exception handling
    
    * Fix state
    
    * A bit more conservative
    
    * Linting
    
    * lint
---
 sdks/python/apache_beam/transforms/core.py      | 64 +++++++++++++++++++++---
 sdks/python/apache_beam/transforms/core_test.py | 65 +++++++++++++++++++++++++
 2 files changed, 123 insertions(+), 6 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 67f232b9ffa..e7d88e8884e 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -30,6 +30,7 @@ import traceback
 import types
 import typing
 from collections import defaultdict
+from functools import wraps
 from itertools import dropwhile
 
 from apache_beam import coders
@@ -1596,7 +1597,8 @@ class ParDo(PTransformWithSideInputs):
       timeout=None,
       error_handler=None,
       on_failure_callback: typing.Optional[typing.Callable[
-          [Exception, typing.Any], None]] = None):
+          [Exception, typing.Any], None]] = None,
+      allow_unsafe_userstate_in_process=False):
     """Automatically provides a dead letter output for saving bad inputs.
     This can allow a pipeline to continue successfully rather than fail or
     continuously throw errors on retry when bad elements are encountered.
@@ -1653,6 +1655,13 @@ class ParDo(PTransformWithSideInputs):
           the exception will be of type `TimeoutError`. Be careful with this
           callback - if you set a timeout, it will not apply to the callback,
           and if the callback fails it will not be retried.
+      allow_unsafe_userstate_in_process: If False, user state will not be
+          permitted in the DoFn's process method. This is disabled by default
+          because user state is potentially unsafe with exception handling
+          since it can be successfully stored or cleared even if the associated
+          element fails and is routed to a dead letter queue. Semantics around
+          state in this kind of failure scenario are not well defined and are
+          subject to change.
     """
     args, kwargs = self.raw_side_inputs
     return self.label >> _ExceptionHandlingWrapper(
@@ -1668,7 +1677,8 @@ class ParDo(PTransformWithSideInputs):
         threshold_windowing,
         timeout,
         error_handler,
-        on_failure_callback)
+        on_failure_callback,
+        allow_unsafe_userstate_in_process)
 
   def with_error_handler(self, error_handler, **exception_handling_kwargs):
     """An alias for `with_exception_handling(error_handler=error_handler, ...)`
@@ -2273,7 +2283,8 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
       threshold_windowing,
       timeout,
       error_handler,
-      on_failure_callback):
+      on_failure_callback,
+      allow_unsafe_userstate_in_process):
     if partial and use_subprocess:
       raise ValueError('partial and use_subprocess are mutually incompatible.')
     self._fn = fn
@@ -2289,8 +2300,17 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
     self._timeout = timeout
     self._error_handler = error_handler
     self._on_failure_callback = on_failure_callback
+    self._allow_unsafe_userstate_in_process = allow_unsafe_userstate_in_process
 
   def expand(self, pcoll):
+    if self._allow_unsafe_userstate_in_process:
+      if self._use_subprocess or self._timeout:
+        # TODO(https://github.com/apache/beam/issues/35976): Implement this
+        raise Exception(
+            'allow_unsafe_userstate_in_process is incompatible with ' +
+            'exception handling done with subprocesses or timeouts. If you ' +
+            'need this feature, comment in ' +
+            'https://github.com/apache/beam/issues/35976')
     if self._use_subprocess:
       wrapped_fn = _SubprocessDoFn(self._fn, timeout=self._timeout)
     elif self._timeout:
@@ -2303,7 +2323,8 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
             self._dead_letter_tag,
             self._exc_class,
             self._partial,
-            self._on_failure_callback),
+            self._on_failure_callback,
+            self._allow_unsafe_userstate_in_process),
         *self._args,
         **self._kwargs).with_outputs(
             self._dead_letter_tag, main=self._main_tag, 
allow_unknown_tags=True)
@@ -2347,13 +2368,44 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
 
 class _ExceptionHandlingWrapperDoFn(DoFn):
   def __init__(
-      self, fn, dead_letter_tag, exc_class, partial, on_failure_callback):
+      self,
+      fn,
+      dead_letter_tag,
+      exc_class,
+      partial,
+      on_failure_callback,
+      allow_unsafe_userstate_in_process):
     self._fn = fn
     self._dead_letter_tag = dead_letter_tag
     self._exc_class = exc_class
     self._partial = partial
     self._on_failure_callback = on_failure_callback
 
+    # Wrap process and expose any top level state params so that process can
+    # handle state and timers.
+    if allow_unsafe_userstate_in_process:
+
+      @wraps(self._fn.process)
+      def process_wrapper(self, *args, **kwargs):
+        return self.exception_handling_wrapper_do_fn_custom_process(
+            *args, **kwargs)
+
+      self.process = types.MethodType(process_wrapper, self)
+    else:
+      self.process = self.exception_handling_wrapper_do_fn_custom_process
+      process_sig = inspect.signature(self._fn.process)
+      for name, param in process_sig.parameters.items():
+        if isinstance(param.default, (DoFn.StateParam, DoFn.TimerParam)):
+          logging.warning(
+              'State or timer parameter {} detected in process method of ' +
+              '{}. State and timers are unsupported when using ' +
+              'with_exception_handling and may lead to errors. To enable ' +
+              'state and timers with limited consistency guarantees, pass ' +
+              'in the allow_unsafe_userstate_in_process parameters to the ' +
+              'with_exception_handling method.',
+              name,
+              self.fn)
+
   def __getattribute__(self, name):
     if (name.startswith('__') or name in self.__dict__ or
         name in _ExceptionHandlingWrapperDoFn.__dict__):
@@ -2361,7 +2413,7 @@ class _ExceptionHandlingWrapperDoFn(DoFn):
     else:
       return getattr(self._fn, name)
 
-  def process(self, *args, **kwargs):
+  def exception_handling_wrapper_do_fn_custom_process(self, *args, **kwargs):
     try:
       result = self._fn.process(*args, **kwargs)
       if not self._partial:
diff --git a/sdks/python/apache_beam/transforms/core_test.py 
b/sdks/python/apache_beam/transforms/core_test.py
index 0d6be3fc7ae..a4fa3b52817 100644
--- a/sdks/python/apache_beam/transforms/core_test.py
+++ b/sdks/python/apache_beam/transforms/core_test.py
@@ -27,8 +27,13 @@ from typing import TypeVar
 import pytest
 
 import apache_beam as beam
+from apache_beam.coders import coders
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.transforms.userstate import BagStateSpec
+from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
+from apache_beam.transforms.userstate import TimerSpec
+from apache_beam.transforms.userstate import on_timer
 from apache_beam.transforms.window import FixedWindows
 from apache_beam.typehints import TypeCheckError
 from apache_beam.typehints import row_type
@@ -120,6 +125,42 @@ class TestDoFn12(beam.DoFn):
     return
 
 
+class TestDoFnStateful(beam.DoFn):
+  STATE_SPEC = ReadModifyWriteStateSpec('num_elements', coders.VarIntCoder())
+  """test process with a stateful dofn"""
+  def process(self, element, state=beam.DoFn.StateParam(STATE_SPEC)):
+    if len(element[1]) > 3:
+      raise ValueError('Not allowed to have long elements')
+    current_value = state.read() or 1
+    state.write(current_value + 1)
+    yield current_value
+
+
+class TestDoFnWithTimer(beam.DoFn):
+  ALL_ELEMENTS = BagStateSpec('buffer', coders.VarIntCoder())
+  TIMER = TimerSpec('timer', beam.TimeDomain.WATERMARK)
+  """test process with a stateful dofn"""
+  def process(
+      self,
+      element,
+      t=beam.DoFn.TimestampParam,
+      state=beam.DoFn.StateParam(ALL_ELEMENTS),
+      timer=beam.DoFn.TimerParam(TIMER)):
+    if element[1] > 3:
+      raise ValueError('Not allowed to have large numbers')
+    state.add(element[1])
+    timer.set(t)
+
+    return []
+
+  @on_timer(TIMER)
+  def expiry_callback(self, state=beam.DoFn.StateParam(ALL_ELEMENTS)):
+    unique_elements = list(state.read())
+    state.clear()
+
+    return unique_elements
+
+
 class CreateTest(unittest.TestCase):
   @pytest.fixture(autouse=True)
   def inject_fixtures(self, caplog):
@@ -296,6 +337,30 @@ class ExceptionHandlingTest(unittest.TestCase):
         assert_that(bad_elements, equal_to([]), 'bad')
       self.assertFalse(os.path.isfile(tmp_path))
 
+  def test_stateful_exception_handling(self):
+    with beam.Pipeline() as pipeline:
+      good, bad = (
+        pipeline | beam.Create([(1, 'abc'), (1, 'long_word'),
+                                (1, 'foo'), (1, 'bar'), (1, 'foobar')])
+        | beam.ParDo(TestDoFnStateful()).with_exception_handling(
+          allow_unsafe_userstate_in_process=True)
+      )
+      bad_elements = bad | beam.Keys()
+      assert_that(good, equal_to([1, 2, 3]), 'good')
+      assert_that(
+          bad_elements, equal_to([(1, 'long_word'), (1, 'foobar')]), 'bad')
+
+  def test_timer_exception_handling(self):
+    with beam.Pipeline() as pipeline:
+      good, bad = (
+        pipeline | beam.Create([(1, 0), (1, 1), (1, 2), (1, 5), (1, 10)])
+        | beam.ParDo(TestDoFnWithTimer()).with_exception_handling(
+          allow_unsafe_userstate_in_process=True)
+      )
+      bad_elements = bad | beam.Keys()
+      assert_that(good, equal_to([0, 1, 2]), 'good')
+      assert_that(bad_elements, equal_to([(1, 5), (1, 10)]), 'bad')
+
 
 def test_callablewrapper_typehint():
   T = TypeVar("T")

Reply via email to