This is an automated email from the ASF dual-hosted git repository.

riteshghorse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 243c795a843 Add timeout parameter to with_exception_handling. (#26526)
243c795a843 is described below

commit 243c795a843d96fe1da5731c9d1120caa604a998
Author: Robert Bradshaw <rober...@gmail.com>
AuthorDate: Wed May 3 13:36:20 2023 -0700

    Add timeout parameter to with_exception_handling. (#26526)
---
 sdks/python/apache_beam/transforms/core.py         | 90 ++++++++++++++++++----
 .../apache_beam/transforms/ptransform_test.py      | 23 ++++++
 2 files changed, 100 insertions(+), 13 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 6260975b32c..22f6170c546 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -25,6 +25,7 @@ import inspect
 import logging
 import random
 import sys
+import time
 import traceback
 import types
 import typing
@@ -1502,7 +1503,8 @@ class ParDo(PTransformWithSideInputs):
       partial=False,
       use_subprocess=False,
       threshold=1,
-      threshold_windowing=None):
+      threshold_windowing=None,
+      timeout=None):
     """Automatically provides a dead letter output for skipping bad records.
     This can allow a pipeline to continue successfully rather than fail or
     continuously throw errors on retry when bad elements are encountered.
@@ -1548,6 +1550,8 @@ class ParDo(PTransformWithSideInputs):
           up to 100% of records can be bad and the pipeline will still 
succeed).
       threshold_windowing: Event-time windowing to use for threshold. Optional,
           defaults to the windowing of the input.
+      timeout: If the element has not finished processing in timeout seconds,
+          raise a TimeoutError.  Defaults to None, meaning no time limit.
     """
     args, kwargs = self.raw_side_inputs
     return self.label >> _ExceptionHandlingWrapper(
@@ -1560,7 +1564,8 @@ class ParDo(PTransformWithSideInputs):
         partial,
         use_subprocess,
         threshold,
-        threshold_windowing)
+        threshold_windowing,
+        timeout)
 
   def default_type_hints(self):
     return self.fn.get_type_hints()
@@ -2152,7 +2157,8 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
       partial,
       use_subprocess,
       threshold,
-      threshold_windowing):
+      threshold_windowing,
+      timeout):
     if partial and use_subprocess:
       raise ValueError('partial and use_subprocess are mutually incompatible.')
     self._fn = fn
@@ -2165,14 +2171,18 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
     self._use_subprocess = use_subprocess
     self._threshold = threshold
     self._threshold_windowing = threshold_windowing
+    self._timeout = timeout
 
   def expand(self, pcoll):
+    if self._use_subprocess:
+      wrapped_fn = _SubprocessDoFn(self._fn, timeout=self._timeout)
+    elif self._timeout:
+      wrapped_fn = _TimeoutDoFn(self._fn, timeout=self._timeout)
+    else:
+      wrapped_fn = self._fn
     result = pcoll | ParDo(
         _ExceptionHandlingWrapperDoFn(
-            _SubprocessDoFn(self._fn) if self._use_subprocess else self._fn,
-            self._dead_letter_tag,
-            self._exc_class,
-            self._partial),
+            wrapped_fn, self._dead_letter_tag, self._exc_class, self._partial),
         *self._args,
         **self._kwargs).with_outputs(
             self._dead_letter_tag, main=self._main_tag, 
allow_unknown_tags=True)
@@ -2241,9 +2251,10 @@ class _ExceptionHandlingWrapperDoFn(DoFn):
 class _SubprocessDoFn(DoFn):
   """Process method run in a subprocess, turning hard crashes into exceptions.
   """
-  def __init__(self, fn):
+  def __init__(self, fn, timeout=None):
     self._fn = fn
     self._serialized_fn = pickler.dumps(fn)
+    self._timeout = timeout
 
   def __getattribute__(self, name):
     if (name.startswith('__') or name in self.__dict__ or
@@ -2268,19 +2279,35 @@ class _SubprocessDoFn(DoFn):
 
   def teardown(self):
     self._call_remote(self._remote_teardown)
-    self._pool.shutdown()
-    self._pool = None
+    self._terminate_pool()
 
   def _call_remote(self, method, *args, **kwargs):
     if self._pool is None:
       self._pool = concurrent.futures.ProcessPoolExecutor(1)
       self._pool.submit(self._remote_init, self._serialized_fn).result()
     try:
-      return self._pool.submit(method, *args, **kwargs).result()
-    except concurrent.futures.process.BrokenProcessPool:
-      self._pool = None
+      return self._pool.submit(method, *args, **kwargs).result(
+          self._timeout if method == self._remote_process else None)
+    except (concurrent.futures.process.BrokenProcessPool,
+            TimeoutError,
+            concurrent.futures._base.TimeoutError):
+      self._terminate_pool()
       raise
 
+  def _terminate_pool(self):
+    """Forcibly terminate the pool, not leaving any live subprocesses."""
+    pool = self._pool
+    self._pool = None
+    processes = list(pool._processes.values())
+    pool.shutdown(wait=False)
+    for p in processes:
+      if p.is_alive():
+        p.kill()
+    time.sleep(1)
+    for p in processes:
+      if p.is_alive():
+        p.terminate()
+
   # These are classmethods to avoid picking the state of self.
   # They should only be called in an isolated process, so there's no concern
   # about sharing state or thread safety.
@@ -2321,6 +2348,43 @@ class _SubprocessDoFn(DoFn):
     cls._fn = None
 
 
+class _TimeoutDoFn(DoFn):
+  """Process method run in a separate thread allowing timeouts.
+  """
+  def __init__(self, fn, timeout=None):
+    self._fn = fn
+    self._timeout = timeout
+    self._pool = None
+
+  def __getattribute__(self, name):
+    if (name.startswith('__') or name in self.__dict__ or
+        name in type(self).__dict__):
+      return object.__getattribute__(self, name)
+    else:
+      return getattr(self._fn, name)
+
+  def process(self, *args, **kwargs):
+    if self._pool is None:
+      self._pool = concurrent.futures.ThreadPoolExecutor(10)
+    # Ensure we iterate over the entire output list in the given amount of 
time.
+    try:
+      return self._pool.submit(
+          lambda: list(self._fn.process(*args, **kwargs))).result(
+              self._timeout)
+    except TimeoutError:
+      self._pool.shutdown(wait=False)
+      self._pool = None
+      raise
+
+  def teardown(self):
+    try:
+      self._fn.teardown()
+    finally:
+      if self._pool is not None:
+        self._pool.shutdown(wait=False)
+        self._pool = None
+
+
 def Filter(fn, *args, **kwargs):  # pylint: disable=invalid-name
   """:func:`Filter` is a :func:`FlatMap` with its callable filtering out
   elements.
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py 
b/sdks/python/apache_beam/transforms/ptransform_test.py
index 0547dbc06a8..3134659549e 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -2735,6 +2735,29 @@ class DeadLettersTest(unittest.TestCase):
               use_subprocess=self.use_subprocess))
       assert_that(good, equal_to([(0, Timestamp(0)), (1, Timestamp(1))]))
 
+  def test_timeout(self):
+    import time
+    timeout = 1 if self.use_subprocess else .1
+
+    with TestPipeline() as p:
+      good, bad = (
+          p
+          | beam.Create('records starting with lowercase S are slow'.split())
+          | beam.Map(
+              lambda x: time.sleep(2.5 * timeout) if x.startswith('s') else x)
+          .with_exception_handling(
+              use_subprocess=self.use_subprocess, timeout=timeout))
+      assert_that(
+          good,
+          equal_to(['records', 'with', 'lowercase', 'S', 'are']),
+          label='CheckGood')
+      assert_that(
+          bad |
+          beam.MapTuple(lambda e, exc_info: (e, exc_info[1].replace(',', ''))),
+          equal_to([('starting', 'TimeoutError()'),
+                    ('slow', 'TimeoutError()')]),
+          label='CheckBad')
+
   def test_lifecycle(self):
     die = type(self).die
 

Reply via email to