This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch users/damccorm/inference_args in repository https://gitbox.apache.org/repos/asf/beam.git
commit 8feee257a048f4bf0f07190ef6de7086e449d442 Author: Danny Mccormick <[email protected]> AuthorDate: Fri Dec 12 10:23:33 2025 -0500 Allow inference args to be passed in for most cases --- sdks/python/apache_beam/ml/inference/base.py | 13 +++++-------- sdks/python/apache_beam/ml/inference/pytorch_inference.py | 6 ------ sdks/python/apache_beam/ml/inference/sklearn_inference.py | 3 ++- .../python/apache_beam/ml/inference/tensorflow_inference.py | 6 ------ sdks/python/apache_beam/ml/inference/tensorrt_inference.py | 10 ++++++++++ sdks/python/apache_beam/ml/inference/vertex_ai_inference.py | 3 --- 6 files changed, 17 insertions(+), 24 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 2e1c4963f11..d79565ee24d 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -213,15 +213,12 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]): return {} def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - """Validates inference_args passed in the inference call. - - Because most frameworks do not need extra arguments in their predict() call, - the default behavior is to error out if inference_args are present. """ - if inference_args: - raise ValueError( - 'inference_args were provided, but should be None because this ' - 'framework does not expect extra arguments on inferences.') + Allows model handlers to provide some validation to make sure passed in + inference args are valid. Some ModelHandlers throw here to disallow + inference args altogether. + """ + pass def update_model_path(self, model_path: Optional[str] = None): """ diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index f73eeff808c..affbcd977f5 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -342,9 +342,6 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, """ return 'BeamML_PyTorch' - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self): return self._batching_kwargs @@ -590,9 +587,6 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[dict[str, torch.Tensor], """ return 'BeamML_PyTorch' - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self): return self._batching_kwargs diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 1e5962ba64c..84947bec3df 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -73,9 +73,10 @@ def _default_numpy_inference_fn( model: BaseEstimator, batch: Sequence[numpy.ndarray], inference_args: Optional[dict[str, Any]] = None) -> Any: + inference_args = {} if not inference_args else inference_args # vectorize data for better performance vectorized_batch = numpy.stack(batch, axis=0) - return model.predict(vectorized_batch) + return model.predict(vectorized_batch, **inference_args) class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray, diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py index d13ea53cf1b..5ce293a06ac 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py @@ -219,9 +219,6 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray, """ return 'BeamML_TF_Numpy' - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self): return self._batching_kwargs @@ -360,9 +357,6 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult, """ return 'BeamML_TF_Tensor' - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self): return self._batching_kwargs diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py index 1b11bd9f39e..b575dfa849d 100644 --- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py @@ -341,3 +341,13 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray, def model_copies(self) -> int: return self._model_copies + + def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): + """ + Currently, this model handler does not support inference args. Given that, + we will throw if any are passed in. + """ + if inference_args: + raise ValueError( + 'inference_args were provided, but should be None because this ' + 'framework does not expect extra arguments on inferences.') diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py index 471f2379cfb..9858b59039c 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py @@ -207,8 +207,5 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any, return utils._convert_to_result( batch, prediction.predictions, prediction.deployed_model_id) - def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - pass - def batch_elements_kwargs(self) -> Mapping[str, Any]: return self._batching_kwargs
