This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b9f27f9da2e DLQ support in RunInference (#26261)
b9f27f9da2e is described below

commit b9f27f9da2e63b564feecaeb593d7b12783192b0
Author: Danny McCormick <dannymccorm...@google.com>
AuthorDate: Fri Apr 14 17:23:30 2023 -0400

    DLQ support in RunInference (#26261)
    
    * DLQ support in RunInference
    
    * Doc example
    
    * Comment
    
    * CHANGES.md
---
 CHANGES.md                                        |  2 +-
 sdks/python/apache_beam/ml/inference/base.py      | 76 +++++++++++++++++++----
 sdks/python/apache_beam/ml/inference/base_test.py | 28 ++++++++-
 3 files changed, 92 insertions(+), 14 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 3113c875005..7be9752517d 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -64,7 +64,7 @@
 
 ## New Features / Improvements
 
-* X feature added (Java/Python) 
([#X](https://github.com/apache/beam/issues/X)).
+* Dead letter queue support added to RunInference in Python 
([#24209](https://github.com/apache/beam/issues/24209)).
 
 ## Breaking Changes
 
diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 50056107702..5bc3bb31557 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -339,6 +339,7 @@ class 
RunInference(beam.PTransform[beam.PCollection[ExampleT],
     self._metrics_namespace = metrics_namespace
     self._model_metadata_pcoll = model_metadata_pcoll
     self._enable_side_input_loading = self._model_metadata_pcoll is not None
+    self._with_exception_handling = False
 
   # TODO(BEAM-14046): Add and link to help documentation.
   @classmethod
@@ -368,20 +369,71 @@ class 
RunInference(beam.PTransform[beam.PCollection[ExampleT],
         # batching DoFn APIs.
         | beam.BatchElements(**self._model_handler.batch_elements_kwargs()))
 
+    run_inference_pardo = beam.ParDo(
+        _RunInferenceDoFn(
+            self._model_handler,
+            self._clock,
+            self._metrics_namespace,
+            self._enable_side_input_loading),
+        self._inference_args,
+        beam.pvalue.AsSingleton(
+            self._model_metadata_pcoll,
+        ) if self._enable_side_input_loading else None).with_resource_hints(
+            **resource_hints)
+
+    if self._with_exception_handling:
+      run_inference_pardo = run_inference_pardo.with_exception_handling(
+          exc_class=self._exc_class,
+          use_subprocess=self._use_subprocess,
+          threshold=self._threshold)
+
     return (
         batched_elements_pcoll
-        | 'BeamML_RunInference' >> (
-            beam.ParDo(
-                _RunInferenceDoFn(
-                    self._model_handler,
-                    self._clock,
-                    self._metrics_namespace,
-                    self._enable_side_input_loading),
-                self._inference_args,
-                beam.pvalue.AsSingleton(
-                    self._model_metadata_pcoll,
-                ) if self._enable_side_input_loading else
-                None).with_resource_hints(**resource_hints)))
+        | 'BeamML_RunInference' >> run_inference_pardo)
+
+  def with_exception_handling(
+      self, *, exc_class=Exception, use_subprocess=False, threshold=1):
+    """Automatically provides a dead letter output for skipping bad records.
+    This can allow a pipeline to continue successfully rather than fail or
+    continuously throw errors on retry when bad elements are encountered.
+
+    This returns a tagged output with two PCollections, the first being the
+    results of successfully processing the input PCollection, and the second
+    being the set of bad batches of records (those which threw exceptions
+    during processing) along with information about the errors raised.
+
+    For example, one would write::
+
+        good, bad = RunInference(
+          maybe_error_raising_model_handler
+        ).with_exception_handling()
+
+    and `good` will be a PCollection of PredictionResults and `bad` will
+    contain a tuple of all batches that raised exceptions, along with their
+    corresponding exception.
+
+
+    Args:
+      exc_class: An exception class, or tuple of exception classes, to catch.
+          Optional, defaults to 'Exception'.
+      use_subprocess: Whether to execute the DoFn logic in a subprocess. This
+          allows one to recover from errors that can crash the calling process
+          (e.g. from an underlying library causing a segfault), but is
+          slower as elements and results must cross a process boundary.  Note
+          that this starts up a long-running process that is used to handle
+          all the elements (until hard failure, which should be rare) rather
+          than a new process per element, so the overhead should be minimal
+          (and can be amortized if there's any per-process or per-bundle
+          initialization that needs to be done). Optional, defaults to False.
+      threshold: An upper bound on the ratio of records that can be bad before
+          aborting the entire pipeline. Optional, defaults to 1.0 (meaning
+          up to 100% of records can be bad and the pipeline will still 
succeed).
+    """
+    self._with_exception_handling = True
+    self._exc_class = exc_class
+    self._use_subprocess = use_subprocess
+    self._threshold = threshold
+    return self
 
 
 class _MetricsCollector:
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py 
b/sdks/python/apache_beam/ml/inference/base_test.py
index dad18c7b9e1..da82095a4e9 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -48,8 +48,10 @@ class FakeModel:
 
 
 class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
-  def __init__(self, clock=None):
+  def __init__(self, clock=None, min_batch_size=1, max_batch_size=9999):
     self._fake_clock = clock
+    self._min_batch_size = min_batch_size
+    self._max_batch_size = max_batch_size
 
   def load_model(self):
     if self._fake_clock:
@@ -69,6 +71,12 @@ class FakeModelHandler(base.ModelHandler[int, int, 
FakeModel]):
   def update_model_path(self, model_path: Optional[str] = None):
     pass
 
+  def batch_elements_kwargs(self):
+    return {
+        'min_batch_size': self._min_batch_size,
+        'max_batch_size': self._max_batch_size
+    }
+
 
 class FakeModelHandlerReturnsPredictionResult(
     base.ModelHandler[int, base.PredictionResult, FakeModel]):
@@ -171,6 +179,24 @@ class RunInferenceBaseTest(unittest.TestCase):
           model_handler)
       assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
 
+  def test_run_inference_impl_dlq(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 'TEST', 3, 10, 'TEST2']
+      expected_good = [2, 4, 11]
+      expected_bad = ['TEST', 'TEST2']
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      good, bad = pcoll | base.RunInference(
+          FakeModelHandler(
+            min_batch_size=1,
+            max_batch_size=1
+          )).with_exception_handling()
+      assert_that(good, equal_to(expected_good), label='assert:inferences')
+
+      # bad will be in form [batch[elements], error]. Just pull out bad 
element.
+      bad_without_error = bad | beam.Map(lambda x: x[0][0])
+      assert_that(
+          bad_without_error, equal_to(expected_bad), label='assert:failures')
+
   def test_run_inference_impl_inference_args(self):
     with TestPipeline() as pipeline:
       examples = [1, 5, 3, 10]

Reply via email to