This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit c9e03cbd750c4210cb60e4683a7d27f1de86aa4a Author: youjin <[email protected]> AuthorDate: Tue Jan 13 07:26:30 2026 +0800 [Feature][runtime] Support the use of Java EembeddingModel in Python --- dist/pom.xml | 5 + .../api/embedding_models/embedding_model.py | 2 +- .../api/embedding_models/java_embedding_model.py | 47 ++++++++ python/flink_agents/api/events/chat_event.py | 2 +- .../embedding_model_cross_language_agent.py | 123 +++++++++++++++++++++ .../embedding_model_cross_language_test.py | 103 +++++++++++++++++ python/flink_agents/plan/resource_provider.py | 2 + .../runtime/java/java_embedding_model.py | 107 ++++++++++++++++++ 8 files changed, 389 insertions(+), 2 deletions(-) diff --git a/dist/pom.xml b/dist/pom.xml index 225c368..b8803d0 100644 --- a/dist/pom.xml +++ b/dist/pom.xml @@ -64,6 +64,11 @@ under the License. <artifactId>flink-agents-integrations-chat-models-ollama</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-agents-integrations-embedding-models-ollama</artifactId> + <version>${project.version}</version> + </dependency> </dependencies> <build> diff --git a/python/flink_agents/api/embedding_models/embedding_model.py b/python/flink_agents/api/embedding_models/embedding_model.py index 2ec369a..be06913 100644 --- a/python/flink_agents/api/embedding_models/embedding_model.py +++ b/python/flink_agents/api/embedding_models/embedding_model.py @@ -85,7 +85,7 @@ class BaseEmbeddingModelSetup(Resource, ABC): def model_kwargs(self) -> Dict[str, Any]: """Return embedding model settings.""" - def embed(self, text: str, **kwargs: Any) -> list[float]: + def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] | list[list[float]]: """Generate embedding vector for a single text query. Converts the input text into a high-dimensional vector representation diff --git a/python/flink_agents/api/embedding_models/java_embedding_model.py b/python/flink_agents/api/embedding_models/java_embedding_model.py new file mode 100644 index 0000000..18d2f9a --- /dev/null +++ b/python/flink_agents/api/embedding_models/java_embedding_model.py @@ -0,0 +1,47 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from flink_agents.api.decorators import java_resource +from flink_agents.api.embedding_models.embedding_model import ( + BaseEmbeddingModelConnection, + BaseEmbeddingModelSetup, +) + + +@java_resource +class JavaEmbeddingModelConnection(BaseEmbeddingModelConnection): + """Java-based implementation of EmbeddingModelConnection that wraps a Java embedding + model object. + + This class serves as a bridge between Python and Java embedding model environments, + but unlike JavaEmbeddingModelSetup, it does not provide direct embedding + functionality in Python. + """ + + java_class_name: str="" + +@java_resource +class JavaEmbeddingModelSetup(BaseEmbeddingModelSetup): + """Java-based implementation of EmbeddingModelSetup that bridges Python and Java + embedding model functionality. + + This class wraps a Java embedding model setup object and provides Python interface + compatibility while delegating actual embedding operations to the underlying Java + implementation. + """ + + java_class_name: str="" diff --git a/python/flink_agents/api/events/chat_event.py b/python/flink_agents/api/events/chat_event.py index 2fb4266..5e1237f 100644 --- a/python/flink_agents/api/events/chat_event.py +++ b/python/flink_agents/api/events/chat_event.py @@ -18,7 +18,7 @@ from typing import List from uuid import UUID -from flink_agents.api.agents.react_agent import OutputSchema +from flink_agents.api.agents.types import OutputSchema from flink_agents.api.chat_message import ChatMessage from flink_agents.api.events.event import Event diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_agent.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_agent.py new file mode 100644 index 0000000..ad5c0eb --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_agent.py @@ -0,0 +1,123 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import os + +from flink_agents.api.agents.agent import Agent +from flink_agents.api.decorators import ( + action, + embedding_model_connection, + embedding_model_setup, +) +from flink_agents.api.embedding_models.java_embedding_model import ( + JavaEmbeddingModelConnection, + JavaEmbeddingModelSetup, +) +from flink_agents.api.events.event import InputEvent, OutputEvent +from flink_agents.api.resource import ResourceDescriptor, ResourceType +from flink_agents.api.runner_context import RunnerContext + + +class EmbeddingModelCrossLanguageAgent(Agent): + """Example agent demonstrating cross-language embedding model testing.""" + + @embedding_model_connection + @staticmethod + def embedding_model_connection() -> ResourceDescriptor: + """EmbeddingModelConnection responsible for ollama model service connection.""" + return ResourceDescriptor( + clazz=JavaEmbeddingModelConnection, + java_clazz="org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelConnection", + host="http://localhost:11434", + ) + + @embedding_model_setup + @staticmethod + def embedding_model() -> ResourceDescriptor: + """EmbeddingModel which focus on math, and reuse ChatModelConnection.""" + return ResourceDescriptor( + clazz=JavaEmbeddingModelSetup, + java_clazz="org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelSetup", + connection="embedding_model_connection", + model=os.environ.get("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text:latest"), + ) + + @action(InputEvent) + @staticmethod + def process_input(event: InputEvent, ctx: RunnerContext) -> None: + """User defined action for processing input. + + In this action, we will test embedding model functionality. + """ + input_text = event.input + + short_doc = f"{input_text[:5]}..." + + print(f"[TEST] Starting embedding generation test for: '{input_text[:10]}...'") + + try: + # Get embedding model + embeddingModel = ctx.get_resource("embedding_model", ResourceType.EMBEDDING_MODEL) + + # Test single text embedding + embedding = embeddingModel.embed(input_text) + print(f"[TEST] Generated embedding with dimension: {len(embedding)}") + + # Validate single embedding result + if embedding is None or not isinstance(embedding, list) or len(embedding) == 0: + err_msg = "Embedding cannot be null or empty" + raise AssertionError(err_msg) # noqa: TRY301 + + if not all(isinstance(x, float) for x in embedding): + err_msg = "All embedding values must be floats" + raise AssertionError(err_msg) # noqa: TRY301 + + print(f"[TEST] Validated single embedding: Text={short_doc}, Dimension={len(embedding)}, Text='{input_text[:30]}...'") + + # Test batch embedding + embeddings = embeddingModel.embed([input_text]) + print(f"[TEST] Generated batch embeddings: count={len(embeddings)}") + + # Validate batch embedding results + if embeddings is None or not isinstance(embeddings, list) or len(embeddings) == 0: + err_msg = "Batch embeddings cannot be null or empty" + raise AssertionError(err_msg) # noqa: TRY301 + + if len(embeddings) != 1: + err_msg = f"Expected 1 embedding but got {len(embeddings)}" + raise AssertionError(err_msg) # noqa: TRY301 + + for i, emb in enumerate(embeddings): + if not isinstance(emb, list) or len(emb) == 0: + err_msg = f"Embedding at index {i} is invalid" + raise AssertionError(err_msg) # noqa: TRY301 + print(f"[TEST] Validated batch embedding {i}: Dimension={len(emb)}") + + # Create test result as a single string + test_result = f"[PASS] Text={short_doc}, Dimension={len(embedding)}, BatchCount={len(embeddings)}" + + ctx.send_event(OutputEvent(output=test_result)) + + print(f"[TEST] Embedding generation test PASSED for: '{input_text[:50]}...'") + + except Exception as e: + # Create error result as a single string + test_result = f"[FAIL] Text={short_doc}, Error={e!s}" + + ctx.send_event(OutputEvent(output=test_result)) + + print(f"[TEST] Embedding generation test FAILED: {e!s}") diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_test.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_test.py new file mode 100644 index 0000000..6ba762a --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/embedding_model_cross_language_test.py @@ -0,0 +1,103 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import os +import sysconfig +from pathlib import Path + +import pytest +from pyflink.common import Encoder, WatermarkStrategy +from pyflink.common.typeinfo import Types +from pyflink.datastream import ( + RuntimeExecutionMode, + StreamExecutionEnvironment, +) +from pyflink.datastream.connectors.file_system import ( + FileSource, + StreamFormat, + StreamingFileSink, +) + +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.e2e_tests.e2e_tests_resource_cross_language.embedding_model_cross_language_agent import ( + EmbeddingModelCrossLanguageAgent, +) +from flink_agents.e2e_tests.test_utils import pull_model + +current_dir = Path(__file__).parent + +OLLAMA_MODEL = os.environ.get("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text:latest") +os.environ["OLLAMA_EMBEDDING_MODEL"] = OLLAMA_MODEL + +client = pull_model(OLLAMA_MODEL) + +os.environ["PYTHONPATH"] = sysconfig.get_paths()["purelib"] + [email protected](client is None, reason="Ollama client is not available or test model is missing.") +def test_java_embedding_model_integration(tmp_path: Path) -> None: # noqa: D103 + env = StreamExecutionEnvironment.get_execution_environment() + env.set_runtime_mode(RuntimeExecutionMode.STREAMING) + env.set_parallelism(1) + + # currently, bounded source is not supported due to runtime implementation, so + # we use continuous file source here. + input_datastream = env.from_source( + source=FileSource.for_record_stream_format( + StreamFormat.text_line_format(), f"file:///{current_dir}/../resources/java_chat_module_input" + ).build(), + watermark_strategy=WatermarkStrategy.no_watermarks(), + source_name="streaming_agent_example", + ) + + deserialize_datastream = input_datastream.map( + lambda x: str(x) + ) + + agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) + output_datastream = ( + agents_env.from_datastream( + input=deserialize_datastream, key_selector= lambda x: "orderKey" + ) + .apply(EmbeddingModelCrossLanguageAgent()) + .to_datastream() + ) + + result_dir = tmp_path / "results" + result_dir.mkdir(parents=True, exist_ok=True) + + (output_datastream.map(lambda x: str(x).replace('\n', '') + .replace('\r', ''), Types.STRING()).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + )) + + agents_env.execute() + + actual_result = [] + for file in result_dir.iterdir(): + if file.is_dir(): + for child in file.iterdir(): + with child.open() as f: + actual_result.extend(f.readlines()) + if file.is_file(): + with file.open() as f: + actual_result.extend(f.readlines()) + + assert "PASS" in actual_result[0] + assert "PASS" in actual_result[1] diff --git a/python/flink_agents/plan/resource_provider.py b/python/flink_agents/plan/resource_provider.py index 1263cc2..c7fc13c 100644 --- a/python/flink_agents/plan/resource_provider.py +++ b/python/flink_agents/plan/resource_provider.py @@ -157,6 +157,8 @@ class PythonSerializableResourceProvider(SerializableResourceProvider): JAVA_RESOURCE_MAPPING: dict[ResourceType, str] = { ResourceType.CHAT_MODEL: "flink_agents.runtime.java.java_chat_model.JavaChatModelSetupImpl", ResourceType.CHAT_MODEL_CONNECTION: "flink_agents.runtime.java.java_chat_model.JavaChatModelConnectionImpl", + ResourceType.EMBEDDING_MODEL: "flink_agents.runtime.java.java_embedding_model.JavaEmbeddingModelSetupImpl", + ResourceType.EMBEDDING_MODEL_CONNECTION: "flink_agents.runtime.java.java_embedding_model.JavaEmbeddingModelConnectionImpl", } class JavaResourceProvider(ResourceProvider): diff --git a/python/flink_agents/runtime/java/java_embedding_model.py b/python/flink_agents/runtime/java/java_embedding_model.py new file mode 100644 index 0000000..7cfe93d --- /dev/null +++ b/python/flink_agents/runtime/java/java_embedding_model.py @@ -0,0 +1,107 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from typing import Any, Dict, Sequence + +from flink_agents.api.embedding_models.java_embedding_model import ( + JavaEmbeddingModelConnection, + JavaEmbeddingModelSetup, +) + + +class JavaEmbeddingModelConnectionImpl(JavaEmbeddingModelConnection): + """Java-based implementation of EmbeddingModelConnection that wraps a Java embedding + model object. + This class serves as a bridge between Python and Java embedding model environments, + but unlike JavaEmbeddingModelSetup, it does not provide direct embedding + functionality in Python. + """ + + _j_resource: Any + _j_resource_adapter: Any + + def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> None: + """Creates a new JavaEmbeddingModelConnection. + + Args: + j_resource: The Java resource object + j_resource_adapter: The Java resource adapter for method invocation + **kwargs: Additional keyword arguments + """ + super().__init__(**kwargs) + self._j_resource=j_resource + self._j_resource_adapter=j_resource_adapter + + def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] | list[list[float]]: + """Generate embedding vector for a single text input. + Converts the input text into a high-dimensional vector representation + suitable for semantic similarity search and retrieval operations. + + Args: + text: The text string to convert into an embedding vector. + **kwargs: Additional parameters passed to the embedding model. + """ + result = self._j_resource.embed( + text if isinstance(text, str) else list(text), kwargs + ) + return list(result) if isinstance(text, str) else [list(emb) for emb in result] + + +class JavaEmbeddingModelSetupImpl(JavaEmbeddingModelSetup): + """Java-based implementation of EmbeddingModelSetup that wraps a Java embedding + model object. + This class serves as a bridge between Python and Java embedding model environments, + but unlike JavaEmbeddingModelConnection, it does not provide direct embedding + functionality in Python. + """ + _j_resource: Any + _j_resource_adapter: Any + + def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> None: + """Creates a new JavaEmbeddingModelSetup. + + Args: + j_resource: The Java resource object + j_resource_adapter: The Java resource adapter for method invocation + **kwargs: Additional keyword arguments + """ + super().__init__(**kwargs) + self._j_resource=j_resource + self._j_resource_adapter=j_resource_adapter + + @property + def model_kwargs(self) -> Dict[str, Any]: + """Return embedding model settings. + + Returns: + Empty dictionary as parameters are managed by Java side + """ + return {} + + def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] | list[list[float]]: + """Generate embedding vector for a single text query. + Converts the input text into a high-dimensional vector representation + suitable for semantic similarity search and retrieval operations. + + Args: + text: The text string to convert into an embedding vector. + **kwargs: Additional parameters passed to the embedding model. + """ + result = self._j_resource.embed( + text if isinstance(text, str) else list(text), kwargs + ) + return list(result) if isinstance(text, str) else [list(emb) for emb in result]
