This is an automated email from the ASF dual-hosted git repository. riteshghorse pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 243c795a843 Add timeout parameter to with_exception_handling. (#26526) 243c795a843 is described below commit 243c795a843d96fe1da5731c9d1120caa604a998 Author: Robert Bradshaw <rober...@gmail.com> AuthorDate: Wed May 3 13:36:20 2023 -0700 Add timeout parameter to with_exception_handling. (#26526) --- sdks/python/apache_beam/transforms/core.py | 90 ++++++++++++++++++---- .../apache_beam/transforms/ptransform_test.py | 23 ++++++ 2 files changed, 100 insertions(+), 13 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 6260975b32c..22f6170c546 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -25,6 +25,7 @@ import inspect import logging import random import sys +import time import traceback import types import typing @@ -1502,7 +1503,8 @@ class ParDo(PTransformWithSideInputs): partial=False, use_subprocess=False, threshold=1, - threshold_windowing=None): + threshold_windowing=None, + timeout=None): """Automatically provides a dead letter output for skipping bad records. This can allow a pipeline to continue successfully rather than fail or continuously throw errors on retry when bad elements are encountered. @@ -1548,6 +1550,8 @@ class ParDo(PTransformWithSideInputs): up to 100% of records can be bad and the pipeline will still succeed). threshold_windowing: Event-time windowing to use for threshold. Optional, defaults to the windowing of the input. + timeout: If the element has not finished processing in timeout seconds, + raise a TimeoutError. Defaults to None, meaning no time limit. """ args, kwargs = self.raw_side_inputs return self.label >> _ExceptionHandlingWrapper( @@ -1560,7 +1564,8 @@ class ParDo(PTransformWithSideInputs): partial, use_subprocess, threshold, - threshold_windowing) + threshold_windowing, + timeout) def default_type_hints(self): return self.fn.get_type_hints() @@ -2152,7 +2157,8 @@ class _ExceptionHandlingWrapper(ptransform.PTransform): partial, use_subprocess, threshold, - threshold_windowing): + threshold_windowing, + timeout): if partial and use_subprocess: raise ValueError('partial and use_subprocess are mutually incompatible.') self._fn = fn @@ -2165,14 +2171,18 @@ class _ExceptionHandlingWrapper(ptransform.PTransform): self._use_subprocess = use_subprocess self._threshold = threshold self._threshold_windowing = threshold_windowing + self._timeout = timeout def expand(self, pcoll): + if self._use_subprocess: + wrapped_fn = _SubprocessDoFn(self._fn, timeout=self._timeout) + elif self._timeout: + wrapped_fn = _TimeoutDoFn(self._fn, timeout=self._timeout) + else: + wrapped_fn = self._fn result = pcoll | ParDo( _ExceptionHandlingWrapperDoFn( - _SubprocessDoFn(self._fn) if self._use_subprocess else self._fn, - self._dead_letter_tag, - self._exc_class, - self._partial), + wrapped_fn, self._dead_letter_tag, self._exc_class, self._partial), *self._args, **self._kwargs).with_outputs( self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True) @@ -2241,9 +2251,10 @@ class _ExceptionHandlingWrapperDoFn(DoFn): class _SubprocessDoFn(DoFn): """Process method run in a subprocess, turning hard crashes into exceptions. """ - def __init__(self, fn): + def __init__(self, fn, timeout=None): self._fn = fn self._serialized_fn = pickler.dumps(fn) + self._timeout = timeout def __getattribute__(self, name): if (name.startswith('__') or name in self.__dict__ or @@ -2268,19 +2279,35 @@ class _SubprocessDoFn(DoFn): def teardown(self): self._call_remote(self._remote_teardown) - self._pool.shutdown() - self._pool = None + self._terminate_pool() def _call_remote(self, method, *args, **kwargs): if self._pool is None: self._pool = concurrent.futures.ProcessPoolExecutor(1) self._pool.submit(self._remote_init, self._serialized_fn).result() try: - return self._pool.submit(method, *args, **kwargs).result() - except concurrent.futures.process.BrokenProcessPool: - self._pool = None + return self._pool.submit(method, *args, **kwargs).result( + self._timeout if method == self._remote_process else None) + except (concurrent.futures.process.BrokenProcessPool, + TimeoutError, + concurrent.futures._base.TimeoutError): + self._terminate_pool() raise + def _terminate_pool(self): + """Forcibly terminate the pool, not leaving any live subprocesses.""" + pool = self._pool + self._pool = None + processes = list(pool._processes.values()) + pool.shutdown(wait=False) + for p in processes: + if p.is_alive(): + p.kill() + time.sleep(1) + for p in processes: + if p.is_alive(): + p.terminate() + # These are classmethods to avoid picking the state of self. # They should only be called in an isolated process, so there's no concern # about sharing state or thread safety. @@ -2321,6 +2348,43 @@ class _SubprocessDoFn(DoFn): cls._fn = None +class _TimeoutDoFn(DoFn): + """Process method run in a separate thread allowing timeouts. + """ + def __init__(self, fn, timeout=None): + self._fn = fn + self._timeout = timeout + self._pool = None + + def __getattribute__(self, name): + if (name.startswith('__') or name in self.__dict__ or + name in type(self).__dict__): + return object.__getattribute__(self, name) + else: + return getattr(self._fn, name) + + def process(self, *args, **kwargs): + if self._pool is None: + self._pool = concurrent.futures.ThreadPoolExecutor(10) + # Ensure we iterate over the entire output list in the given amount of time. + try: + return self._pool.submit( + lambda: list(self._fn.process(*args, **kwargs))).result( + self._timeout) + except TimeoutError: + self._pool.shutdown(wait=False) + self._pool = None + raise + + def teardown(self): + try: + self._fn.teardown() + finally: + if self._pool is not None: + self._pool.shutdown(wait=False) + self._pool = None + + def Filter(fn, *args, **kwargs): # pylint: disable=invalid-name """:func:`Filter` is a :func:`FlatMap` with its callable filtering out elements. diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index 0547dbc06a8..3134659549e 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -2735,6 +2735,29 @@ class DeadLettersTest(unittest.TestCase): use_subprocess=self.use_subprocess)) assert_that(good, equal_to([(0, Timestamp(0)), (1, Timestamp(1))])) + def test_timeout(self): + import time + timeout = 1 if self.use_subprocess else .1 + + with TestPipeline() as p: + good, bad = ( + p + | beam.Create('records starting with lowercase S are slow'.split()) + | beam.Map( + lambda x: time.sleep(2.5 * timeout) if x.startswith('s') else x) + .with_exception_handling( + use_subprocess=self.use_subprocess, timeout=timeout)) + assert_that( + good, + equal_to(['records', 'with', 'lowercase', 'S', 'are']), + label='CheckGood') + assert_that( + bad | + beam.MapTuple(lambda e, exc_info: (e, exc_info[1].replace(',', ''))), + equal_to([('starting', 'TimeoutError()'), + ('slow', 'TimeoutError()')]), + label='CheckBad') + def test_lifecycle(self): die = type(self).die