This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 1a51ba5edeb Add support for pre/post processing in RunInference 
(#26309)
1a51ba5edeb is described below

commit 1a51ba5edebef9622c016331071764830adffc94
Author: Danny McCormick <[email protected]>
AuthorDate: Tue Apr 25 14:17:54 2023 -0400

    Add support for pre/post processing in RunInference (#26309)
    
    * Add support for pre/post processing in RunInference
    
    * Changes
    
    * DLQ
    
    * Docs
    
    * Named tuple for DLQ
    
    * Update to us with_preprocessing_fn approach
    
    * Pydoc
    
    * typos + refactor repeated code
---
 CHANGES.md                                         |   1 +
 .../apache_beam/examples/inference/README.md       |   9 +-
 .../inference/pytorch_image_classification.py      |  32 ++-
 .../inference/tensorflow_imagenet_segmentation.py  |   8 +-
 sdks/python/apache_beam/ml/inference/base.py       | 256 ++++++++++++++++++++-
 sdks/python/apache_beam/ml/inference/base_test.py  | 200 +++++++++++++++-
 6 files changed, 468 insertions(+), 38 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 3e490a27021..871f24bf9da 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -67,6 +67,7 @@
 ## New Features / Improvements
 
 * Dead letter queue support added to RunInference in Python 
([#24209](https://github.com/apache/beam/issues/24209)).
+* Support added for defining pre/postprocessing operations on the RunInference 
transform ([#26308](https://github.com/apache/beam/issues/26308))
 
 ## Breaking Changes
 
diff --git a/sdks/python/apache_beam/examples/inference/README.md 
b/sdks/python/apache_beam/examples/inference/README.md
index d195dbf04c0..3f1d1bfe4ab 100644
--- a/sdks/python/apache_beam/examples/inference/README.md
+++ b/sdks/python/apache_beam/examples/inference/README.md
@@ -26,10 +26,11 @@ Some examples are also used in [our 
benchmarks](http://s.apache.org/beam-communi
 
 ## Prerequisites
 
-You must have `apache-beam>=2.40.0` or greater installed in order to run these 
pipelines,
-because the `apache_beam.examples.inference` module was added in that release.
+You must have the latest (possibly unreleased) `apache-beam` or greater 
installed from the Beam repo in order to run these pipelines,
+because some examples rely on the latest features that are actively in 
development. To install Beam, run the following from the `sdks/python` 
directory:
 ```
-pip install apache-beam==2.40.0
+pip install -r build-requirements.txt
+pip install -e .[gcp]
 ```
 
 ### Tensorflow dependencies
@@ -38,7 +39,7 @@ The following installation requirement is for the Tensorflow 
model handler examp
 
 The RunInference API supports the Tensorflow framework. To use Tensorflow 
locally, first install `tensorflow`.
 ```
-pip install tensorflow==2.11.0
+pip install tensorflow==2.12.0
 ```
 
 ### PyTorch dependencies
diff --git 
a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py 
b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
index 4153c3d0a35..d627001bcb8 100644
--- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
+++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
@@ -21,7 +21,6 @@ import argparse
 import io
 import logging
 import os
-from typing import Iterable
 from typing import Iterator
 from typing import Optional
 from typing import Tuple
@@ -69,13 +68,6 @@ def filter_empty_lines(text: str) -> Iterator[str]:
     yield text
 
 
-class PostProcessor(beam.DoFn):
-  def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
-    filename, prediction_result = element
-    prediction = torch.argmax(prediction_result.inference, dim=0)
-    yield filename + ',' + str(prediction.item())
-
-
 def parse_known_args(argv):
   """Parses args for the workflow."""
   parser = argparse.ArgumentParser()
@@ -130,6 +122,17 @@ def run(
     model_class = models.mobilenet_v2
     model_params = {'num_classes': 1000}
 
+  def preprocess(image_name: str) -> Tuple[str, torch.Tensor]:
+    image_name, image = read_image(
+      image_file_name=image_name,
+      path_to_dir=known_args.images_dir)
+    return (image_name, preprocess_image(image))
+
+  def postprocess(element: Tuple[str, PredictionResult]) -> str:
+    filename, prediction_result = element
+    prediction = torch.argmax(prediction_result.inference, dim=0)
+    return filename + ',' + str(prediction.item())
+
   # In this example we pass keyed inputs to RunInference transform.
   # Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler.
   model_handler = KeyedModelHandler(
@@ -139,7 +142,8 @@ def run(
           model_params=model_params,
           device=device,
           min_batch_size=10,
-          max_batch_size=100))
+          max_batch_size=100)).with_preprocess_fn(
+              preprocess).with_postprocess_fn(postprocess)
 
   pipeline = test_pipeline
   if not test_pipeline:
@@ -148,16 +152,10 @@ def run(
   filename_value_pair = (
       pipeline
       | 'ReadImageNames' >> beam.io.ReadFromText(known_args.input)
-      | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines)
-      | 'ReadImageData' >> beam.Map(
-          lambda image_name: read_image(
-              image_file_name=image_name, path_to_dir=known_args.images_dir))
-      | 'PreprocessImages' >> beam.MapTuple(
-          lambda file_name, data: (file_name, preprocess_image(data))))
+      | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines))
   predictions = (
       filename_value_pair
-      | 'PyTorchRunInference' >> RunInference(model_handler)
-      | 'ProcessOutput' >> beam.ParDo(PostProcessor()))
+      | 'PyTorchRunInference' >> RunInference(model_handler))
 
   predictions | "WriteOutputToGCS" >> beam.io.WriteToText( # pylint: 
disable=expression-not-assigned
     known_args.output,
diff --git 
a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
 
b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
index b7638ddd3c9..a0f249dcfbf 100644
--- 
a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
+++ 
b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
@@ -99,7 +99,9 @@ def run(
   pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
 
   # In this example we will use the TensorflowHub model URL.
-  model_loader = TFModelHandlerTensor(model_uri=known_args.model_path)
+  model_loader = TFModelHandlerTensor(
+      model_uri=known_args.model_path).with_preprocess_fn(
+          lambda image_name: read_image(image_name, known_args.image_dir))
 
   pipeline = test_pipeline
   if not test_pipeline:
@@ -108,9 +110,7 @@ def run(
   image = (
       pipeline
       | 'ReadImageNames' >> beam.io.ReadFromText(known_args.input)
-      | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines)
-      | "PreProcessInputs" >>
-      beam.Map(lambda image_name: read_image(image_name, 
known_args.image_dir)))
+      | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines))
 
   predictions = (
       image
diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 5bc3bb31557..7c7eb02705a 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -33,6 +33,7 @@ import sys
 import threading
 import time
 from typing import Any
+from typing import Callable
 from typing import Dict
 from typing import Generic
 from typing import Iterable
@@ -58,7 +59,9 @@ _NANOSECOND_TO_MICROSECOND = 1_000
 
 ModelT = TypeVar('ModelT')
 ExampleT = TypeVar('ExampleT')
+PreProcessT = TypeVar('PreProcessT')
 PredictionT = TypeVar('PredictionT')
+PostProcessT = TypeVar('PostProcessT')
 _INPUT_TYPE = TypeVar('_INPUT_TYPE')
 _OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE')
 KeyT = TypeVar('KeyT')
@@ -91,6 +94,12 @@ class ModelMetadata(NamedTuple):
   model_name: str
 
 
+class RunInferenceDLQ(NamedTuple):
+  failed_inferences: beam.PCollection
+  failed_preprocessing: Sequence[beam.PCollection]
+  failed_postprocessing: Sequence[beam.PCollection]
+
+
 ModelMetadata.model_id.__doc__ = """Unique identifier for the model. This can 
be
     a file path or a URL where the model can be accessed. It is used to load
     the model for inference."""
@@ -174,6 +183,38 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
     """Update the model paths produced by side inputs."""
     pass
 
+  def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    """Gets all preprocessing functions to be run before batching/inference.
+    Functions are in order that they should be applied."""
+    return []
+
+  def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    """Gets all postprocessing functions to be run after inference.
+    Functions are in order that they should be applied."""
+    return []
+
+  def with_preprocess_fn(
+      self, fn: Callable[[PreProcessT], ExampleT]
+  ) -> 'ModelHandler[PreProcessT, PredictionT, ModelT, PreProcessT]':
+    """Returns a new ModelHandler with a preprocessing function
+    associated with it. The preprocessing function will be run
+    before batching/inference and should map your input PCollection
+    to the base ModelHandler's input type. If you apply multiple
+    preprocessing functions, they will be run on your original
+    PCollection in order from last applied to first applied."""
+    return _PreProcessingModelHandler(self, fn)
+
+  def with_postprocess_fn(
+      self, fn: Callable[[PredictionT], PostProcessT]
+  ) -> 'ModelHandler[ExampleT, PostProcessT, ModelT, PostProcessT]':
+    """Returns a new ModelHandler with a postprocessing function
+    associated with it. The postprocessing function will be run
+    after inference and should map the base ModelHandler's output
+    type to your desired output type. If you apply multiple
+    postprocessing functions, they will be run on your original
+    inference result in order from first applied to last applied."""
+    return _PostProcessingModelHandler(self, fn)
+
 
 class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
                         ModelHandler[Tuple[KeyT, ExampleT],
@@ -190,6 +231,12 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
     Args:
       unkeyed: An implementation of ModelHandler that does not require keys.
     """
+    if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()):
+      raise Exception(
+          'Cannot make make an unkeyed model handler with pre or '
+          'postprocessing functions defined into a keyed model handler. All '
+          'pre/postprocessing functions must be defined on the outer model'
+          'handler.')
     self._unkeyed = unkeyed
 
   def load_model(self) -> ModelT:
@@ -224,6 +271,12 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
   def update_model_path(self, model_path: Optional[str] = None):
     return self._unkeyed.update_model_path(model_path=model_path)
 
+  def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    return self._unkeyed.get_preprocess_fns()
+
+  def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    return self._unkeyed.get_postprocess_fns()
+
 
 class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
                              ModelHandler[Union[ExampleT, Tuple[KeyT,
@@ -248,6 +301,12 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
     Args:
       unkeyed: An implementation of ModelHandler that does not require keys.
     """
+    if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()):
+      raise Exception(
+          'Cannot make make an unkeyed model handler with pre or '
+          'postprocessing functions defined into a keyed model handler. All '
+          'pre/postprocessing functions must be defined on the outer model'
+          'handler.')
     self._unkeyed = unkeyed
 
   def load_model(self) -> ModelT:
@@ -300,6 +359,123 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
   def update_model_path(self, model_path: Optional[str] = None):
     return self._unkeyed.update_model_path(model_path=model_path)
 
+  def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    return self._unkeyed.get_preprocess_fns()
+
+  def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    return self._unkeyed.get_postprocess_fns()
+
+
+class _PreProcessingModelHandler(Generic[ExampleT,
+                                         PredictionT,
+                                         ModelT,
+                                         PreProcessT],
+                                 ModelHandler[PreProcessT, PredictionT,
+                                              ModelT]):
+  def __init__(
+      self,
+      base: ModelHandler[ExampleT, PredictionT, ModelT],
+      preprocess_fn: Callable[[PreProcessT], ExampleT]):
+    """A ModelHandler that has a preprocessing function associated with it.
+
+    Args:
+      base: An implementation of the underlying model handler.
+      preprocess_fn: the preprocessing function to use.
+    """
+    self._base = base
+    self._preprocess_fn = preprocess_fn
+
+  def load_model(self) -> ModelT:
+    return self._base.load_model()
+
+  def run_inference(
+      self,
+      batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
+      model: ModelT,
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
+    return self._base.run_inference(batch, model, inference_args)
+
+  def get_num_bytes(
+      self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
+    return self._base.get_num_bytes(batch)
+
+  def get_metrics_namespace(self) -> str:
+    return self._base.get_metrics_namespace()
+
+  def get_resource_hints(self):
+    return self._base.get_resource_hints()
+
+  def batch_elements_kwargs(self):
+    return self._base.batch_elements_kwargs()
+
+  def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+    return self._base.validate_inference_args(inference_args)
+
+  def update_model_path(self, model_path: Optional[str] = None):
+    return self._base.update_model_path(model_path=model_path)
+
+  def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    return [self._preprocess_fn] + self._base.get_preprocess_fns()
+
+  def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    return self._base.get_postprocess_fns()
+
+
+class _PostProcessingModelHandler(Generic[ExampleT,
+                                          PredictionT,
+                                          ModelT,
+                                          PostProcessT],
+                                  ModelHandler[ExampleT, PostProcessT, 
ModelT]):
+  def __init__(
+      self,
+      base: ModelHandler[ExampleT, PredictionT, ModelT],
+      postprocess_fn: Callable[[PredictionT], PostProcessT]):
+    """A ModelHandler that has a preprocessing function associated with it.
+
+    Args:
+      base: An implementation of the underlying model handler.
+      postprocess_fn: the preprocessing function to use.
+    """
+    self._base = base
+    self._postprocess_fn = postprocess_fn
+
+  def load_model(self) -> ModelT:
+    return self._base.load_model()
+
+  def run_inference(
+      self,
+      batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
+      model: ModelT,
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
+    return self._base.run_inference(batch, model, inference_args)
+
+  def get_num_bytes(
+      self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
+    return self._base.get_num_bytes(batch)
+
+  def get_metrics_namespace(self) -> str:
+    return self._base.get_metrics_namespace()
+
+  def get_resource_hints(self):
+    return self._base.get_resource_hints()
+
+  def batch_elements_kwargs(self):
+    return self._base.batch_elements_kwargs()
+
+  def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+    return self._base.validate_inference_args(inference_args)
+
+  def update_model_path(self, model_path: Optional[str] = None):
+    return self._base.update_model_path(model_path=model_path)
+
+  def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    return self._base.get_preprocess_fns()
+
+  def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    return self._base.get_postprocess_fns() + [self._postprocess_fn]
+
 
 class RunInference(beam.PTransform[beam.PCollection[ExampleT],
                                    beam.PCollection[PredictionT]]):
@@ -356,13 +532,45 @@ class 
RunInference(beam.PTransform[beam.PCollection[ExampleT],
     """
     return cls(model_handler_provider(**kwargs))
 
+  def _apply_fns(
+      self,
+      pcoll: beam.PCollection,
+      fns: Iterable[Callable[[Any], Any]],
+      step_prefix: str) -> Tuple[beam.PCollection, Iterable[beam.PCollection]]:
+    bad_preprocessed = []
+    for idx in range(len(fns)):
+      fn = fns[idx]
+      if self._with_exception_handling:
+        pcoll, bad = (pcoll
+        | f"{step_prefix}-{idx}" >> beam.Map(
+          fn).with_exception_handling(
+          exc_class=self._exc_class,
+          use_subprocess=self._use_subprocess,
+          threshold=self._threshold))
+        bad_preprocessed.append(bad)
+      else:
+        pcoll = pcoll | f"{step_prefix}-{idx}" >> beam.Map(fn)
+
+    return pcoll, bad_preprocessed
+
   # TODO(https://github.com/apache/beam/issues/21447): Add batch_size back off
   # in the case there are functional reasons large batch sizes cannot be
   # handled.
   def expand(
       self, pcoll: beam.PCollection[ExampleT]) -> 
beam.PCollection[PredictionT]:
     self._model_handler.validate_inference_args(self._inference_args)
+    # DLQ pcollections
+    bad_preprocessed = []
+    bad_inference = None
+    bad_postprocessed = []
+    preprocess_fns = self._model_handler.get_preprocess_fns()
+    postprocess_fns = self._model_handler.get_postprocess_fns()
+
+    pcoll, bad_preprocessed = self._apply_fns(
+      pcoll, preprocess_fns, 'BeamML_RunInference_Preprocess')
+
     resource_hints = self._model_handler.get_resource_hints()
+
     batched_elements_pcoll = (
         pcoll
         # TODO(https://github.com/apache/beam/issues/21440): Hook into the
@@ -382,14 +590,26 @@ class 
RunInference(beam.PTransform[beam.PCollection[ExampleT],
             **resource_hints)
 
     if self._with_exception_handling:
-      run_inference_pardo = run_inference_pardo.with_exception_handling(
+      results, bad_inference = (
+          batched_elements_pcoll
+          | 'BeamML_RunInference' >>
+          run_inference_pardo.with_exception_handling(
           exc_class=self._exc_class,
           use_subprocess=self._use_subprocess,
-          threshold=self._threshold)
+          threshold=self._threshold))
+    else:
+      results = (
+          batched_elements_pcoll
+          | 'BeamML_RunInference' >> run_inference_pardo)
+
+    results, bad_postprocessed = self._apply_fns(
+      results, postprocess_fns, 'BeamML_RunInference_Postprocess')
+
+    if self._with_exception_handling:
+      dlq = RunInferenceDLQ(bad_inference, bad_preprocessed, bad_postprocessed)
+      return results, dlq
 
-    return (
-        batched_elements_pcoll
-        | 'BeamML_RunInference' >> run_inference_pardo)
+    return results
 
   def with_exception_handling(
       self, *, exc_class=Exception, use_subprocess=False, threshold=1):
@@ -404,13 +624,31 @@ class 
RunInference(beam.PTransform[beam.PCollection[ExampleT],
 
     For example, one would write::
 
-        good, bad = RunInference(
+        main, other = RunInference(
           maybe_error_raising_model_handler
         ).with_exception_handling()
 
-    and `good` will be a PCollection of PredictionResults and `bad` will
-    contain a tuple of all batches that raised exceptions, along with their
-    corresponding exception.
+    and `main` will be a PCollection of PredictionResults and `other` will
+    contain a `RunInferenceDLQ` object with PCollections containing failed
+    records for each failed inference, preprocess operation, or postprocess
+    operation. To access each collection of failed records, one would write:
+
+        failed_inferences = other.failed_inferences
+        failed_preprocessing = other.failed_preprocessing
+        failed_postprocessing = other.failed_postprocessing
+
+    failed_inferences is in the form
+    PCollection[Tuple[failed batch, exception]].
+
+    failed_preprocessing is in the form
+    list[PCollection[Tuple[failed record, exception]]]], where each element of
+    the list corresponds to a preprocess function. These PCollections are
+    in the same order that the preprocess functions are applied.
+
+    failed_postprocessing is in the form
+    List[PCollection[Tuple[failed record, exception]]]], where each element of
+    the list corresponds to a postprocess function. These PCollections are
+    in the same order that the postprocess functions are applied.
 
 
     Args:
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py 
b/sdks/python/apache_beam/ml/inference/base_test.py
index da82095a4e9..893c864c90e 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -179,21 +179,213 @@ class RunInferenceBaseTest(unittest.TestCase):
           model_handler)
       assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
 
+  def test_run_inference_preprocessing(self):
+    def mult_two(example: str) -> int:
+      return int(example) * 2
+
+    with TestPipeline() as pipeline:
+      examples = ["1", "5", "3", "10"]
+      expected = [int(example) * 2 + 1 for example in examples]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(
+          FakeModelHandler().with_preprocess_fn(mult_two))
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_run_inference_preprocessing_multiple_fns(self):
+    def add_one(example: str) -> int:
+      return int(example) + 1
+
+    def mult_two(example: int) -> int:
+      return example * 2
+
+    with TestPipeline() as pipeline:
+      examples = ["1", "5", "3", "10"]
+      expected = [(int(example) + 1) * 2 + 1 for example in examples]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(
+          FakeModelHandler().with_preprocess_fn(mult_two).with_preprocess_fn(
+              add_one))
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_run_inference_postprocessing(self):
+    def mult_two(example: int) -> str:
+      return str(example * 2)
+
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      expected = [str((example + 1) * 2) for example in examples]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(
+          FakeModelHandler().with_postprocess_fn(mult_two))
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_run_inference_postprocessing_multiple_fns(self):
+    def add_one(example: int) -> str:
+      return str(int(example) + 1)
+
+    def mult_two(example: int) -> int:
+      return example * 2
+
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      expected = [str(((example + 1) * 2) + 1) for example in examples]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(
+          FakeModelHandler().with_postprocess_fn(mult_two).with_postprocess_fn(
+              add_one))
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_run_inference_preprocessing_dlq(self):
+    def mult_two(example: str) -> int:
+      if example == "5":
+        raise Exception("TEST")
+      return int(example) * 2
+
+    with TestPipeline() as pipeline:
+      examples = ["1", "5", "3", "10"]
+      expected = [3, 7, 21]
+      expected_bad = ["5"]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      main, other = pcoll | base.RunInference(
+          FakeModelHandler().with_preprocess_fn(mult_two)
+          ).with_exception_handling()
+      assert_that(main, equal_to(expected), label='assert:inferences')
+      assert_that(
+          other.failed_inferences, equal_to([]), label='assert:bad_infer')
+
+      # bad will be in form [element, error]. Just pull out bad element.
+      bad_without_error = other.failed_preprocessing[0] | beam.Map(
+          lambda x: x[0])
+      assert_that(
+          bad_without_error, equal_to(expected_bad), label='assert:failures')
+
+  def test_run_inference_postprocessing_dlq(self):
+    def mult_two(example: int) -> str:
+      if example == 6:
+        raise Exception("TEST")
+      return str(example * 2)
+
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      expected = ["4", "8", "22"]
+      expected_bad = [6]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      main, other = pcoll | base.RunInference(
+          FakeModelHandler().with_postprocess_fn(mult_two)
+          ).with_exception_handling()
+      assert_that(main, equal_to(expected), label='assert:inferences')
+      assert_that(
+          other.failed_inferences, equal_to([]), label='assert:bad_infer')
+
+      # bad will be in form [element, error]. Just pull out bad element.
+      bad_without_error = other.failed_postprocessing[0] | beam.Map(
+          lambda x: x[0])
+      assert_that(
+          bad_without_error, equal_to(expected_bad), label='assert:failures')
+
+  def test_run_inference_pre_and_post_processing_dlq(self):
+    def mult_two_pre(example: str) -> int:
+      if example == "5":
+        raise Exception("TEST")
+      return int(example) * 2
+
+    def mult_two_post(example: int) -> str:
+      if example == 7:
+        raise Exception("TEST")
+      return str(example * 2)
+
+    with TestPipeline() as pipeline:
+      examples = ["1", "5", "3", "10"]
+      expected = ["6", "42"]
+      expected_bad_pre = ["5"]
+      expected_bad_post = [7]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      main, other = pcoll | base.RunInference(
+          FakeModelHandler().with_preprocess_fn(
+            mult_two_pre
+            ).with_postprocess_fn(
+              mult_two_post
+              )).with_exception_handling()
+      assert_that(main, equal_to(expected), label='assert:inferences')
+      assert_that(
+          other.failed_inferences, equal_to([]), label='assert:bad_infer')
+
+      # bad will be in form [elements, error]. Just pull out bad element.
+      bad_without_error_pre = other.failed_preprocessing[0] | beam.Map(
+          lambda x: x[0])
+      assert_that(
+          bad_without_error_pre,
+          equal_to(expected_bad_pre),
+          label='assert:failures_pre')
+
+      # bad will be in form [elements, error]. Just pull out bad element.
+      bad_without_error_post = other.failed_postprocessing[0] | beam.Map(
+          lambda x: x[0])
+      assert_that(
+          bad_without_error_post,
+          equal_to(expected_bad_post),
+          label='assert:failures_post')
+
+  def test_run_inference_keyed_pre_and_post_processing(self):
+    def mult_two(element):
+      return (element[0], element[1] * 2)
+
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      keyed_examples = [(i, example) for i, example in enumerate(examples)]
+      expected = [
+          (i, ((example * 2) + 1) * 2) for i, example in enumerate(examples)
+      ]
+      pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
+      actual = pcoll | base.RunInference(
+          base.KeyedModelHandler(FakeModelHandler()).with_preprocess_fn(
+              mult_two).with_postprocess_fn(mult_two))
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_run_inference_maybe_keyed_pre_and_post_processing(self):
+    def mult_two(element):
+      return element * 2
+
+    def mult_two_keyed(element):
+      return (element[0], element[1] * 2)
+
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      keyed_examples = [(i, example) for i, example in enumerate(examples)]
+      expected = [((2 * example) + 1) * 2 for example in examples]
+      keyed_expected = [
+          (i, ((2 * example) + 1) * 2) for i, example in enumerate(examples)
+      ]
+      model_handler = base.MaybeKeyedModelHandler(FakeModelHandler())
+
+      pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
+      actual = pcoll | 'RunUnkeyed' >> base.RunInference(
+          model_handler.with_preprocess_fn(mult_two).with_postprocess_fn(
+              mult_two))
+      assert_that(actual, equal_to(expected), label='CheckUnkeyed')
+
+      keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
+      keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
+          model_handler.with_preprocess_fn(mult_two_keyed).with_postprocess_fn(
+              mult_two_keyed))
+      assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
+
   def test_run_inference_impl_dlq(self):
     with TestPipeline() as pipeline:
       examples = [1, 'TEST', 3, 10, 'TEST2']
       expected_good = [2, 4, 11]
       expected_bad = ['TEST', 'TEST2']
       pcoll = pipeline | 'start' >> beam.Create(examples)
-      good, bad = pcoll | base.RunInference(
+      main, other = pcoll | base.RunInference(
           FakeModelHandler(
             min_batch_size=1,
             max_batch_size=1
           )).with_exception_handling()
-      assert_that(good, equal_to(expected_good), label='assert:inferences')
+      assert_that(main, equal_to(expected_good), label='assert:inferences')
 
-      # bad will be in form [batch[elements], error]. Just pull out bad 
element.
-      bad_without_error = bad | beam.Map(lambda x: x[0][0])
+      # bad.failed_inferences will be in form [batch[elements], error].
+      # Just pull out bad element.
+      bad_without_error = other.failed_inferences | beam.Map(lambda x: x[0][0])
       assert_that(
           bad_without_error, equal_to(expected_bad), label='assert:failures')
 

Reply via email to