This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 697d2992aeb feat: Add support for custom prediction routes in Vertex 
AI inference (#37155)
697d2992aeb is described below

commit 697d2992aeb975b871107a62e2b196a08d541a04
Author: liferoad <[email protected]>
AuthorDate: Tue Jan 13 10:09:28 2026 -0500

    feat: Add support for custom prediction routes in Vertex AI inference 
(#37155)
    
    * feat: Add support for custom prediction routes in Vertex AI inference 
using the `invoke_route` parameter and custom response parsing.
    
    * lint
    
    * lint 2
    
    * fix: ensure invoke response is bytes and add type hint for request_body
    
    * test: mock `aiplatform.init` in VertexAI inference tests to prevent 
global state pollution.
    
    * lint
    
    * added the IT
    
    * added license
    
    * lint
    
    * updated endpoint
    
    * trigger postcommit
    
    * lint
    
    * lint
    
    * lint
    
    * fixed the response
---
 .github/trigger_files/beam_PostCommit_Python.json  |   5 +-
 .../vertex_ai_custom_prediction/Dockerfile         |  21 +++++
 .../vertex_ai_custom_prediction/README.md          | 103 +++++++++++++++++++++
 .../vertex_ai_custom_prediction/echo_server.py     |  43 +++++++++
 .../ml/inference/vertex_ai_inference.py            |  72 +++++++++++++-
 .../ml/inference/vertex_ai_inference_it_test.py    |  47 ++++++++++
 .../ml/inference/vertex_ai_inference_test.py       |  70 ++++++++++++++
 sdks/python/apache_beam/yaml/yaml_ml.py            |   9 ++
 8 files changed, 364 insertions(+), 6 deletions(-)

diff --git a/.github/trigger_files/beam_PostCommit_Python.json 
b/.github/trigger_files/beam_PostCommit_Python.json
index 47e479f18a9..e43868bf4f2 100644
--- a/.github/trigger_files/beam_PostCommit_Python.json
+++ b/.github/trigger_files/beam_PostCommit_Python.json
@@ -1,6 +1,5 @@
 {
   "comment": "Modify this file in a trivial way to cause this test suite to 
run.",
   "pr": "36271",
-  "modification": 36
-}
-
+  "modification": 37
+}
\ No newline at end of file
diff --git 
a/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/Dockerfile
 
b/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/Dockerfile
new file mode 100644
index 00000000000..a62b9edd406
--- /dev/null
+++ 
b/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/Dockerfile
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+FROM python:3.10-slim
+WORKDIR /app
+RUN pip install flask gunicorn
+COPY echo_server.py main.py
+CMD ["gunicorn", "--bind", "0.0.0.0:8080", "main:app"]
diff --git 
a/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/README.md
 
b/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/README.md
new file mode 100644
index 00000000000..834a27be7f7
--- /dev/null
+++ 
b/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/README.md
@@ -0,0 +1,103 @@
+<!--
+    Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+    Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+-->
+
+# Vertex AI Custom Prediction Route Test Setup
+
+To run the `test_vertex_ai_custom_prediction_route` in 
[vertex_ai_inference_it_test.py](../../vertex_ai_inference_it_test.py), you 
need a dedicated Vertex AI endpoint with an invoke-enabled model deployed.
+
+## Resource Setup Steps
+
+Run these commands in the `apache-beam-testing` project (or your own test 
project).
+
+### 1. Build and Push Container
+
+From this directory:
+
+```bash
+# on Linux
+export PROJECT_ID="apache-beam-testing"  # Or your project
+export IMAGE_URI="gcr.io/${PROJECT_ID}/beam-ml/beam-invoke-echo-model:latest"
+
+docker build -t ${IMAGE_URI} .
+docker push ${IMAGE_URI}
+```
+
+### 2. Upload Model and Deploy Endpoint
+
+Use the Python SDK to deploy (easier than gcloud for specific invocation 
flags).
+
+```python
+from google.cloud import aiplatform
+
+PROJECT_ID = "apache-beam-testing"
+REGION = "us-central1"
+IMAGE_URI = f"gcr.io/{PROJECT_ID}/beam-ml/beam-invoke-echo-model:latest"
+
+aiplatform.init(project=PROJECT_ID, location=REGION)
+
+# 1. Upload Model with invoke route enabled
+model = aiplatform.Model.upload(
+    display_name="beam-invoke-echo-model",
+    serving_container_image_uri=IMAGE_URI,
+    serving_container_invoke_route_prefix="/*",  # <--- Critical for custom 
routes
+    serving_container_health_route="/health",
+    sync=True,
+)
+
+# 2. Create Dedicated Endpoint (required for invoke)
+endpoint = aiplatform.Endpoint.create(
+    display_name="beam-invoke-test-endpoint",
+    dedicated_endpoint_enabled=True,
+    sync=True,
+)
+
+# 3. Deploy Model
+# NOTE: Set min_replica_count=0 to save costs when not testing
+endpoint.deploy(
+    model=model,
+    traffic_percentage=100,
+    machine_type="n1-standard-2",
+    min_replica_count=0,
+    max_replica_count=1,
+    sync=True,
+)
+
+print(f"Deployment Complete!")
+print(f"Endpoint ID: {endpoint.name}")
+```
+
+### 3. Update Test Configuration
+
+1. Copy the **Endpoint ID** printed above (e.g., `1234567890`).
+2. Update `_INVOKE_ENDPOINT_ID` in 
`apache_beam/ml/inference/vertex_ai_inference_it_test.py`.
+
+## Cleanup
+
+To avoid costs, undeploy and delete resources when finished:
+
+```bash
+# Undeploy model from endpoint
+gcloud ai endpoints undeploy-model <ENDPOINT_ID> --deployed-model-id 
<DEPLOYED_MODEL_ID> --region=us-central1
+
+# Delete endpoint
+gcloud ai endpoints delete <ENDPOINT_ID> --region=us-central1
+
+# Delete model
+gcloud ai models delete <MODEL_ID> --region=us-central1
+```
diff --git 
a/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/echo_server.py
 
b/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/echo_server.py
new file mode 100644
index 00000000000..6e48e62a2a7
--- /dev/null
+++ 
b/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/echo_server.py
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from flask import Flask
+from flask import jsonify
+from flask import request
+
+app = Flask(__name__)
+
+
[email protected]('/predict', methods=['POST'])
+def predict():
+  data = request.get_json()
+  # Echo back the instances
+  return jsonify({
+      "predictions": [{
+          "echo": inst
+      } for inst in data.get('instances', [])],
+      "deployedModelId": "echo-model"
+  })
+
+
[email protected]('/health', methods=['GET'])
+def health():
+  return 'OK', 200
+
+
+if __name__ == '__main__':
+  app.run(host='0.0.0.0', port=8080)
diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py 
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
index 9858b59039c..cd3d0beb593 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 
+import json
 import logging
 from collections.abc import Iterable
 from collections.abc import Mapping
@@ -63,6 +64,7 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
       experiment: Optional[str] = None,
       network: Optional[str] = None,
       private: bool = False,
+      invoke_route: Optional[str] = None,
       *,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
@@ -95,6 +97,12 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
       private: optional. if the deployed Vertex AI endpoint is
         private, set to true. Requires a network to be provided
         as well.
+      invoke_route: optional. the custom route path to use when invoking
+        endpoints with arbitrary prediction routes. When specified, uses
+        `Endpoint.invoke()` instead of `Endpoint.predict()`. The route
+        should start with a forward slash, e.g., "/predict/v1".
+        See 
https://cloud.google.com/vertex-ai/docs/predictions/use-arbitrary-custom-routes
+        for more information.
       min_batch_size: optional. the minimum batch size to use when batching
         inputs.
       max_batch_size: optional. the maximum batch size to use when batching
@@ -104,6 +112,7 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
     """
     self._batching_kwargs = {}
     self._env_vars = kwargs.get('env_vars', {})
+    self._invoke_route = invoke_route
     if min_batch_size is not None:
       self._batching_kwargs["min_batch_size"] = min_batch_size
     if max_batch_size is not None:
@@ -203,9 +212,66 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
     Returns:
       An iterable of Predictions.
     """
-    prediction = model.predict(instances=list(batch), 
parameters=inference_args)
-    return utils._convert_to_result(
-        batch, prediction.predictions, prediction.deployed_model_id)
+    if self._invoke_route:
+      # Use invoke() for endpoints with custom prediction routes
+      request_body: dict[str, Any] = {"instances": list(batch)}
+      if inference_args:
+        request_body["parameters"] = inference_args
+      response = model.invoke(
+          request_path=self._invoke_route,
+          body=json.dumps(request_body).encode("utf-8"),
+          headers={"Content-Type": "application/json"})
+      if hasattr(response, "content"):
+        return self._parse_invoke_response(batch, response.content)
+      return self._parse_invoke_response(batch, bytes(response))
+    else:
+      prediction = model.predict(
+          instances=list(batch), parameters=inference_args)
+      return utils._convert_to_result(
+          batch, prediction.predictions, prediction.deployed_model_id)
+
+  def _parse_invoke_response(self, batch: Sequence[Any],
+                             response: bytes) -> Iterable[PredictionResult]:
+    """Parses the response from Endpoint.invoke() into PredictionResults.
+
+    Args:
+      batch: the original batch of inputs.
+      response: the raw bytes response from invoke().
+
+    Returns:
+      An iterable of PredictionResults.
+    """
+    try:
+      response_json = json.loads(response.decode("utf-8"))
+    except (json.JSONDecodeError, UnicodeDecodeError) as e:
+      LOGGER.warning(
+          "Failed to decode invoke response as JSON, returning raw bytes: %s",
+          e)
+      # Return raw response for each batch item
+      return [
+          PredictionResult(example=example, inference=response)
+          for example in batch
+      ]
+
+    # Handle standard Vertex AI response format with "predictions" key
+    if isinstance(response_json, dict) and "predictions" in response_json:
+      predictions = response_json["predictions"]
+      model_id = response_json.get("deployedModelId")
+      return utils._convert_to_result(batch, predictions, model_id)
+
+    # Handle response as a list of predictions (one per input)
+    if isinstance(response_json, list) and len(response_json) == len(batch):
+      return utils._convert_to_result(batch, response_json, None)
+
+    # Handle single prediction response
+    if len(batch) == 1:
+      return [PredictionResult(example=batch[0], inference=response_json)]
+
+    # Fallback: return the full response for each batch item
+    return [
+        PredictionResult(example=example, inference=response_json)
+        for example in batch
+    ]
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
diff --git 
a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py 
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py
index c6d62eb3e3e..11643992c39 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_it_test.py
@@ -23,12 +23,15 @@ import uuid
 
 import pytest
 
+import apache_beam as beam
 from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import RunInference
 from apache_beam.testing.test_pipeline import TestPipeline
 
 # pylint: disable=ungrouped-imports
 try:
   from apache_beam.examples.inference import vertex_ai_image_classification
+  from apache_beam.ml.inference.vertex_ai_inference import 
VertexAIModelHandlerJSON
 except ImportError as e:
   raise unittest.SkipTest(
       "Vertex AI model handler dependencies are not installed")
@@ -42,6 +45,13 @@ _ENDPOINT_NETWORK = 
"projects/844138762903/global/networks/beam-test-vpc"
 # pylint: disable=line-too-long
 _SUBNETWORK = 
"https://www.googleapis.com/compute/v1/projects/apache-beam-testing/regions/us-central1/subnetworks/beam-test-vpc";
 
+# Constants for custom prediction routes (invoke) test
+# Follow 
beam/sdks/python/apache_beam/ml/inference/test_resources/vertex_ai_custom_prediction/README.md
+# to get endpoint ID after deploying invoke-enabled model
+_INVOKE_ENDPOINT_ID = "6890840581900075008"
+_INVOKE_ROUTE = "/predict"
+_INVOKE_OUTPUT_DIR = "gs://apache-beam-ml/testing/outputs/vertex_invoke"
+
 
 class VertexAIInference(unittest.TestCase):
   @pytest.mark.vertex_ai_postcommit
@@ -63,6 +73,43 @@ class VertexAIInference(unittest.TestCase):
         test_pipeline.get_full_options_as_args(**extra_opts))
     self.assertEqual(FileSystems().exists(output_file), True)
 
+  @pytest.mark.vertex_ai_postcommit
+  @unittest.skipIf(
+      not _INVOKE_ENDPOINT_ID,
+      "Invoke endpoint not configured. Set _INVOKE_ENDPOINT_ID.")
+  def test_vertex_ai_custom_prediction_route(self):
+    """Test custom prediction routes using invoke_route parameter.
+
+    This test verifies that VertexAIModelHandlerJSON correctly uses
+    Endpoint.invoke() instead of Endpoint.predict() when invoke_route
+    is specified, enabling custom prediction routes.
+    """
+    output_file = '/'.join(
+        [_INVOKE_OUTPUT_DIR, str(uuid.uuid4()), 'output.txt'])
+
+    test_pipeline = TestPipeline(is_integration_test=True)
+
+    model_handler = VertexAIModelHandlerJSON(
+        endpoint_id=_INVOKE_ENDPOINT_ID,
+        project=_ENDPOINT_PROJECT,
+        location=_ENDPOINT_REGION,
+        invoke_route=_INVOKE_ROUTE)
+
+    # Test inputs - simple data to echo back
+    test_inputs = [{"value": 1}, {"value": 2}, {"value": 3}]
+
+    with test_pipeline as p:
+      results = (
+          p
+          | "CreateInputs" >> beam.Create(test_inputs)
+          | "RunInference" >> RunInference(model_handler)
+          | "ExtractResults" >>
+          beam.Map(lambda result: f"{result.example}:{result.inference}"))
+      _ = results | "WriteOutput" >> beam.io.WriteToText(
+          output_file, shard_name_template='')
+
+    self.assertTrue(FileSystems().exists(output_file))
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.DEBUG)
diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py 
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
index 91a3b82cf76..8aa638ebe7c 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
@@ -48,5 +48,75 @@ class ModelHandlerArgConditions(unittest.TestCase):
         private=True)
 
 
+class ParseInvokeResponseTest(unittest.TestCase):
+  """Tests for _parse_invoke_response method."""
+  def _create_handler_with_invoke_route(self, invoke_route="/test"):
+    """Creates a mock handler with invoke_route for testing."""
+    import unittest.mock as mock
+
+    # Mock both _retrieve_endpoint and aiplatform.init to prevent test
+    # pollution of global aiplatform state
+    with mock.patch.object(VertexAIModelHandlerJSON,
+                           '_retrieve_endpoint',
+                           return_value=None):
+      with mock.patch('google.cloud.aiplatform.init'):
+        handler = VertexAIModelHandlerJSON(
+            endpoint_id="1",
+            project="testproject",
+            location="us-central1",
+            invoke_route=invoke_route)
+    return handler
+
+  def test_parse_invoke_response_with_predictions_key(self):
+    """Test parsing response with standard 'predictions' key."""
+    handler = self._create_handler_with_invoke_route()
+    batch = [{"input": "test1"}, {"input": "test2"}]
+    response = (
+        b'{"predictions": ["result1", "result2"], '
+        b'"deployedModelId": "model123"}')
+
+    results = list(handler._parse_invoke_response(batch, response))
+
+    self.assertEqual(len(results), 2)
+    self.assertEqual(results[0].example, {"input": "test1"})
+    self.assertEqual(results[0].inference, "result1")
+    self.assertEqual(results[1].example, {"input": "test2"})
+    self.assertEqual(results[1].inference, "result2")
+
+  def test_parse_invoke_response_list_format(self):
+    """Test parsing response as a list of predictions."""
+    handler = self._create_handler_with_invoke_route()
+    batch = [{"input": "test1"}, {"input": "test2"}]
+    response = b'["result1", "result2"]'
+
+    results = list(handler._parse_invoke_response(batch, response))
+
+    self.assertEqual(len(results), 2)
+    self.assertEqual(results[0].inference, "result1")
+    self.assertEqual(results[1].inference, "result2")
+
+  def test_parse_invoke_response_single_prediction(self):
+    """Test parsing response with a single prediction."""
+    handler = self._create_handler_with_invoke_route()
+    batch = [{"input": "test1"}]
+    response = b'{"output": "single result"}'
+
+    results = list(handler._parse_invoke_response(batch, response))
+
+    self.assertEqual(len(results), 1)
+    self.assertEqual(results[0].inference, {"output": "single result"})
+
+  def test_parse_invoke_response_non_json(self):
+    """Test handling non-JSON response."""
+    handler = self._create_handler_with_invoke_route()
+    batch = [{"input": "test1"}]
+    response = b'not valid json'
+
+    results = list(handler._parse_invoke_response(batch, response))
+
+    self.assertEqual(len(results), 1)
+    self.assertEqual(results[0].inference, response)
+
+
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py 
b/sdks/python/apache_beam/yaml/yaml_ml.py
index e5a88f54eba..4e750b79ce3 100644
--- a/sdks/python/apache_beam/yaml/yaml_ml.py
+++ b/sdks/python/apache_beam/yaml/yaml_ml.py
@@ -168,6 +168,7 @@ class 
VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
       experiment: Optional[str] = None,
       network: Optional[str] = None,
       private: bool = False,
+      invoke_route: Optional[str] = None,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None):
@@ -236,6 +237,13 @@ class 
VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
       private: If the deployed Vertex AI endpoint is
         private, set to true. Requires a network to be provided
         as well.
+      invoke_route: The custom route path to use when invoking
+        endpoints with arbitrary prediction routes. When specified, uses
+        `Endpoint.invoke()` instead of `Endpoint.predict()`. The route
+        should start with a forward slash, e.g., "/predict/v1".
+        See
+        
https://cloud.google.com/vertex-ai/docs/predictions/use-arbitrary-custom-routes
+        for more information.
       min_batch_size: The minimum batch size to use when batching
         inputs.
       max_batch_size: The maximum batch size to use when batching
@@ -258,6 +266,7 @@ class 
VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
         experiment=experiment,
         network=network,
         private=private,
+        invoke_route=invoke_route,
         min_batch_size=min_batch_size,
         max_batch_size=max_batch_size,
         max_batch_duration_secs=max_batch_duration_secs)

Reply via email to