This is an automated email from the ASF dual-hosted git repository.
jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new fe188e3635b [#34236] Add Vertex AI Multi-Modal embedding handler
(#35677)
fe188e3635b is described below
commit fe188e3635b894fef4f6d2b2f7eda0c09608556f
Author: Jack McCluskey <[email protected]>
AuthorDate: Mon Sep 8 16:51:33 2025 -0400
[#34236] Add Vertex AI Multi-Modal embedding handler (#35677)
* Prototype Vertex MultiModal embedding handler
* remove unused types
* change temp file and artifact paths to use dedicated directories
* formatting
* quick unit tests for the base multimodal embedding handler
* Migrate to input adapter, add testing for video
* linting
* isort
* made segment configuration per-video instance
* fix corrected video input type
* speed up video test by passing a GCS URI instead of loading the video
* formatting
* move to wrapped inputs
* clarify types in dict_input_fn
* linting
* fix chunk construction
* update main input to use wrappers
---
sdks/python/apache_beam/ml/transforms/base.py | 39 ++++
sdks/python/apache_beam/ml/transforms/base_test.py | 117 ++++++++++
.../ml/transforms/embeddings/vertex_ai.py | 241 ++++++++++++++++++++-
.../ml/transforms/embeddings/vertex_ai_test.py | 107 +++++++++
4 files changed, 501 insertions(+), 3 deletions(-)
diff --git a/sdks/python/apache_beam/ml/transforms/base.py
b/sdks/python/apache_beam/ml/transforms/base.py
index 3b95ed719e5..4031777ce15 100644
--- a/sdks/python/apache_beam/ml/transforms/base.py
+++ b/sdks/python/apache_beam/ml/transforms/base.py
@@ -810,3 +810,42 @@ class _ImageEmbeddingHandler(_EmbeddingHandler):
return (
self._underlying.get_metrics_namespace() or
'BeamML_ImageEmbeddingHandler')
+
+
+class _MultiModalEmbeddingHandler(_EmbeddingHandler):
+ """
+ A ModelHandler intended to be work on
+ list[dict[str, TypedDict(Image, Video, str)]] inputs.
+
+ The inputs to the model handler are expected to be a list of dicts.
+
+ For example, if the original mode is used with RunInference to take a
+ PCollection[E] to a PCollection[P], this ModelHandler would take a
+ PCollection[dict[str, E]] to a PCollection[dict[str, P]].
+
+ _MultiModalEmbeddingHandler will accept an EmbeddingsManager instance, which
+ contains the details of the model to be loaded and the inference_fn to be
+ used. The purpose of _MultiMOdalEmbeddingHandler is to generate embeddings
+ for image, video, and text inputs using the EmbeddingsManager instance.
+
+ If the input is not an Image representation column, a RuntimeError will be
+ raised.
+
+ This is an internal class and offers no backwards compatibility guarantees.
+
+ Args:
+ embeddings_manager: An EmbeddingsManager instance.
+ """
+ def _validate_column_data(self, batch):
+ # Don't want to require framework-specific imports
+ # here, so just catch columns of primatives for now.
+ if isinstance(batch[0], (int, str, float, bool)):
+ raise TypeError(
+ 'Embeddings can only be generated on '
+ ' dict[str, dataclass] types. '
+ f'Got dict[str, {type(batch[0])}] instead.')
+
+ def get_metrics_namespace(self) -> str:
+ return (
+ self._underlying.get_metrics_namespace() or
+ 'BeamML_MultiModalEmbeddingHandler')
diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py
b/sdks/python/apache_beam/ml/transforms/base_test.py
index 309c085f08f..190381cc2f3 100644
--- a/sdks/python/apache_beam/ml/transforms/base_test.py
+++ b/sdks/python/apache_beam/ml/transforms/base_test.py
@@ -23,6 +23,7 @@ import tempfile
import time
import unittest
from collections.abc import Sequence
+from dataclasses import dataclass
from typing import Any
from typing import Optional
@@ -629,6 +630,122 @@ class TestImageEmbeddingHandler(unittest.TestCase):
)
+@dataclass
+class FakeMultiModalInput:
+ image: Optional[PIL_Image] = None
+ video: Optional[Any] = None
+ text: Optional[str] = None
+
+
+class FakeMultiModalModel:
+ def __call__(self,
+ example: list[FakeMultiModalInput]) ->
list[FakeMultiModalInput]:
+ for i in range(len(example)):
+ if not isinstance(example[i], FakeMultiModalInput):
+ raise TypeError('Input must be a MultiModalInput')
+ return example
+
+
+class FakeMultiModalModelHandler(ModelHandler):
+ def run_inference(
+ self,
+ batch: Sequence[FakeMultiModalInput],
+ model: Any,
+ inference_args: Optional[dict[str, Any]] = None):
+ return model(batch)
+
+ def load_model(self):
+ return FakeMultiModalModel()
+
+
+class FakeMultiModalEmbeddingsManager(base.EmbeddingsManager):
+ def __init__(self, columns, **kwargs):
+ super().__init__(columns=columns, **kwargs)
+
+ def get_model_handler(self) -> ModelHandler:
+ FakeModelHandler.__repr__ = lambda x: 'FakeMultiModalEmbeddingsManager' #
type: ignore[method-assign]
+ return FakeMultiModalModelHandler()
+
+ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
+ return (RunInference(model_handler=base._MultiModalEmbeddingHandler(self)))
+
+ def __repr__(self):
+ return 'FakeMultiModalEmbeddingsManager'
+
+
+class TestMultiModalEmbeddingHandler(unittest.TestCase):
+ def setUp(self) -> None:
+ self.embedding_config = FakeMultiModalEmbeddingsManager(columns=['x'])
+ self.artifact_location = tempfile.mkdtemp()
+
+ def tearDown(self) -> None:
+ shutil.rmtree(self.artifact_location)
+
+ @unittest.skipIf(PIL is None, 'PIL module is not installed.')
+ def test_handler_with_non_dict_datatype(self):
+ image_handler = base._MultiModalEmbeddingHandler(
+ embeddings_manager=self.embedding_config)
+ data = [
+ ('x', 'hi there'),
+ ('x', 'not an image'),
+ ('x', 'image_path.jpg'),
+ ]
+ with self.assertRaises(TypeError):
+ image_handler.run_inference(data, None, None)
+
+ @unittest.skipIf(PIL is None, 'PIL module is not installed.')
+ def test_handler_with_incorrect_datatype(self):
+ image_handler = base._MultiModalEmbeddingHandler(
+ embeddings_manager=self.embedding_config)
+ data = [
+ {
+ 'x': 'hi there'
+ },
+ {
+ 'x': 'not an image'
+ },
+ {
+ 'x': 'image_path.jpg'
+ },
+ ]
+ with self.assertRaises(TypeError):
+ image_handler.run_inference(data, None, None)
+
+ @unittest.skipIf(PIL is None, 'PIL module is not installed.')
+ def test_handler_with_dict_inputs(self):
+ input_one = FakeMultiModalInput(
+ image=PIL.Image.new(mode='RGB', size=(1, 1)), text="test image one")
+ input_two = FakeMultiModalInput(
+ image=PIL.Image.new(mode='RGB', size=(1, 1)), text="test image two")
+ input_three = FakeMultiModalInput(
+ image=PIL.Image.new(mode='RGB', size=(1, 1)),
+ video=bytes.fromhex('2Ef0 F1f2 '),
+ text="test image three with video")
+ data = [
+ {
+ 'x': input_one
+ },
+ {
+ 'x': input_two
+ },
+ {
+ 'x': input_three
+ },
+ ]
+ expected_data = [{key: value for key, value in d.items()} for d in data]
+ with beam.Pipeline() as p:
+ result = (
+ p
+ | beam.Create(data)
+ | base.MLTransform(
+ write_artifact_location=self.artifact_location).with_transform(
+ self.embedding_config))
+ assert_that(
+ result,
+ equal_to(expected_data),
+ )
+
+
class TestUtilFunctions(unittest.TestCase):
def test_dict_input_fn_normal(self):
input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]
diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py
b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py
index a645ce32e2a..c7c46d246b9 100644
--- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py
+++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py
@@ -19,10 +19,14 @@
# Follow
https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk #
pylint: disable=line-too-long
# to install Vertex AI Python SDK.
+import functools
import logging
+from collections.abc import Callable
from collections.abc import Sequence
+from dataclasses import dataclass
from typing import Any
from typing import Optional
+from typing import cast
from google.api_core.exceptions import ServerError
from google.api_core.exceptions import TooManyRequests
@@ -33,15 +37,28 @@ import vertexai
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import RemoteModelHandler
from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.types import Embedding
from apache_beam.ml.transforms.base import EmbeddingsManager
+from apache_beam.ml.transforms.base import EmbeddingTypeAdapter
from apache_beam.ml.transforms.base import _ImageEmbeddingHandler
+from apache_beam.ml.transforms.base import _MultiModalEmbeddingHandler
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from vertexai.language_models import TextEmbeddingInput
from vertexai.language_models import TextEmbeddingModel
from vertexai.vision_models import Image
from vertexai.vision_models import MultiModalEmbeddingModel
-
-__all__ = ["VertexAITextEmbeddings", "VertexAIImageEmbeddings"]
+from vertexai.vision_models import MultiModalEmbeddingResponse
+from vertexai.vision_models import Video
+from vertexai.vision_models import VideoEmbedding
+from vertexai.vision_models import VideoSegmentConfig
+
+__all__ = [
+ "VertexAITextEmbeddings",
+ "VertexAIImageEmbeddings",
+ "VertexAIMultiModalEmbeddings",
+ "VertexAIMultiModalInput",
+]
DEFAULT_TASK_TYPE = "RETRIEVAL_DOCUMENT"
# TODO: https://github.com/apache/beam/issues/29356
@@ -54,7 +71,6 @@ TASK_TYPE_INPUTS = [
"CLUSTERING"
]
_BATCH_SIZE = 5 # Vertex AI limits requests to 5 at a time.
-_MSEC_TO_SEC = 1000
LOGGER = logging.getLogger("VertexAIEmbeddings")
@@ -281,3 +297,222 @@ class VertexAIImageEmbeddings(EmbeddingsManager):
return RunInference(
model_handler=_ImageEmbeddingHandler(self),
inference_args=self.inference_args)
+
+
+@dataclass
+class VertexImage:
+ image_content: Image
+ embedding: Optional[list[float]] = None
+
+
+@dataclass
+class VertexVideo:
+ video_content: Video
+ config: VideoSegmentConfig
+ embeddings: Optional[list[VideoEmbedding]] = None
+
+
+@dataclass
+class VertexAIMultiModalInput:
+ image: Optional[VertexImage] = None
+ video: Optional[VertexVideo] = None
+ contextual_text: Optional[Chunk] = None
+
+
+class _VertexAIMultiModalEmbeddingHandler(RemoteModelHandler):
+ def __init__(
+ self,
+ model_name: str,
+ dimension: Optional[int] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[Credentials] = None,
+ **kwargs):
+ vertexai.init(project=project, location=location, credentials=credentials)
+ self.model_name = model_name
+ self.dimension = dimension
+
+ super().__init__(
+ namespace='VertexAIMultiModelEmbeddingHandler',
+ retry_filter=_retry_on_appropriate_gcp_error,
+ **kwargs)
+
+ def request(
+ self,
+ batch: Sequence[VertexAIMultiModalInput],
+ model: MultiModalEmbeddingModel,
+ inference_args: Optional[dict[str, Any]] = None):
+ embeddings = []
+ # Max request size for multi-modal embedding models is 1
+ for input in batch:
+ image_content: Optional[Image] = None
+ video_content: Optional[Video] = None
+ text_content: Optional[str] = None
+ video_config: Optional[VideoSegmentConfig] = None
+
+ if input.image:
+ image_content = input.image.image_content
+ if input.video:
+ video_content = input.video.video_content
+ video_config = input.video.config
+ if input.contextual_text:
+ text_content = input.contextual_text.content.text
+
+ prediction = model.get_embeddings(
+ image=image_content,
+ video=video_content,
+ contextual_text=text_content,
+ dimension=self.dimension,
+ video_segment_config=video_config)
+ embeddings.append(prediction)
+ return embeddings
+
+ def create_client(self) -> MultiModalEmbeddingModel:
+ model = MultiModalEmbeddingModel.from_pretrained(self.model_name)
+ return model
+
+ def __repr__(self):
+ # ModelHandler is internal to the user and is not exposed.
+ # Hence we need to override the __repr__ method to expose
+ # the name of the class.
+ return 'VertexAIMultiModalEmbeddings'
+
+
+def _multimodal_dict_input_fn(
+ image_column: Optional[str],
+ video_column: Optional[str],
+ text_column: Optional[str],
+ batch: Sequence[dict[str, Any]]) -> list[VertexAIMultiModalInput]:
+ multimodal_inputs: list[VertexAIMultiModalInput] = []
+ for item in batch:
+ img: Optional[VertexImage] = None
+ vid: Optional[VertexVideo] = None
+ text: Optional[Chunk] = None
+ if image_column:
+ img = item[image_column]
+ if video_column:
+ vid = item[video_column]
+ if text_column:
+ text = item[text_column]
+ multimodal_inputs.append(
+ VertexAIMultiModalInput(image=img, video=vid, contextual_text=text))
+ return multimodal_inputs
+
+
+def _multimodal_dict_output_fn(
+ image_column: Optional[str],
+ video_column: Optional[str],
+ text_column: Optional[str],
+ batch: Sequence[dict[str, Any]],
+ embeddings: Sequence[MultiModalEmbeddingResponse]) -> list[dict[str, Any]]:
+ results = []
+ for batch_idx, item in enumerate(batch):
+ mm_embedding = embeddings[batch_idx]
+ if image_column:
+ item[image_column].embedding = mm_embedding.image_embedding
+ if video_column:
+ item[video_column].embeddings = mm_embedding.video_embeddings
+ if text_column:
+ item[text_column].embedding = Embedding(
+ dense_embedding=mm_embedding.text_embedding)
+ results.append(item)
+ return results
+
+
+def _create_multimodal_dict_adapter(
+ image_column: Optional[str],
+ video_column: Optional[str],
+ text_column: Optional[str]
+) -> EmbeddingTypeAdapter[dict[str, Any], dict[str, Any]]:
+ return EmbeddingTypeAdapter[dict[str, Any], dict[str, Any]](
+ input_fn=cast(
+ Callable[[Sequence[dict[str, Any]]], list[str]],
+ functools.partial(
+ _multimodal_dict_input_fn,
+ image_column,
+ video_column,
+ text_column)),
+ output_fn=cast(
+ Callable[[Sequence[dict[str, Any]], Sequence[Any]],
+ list[dict[str, Any]]],
+ functools.partial(
+ _multimodal_dict_output_fn,
+ image_column,
+ video_column,
+ text_column)))
+
+
+class VertexAIMultiModalEmbeddings(EmbeddingsManager):
+ def __init__(
+ self,
+ model_name: str,
+ image_column: Optional[str] = None,
+ video_column: Optional[str] = None,
+ text_column: Optional[str] = None,
+ dimension: Optional[int] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[Credentials] = None,
+ **kwargs):
+ """
+ Embedding Config for Vertex AI Multi-Modal Embedding models following
+
https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-multimodal-embeddings
# pylint: disable=line-too-long
+ Multi-Modal Embeddings are generated for a batch of image, video, and
+ string groupings using the Vertex AI API. Embeddings are returned in a list
+ for each image in the batch as MultiModalEmbeddingResponses. This
+ transform makes remote calls to the Vertex AI service and may incur costs
+ for use.
+
+ Args:
+ model_name: The name of the Vertex AI Multi-Modal Embedding model.
+ image_column: The column containing image data to be embedded. This data
+ is expected to be formatted as VertexImage objects, containing a Vertex
+ Image object.
+ video_column: The column containing video data to be embedded. This data
+ is expected to be formatted as VertexVideo objects, containing a Vertex
+ Video object an a VideoSegmentConfig object.
+ text_column: The column containing text data to be embedded. This data is
+ expected to be formatted as Chunk objects, containing the string to be
+ embedded in the Chunk's content field.
+ dimension: The length of the embedding vector to generate. Must be one of
+ 128, 256, 512, or 1408. If not set, Vertex AI's default value is 1408.
+ If submitting video content, dimension *musst* be 1408.
+ project: The default GCP project for API calls.
+ location: The default location for API calls.
+ credentials: Custom credentials for API calls.
+ Defaults to environment credentials.
+ """
+ self.model_name = model_name
+ self.project = project
+ self.location = location
+ self.credentials = credentials
+ self.kwargs = kwargs
+ if dimension is not None and dimension not in (128, 256, 512, 1408):
+ raise ValueError(
+ "dimension argument must be one of 128, 256, 512, or 1408")
+ self.dimension = dimension
+ if not image_column and not video_column and not text_column:
+ raise ValueError("at least one input column must be specified")
+ if video_column is not None and dimension != 1408:
+ raise ValueError(
+ "Vertex AI does not support custom dimensions for video input, want
dimension = 1408, got ",
+ dimension)
+ self.type_adapter = _create_multimodal_dict_adapter(
+ image_column=image_column,
+ video_column=video_column,
+ text_column=text_column)
+ super().__init__(type_adapter=self.type_adapter, **kwargs)
+
+ def get_model_handler(self) -> ModelHandler:
+ return _VertexAIMultiModalEmbeddingHandler(
+ model_name=self.model_name,
+ dimension=self.dimension,
+ project=self.project,
+ location=self.location,
+ credentials=self.credentials,
+ **self.kwargs)
+
+ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
+ return RunInference(
+ model_handler=_MultiModalEmbeddingHandler(self),
+ inference_args=self.inference_args)
diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py
b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py
index 1a47f81b665..ba43ea32508 100644
--- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py
+++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai_test.py
@@ -26,10 +26,18 @@ from apache_beam.ml.transforms import base
from apache_beam.ml.transforms.base import MLTransform
try:
+ from apache_beam.ml.rag.types import Chunk
+ from apache_beam.ml.rag.types import Content
+ from apache_beam.ml.transforms.embeddings.vertex_ai import
VertexAIMultiModalEmbeddings
from apache_beam.ml.transforms.embeddings.vertex_ai import
VertexAITextEmbeddings
from apache_beam.ml.transforms.embeddings.vertex_ai import
VertexAIImageEmbeddings
+ from apache_beam.ml.transforms.embeddings.vertex_ai import VertexImage
+ from apache_beam.ml.transforms.embeddings.vertex_ai import VertexVideo
from vertexai.vision_models import Image
+ from vertexai.vision_models import Video
+ from vertexai.vision_models import VideoSegmentConfig
except ImportError:
+ VertexAIMultiModalEmbeddings = None # type: ignore
VertexAITextEmbeddings = None # type: ignore
VertexAIImageEmbeddings = None # type: ignore
@@ -286,5 +294,104 @@ class VertexAIImageEmbeddingsTest(unittest.TestCase):
dimension=127)
+image_feature_column: str = "img_feature"
+text_feature_column: str = "txt_feature"
+video_feature_column: str = "vid_feature"
+
+
+def _make_text_chunk(input: str) -> Chunk:
+ return Chunk(content=Content(text=input))
+
+
[email protected](
+ VertexAIMultiModalEmbeddings is None,
+ 'Vertex AI Python SDK is not installed.')
+class VertexAIMultiModalEmbeddingsTest(unittest.TestCase):
+ def setUp(self) -> None:
+ self.artifact_location = tempfile.mkdtemp(
+ prefix='_vertex_ai_multi_modal_test')
+ self.gcs_artifact_location = os.path.join(
+ 'gs://temp-storage-for-perf-tests/vertex_ai_multi_modal',
+ uuid.uuid4().hex)
+ self.model_name = "multimodalembedding"
+ self.image_path =
"gs://apache-beam-ml/testing/inputs/vertex_images/sunflowers/1008566138_6927679c8a.jpg"
# pylint: disable=line-too-long
+ self.video_path =
"gs://cloud-samples-data/vertex-ai-vision/highway_vehicles.mp4" # pylint:
disable=line-too-long
+ self.video_segment_config = VideoSegmentConfig(end_offset_sec=1)
+
+ def tearDown(self) -> None:
+ shutil.rmtree(self.artifact_location)
+
+ def test_vertex_ai_multimodal_embedding_img_and_text(self):
+ embedding_config = VertexAIMultiModalEmbeddings(
+ model_name=self.model_name,
+ image_column=image_feature_column,
+ text_column=text_feature_column,
+ dimension=128,
+ project="apache-beam-testing",
+ location="us-central1")
+ with beam.Pipeline() as pipeline:
+ transformed_pcoll = (
+ pipeline | "CreateData" >> beam.Create([{
+ image_feature_column: VertexImage(
+ image_content=Image(gcs_uri=self.image_path)),
+ text_feature_column: _make_text_chunk("an image of sunflowers"),
+ }])
+ | "MLTransform" >> MLTransform(
+ write_artifact_location=self.artifact_location).with_transform(
+ embedding_config))
+
+ def assert_element(element):
+ assert len(element[image_feature_column].embedding) == 128
+ assert len(
+ element[text_feature_column].embedding.dense_embedding) == 128
+
+ _ = (transformed_pcoll | beam.Map(assert_element))
+
+ def test_vertex_ai_multimodal_embedding_video(self):
+ embedding_config = VertexAIMultiModalEmbeddings(
+ model_name=self.model_name,
+ video_column=video_feature_column,
+ dimension=1408,
+ project="apache-beam-testing",
+ location="us-central1")
+ with beam.Pipeline() as pipeline:
+ transformed_pcoll = (
+ pipeline | "CreateData" >> beam.Create([{
+ video_feature_column: VertexVideo(
+ video_content=Video(gcs_uri=self.video_path),
+ config=self.video_segment_config)
+ }])
+ | "MLTransform" >> MLTransform(
+ write_artifact_location=self.artifact_location).with_transform(
+ embedding_config))
+
+ def assert_element(element):
+ # Videos are returned in VideoEmbedding objects, must unroll
+ # for each segment.
+ for segment in element[video_feature_column].embeddings:
+ assert len(segment.embedding) == 1408
+
+ _ = (transformed_pcoll | beam.Map(assert_element))
+
+ def test_improper_dimension(self):
+ with self.assertRaises(ValueError):
+ _ = VertexAIMultiModalEmbeddings(
+ model_name=self.model_name,
+ image_column="fake_img_column",
+ dimension=127)
+
+ def test_missing_columns(self):
+ with self.assertRaises(ValueError):
+ _ = VertexAIMultiModalEmbeddings(
+ model_name=self.model_name, dimension=128)
+
+ def test_improper_video_dimension(self):
+ with self.assertRaises(ValueError):
+ _ = VertexAIMultiModalEmbeddings(
+ model_name=self.model_name,
+ video_column=video_feature_column,
+ dimension=128)
+
+
if __name__ == '__main__':
unittest.main()