This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 1a51ba5edeb Add support for pre/post processing in RunInference
(#26309)
1a51ba5edeb is described below
commit 1a51ba5edebef9622c016331071764830adffc94
Author: Danny McCormick <[email protected]>
AuthorDate: Tue Apr 25 14:17:54 2023 -0400
Add support for pre/post processing in RunInference (#26309)
* Add support for pre/post processing in RunInference
* Changes
* DLQ
* Docs
* Named tuple for DLQ
* Update to us with_preprocessing_fn approach
* Pydoc
* typos + refactor repeated code
---
CHANGES.md | 1 +
.../apache_beam/examples/inference/README.md | 9 +-
.../inference/pytorch_image_classification.py | 32 ++-
.../inference/tensorflow_imagenet_segmentation.py | 8 +-
sdks/python/apache_beam/ml/inference/base.py | 256 ++++++++++++++++++++-
sdks/python/apache_beam/ml/inference/base_test.py | 200 +++++++++++++++-
6 files changed, 468 insertions(+), 38 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 3e490a27021..871f24bf9da 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -67,6 +67,7 @@
## New Features / Improvements
* Dead letter queue support added to RunInference in Python
([#24209](https://github.com/apache/beam/issues/24209)).
+* Support added for defining pre/postprocessing operations on the RunInference
transform ([#26308](https://github.com/apache/beam/issues/26308))
## Breaking Changes
diff --git a/sdks/python/apache_beam/examples/inference/README.md
b/sdks/python/apache_beam/examples/inference/README.md
index d195dbf04c0..3f1d1bfe4ab 100644
--- a/sdks/python/apache_beam/examples/inference/README.md
+++ b/sdks/python/apache_beam/examples/inference/README.md
@@ -26,10 +26,11 @@ Some examples are also used in [our
benchmarks](http://s.apache.org/beam-communi
## Prerequisites
-You must have `apache-beam>=2.40.0` or greater installed in order to run these
pipelines,
-because the `apache_beam.examples.inference` module was added in that release.
+You must have the latest (possibly unreleased) `apache-beam` or greater
installed from the Beam repo in order to run these pipelines,
+because some examples rely on the latest features that are actively in
development. To install Beam, run the following from the `sdks/python`
directory:
```
-pip install apache-beam==2.40.0
+pip install -r build-requirements.txt
+pip install -e .[gcp]
```
### Tensorflow dependencies
@@ -38,7 +39,7 @@ The following installation requirement is for the Tensorflow
model handler examp
The RunInference API supports the Tensorflow framework. To use Tensorflow
locally, first install `tensorflow`.
```
-pip install tensorflow==2.11.0
+pip install tensorflow==2.12.0
```
### PyTorch dependencies
diff --git
a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
index 4153c3d0a35..d627001bcb8 100644
--- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
+++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py
@@ -21,7 +21,6 @@ import argparse
import io
import logging
import os
-from typing import Iterable
from typing import Iterator
from typing import Optional
from typing import Tuple
@@ -69,13 +68,6 @@ def filter_empty_lines(text: str) -> Iterator[str]:
yield text
-class PostProcessor(beam.DoFn):
- def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
- filename, prediction_result = element
- prediction = torch.argmax(prediction_result.inference, dim=0)
- yield filename + ',' + str(prediction.item())
-
-
def parse_known_args(argv):
"""Parses args for the workflow."""
parser = argparse.ArgumentParser()
@@ -130,6 +122,17 @@ def run(
model_class = models.mobilenet_v2
model_params = {'num_classes': 1000}
+ def preprocess(image_name: str) -> Tuple[str, torch.Tensor]:
+ image_name, image = read_image(
+ image_file_name=image_name,
+ path_to_dir=known_args.images_dir)
+ return (image_name, preprocess_image(image))
+
+ def postprocess(element: Tuple[str, PredictionResult]) -> str:
+ filename, prediction_result = element
+ prediction = torch.argmax(prediction_result.inference, dim=0)
+ return filename + ',' + str(prediction.item())
+
# In this example we pass keyed inputs to RunInference transform.
# Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler.
model_handler = KeyedModelHandler(
@@ -139,7 +142,8 @@ def run(
model_params=model_params,
device=device,
min_batch_size=10,
- max_batch_size=100))
+ max_batch_size=100)).with_preprocess_fn(
+ preprocess).with_postprocess_fn(postprocess)
pipeline = test_pipeline
if not test_pipeline:
@@ -148,16 +152,10 @@ def run(
filename_value_pair = (
pipeline
| 'ReadImageNames' >> beam.io.ReadFromText(known_args.input)
- | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines)
- | 'ReadImageData' >> beam.Map(
- lambda image_name: read_image(
- image_file_name=image_name, path_to_dir=known_args.images_dir))
- | 'PreprocessImages' >> beam.MapTuple(
- lambda file_name, data: (file_name, preprocess_image(data))))
+ | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines))
predictions = (
filename_value_pair
- | 'PyTorchRunInference' >> RunInference(model_handler)
- | 'ProcessOutput' >> beam.ParDo(PostProcessor()))
+ | 'PyTorchRunInference' >> RunInference(model_handler))
predictions | "WriteOutputToGCS" >> beam.io.WriteToText( # pylint:
disable=expression-not-assigned
known_args.output,
diff --git
a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
index b7638ddd3c9..a0f249dcfbf 100644
---
a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
+++
b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py
@@ -99,7 +99,9 @@ def run(
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
# In this example we will use the TensorflowHub model URL.
- model_loader = TFModelHandlerTensor(model_uri=known_args.model_path)
+ model_loader = TFModelHandlerTensor(
+ model_uri=known_args.model_path).with_preprocess_fn(
+ lambda image_name: read_image(image_name, known_args.image_dir))
pipeline = test_pipeline
if not test_pipeline:
@@ -108,9 +110,7 @@ def run(
image = (
pipeline
| 'ReadImageNames' >> beam.io.ReadFromText(known_args.input)
- | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines)
- | "PreProcessInputs" >>
- beam.Map(lambda image_name: read_image(image_name,
known_args.image_dir)))
+ | 'FilterEmptyLines' >> beam.ParDo(filter_empty_lines))
predictions = (
image
diff --git a/sdks/python/apache_beam/ml/inference/base.py
b/sdks/python/apache_beam/ml/inference/base.py
index 5bc3bb31557..7c7eb02705a 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -33,6 +33,7 @@ import sys
import threading
import time
from typing import Any
+from typing import Callable
from typing import Dict
from typing import Generic
from typing import Iterable
@@ -58,7 +59,9 @@ _NANOSECOND_TO_MICROSECOND = 1_000
ModelT = TypeVar('ModelT')
ExampleT = TypeVar('ExampleT')
+PreProcessT = TypeVar('PreProcessT')
PredictionT = TypeVar('PredictionT')
+PostProcessT = TypeVar('PostProcessT')
_INPUT_TYPE = TypeVar('_INPUT_TYPE')
_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE')
KeyT = TypeVar('KeyT')
@@ -91,6 +94,12 @@ class ModelMetadata(NamedTuple):
model_name: str
+class RunInferenceDLQ(NamedTuple):
+ failed_inferences: beam.PCollection
+ failed_preprocessing: Sequence[beam.PCollection]
+ failed_postprocessing: Sequence[beam.PCollection]
+
+
ModelMetadata.model_id.__doc__ = """Unique identifier for the model. This can
be
a file path or a URL where the model can be accessed. It is used to load
the model for inference."""
@@ -174,6 +183,38 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
"""Update the model paths produced by side inputs."""
pass
+ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ """Gets all preprocessing functions to be run before batching/inference.
+ Functions are in order that they should be applied."""
+ return []
+
+ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ """Gets all postprocessing functions to be run after inference.
+ Functions are in order that they should be applied."""
+ return []
+
+ def with_preprocess_fn(
+ self, fn: Callable[[PreProcessT], ExampleT]
+ ) -> 'ModelHandler[PreProcessT, PredictionT, ModelT, PreProcessT]':
+ """Returns a new ModelHandler with a preprocessing function
+ associated with it. The preprocessing function will be run
+ before batching/inference and should map your input PCollection
+ to the base ModelHandler's input type. If you apply multiple
+ preprocessing functions, they will be run on your original
+ PCollection in order from last applied to first applied."""
+ return _PreProcessingModelHandler(self, fn)
+
+ def with_postprocess_fn(
+ self, fn: Callable[[PredictionT], PostProcessT]
+ ) -> 'ModelHandler[ExampleT, PostProcessT, ModelT, PostProcessT]':
+ """Returns a new ModelHandler with a postprocessing function
+ associated with it. The postprocessing function will be run
+ after inference and should map the base ModelHandler's output
+ type to your desired output type. If you apply multiple
+ postprocessing functions, they will be run on your original
+ inference result in order from first applied to last applied."""
+ return _PostProcessingModelHandler(self, fn)
+
class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
ModelHandler[Tuple[KeyT, ExampleT],
@@ -190,6 +231,12 @@ class KeyedModelHandler(Generic[KeyT, ExampleT,
PredictionT, ModelT],
Args:
unkeyed: An implementation of ModelHandler that does not require keys.
"""
+ if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()):
+ raise Exception(
+ 'Cannot make make an unkeyed model handler with pre or '
+ 'postprocessing functions defined into a keyed model handler. All '
+ 'pre/postprocessing functions must be defined on the outer model'
+ 'handler.')
self._unkeyed = unkeyed
def load_model(self) -> ModelT:
@@ -224,6 +271,12 @@ class KeyedModelHandler(Generic[KeyT, ExampleT,
PredictionT, ModelT],
def update_model_path(self, model_path: Optional[str] = None):
return self._unkeyed.update_model_path(model_path=model_path)
+ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ return self._unkeyed.get_preprocess_fns()
+
+ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ return self._unkeyed.get_postprocess_fns()
+
class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
ModelHandler[Union[ExampleT, Tuple[KeyT,
@@ -248,6 +301,12 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT,
PredictionT, ModelT],
Args:
unkeyed: An implementation of ModelHandler that does not require keys.
"""
+ if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()):
+ raise Exception(
+ 'Cannot make make an unkeyed model handler with pre or '
+ 'postprocessing functions defined into a keyed model handler. All '
+ 'pre/postprocessing functions must be defined on the outer model'
+ 'handler.')
self._unkeyed = unkeyed
def load_model(self) -> ModelT:
@@ -300,6 +359,123 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT,
PredictionT, ModelT],
def update_model_path(self, model_path: Optional[str] = None):
return self._unkeyed.update_model_path(model_path=model_path)
+ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ return self._unkeyed.get_preprocess_fns()
+
+ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ return self._unkeyed.get_postprocess_fns()
+
+
+class _PreProcessingModelHandler(Generic[ExampleT,
+ PredictionT,
+ ModelT,
+ PreProcessT],
+ ModelHandler[PreProcessT, PredictionT,
+ ModelT]):
+ def __init__(
+ self,
+ base: ModelHandler[ExampleT, PredictionT, ModelT],
+ preprocess_fn: Callable[[PreProcessT], ExampleT]):
+ """A ModelHandler that has a preprocessing function associated with it.
+
+ Args:
+ base: An implementation of the underlying model handler.
+ preprocess_fn: the preprocessing function to use.
+ """
+ self._base = base
+ self._preprocess_fn = preprocess_fn
+
+ def load_model(self) -> ModelT:
+ return self._base.load_model()
+
+ def run_inference(
+ self,
+ batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
+ model: ModelT,
+ inference_args: Optional[Dict[str, Any]] = None
+ ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
+ return self._base.run_inference(batch, model, inference_args)
+
+ def get_num_bytes(
+ self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
+ return self._base.get_num_bytes(batch)
+
+ def get_metrics_namespace(self) -> str:
+ return self._base.get_metrics_namespace()
+
+ def get_resource_hints(self):
+ return self._base.get_resource_hints()
+
+ def batch_elements_kwargs(self):
+ return self._base.batch_elements_kwargs()
+
+ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+ return self._base.validate_inference_args(inference_args)
+
+ def update_model_path(self, model_path: Optional[str] = None):
+ return self._base.update_model_path(model_path=model_path)
+
+ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ return [self._preprocess_fn] + self._base.get_preprocess_fns()
+
+ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ return self._base.get_postprocess_fns()
+
+
+class _PostProcessingModelHandler(Generic[ExampleT,
+ PredictionT,
+ ModelT,
+ PostProcessT],
+ ModelHandler[ExampleT, PostProcessT,
ModelT]):
+ def __init__(
+ self,
+ base: ModelHandler[ExampleT, PredictionT, ModelT],
+ postprocess_fn: Callable[[PredictionT], PostProcessT]):
+ """A ModelHandler that has a preprocessing function associated with it.
+
+ Args:
+ base: An implementation of the underlying model handler.
+ postprocess_fn: the preprocessing function to use.
+ """
+ self._base = base
+ self._postprocess_fn = postprocess_fn
+
+ def load_model(self) -> ModelT:
+ return self._base.load_model()
+
+ def run_inference(
+ self,
+ batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
+ model: ModelT,
+ inference_args: Optional[Dict[str, Any]] = None
+ ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
+ return self._base.run_inference(batch, model, inference_args)
+
+ def get_num_bytes(
+ self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
+ return self._base.get_num_bytes(batch)
+
+ def get_metrics_namespace(self) -> str:
+ return self._base.get_metrics_namespace()
+
+ def get_resource_hints(self):
+ return self._base.get_resource_hints()
+
+ def batch_elements_kwargs(self):
+ return self._base.batch_elements_kwargs()
+
+ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+ return self._base.validate_inference_args(inference_args)
+
+ def update_model_path(self, model_path: Optional[str] = None):
+ return self._base.update_model_path(model_path=model_path)
+
+ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ return self._base.get_preprocess_fns()
+
+ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ return self._base.get_postprocess_fns() + [self._postprocess_fn]
+
class RunInference(beam.PTransform[beam.PCollection[ExampleT],
beam.PCollection[PredictionT]]):
@@ -356,13 +532,45 @@ class
RunInference(beam.PTransform[beam.PCollection[ExampleT],
"""
return cls(model_handler_provider(**kwargs))
+ def _apply_fns(
+ self,
+ pcoll: beam.PCollection,
+ fns: Iterable[Callable[[Any], Any]],
+ step_prefix: str) -> Tuple[beam.PCollection, Iterable[beam.PCollection]]:
+ bad_preprocessed = []
+ for idx in range(len(fns)):
+ fn = fns[idx]
+ if self._with_exception_handling:
+ pcoll, bad = (pcoll
+ | f"{step_prefix}-{idx}" >> beam.Map(
+ fn).with_exception_handling(
+ exc_class=self._exc_class,
+ use_subprocess=self._use_subprocess,
+ threshold=self._threshold))
+ bad_preprocessed.append(bad)
+ else:
+ pcoll = pcoll | f"{step_prefix}-{idx}" >> beam.Map(fn)
+
+ return pcoll, bad_preprocessed
+
# TODO(https://github.com/apache/beam/issues/21447): Add batch_size back off
# in the case there are functional reasons large batch sizes cannot be
# handled.
def expand(
self, pcoll: beam.PCollection[ExampleT]) ->
beam.PCollection[PredictionT]:
self._model_handler.validate_inference_args(self._inference_args)
+ # DLQ pcollections
+ bad_preprocessed = []
+ bad_inference = None
+ bad_postprocessed = []
+ preprocess_fns = self._model_handler.get_preprocess_fns()
+ postprocess_fns = self._model_handler.get_postprocess_fns()
+
+ pcoll, bad_preprocessed = self._apply_fns(
+ pcoll, preprocess_fns, 'BeamML_RunInference_Preprocess')
+
resource_hints = self._model_handler.get_resource_hints()
+
batched_elements_pcoll = (
pcoll
# TODO(https://github.com/apache/beam/issues/21440): Hook into the
@@ -382,14 +590,26 @@ class
RunInference(beam.PTransform[beam.PCollection[ExampleT],
**resource_hints)
if self._with_exception_handling:
- run_inference_pardo = run_inference_pardo.with_exception_handling(
+ results, bad_inference = (
+ batched_elements_pcoll
+ | 'BeamML_RunInference' >>
+ run_inference_pardo.with_exception_handling(
exc_class=self._exc_class,
use_subprocess=self._use_subprocess,
- threshold=self._threshold)
+ threshold=self._threshold))
+ else:
+ results = (
+ batched_elements_pcoll
+ | 'BeamML_RunInference' >> run_inference_pardo)
+
+ results, bad_postprocessed = self._apply_fns(
+ results, postprocess_fns, 'BeamML_RunInference_Postprocess')
+
+ if self._with_exception_handling:
+ dlq = RunInferenceDLQ(bad_inference, bad_preprocessed, bad_postprocessed)
+ return results, dlq
- return (
- batched_elements_pcoll
- | 'BeamML_RunInference' >> run_inference_pardo)
+ return results
def with_exception_handling(
self, *, exc_class=Exception, use_subprocess=False, threshold=1):
@@ -404,13 +624,31 @@ class
RunInference(beam.PTransform[beam.PCollection[ExampleT],
For example, one would write::
- good, bad = RunInference(
+ main, other = RunInference(
maybe_error_raising_model_handler
).with_exception_handling()
- and `good` will be a PCollection of PredictionResults and `bad` will
- contain a tuple of all batches that raised exceptions, along with their
- corresponding exception.
+ and `main` will be a PCollection of PredictionResults and `other` will
+ contain a `RunInferenceDLQ` object with PCollections containing failed
+ records for each failed inference, preprocess operation, or postprocess
+ operation. To access each collection of failed records, one would write:
+
+ failed_inferences = other.failed_inferences
+ failed_preprocessing = other.failed_preprocessing
+ failed_postprocessing = other.failed_postprocessing
+
+ failed_inferences is in the form
+ PCollection[Tuple[failed batch, exception]].
+
+ failed_preprocessing is in the form
+ list[PCollection[Tuple[failed record, exception]]]], where each element of
+ the list corresponds to a preprocess function. These PCollections are
+ in the same order that the preprocess functions are applied.
+
+ failed_postprocessing is in the form
+ List[PCollection[Tuple[failed record, exception]]]], where each element of
+ the list corresponds to a postprocess function. These PCollections are
+ in the same order that the postprocess functions are applied.
Args:
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py
b/sdks/python/apache_beam/ml/inference/base_test.py
index da82095a4e9..893c864c90e 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -179,21 +179,213 @@ class RunInferenceBaseTest(unittest.TestCase):
model_handler)
assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
+ def test_run_inference_preprocessing(self):
+ def mult_two(example: str) -> int:
+ return int(example) * 2
+
+ with TestPipeline() as pipeline:
+ examples = ["1", "5", "3", "10"]
+ expected = [int(example) * 2 + 1 for example in examples]
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ actual = pcoll | base.RunInference(
+ FakeModelHandler().with_preprocess_fn(mult_two))
+ assert_that(actual, equal_to(expected), label='assert:inferences')
+
+ def test_run_inference_preprocessing_multiple_fns(self):
+ def add_one(example: str) -> int:
+ return int(example) + 1
+
+ def mult_two(example: int) -> int:
+ return example * 2
+
+ with TestPipeline() as pipeline:
+ examples = ["1", "5", "3", "10"]
+ expected = [(int(example) + 1) * 2 + 1 for example in examples]
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ actual = pcoll | base.RunInference(
+ FakeModelHandler().with_preprocess_fn(mult_two).with_preprocess_fn(
+ add_one))
+ assert_that(actual, equal_to(expected), label='assert:inferences')
+
+ def test_run_inference_postprocessing(self):
+ def mult_two(example: int) -> str:
+ return str(example * 2)
+
+ with TestPipeline() as pipeline:
+ examples = [1, 5, 3, 10]
+ expected = [str((example + 1) * 2) for example in examples]
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ actual = pcoll | base.RunInference(
+ FakeModelHandler().with_postprocess_fn(mult_two))
+ assert_that(actual, equal_to(expected), label='assert:inferences')
+
+ def test_run_inference_postprocessing_multiple_fns(self):
+ def add_one(example: int) -> str:
+ return str(int(example) + 1)
+
+ def mult_two(example: int) -> int:
+ return example * 2
+
+ with TestPipeline() as pipeline:
+ examples = [1, 5, 3, 10]
+ expected = [str(((example + 1) * 2) + 1) for example in examples]
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ actual = pcoll | base.RunInference(
+ FakeModelHandler().with_postprocess_fn(mult_two).with_postprocess_fn(
+ add_one))
+ assert_that(actual, equal_to(expected), label='assert:inferences')
+
+ def test_run_inference_preprocessing_dlq(self):
+ def mult_two(example: str) -> int:
+ if example == "5":
+ raise Exception("TEST")
+ return int(example) * 2
+
+ with TestPipeline() as pipeline:
+ examples = ["1", "5", "3", "10"]
+ expected = [3, 7, 21]
+ expected_bad = ["5"]
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ main, other = pcoll | base.RunInference(
+ FakeModelHandler().with_preprocess_fn(mult_two)
+ ).with_exception_handling()
+ assert_that(main, equal_to(expected), label='assert:inferences')
+ assert_that(
+ other.failed_inferences, equal_to([]), label='assert:bad_infer')
+
+ # bad will be in form [element, error]. Just pull out bad element.
+ bad_without_error = other.failed_preprocessing[0] | beam.Map(
+ lambda x: x[0])
+ assert_that(
+ bad_without_error, equal_to(expected_bad), label='assert:failures')
+
+ def test_run_inference_postprocessing_dlq(self):
+ def mult_two(example: int) -> str:
+ if example == 6:
+ raise Exception("TEST")
+ return str(example * 2)
+
+ with TestPipeline() as pipeline:
+ examples = [1, 5, 3, 10]
+ expected = ["4", "8", "22"]
+ expected_bad = [6]
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ main, other = pcoll | base.RunInference(
+ FakeModelHandler().with_postprocess_fn(mult_two)
+ ).with_exception_handling()
+ assert_that(main, equal_to(expected), label='assert:inferences')
+ assert_that(
+ other.failed_inferences, equal_to([]), label='assert:bad_infer')
+
+ # bad will be in form [element, error]. Just pull out bad element.
+ bad_without_error = other.failed_postprocessing[0] | beam.Map(
+ lambda x: x[0])
+ assert_that(
+ bad_without_error, equal_to(expected_bad), label='assert:failures')
+
+ def test_run_inference_pre_and_post_processing_dlq(self):
+ def mult_two_pre(example: str) -> int:
+ if example == "5":
+ raise Exception("TEST")
+ return int(example) * 2
+
+ def mult_two_post(example: int) -> str:
+ if example == 7:
+ raise Exception("TEST")
+ return str(example * 2)
+
+ with TestPipeline() as pipeline:
+ examples = ["1", "5", "3", "10"]
+ expected = ["6", "42"]
+ expected_bad_pre = ["5"]
+ expected_bad_post = [7]
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ main, other = pcoll | base.RunInference(
+ FakeModelHandler().with_preprocess_fn(
+ mult_two_pre
+ ).with_postprocess_fn(
+ mult_two_post
+ )).with_exception_handling()
+ assert_that(main, equal_to(expected), label='assert:inferences')
+ assert_that(
+ other.failed_inferences, equal_to([]), label='assert:bad_infer')
+
+ # bad will be in form [elements, error]. Just pull out bad element.
+ bad_without_error_pre = other.failed_preprocessing[0] | beam.Map(
+ lambda x: x[0])
+ assert_that(
+ bad_without_error_pre,
+ equal_to(expected_bad_pre),
+ label='assert:failures_pre')
+
+ # bad will be in form [elements, error]. Just pull out bad element.
+ bad_without_error_post = other.failed_postprocessing[0] | beam.Map(
+ lambda x: x[0])
+ assert_that(
+ bad_without_error_post,
+ equal_to(expected_bad_post),
+ label='assert:failures_post')
+
+ def test_run_inference_keyed_pre_and_post_processing(self):
+ def mult_two(element):
+ return (element[0], element[1] * 2)
+
+ with TestPipeline() as pipeline:
+ examples = [1, 5, 3, 10]
+ keyed_examples = [(i, example) for i, example in enumerate(examples)]
+ expected = [
+ (i, ((example * 2) + 1) * 2) for i, example in enumerate(examples)
+ ]
+ pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
+ actual = pcoll | base.RunInference(
+ base.KeyedModelHandler(FakeModelHandler()).with_preprocess_fn(
+ mult_two).with_postprocess_fn(mult_two))
+ assert_that(actual, equal_to(expected), label='assert:inferences')
+
+ def test_run_inference_maybe_keyed_pre_and_post_processing(self):
+ def mult_two(element):
+ return element * 2
+
+ def mult_two_keyed(element):
+ return (element[0], element[1] * 2)
+
+ with TestPipeline() as pipeline:
+ examples = [1, 5, 3, 10]
+ keyed_examples = [(i, example) for i, example in enumerate(examples)]
+ expected = [((2 * example) + 1) * 2 for example in examples]
+ keyed_expected = [
+ (i, ((2 * example) + 1) * 2) for i, example in enumerate(examples)
+ ]
+ model_handler = base.MaybeKeyedModelHandler(FakeModelHandler())
+
+ pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
+ actual = pcoll | 'RunUnkeyed' >> base.RunInference(
+ model_handler.with_preprocess_fn(mult_two).with_postprocess_fn(
+ mult_two))
+ assert_that(actual, equal_to(expected), label='CheckUnkeyed')
+
+ keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
+ keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
+ model_handler.with_preprocess_fn(mult_two_keyed).with_postprocess_fn(
+ mult_two_keyed))
+ assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
+
def test_run_inference_impl_dlq(self):
with TestPipeline() as pipeline:
examples = [1, 'TEST', 3, 10, 'TEST2']
expected_good = [2, 4, 11]
expected_bad = ['TEST', 'TEST2']
pcoll = pipeline | 'start' >> beam.Create(examples)
- good, bad = pcoll | base.RunInference(
+ main, other = pcoll | base.RunInference(
FakeModelHandler(
min_batch_size=1,
max_batch_size=1
)).with_exception_handling()
- assert_that(good, equal_to(expected_good), label='assert:inferences')
+ assert_that(main, equal_to(expected_good), label='assert:inferences')
- # bad will be in form [batch[elements], error]. Just pull out bad
element.
- bad_without_error = bad | beam.Map(lambda x: x[0][0])
+ # bad.failed_inferences will be in form [batch[elements], error].
+ # Just pull out bad element.
+ bad_without_error = other.failed_inferences | beam.Map(lambda x: x[0][0])
assert_that(
bad_without_error, equal_to(expected_bad), label='assert:failures')