This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit c63db6c2ce0669e4d04164c8dd5a67f74e7babdc Author: youjin <[email protected]> AuthorDate: Thu Jan 15 11:07:19 2026 +0800 [Feature][runtime] Support the use of Java VectorStore in Python --- dist/pom.xml | 5 + .../api/vector_stores/java_vector_store.py | 35 ++++ .../vector_store_cross_language_agent.py | 207 +++++++++++++++++++++ .../vector_store_cross_language_test.py | 108 +++++++++++ python/flink_agents/plan/resource_provider.py | 1 + .../flink_agents/runtime/java/java_vector_store.py | 148 +++++++++++++++ python/flink_agents/runtime/python_java_utils.py | 19 ++ .../runtime/python/utils/JavaResourceAdapter.java | 29 +++ 8 files changed, 552 insertions(+) diff --git a/dist/pom.xml b/dist/pom.xml index b8803d0b..f730e1d6 100644 --- a/dist/pom.xml +++ b/dist/pom.xml @@ -69,6 +69,11 @@ under the License. <artifactId>flink-agents-integrations-embedding-models-ollama</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-agents-integrations-vector-stores-elasticsearch</artifactId> + <version>${project.version}</version> + </dependency> </dependencies> <build> diff --git a/python/flink_agents/api/vector_stores/java_vector_store.py b/python/flink_agents/api/vector_stores/java_vector_store.py new file mode 100644 index 00000000..f3f7e8d0 --- /dev/null +++ b/python/flink_agents/api/vector_stores/java_vector_store.py @@ -0,0 +1,35 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from flink_agents.api.decorators import java_resource +from flink_agents.api.vector_stores.vector_store import ( + BaseVectorStore, + CollectionManageableVectorStore, +) + + +@java_resource +class JavaVectorStore(BaseVectorStore): + """Java-based implementation of VectorStore that wraps a Java vector store.""" + + java_class_name: str="" + +@java_resource +class JavaCollectionManageableVectorStore(JavaVectorStore, CollectionManageableVectorStore): + """Java-based implementation of VectorStore with collection management capabilities + that bridges Python and Java vector store functionality. + """ diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_agent.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_agent.py new file mode 100644 index 00000000..831945b7 --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_agent.py @@ -0,0 +1,207 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import os +import time + +import pytest + +from flink_agents.api.agents.agent import Agent +from flink_agents.api.decorators import ( + action, + embedding_model_connection, + embedding_model_setup, + vector_store, +) +from flink_agents.api.embedding_models.java_embedding_model import ( + JavaEmbeddingModelConnection, + JavaEmbeddingModelSetup, +) +from flink_agents.api.events.context_retrieval_event import ( + ContextRetrievalRequestEvent, + ContextRetrievalResponseEvent, +) +from flink_agents.api.events.event import InputEvent, OutputEvent +from flink_agents.api.resource import ResourceDescriptor, ResourceType +from flink_agents.api.runner_context import RunnerContext +from flink_agents.api.vector_stores.java_vector_store import ( + JavaCollectionManageableVectorStore, +) +from flink_agents.api.vector_stores.vector_store import ( + CollectionManageableVectorStore, + Document, +) +from flink_agents.integrations.embedding_models.local.ollama_embedding_model import ( + OllamaEmbeddingModelConnection, + OllamaEmbeddingModelSetup, +) + +TEST_COLLECTION = "test_collection" +MAX_RETRIES_TIMES = 10 + +class VectorStoreCrossLanguageAgent(Agent): + """Example agent demonstrating cross-language embedding model testing.""" + + @embedding_model_connection + @staticmethod + def embedding_model_connection() -> ResourceDescriptor: + """EmbeddingModelConnection responsible for ollama model service connection.""" + if os.environ.get("EMBEDDING_TYPE") == "JAVA": + return ResourceDescriptor( + clazz=JavaEmbeddingModelConnection, + java_clazz="org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelConnection", + host="http://localhost:11434", + ) + return ResourceDescriptor( + clazz=OllamaEmbeddingModelConnection, + host="http://localhost:11434", + ) + + @embedding_model_setup + @staticmethod + def embedding_model() -> ResourceDescriptor: + """EmbeddingModel which focus on math, and reuse ChatModelConnection.""" + if os.environ.get("EMBEDDING_TYPE") == "JAVA": + return ResourceDescriptor( + clazz=JavaEmbeddingModelSetup, + java_clazz="org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelSetup", + connection="embedding_model_connection", + model=os.environ.get( + "OLLAMA_EMBEDDING_MODEL", "nomic-embed-text:latest" + ), + ) + return ResourceDescriptor( + clazz=OllamaEmbeddingModelSetup, + connection="embedding_model_connection", + model=os.environ.get("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text:latest"), + ) + + @vector_store + @staticmethod + def vector_store() -> ResourceDescriptor: + """Vector store setup for knowledge base.""" + return ResourceDescriptor( + clazz=JavaCollectionManageableVectorStore, + java_clazz="org.apache.flink.agents.integrations.vectorstores.elasticsearch.ElasticsearchVectorStore", + embedding_model="embedding_model", + host=os.environ.get("ES_HOST"), + index="my_documents", + dims=768, + ) + + @action(InputEvent) + @staticmethod + def process_input(event: InputEvent, ctx: RunnerContext) -> None: + """User defined action for processing input. + + In this action, we will test Vector store Collection Management and + Document Management. + """ + input_text = event.input + + stm = ctx.short_term_memory + + is_initialized = stm.get("is_initialized") or False + + if not is_initialized: + print("[TEST] Initializing vector store...") + + vector_store = ctx.get_resource("vector_store", ResourceType.VECTOR_STORE) + if isinstance(vector_store, CollectionManageableVectorStore): + vector_store.get_or_create_collection(TEST_COLLECTION , metadata={"key1": "value1", "key2": "value2"}) + + collection = vector_store.get_collection(name=TEST_COLLECTION) + + assert collection is not None + assert collection.name == TEST_COLLECTION + assert collection.metadata == {"key1": "value1", "key2": "value2"} + + vector_store.delete_collection(name=TEST_COLLECTION) + + with pytest.raises(RuntimeError): + vector_store.get_collection(name=TEST_COLLECTION) + + print("[TEST] Vector store Collection Management PASSED") + + documents = [ + Document( + id="doc1", + content="The sum of 1 and 2 equals 3.", + metadata={"category": "calculate", "source": "test"}, + ), + Document( + id="doc2", + content="Why did the cat sit on the computer? Because it wanted to keep an eye on the mouse.", + metadata={"category": "ai-agent", "source": "test"}, + ), + Document( + id="doc3", + content="This is a test document used to verify the delete functionality.", + metadata={"category": "utility", "source": "test"}, + ), + ] + vector_store.add(documents=documents) + + # Test size + assert vector_store.size() == 3 + + # Test delete + vector_store.delete(ids="doc3") + + # Wait for vector store to delete doc3 + retry_time = 0 + while vector_store.size() > 2 and retry_time < MAX_RETRIES_TIMES: + retry_time += 1 + time.sleep(2) + print(f"[TEST] Retrying to delete doc3, retry_time={retry_time}") + + assert vector_store.size() == 2 + + # Test get + doc = vector_store.get(ids="doc2") + assert doc is not None + assert doc[0].id == "doc2" + assert doc[0].content == "Why did the cat sit on the computer? Because it wanted to keep an eye on the mouse." + + print("[TEST] Vector store Document Management PASSED") + + stm.set("is_initialized", True) + + + ctx.send_event(ContextRetrievalRequestEvent(query=input_text, vector_store="vector_store")) + + @action(ContextRetrievalResponseEvent) + @staticmethod + def contextRetrievalResponseEvent(event: ContextRetrievalResponseEvent, ctx: RunnerContext) -> None: + """User defined action for processing context retrieval response. + + In this action, we will test Vector store Context Retrieval. + """ + documents = event.documents + + assert documents is not None + assert len(documents) > 0 + + for document in documents: + assert document is not None + assert document.id is not None + assert document.content is not None + + test_result = f"[PASS] retrieved_count={len(documents)}, first_doc_id={documents[0].id}, first_doc_preview={documents[0].content[:50]}" + print(f"[TEST] Vector store Context Retrieval PASSED, first_doc_id={documents[0].id}, first_doc_preview={documents[0].content[:50]}") + + ctx.send_event(OutputEvent(output=test_result)) diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_test.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_test.py new file mode 100644 index 00000000..f47f4f5d --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/vector_store_cross_language_test.py @@ -0,0 +1,108 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import os +import sysconfig +from pathlib import Path + +import pytest +from pyflink.common import Encoder, WatermarkStrategy +from pyflink.common.typeinfo import Types +from pyflink.datastream import ( + RuntimeExecutionMode, + StreamExecutionEnvironment, +) +from pyflink.datastream.connectors.file_system import ( + FileSource, + StreamFormat, + StreamingFileSink, +) + +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.e2e_tests.e2e_tests_resource_cross_language.vector_store_cross_language_agent import ( + VectorStoreCrossLanguageAgent, +) +from flink_agents.e2e_tests.test_utils import pull_model + +current_dir = Path(__file__).parent + +OLLAMA_MODEL = os.environ.get("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text:latest") +os.environ["OLLAMA_EMBEDDING_MODEL"] = OLLAMA_MODEL + +ES_HOST = os.environ.get("ES_HOST") + +client = pull_model(OLLAMA_MODEL) + +os.environ["PYTHONPATH"] = sysconfig.get_paths()["purelib"] + [email protected](client is None or ES_HOST is None, reason="Ollama client or Elasticsearch host is missing.") [email protected]("embedding_type", ["JAVA", "PYTHON"]) +def test_java_embedding_model_integration(tmp_path: Path, embedding_type: str) -> None: # noqa: D103 + os.environ["EMBEDDING_TYPE"] = embedding_type + + env = StreamExecutionEnvironment.get_execution_environment() + env.set_runtime_mode(RuntimeExecutionMode.STREAMING) + env.set_parallelism(1) + + # currently, bounded source is not supported due to runtime implementation, so + # we use continuous file source here. + input_datastream = env.from_source( + source=FileSource.for_record_stream_format( + StreamFormat.text_line_format(), f"file:///{current_dir}/../resources/java_chat_module_input" + ).build(), + watermark_strategy=WatermarkStrategy.no_watermarks(), + source_name="streaming_agent_example", + ) + + deserialize_datastream = input_datastream.map( + lambda x: str(x) + ) + + agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) + output_datastream = ( + agents_env.from_datastream( + input=deserialize_datastream, key_selector= lambda x: "orderKey" + ) + .apply(VectorStoreCrossLanguageAgent()) + .to_datastream() + ) + + result_dir = tmp_path / "results" + result_dir.mkdir(parents=True, exist_ok=True) + + (output_datastream.map(lambda x: str(x).replace('\n', '') + .replace('\r', ''), Types.STRING()).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + )) + + agents_env.execute() + + actual_result = [] + for file in result_dir.iterdir(): + if file.is_dir(): + for child in file.iterdir(): + with child.open() as f: + actual_result.extend(f.readlines()) + if file.is_file(): + with file.open() as f: + actual_result.extend(f.readlines()) + + assert "PASS" in actual_result[0] + assert "PASS" in actual_result[1] diff --git a/python/flink_agents/plan/resource_provider.py b/python/flink_agents/plan/resource_provider.py index c7fc13ce..0bd45de1 100644 --- a/python/flink_agents/plan/resource_provider.py +++ b/python/flink_agents/plan/resource_provider.py @@ -159,6 +159,7 @@ JAVA_RESOURCE_MAPPING: dict[ResourceType, str] = { ResourceType.CHAT_MODEL_CONNECTION: "flink_agents.runtime.java.java_chat_model.JavaChatModelConnectionImpl", ResourceType.EMBEDDING_MODEL: "flink_agents.runtime.java.java_embedding_model.JavaEmbeddingModelSetupImpl", ResourceType.EMBEDDING_MODEL_CONNECTION: "flink_agents.runtime.java.java_embedding_model.JavaEmbeddingModelConnectionImpl", + ResourceType.VECTOR_STORE: "flink_agents.runtime.java.java_vector_store.JavaVectorStoreImpl", } class JavaResourceProvider(ResourceProvider): diff --git a/python/flink_agents/runtime/java/java_vector_store.py b/python/flink_agents/runtime/java/java_vector_store.py new file mode 100644 index 00000000..77f3167f --- /dev/null +++ b/python/flink_agents/runtime/java/java_vector_store.py @@ -0,0 +1,148 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# + +from typing import Any, Dict, List + +from typing_extensions import override + +from flink_agents.api.vector_stores.java_vector_store import ( + JavaCollectionManageableVectorStore, +) +from flink_agents.api.vector_stores.vector_store import ( + Collection, + Document, + VectorStoreQuery, + VectorStoreQueryResult, + _maybe_cast_to_list, +) +from flink_agents.runtime.python_java_utils import ( + from_java_collection, + from_java_document, + from_java_vector_store_query_result, +) + + +class JavaVectorStoreImpl(JavaCollectionManageableVectorStore): + """Java-based implementation of EmbeddingModelSetup that wraps a Java embedding + model object. + This class serves as a bridge between Python and Java embedding model environments, + but unlike JavaEmbeddingModelConnection, it does not provide direct embedding + functionality in Python. + """ + _j_resource: Any + _j_resource_adapter: Any + + def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs: Any) -> None: + """Creates a new JavaEmbeddingModelSetup. + + Args: + j_resource: The Java resource object + j_resource_adapter: The Java resource adapter for method invocation + **kwargs: Additional keyword arguments + """ + # embedding_model are required parameters for BaseVectorStore + embedding_model = kwargs.pop("embedding_model", "") + super().__init__(embedding_model = embedding_model, **kwargs) + + self._j_resource=j_resource + self._j_resource_adapter=j_resource_adapter + + @override + @property + def store_kwargs(self) -> Dict[str, Any]: + return {} + + @override + def add( + self, + documents: Document | List[Document], + collection_name: str | None = None, + **kwargs: Any, + ) -> List[str]: + + documents = _maybe_cast_to_list(documents) + j_documents = [ + self._j_resource_adapter.fromPythonDocument(document) + for document in documents + ] + + return self._j_resource.add(j_documents, collection_name, kwargs) + + @override + def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult: + j_query = self._j_resource_adapter.fromPythonVectorStoreQuery(query) + j_query_result = self._j_resource.query(j_query) + return from_java_vector_store_query_result(j_query_result) + + @override + def size(self, collection_name: str | None = None) -> int: + return self._j_resource.size(collection_name) + + @override + def get( + self, + ids: str | List[str] | None = None, + collection_name: str | None = None, + **kwargs: Any, + ) -> List[Document]: + ids = _maybe_cast_to_list(ids) + j_documents = self._j_resource.get(ids, collection_name, kwargs) + return [from_java_document(j_document) for j_document in j_documents] + + @override + def delete( + self, + ids: str | List[str] | None = None, + collection_name: str | None = None, + **kwargs: Any, + ) -> List[str]: + ids = _maybe_cast_to_list(ids) + return self._j_resource.delete(ids, collection_name, kwargs) + + @override + def get_or_create_collection( + self, name: str, metadata: Dict[str, Any] | None = None + ) -> Collection: + j_collection = self._j_resource.getOrCreateCollection(name, metadata) + return from_java_collection(j_collection) + + @override + def get_collection(self, name: str) -> Collection: + j_collection = self._j_resource.getCollection(name) + return from_java_collection(j_collection) + + @override + def delete_collection(self, name: str) -> Collection: + j_collection = self._j_resource.deleteCollection(name) + return from_java_collection(j_collection) + + @override + def _add_embedding( + self, + *, + documents: List[Document], + collection_name: str | None = None, + **kwargs: Any, + ) -> List[str]: + """Private functions should never be called for the Resource Wrapper.""" + + @override + def _query_embedding( + self, embedding: list[float], limit: int = 10, **kwargs: Any + ) -> list[Document]: + """Private functions should never be called for the Resource Wrapper.""" diff --git a/python/flink_agents/runtime/python_java_utils.py b/python/flink_agents/runtime/python_java_utils.py index de223ddf..509d0c51 100644 --- a/python/flink_agents/runtime/python_java_utils.py +++ b/python/flink_agents/runtime/python_java_utils.py @@ -27,9 +27,11 @@ from flink_agents.api.resource import Resource, ResourceType, get_resource_class from flink_agents.api.tools.tool import ToolMetadata from flink_agents.api.tools.utils import create_model_from_java_tool_schema_str from flink_agents.api.vector_stores.vector_store import ( + Collection, Document, VectorStoreQuery, VectorStoreQueryMode, + VectorStoreQueryResult, ) from flink_agents.plan.resource_provider import JAVA_RESOURCE_MAPPING from flink_agents.runtime.java.java_resource_wrapper import ( @@ -218,6 +220,23 @@ def from_java_vector_store_query(j_query: Any) -> VectorStoreQuery: extra_args=j_query.getExtraArgs() ) +def from_java_vector_store_query_result(j_query: Any) -> VectorStoreQueryResult: + """Convert a Java vector store query result to a Python query result.""" + return VectorStoreQueryResult( + documents=[from_java_document(j_document) for j_document in j_query.getDocuments()], + ) + +def from_java_collection(j_collection: Any) -> Collection: + """Convert a Java collection to a Python collection.""" + return Collection( + name=j_collection.getName(), + metadata=j_collection.getMetadata(), + ) + +def get_mode_value(query: VectorStoreQuery) -> str: + """Get the mode value of a VectorStoreQuery.""" + return query.mode.value + def call_method(obj: Any, method_name: str, kwargs: Dict[str, Any]) -> Any: """Calls a method on `obj` by name and passes in positional and keyword arguments. diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java index 7b6c009c..b2886266 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java @@ -21,8 +21,13 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; +import org.apache.flink.agents.api.vectorstores.VectorStoreQueryMode; import pemja.core.PythonInterpreter; +import pemja.core.object.PyObject; +import java.util.Map; import java.util.function.BiFunction; /** Adapter for managing Java resources and facilitating Python-Java interoperability. */ @@ -72,4 +77,28 @@ public class JavaResourceAdapter { chatMessage.setRole(MessageRole.fromValue(roleValue)); return chatMessage; } + + @SuppressWarnings("unchecked") + public Document fromPythonDocument(PyObject pythonDocument) { + // TODO: Delete this method after the pemja findClass method is fixed. + return new Document( + pythonDocument.getAttr("content").toString(), + (Map<String, Object>) pythonDocument.getAttr("metadata", Map.class), + pythonDocument.getAttr("id").toString()); + } + + @SuppressWarnings("unchecked") + public VectorStoreQuery fromPythonVectorStoreQuery(PyObject pythonVectorStoreQuery) { + // TODO: Delete this method after the pemja findClass method is fixed. + String modeValue = + (String) + interpreter.invoke( + "python_java_utils.get_mode_value", pythonVectorStoreQuery); + return new VectorStoreQuery( + VectorStoreQueryMode.fromValue(modeValue), + (String) pythonVectorStoreQuery.getAttr("query_text"), + pythonVectorStoreQuery.getAttr("limit", Integer.class), + (String) pythonVectorStoreQuery.getAttr("collection_name"), + (Map<String, Object>) pythonVectorStoreQuery.getAttr("extra_args", Map.class)); + } }
