This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 4289ea20a8d Allow stateful exception handling (#35965)
4289ea20a8d is described below
commit 4289ea20a8dc962dab0ead17ec8216a63655c48d
Author: Danny McCormick <[email protected]>
AuthorDate: Tue Sep 2 15:41:50 2025 -0400
Allow stateful exception handling (#35965)
* [WIP] Allow stateful exception handling
* Fix state
* A bit more conservative
* Linting
* lint
---
sdks/python/apache_beam/transforms/core.py | 64 +++++++++++++++++++++---
sdks/python/apache_beam/transforms/core_test.py | 65 +++++++++++++++++++++++++
2 files changed, 123 insertions(+), 6 deletions(-)
diff --git a/sdks/python/apache_beam/transforms/core.py
b/sdks/python/apache_beam/transforms/core.py
index 67f232b9ffa..e7d88e8884e 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -30,6 +30,7 @@ import traceback
import types
import typing
from collections import defaultdict
+from functools import wraps
from itertools import dropwhile
from apache_beam import coders
@@ -1596,7 +1597,8 @@ class ParDo(PTransformWithSideInputs):
timeout=None,
error_handler=None,
on_failure_callback: typing.Optional[typing.Callable[
- [Exception, typing.Any], None]] = None):
+ [Exception, typing.Any], None]] = None,
+ allow_unsafe_userstate_in_process=False):
"""Automatically provides a dead letter output for saving bad inputs.
This can allow a pipeline to continue successfully rather than fail or
continuously throw errors on retry when bad elements are encountered.
@@ -1653,6 +1655,13 @@ class ParDo(PTransformWithSideInputs):
the exception will be of type `TimeoutError`. Be careful with this
callback - if you set a timeout, it will not apply to the callback,
and if the callback fails it will not be retried.
+ allow_unsafe_userstate_in_process: If False, user state will not be
+ permitted in the DoFn's process method. This is disabled by default
+ because user state is potentially unsafe with exception handling
+ since it can be successfully stored or cleared even if the associated
+ element fails and is routed to a dead letter queue. Semantics around
+ state in this kind of failure scenario are not well defined and are
+ subject to change.
"""
args, kwargs = self.raw_side_inputs
return self.label >> _ExceptionHandlingWrapper(
@@ -1668,7 +1677,8 @@ class ParDo(PTransformWithSideInputs):
threshold_windowing,
timeout,
error_handler,
- on_failure_callback)
+ on_failure_callback,
+ allow_unsafe_userstate_in_process)
def with_error_handler(self, error_handler, **exception_handling_kwargs):
"""An alias for `with_exception_handling(error_handler=error_handler, ...)`
@@ -2273,7 +2283,8 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
threshold_windowing,
timeout,
error_handler,
- on_failure_callback):
+ on_failure_callback,
+ allow_unsafe_userstate_in_process):
if partial and use_subprocess:
raise ValueError('partial and use_subprocess are mutually incompatible.')
self._fn = fn
@@ -2289,8 +2300,17 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
self._timeout = timeout
self._error_handler = error_handler
self._on_failure_callback = on_failure_callback
+ self._allow_unsafe_userstate_in_process = allow_unsafe_userstate_in_process
def expand(self, pcoll):
+ if self._allow_unsafe_userstate_in_process:
+ if self._use_subprocess or self._timeout:
+ # TODO(https://github.com/apache/beam/issues/35976): Implement this
+ raise Exception(
+ 'allow_unsafe_userstate_in_process is incompatible with ' +
+ 'exception handling done with subprocesses or timeouts. If you ' +
+ 'need this feature, comment in ' +
+ 'https://github.com/apache/beam/issues/35976')
if self._use_subprocess:
wrapped_fn = _SubprocessDoFn(self._fn, timeout=self._timeout)
elif self._timeout:
@@ -2303,7 +2323,8 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
self._dead_letter_tag,
self._exc_class,
self._partial,
- self._on_failure_callback),
+ self._on_failure_callback,
+ self._allow_unsafe_userstate_in_process),
*self._args,
**self._kwargs).with_outputs(
self._dead_letter_tag, main=self._main_tag,
allow_unknown_tags=True)
@@ -2347,13 +2368,44 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
class _ExceptionHandlingWrapperDoFn(DoFn):
def __init__(
- self, fn, dead_letter_tag, exc_class, partial, on_failure_callback):
+ self,
+ fn,
+ dead_letter_tag,
+ exc_class,
+ partial,
+ on_failure_callback,
+ allow_unsafe_userstate_in_process):
self._fn = fn
self._dead_letter_tag = dead_letter_tag
self._exc_class = exc_class
self._partial = partial
self._on_failure_callback = on_failure_callback
+ # Wrap process and expose any top level state params so that process can
+ # handle state and timers.
+ if allow_unsafe_userstate_in_process:
+
+ @wraps(self._fn.process)
+ def process_wrapper(self, *args, **kwargs):
+ return self.exception_handling_wrapper_do_fn_custom_process(
+ *args, **kwargs)
+
+ self.process = types.MethodType(process_wrapper, self)
+ else:
+ self.process = self.exception_handling_wrapper_do_fn_custom_process
+ process_sig = inspect.signature(self._fn.process)
+ for name, param in process_sig.parameters.items():
+ if isinstance(param.default, (DoFn.StateParam, DoFn.TimerParam)):
+ logging.warning(
+ 'State or timer parameter {} detected in process method of ' +
+ '{}. State and timers are unsupported when using ' +
+ 'with_exception_handling and may lead to errors. To enable ' +
+ 'state and timers with limited consistency guarantees, pass ' +
+ 'in the allow_unsafe_userstate_in_process parameters to the ' +
+ 'with_exception_handling method.',
+ name,
+ self.fn)
+
def __getattribute__(self, name):
if (name.startswith('__') or name in self.__dict__ or
name in _ExceptionHandlingWrapperDoFn.__dict__):
@@ -2361,7 +2413,7 @@ class _ExceptionHandlingWrapperDoFn(DoFn):
else:
return getattr(self._fn, name)
- def process(self, *args, **kwargs):
+ def exception_handling_wrapper_do_fn_custom_process(self, *args, **kwargs):
try:
result = self._fn.process(*args, **kwargs)
if not self._partial:
diff --git a/sdks/python/apache_beam/transforms/core_test.py
b/sdks/python/apache_beam/transforms/core_test.py
index 0d6be3fc7ae..a4fa3b52817 100644
--- a/sdks/python/apache_beam/transforms/core_test.py
+++ b/sdks/python/apache_beam/transforms/core_test.py
@@ -27,8 +27,13 @@ from typing import TypeVar
import pytest
import apache_beam as beam
+from apache_beam.coders import coders
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
+from apache_beam.transforms.userstate import BagStateSpec
+from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
+from apache_beam.transforms.userstate import TimerSpec
+from apache_beam.transforms.userstate import on_timer
from apache_beam.transforms.window import FixedWindows
from apache_beam.typehints import TypeCheckError
from apache_beam.typehints import row_type
@@ -120,6 +125,42 @@ class TestDoFn12(beam.DoFn):
return
+class TestDoFnStateful(beam.DoFn):
+ STATE_SPEC = ReadModifyWriteStateSpec('num_elements', coders.VarIntCoder())
+ """test process with a stateful dofn"""
+ def process(self, element, state=beam.DoFn.StateParam(STATE_SPEC)):
+ if len(element[1]) > 3:
+ raise ValueError('Not allowed to have long elements')
+ current_value = state.read() or 1
+ state.write(current_value + 1)
+ yield current_value
+
+
+class TestDoFnWithTimer(beam.DoFn):
+ ALL_ELEMENTS = BagStateSpec('buffer', coders.VarIntCoder())
+ TIMER = TimerSpec('timer', beam.TimeDomain.WATERMARK)
+ """test process with a stateful dofn"""
+ def process(
+ self,
+ element,
+ t=beam.DoFn.TimestampParam,
+ state=beam.DoFn.StateParam(ALL_ELEMENTS),
+ timer=beam.DoFn.TimerParam(TIMER)):
+ if element[1] > 3:
+ raise ValueError('Not allowed to have large numbers')
+ state.add(element[1])
+ timer.set(t)
+
+ return []
+
+ @on_timer(TIMER)
+ def expiry_callback(self, state=beam.DoFn.StateParam(ALL_ELEMENTS)):
+ unique_elements = list(state.read())
+ state.clear()
+
+ return unique_elements
+
+
class CreateTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
@@ -296,6 +337,30 @@ class ExceptionHandlingTest(unittest.TestCase):
assert_that(bad_elements, equal_to([]), 'bad')
self.assertFalse(os.path.isfile(tmp_path))
+ def test_stateful_exception_handling(self):
+ with beam.Pipeline() as pipeline:
+ good, bad = (
+ pipeline | beam.Create([(1, 'abc'), (1, 'long_word'),
+ (1, 'foo'), (1, 'bar'), (1, 'foobar')])
+ | beam.ParDo(TestDoFnStateful()).with_exception_handling(
+ allow_unsafe_userstate_in_process=True)
+ )
+ bad_elements = bad | beam.Keys()
+ assert_that(good, equal_to([1, 2, 3]), 'good')
+ assert_that(
+ bad_elements, equal_to([(1, 'long_word'), (1, 'foobar')]), 'bad')
+
+ def test_timer_exception_handling(self):
+ with beam.Pipeline() as pipeline:
+ good, bad = (
+ pipeline | beam.Create([(1, 0), (1, 1), (1, 2), (1, 5), (1, 10)])
+ | beam.ParDo(TestDoFnWithTimer()).with_exception_handling(
+ allow_unsafe_userstate_in_process=True)
+ )
+ bad_elements = bad | beam.Keys()
+ assert_that(good, equal_to([0, 1, 2]), 'good')
+ assert_that(bad_elements, equal_to([(1, 5), (1, 10)]), 'bad')
+
def test_callablewrapper_typehint():
T = TypeVar("T")