This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/inference_args
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 8feee257a048f4bf0f07190ef6de7086e449d442
Author: Danny Mccormick <[email protected]>
AuthorDate: Fri Dec 12 10:23:33 2025 -0500

    Allow inference args to be passed in for most cases
---
 sdks/python/apache_beam/ml/inference/base.py                | 13 +++++--------
 sdks/python/apache_beam/ml/inference/pytorch_inference.py   |  6 ------
 sdks/python/apache_beam/ml/inference/sklearn_inference.py   |  3 ++-
 .../python/apache_beam/ml/inference/tensorflow_inference.py |  6 ------
 sdks/python/apache_beam/ml/inference/tensorrt_inference.py  | 10 ++++++++++
 sdks/python/apache_beam/ml/inference/vertex_ai_inference.py |  3 ---
 6 files changed, 17 insertions(+), 24 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 2e1c4963f11..d79565ee24d 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -213,15 +213,12 @@ class ModelHandler(Generic[ExampleT, PredictionT, 
ModelT]):
     return {}
 
   def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
-    """Validates inference_args passed in the inference call.
-
-    Because most frameworks do not need extra arguments in their predict() 
call,
-    the default behavior is to error out if inference_args are present.
     """
-    if inference_args:
-      raise ValueError(
-          'inference_args were provided, but should be None because this '
-          'framework does not expect extra arguments on inferences.')
+    Allows model handlers to provide some validation to make sure passed in
+    inference args are valid. Some ModelHandlers throw here to disallow
+    inference args altogether.
+    """
+    pass
 
   def update_model_path(self, model_path: Optional[str] = None):
     """
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py 
b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index f73eeff808c..affbcd977f5 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -342,9 +342,6 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
     """
     return 'BeamML_PyTorch'
 
-  def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
-    pass
-
   def batch_elements_kwargs(self):
     return self._batching_kwargs
 
@@ -590,9 +587,6 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[dict[str, 
torch.Tensor],
     """
     return 'BeamML_PyTorch'
 
-  def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
-    pass
-
   def batch_elements_kwargs(self):
     return self._batching_kwargs
 
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py 
b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index 1e5962ba64c..84947bec3df 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -73,9 +73,10 @@ def _default_numpy_inference_fn(
     model: BaseEstimator,
     batch: Sequence[numpy.ndarray],
     inference_args: Optional[dict[str, Any]] = None) -> Any:
+  inference_args = {} if not inference_args else inference_args
   # vectorize data for better performance
   vectorized_batch = numpy.stack(batch, axis=0)
-  return model.predict(vectorized_batch)
+  return model.predict(vectorized_batch, **inference_args)
 
 
 class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py 
b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
index d13ea53cf1b..5ce293a06ac 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
@@ -219,9 +219,6 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
     """
     return 'BeamML_TF_Numpy'
 
-  def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
-    pass
-
   def batch_elements_kwargs(self):
     return self._batching_kwargs
 
@@ -360,9 +357,6 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, 
PredictionResult,
     """
     return 'BeamML_TF_Tensor'
 
-  def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
-    pass
-
   def batch_elements_kwargs(self):
     return self._batching_kwargs
 
diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py 
b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
index 1b11bd9f39e..b575dfa849d 100644
--- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
@@ -341,3 +341,13 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
 
   def model_copies(self) -> int:
     return self._model_copies
+
+  def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
+    """
+    Currently, this model handler does not support inference args. Given that,
+    we will throw if any are passed in.
+    """
+    if inference_args:
+      raise ValueError(
+          'inference_args were provided, but should be None because this '
+          'framework does not expect extra arguments on inferences.')
diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py 
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
index 471f2379cfb..9858b59039c 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
@@ -207,8 +207,5 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
     return utils._convert_to_result(
         batch, prediction.predictions, prediction.deployed_model_id)
 
-  def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
-    pass
-
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs

Reply via email to