This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 05ed99104eae9e09a48e733ff21d3944630d3805 Author: WenjinXie <[email protected]> AuthorDate: Tue Mar 31 16:51:10 2026 +0800 [runtime] Support async execution for cross language resources. --- .../agents/api/chat/model/BaseChatModelSetup.java | 66 ++++++++++----- .../chat/model/python/PythonChatModelSetup.java | 6 ++ .../embedding/model/BaseEmbeddingModelSetup.java | 43 ++++++---- .../model/python/PythonEmbeddingModelSetup.java | 6 ++ .../apache/flink/agents/api/resource/Resource.java | 3 + .../agents/api/vectorstores/BaseVectorStore.java | 41 ++++++--- .../api/vectorstores/python/PythonVectorStore.java | 6 ++ .../bedrock/BedrockEmbeddingModelSetup.java | 5 -- integrations/embedding-models/ollama/pom.xml | 6 ++ .../ollama/OllamaEmbeddingModelSetup.java | 5 -- .../ollama/OllamaEmbeddingModelConnectionTest.java | 6 +- .../opensearch/OpenSearchVectorStore.java | 6 +- .../s3vectors/S3VectorsVectorStore.java | 6 +- .../flink/agents/plan/actions/ChatModelAction.java | 11 ++- .../plan/actions/ContextRetrievalAction.java | 8 +- .../apache/flink/agents/plan/actions/Utils.java | 97 ++++++++++++++++++++++ .../agents/plan/AgentPlanDeclareChatModelTest.java | 10 ++- .../plan/actions/ChatModelActionRetryTest.java | 4 +- python/flink_agents/api/chat_models/chat_model.py | 82 ++++++++++++------ .../api/embedding_models/embedding_model.py | 32 +++++-- python/flink_agents/api/resource.py | 37 +++++---- .../flink_agents/api/vector_stores/vector_store.py | 33 ++++---- .../built_in_action_async_execution_test.py | 3 + .../e2e_tests_integration/long_term_memory_test.py | 6 ++ .../chat_models/tests/test_ollama_chat_model.py | 5 ++ .../chat_models/tests/test_tongyi_chat_model.py | 4 + .../chroma/tests/test_chroma_vector_store.py | 18 +++- .../flink_agents/plan/actions/chat_model_action.py | 23 +++-- .../plan/actions/context_retrieval_action.py | 9 +- python/flink_agents/plan/actions/utils.py | 58 +++++++++++++ python/flink_agents/plan/resource_provider.py | 1 - python/flink_agents/plan/tests/test_agent_plan.py | 3 + .../flink_agents/runtime/flink_runner_context.py | 1 + .../flink_agents/runtime/java/java_chat_model.py | 5 ++ .../runtime/java/java_embedding_model.py | 7 ++ .../flink_agents/runtime/java/java_vector_store.py | 4 + .../tests/test_vector_store_long_term_memory.py | 24 +++--- python/flink_agents/runtime/resource_cache.py | 1 + .../runtime/tests/test_built_in_actions.py | 3 + .../runtime/tests/test_get_resource_in_action.py | 3 + .../apache/flink/agents/runtime/ResourceCache.java | 2 +- .../agents/runtime/context/RunnerContextImpl.java | 1 + .../python/context/PythonRunnerContextImpl.java | 5 ++ 43 files changed, 534 insertions(+), 171 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java index 15aa953e..bcfb0e50 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java @@ -26,6 +26,9 @@ import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.Tool; import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; import java.util.ArrayList; import java.util.Collections; @@ -35,18 +38,44 @@ import java.util.Map; import java.util.function.BiFunction; public abstract class BaseChatModelSetup extends Resource { - protected final String connection; + protected final String connectionName; protected String model; protected Object prompt; - protected List<String> tools; + protected List<String> toolNames; + + @Nullable protected BaseChatModelConnection connection; + protected final List<Tool> tools = new ArrayList<>(); public BaseChatModelSetup( ResourceDescriptor descriptor, BiFunction<String, ResourceType, Resource> getResource) { super(descriptor, getResource); - this.connection = descriptor.getArgument("connection"); + this.connectionName = descriptor.getArgument("connection"); this.model = descriptor.getArgument("model"); this.prompt = descriptor.getArgument("prompt"); - this.tools = descriptor.getArgument("tools"); + this.toolNames = descriptor.getArgument("tools"); + } + + /** + * Trigger construction for resource objects. + * + * <p>Currently, in cross-language invocation scenarios, constructing resource object within an + * async thread may encounter issues. We resolved this issue by moving the construction of the + * resources object out of the method to be async executed and invoking it in the main thread. + */ + @Override + public void open() throws Exception { + this.connection = + (BaseChatModelConnection) + this.getResource.apply( + this.connectionName, ResourceType.CHAT_MODEL_CONNECTION); + if (this.prompt != null && this.prompt instanceof String) { + this.prompt = this.getResource.apply((String) this.prompt, ResourceType.PROMPT); + } + if (this.toolNames != null) { + for (String name : this.toolNames) { + this.tools.add((Tool) this.getResource.apply(name, ResourceType.TOOL)); + } + } } public abstract Map<String, Object> getParameters(); @@ -56,18 +85,17 @@ public abstract class BaseChatModelSetup extends Resource { } public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> parameters) { - BaseChatModelConnection connection = - (BaseChatModelConnection) - this.getResource.apply(this.connection, ResourceType.CHAT_MODEL_CONNECTION); - + Preconditions.checkNotNull( + connection, + "Connection is not initialized. Ensure open() is called before chat()."); // Pass metric group to connection for token usage tracking connection.setMetricGroup(getMetricGroup()); // Format input messages if set prompt. if (this.prompt != null) { - if (this.prompt instanceof String) { - this.prompt = this.getResource.apply((String) this.prompt, ResourceType.PROMPT); - } + Preconditions.checkState( + prompt instanceof Prompt, + "Prompt is not initialized. Ensure open() is called before chat()."); Prompt prompt = (Prompt) this.prompt; Map<String, String> arguments = new HashMap<>(); for (ChatMessage message : messages) { @@ -87,14 +115,6 @@ public abstract class BaseChatModelSetup extends Resource { messages = promptMessages; } - // Get tools can be used. - List<Tool> tools = new ArrayList<>(); - if (this.tools != null) { - for (String name : this.tools) { - tools.add((Tool) this.getResource.apply(name, ResourceType.TOOL)); - } - } - Map<String, Object> params = this.getParameters(); params.putAll(parameters); return connection.chat(messages, tools, params); @@ -106,8 +126,8 @@ public abstract class BaseChatModelSetup extends Resource { } @VisibleForTesting - public String getConnection() { - return connection; + public String getConnectionName() { + return this.connectionName; } @VisibleForTesting @@ -121,7 +141,7 @@ public abstract class BaseChatModelSetup extends Resource { } @VisibleForTesting - public List<String> getTools() { - return tools; + public List<String> getToolNames() { + return toolNames; } } diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java index 7a5cf436..ab985cd6 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java @@ -27,6 +27,7 @@ import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; import pemja.core.object.PyObject; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -57,6 +58,11 @@ public class PythonChatModelSetup extends BaseChatModelSetup implements PythonRe this.adapter = adapter; } + @Override + public void open() { + this.adapter.callMethod(chatModelSetup, "open", Collections.emptyMap()); + } + @Override public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> parameters) { checkState( diff --git a/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java index 205cfed3..a7b4cf95 100644 --- a/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/embedding/model/BaseEmbeddingModelSetup.java @@ -21,6 +21,10 @@ package org.apache.flink.agents.api.embedding.model; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; import java.util.Collections; import java.util.List; @@ -34,16 +38,32 @@ import java.util.function.BiFunction; * management and model configuration. */ public abstract class BaseEmbeddingModelSetup extends Resource { - protected final String connection; + protected final String connectionName; protected String model; + @Nullable protected BaseEmbeddingModelConnection connection; + public BaseEmbeddingModelSetup( ResourceDescriptor descriptor, BiFunction<String, ResourceType, Resource> getResource) { super(descriptor, getResource); - this.connection = descriptor.getArgument("connection"); + this.connectionName = descriptor.getArgument("connection"); this.model = descriptor.getArgument("model"); } + /** + * Trigger construction for resource objects. + * + * <p>Currently, in cross-language invocation scenarios, constructing resource object within an + * async thread may encounter issues. We resolved this issue by moving the construction of the + * resources object out of the method to be async executed and invoking it in the main thread. + */ + @Override + public void open() { + this.connection = + (BaseEmbeddingModelConnection) + getResource.apply(connectionName, ResourceType.EMBEDDING_MODEL_CONNECTION); + } + public abstract Map<String, Object> getParameters(); @Override @@ -56,9 +76,12 @@ public abstract class BaseEmbeddingModelSetup extends Resource { * * @return The embedding model connection instance */ + @VisibleForTesting public BaseEmbeddingModelConnection getConnection() { - return (BaseEmbeddingModelConnection) - getResource.apply(connection, ResourceType.EMBEDDING_MODEL_CONNECTION); + Preconditions.checkNotNull( + connection, + "Connection is not initialized. Ensure open() is called before embed()."); + return connection; } /** @@ -81,13 +104,9 @@ public abstract class BaseEmbeddingModelSetup extends Resource { } public float[] embed(String text, Map<String, Object> parameters) { - BaseEmbeddingModelConnection connection = getConnection(); - Map<String, Object> params = this.getParameters(); params.putAll(parameters); - - // params are propagated to the connection - return connection.embed(text, params); + return getConnection().embed(text, params); } /** @@ -102,12 +121,8 @@ public abstract class BaseEmbeddingModelSetup extends Resource { } public List<float[]> embed(List<String> texts, Map<String, Object> parameters) { - BaseEmbeddingModelConnection connection = getConnection(); - Map<String, Object> params = this.getParameters(); params.putAll(parameters); - - // params are propagated to the connection - return connection.embed(texts, params); + return getConnection().embed(texts, params); } } diff --git a/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java index 33c4410f..d0eb4979 100644 --- a/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/embedding/model/python/PythonEmbeddingModelSetup.java @@ -27,6 +27,7 @@ import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; import pemja.core.object.PyObject; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -64,6 +65,11 @@ public class PythonEmbeddingModelSetup extends BaseEmbeddingModelSetup this.adapter = adapter; } + @Override + public void open() { + this.adapter.callMethod(embeddingModelSetup, "open", Collections.emptyMap()); + } + @Override public float[] embed(String text, Map<String, Object> parameters) { checkState( diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java b/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java index d386ca6d..de2e588b 100644 --- a/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java +++ b/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java @@ -65,6 +65,9 @@ public abstract class Resource { return metricGroup; } + /** Open the resource. */ + public void open() throws Exception {} + /** Close the resource. */ public void close() throws Exception {} } diff --git a/api/src/main/java/org/apache/flink/agents/api/vectorstores/BaseVectorStore.java b/api/src/main/java/org/apache/flink/agents/api/vectorstores/BaseVectorStore.java index 64cc7455..89b93872 100644 --- a/api/src/main/java/org/apache/flink/agents/api/vectorstores/BaseVectorStore.java +++ b/api/src/main/java/org/apache/flink/agents/api/vectorstores/BaseVectorStore.java @@ -22,6 +22,7 @@ import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.util.Preconditions; import javax.annotation.Nullable; @@ -38,12 +39,29 @@ import java.util.function.BiFunction; public abstract class BaseVectorStore extends Resource { /** Name of the embedding model resource to use. */ - protected final String embeddingModel; + protected final String embeddingModelName; + + @Nullable protected BaseEmbeddingModelSetup embeddingModel; public BaseVectorStore( ResourceDescriptor descriptor, BiFunction<String, ResourceType, Resource> getResource) { super(descriptor, getResource); - this.embeddingModel = descriptor.getArgument("embedding_model"); + this.embeddingModelName = descriptor.getArgument("embedding_model"); + } + + /** + * Trigger construction for resource objects. + * + * <p>Currently, in cross-language invocation scenarios, constructing resource object within an + * async thread may encounter issues. We resolved this issue by moving the construction of the + * resources object out of the method to be async executed and invoking it in the main thread. + */ + @Override + public void open() { + this.embeddingModel = + (BaseEmbeddingModelSetup) + this.getResource.apply( + this.embeddingModelName, ResourceType.EMBEDDING_MODEL); } @Override @@ -71,13 +89,9 @@ public abstract class BaseVectorStore extends Resource { public List<String> add( List<Document> documents, @Nullable String collection, Map<String, Object> extraArgs) throws IOException { - final BaseEmbeddingModelSetup embeddingModel = - (BaseEmbeddingModelSetup) - this.getResource.apply(this.embeddingModel, ResourceType.EMBEDDING_MODEL); - for (Document doc : documents) { if (doc.getEmbedding() == null) { - doc.setEmbedding(embeddingModel.embed(doc.getContent())); + doc.setEmbedding(getEmbeddingModel().embed(doc.getContent())); } } @@ -95,11 +109,7 @@ public abstract class BaseVectorStore extends Resource { * @return VectorStoreQueryResult containing the retrieved documents */ public VectorStoreQueryResult query(VectorStoreQuery query) { - final BaseEmbeddingModelSetup embeddingModel = - (BaseEmbeddingModelSetup) - this.getResource.apply(this.embeddingModel, ResourceType.EMBEDDING_MODEL); - - final float[] queryEmbedding = embeddingModel.embed(query.getQueryText()); + final float[] queryEmbedding = getEmbeddingModel().embed(query.getQueryText()); final Map<String, Object> storeKwargs = this.getStoreKwargs(); storeKwargs.putAll(query.getExtraArgs()); @@ -170,4 +180,11 @@ public abstract class BaseVectorStore extends Resource { protected abstract List<String> addEmbedding( List<Document> documents, @Nullable String collection, Map<String, Object> extraArgs) throws IOException; + + private BaseEmbeddingModelSetup getEmbeddingModel() { + Preconditions.checkNotNull( + embeddingModel, + "Embedding model is not initialized. Ensure open() is called before add()."); + return embeddingModel; + } } diff --git a/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java b/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java index c570ad59..6718cf7a 100644 --- a/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java +++ b/api/src/main/java/org/apache/flink/agents/api/vectorstores/python/PythonVectorStore.java @@ -32,6 +32,7 @@ import pemja.core.object.PyObject; import javax.annotation.Nullable; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -69,6 +70,11 @@ public class PythonVectorStore extends BaseVectorStore implements PythonResource this.adapter = adapter; } + @Override + public void open() { + adapter.callMethod(vectorStore, "open", Collections.emptyMap()); + } + @Override @SuppressWarnings("unchecked") public List<String> add( diff --git a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java index 90dc1934..e3e6d6dd 100644 --- a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java +++ b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java @@ -72,9 +72,4 @@ public class BedrockEmbeddingModelSetup extends BaseEmbeddingModelSetup { } return params; } - - @Override - public BedrockEmbeddingModelConnection getConnection() { - return (BedrockEmbeddingModelConnection) super.getConnection(); - } } diff --git a/integrations/embedding-models/ollama/pom.xml b/integrations/embedding-models/ollama/pom.xml index 863d4ac6..285d33e8 100644 --- a/integrations/embedding-models/ollama/pom.xml +++ b/integrations/embedding-models/ollama/pom.xml @@ -43,6 +43,12 @@ under the License. <artifactId>ollama4j</artifactId> <version>${ollama4j.version}</version> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-core</artifactId> + <version>${flink.version}</version> + <scope>test</scope> + </dependency> </dependencies> </project> diff --git a/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java index 39f38e0b..3880dfb2 100644 --- a/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java +++ b/integrations/embedding-models/ollama/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelSetup.java @@ -46,9 +46,4 @@ public class OllamaEmbeddingModelSetup extends BaseEmbeddingModelSetup { return parameters; } - - @Override - public OllamaEmbeddingModelConnection getConnection() { - return (OllamaEmbeddingModelConnection) super.getConnection(); - } } diff --git a/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java b/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java index 47c9dc6c..2b908e20 100644 --- a/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java +++ b/integrations/embedding-models/ollama/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/ollama/OllamaEmbeddingModelConnectionTest.java @@ -19,7 +19,6 @@ package org.apache.flink.agents.integrations.embeddingmodels.ollama; import org.apache.flink.agents.api.annotation.EmbeddingModelConnection; -import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; @@ -79,11 +78,12 @@ class OllamaEmbeddingModelConnectionTest { } @Override - public BaseEmbeddingModelConnection getConnection() { - return new OllamaEmbeddingModelConnection(buildDescriptor(), dummyResource); + public void open() { + connection = new OllamaEmbeddingModelConnection(buildDescriptor(), dummyResource); } } DummySetup setup = new DummySetup(buildDescriptor(), dummyResource); + setup.open(); assertNotNull(setup.getConnection()); } } diff --git a/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java b/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java index a6fd2a0d..1f06a0f8 100644 --- a/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java +++ b/integrations/vector-stores/opensearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/opensearch/OpenSearchVectorStore.java @@ -23,7 +23,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; import org.apache.flink.agents.api.RetryExecutor; -import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; @@ -200,9 +199,6 @@ public class OpenSearchVectorStore extends BaseVectorStore public List<String> add( List<Document> documents, @Nullable String collection, Map<String, Object> extraArgs) throws IOException { - BaseEmbeddingModelSetup emb = - (BaseEmbeddingModelSetup) - this.getResource.apply(this.embeddingModel, ResourceType.EMBEDDING_MODEL); List<String> texts = new ArrayList<>(); List<Integer> needsEmbedding = new ArrayList<>(); for (int i = 0; i < documents.size(); i++) { @@ -212,7 +208,7 @@ public class OpenSearchVectorStore extends BaseVectorStore } } if (!texts.isEmpty()) { - List<float[]> embeddings = emb.embed(texts); + List<float[]> embeddings = this.embeddingModel.embed(texts); for (int j = 0; j < needsEmbedding.size(); j++) { documents.get(needsEmbedding.get(j)).setEmbedding(embeddings.get(j)); } diff --git a/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java b/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java index 833010c9..f18535d9 100644 --- a/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java +++ b/integrations/vector-stores/s3vectors/src/main/java/org/apache/flink/agents/integrations/vectorstores/s3vectors/S3VectorsVectorStore.java @@ -19,7 +19,6 @@ package org.apache.flink.agents.integrations.vectorstores.s3vectors; import org.apache.flink.agents.api.RetryExecutor; -import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; @@ -133,9 +132,6 @@ public class S3VectorsVectorStore extends BaseVectorStore { public List<String> add( List<Document> documents, @Nullable String collection, Map<String, Object> extraArgs) throws IOException { - BaseEmbeddingModelSetup emb = - (BaseEmbeddingModelSetup) - this.getResource.apply(this.embeddingModel, ResourceType.EMBEDDING_MODEL); List<String> texts = new ArrayList<>(); List<Integer> needsEmbedding = new ArrayList<>(); for (int i = 0; i < documents.size(); i++) { @@ -145,7 +141,7 @@ public class S3VectorsVectorStore extends BaseVectorStore { } } if (!texts.isEmpty()) { - List<float[]> embeddings = emb.embed(texts); + List<float[]> embeddings = this.embeddingModel.embed(texts); for (int j = 0; j < needsEmbedding.size(); j++) { documents.get(needsEmbedding.get(j)).setEmbedding(embeddings.get(j)); } diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java index c3dc74e1..1abd392d 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -48,6 +48,7 @@ import javax.annotation.Nullable; import java.util.*; import static org.apache.flink.agents.api.agents.Agent.STRUCTURED_OUTPUT; +import static org.apache.flink.agents.plan.actions.Utils.supportAsync; /** Built-in action for processing chat request and tool call result. */ public class ChatModelAction { @@ -247,9 +248,10 @@ public class ChatModelAction { (BaseChatModelSetup) ctx.getResource(model, ResourceType.CHAT_MODEL); boolean chatAsync = ctx.getConfig().get(AgentExecutionOptions.CHAT_ASYNC); - // TODO: python chat model doesn't support async execution yet, see - // https://github.com/apache/flink-agents/issues/448 for details. - chatAsync = chatAsync && !(chatModel instanceof PythonChatModelSetup); + + if ((chatModel instanceof PythonChatModelSetup) && !supportAsync()) { + chatAsync = false; + } Agent.ErrorHandlingStrategy strategy = ctx.getConfig().get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY); @@ -343,7 +345,8 @@ public class ChatModelAction { int totalRetryCount = retryStats.get(TOTAL_RETRY_COUNT).intValue(); int totalRetryWaitSec = retryStats.get(TOTAL_RETRY_WAIT_SEC).intValue(); - recordRetryMetrics(ctx, chatModel.getConnection(), totalRetryCount, totalRetryWaitSec); + recordRetryMetrics( + ctx, chatModel.getConnectionName(), totalRetryCount, totalRetryWaitSec); ctx.sendEvent( new ChatResponseEvent( diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java index 420ceafa..c66cda31 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java @@ -33,6 +33,8 @@ import org.apache.flink.agents.plan.JavaFunction; import java.util.List; +import static org.apache.flink.agents.plan.actions.Utils.supportAsync; + /** Built-in action for processing context retrieval requests. */ public class ContextRetrievalAction { @@ -60,9 +62,9 @@ public class ContextRetrievalAction { contextRetrievalRequestEvent.getVectorStore(), ResourceType.VECTOR_STORE); - // TODO: python vector store doesn't support async execution yet, see - // https://github.com/apache/flink-agents/issues/448 for details. - ragAsync = ragAsync && !(vectorStore instanceof PythonVectorStore); + if ((vectorStore instanceof PythonVectorStore) && !supportAsync()) { + ragAsync = false; + } final VectorStoreQuery vectorStoreQuery = new VectorStoreQuery( diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/Utils.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/Utils.java new file mode 100644 index 00000000..8789b292 --- /dev/null +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/Utils.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.plan.actions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Properties; +import java.util.StringTokenizer; + +public final class Utils { + private static final Logger LOG = LoggerFactory.getLogger(Utils.class); + private static final String DEFAULT_VALUE = "<unknown>"; + + static final Versions INSTANCE = new Versions(); + + private Utils() {} + + /** + * Check whether the current Flink version supports the async execution for cross-language + * resource. + * + * <p>The async execution for java resource is supported only on flink with the pemja 0.6.2 + * dependency. See <a href="https://github.com/apache/flink-agents/pull/571">flink-agents</a> + * for details. + */ + public static boolean supportAsync() { + String version = INSTANCE.projectVersion; + + if (DEFAULT_VALUE.equals(version)) { + return false; + } + + try { + StringTokenizer st = new StringTokenizer(version, "."); + int major = Integer.parseInt(st.nextToken()); + int minor = Integer.parseInt(st.nextToken()); + int micro = Integer.parseInt(st.nextToken()); + + if ((major == 1 && (minor < 20 || (minor == 20 && micro <= 3))) + || (major == 2 && minor == 0 && micro <= 1) + || (major == 2 && minor == 1 && micro <= 1) + || (major == 2 && minor == 2 && micro <= 0)) { + LOG.debug( + "Flink {} doesn't support async execution for java resource, will fallback to sync execution.", + version); + return false; + } + return true; + } catch (Exception e) { + LOG.debug("Can't decide flink version, will fallback to sync execution.", e); + return false; + } + } + + private static class Versions { + private String projectVersion = DEFAULT_VALUE; + + public Versions() { + ClassLoader classLoader = Utils.class.getClassLoader(); + try (InputStream propFile = + classLoader.getResourceAsStream(".flink-runtime.version.properties")) { + if (propFile != null) { + Properties properties = new Properties(); + properties.load(propFile); + this.projectVersion = getProperty(properties, "project.version", DEFAULT_VALUE); + } + } catch (IOException ioe) { + LOG.info( + "Cannot determine code revision: Unable to read version property file.: {}", + ioe.getMessage()); + } + } + + private String getProperty(Properties properties, String key, String DEFAULT_VALUE) { + String value = properties.getProperty(key); + return value != null && value.charAt(0) != '$' ? value : DEFAULT_VALUE; + } + } +} diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareChatModelTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareChatModelTest.java index 6f1ecb60..c7ab5a93 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareChatModelTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareChatModelTest.java @@ -58,6 +58,11 @@ class AgentPlanDeclareChatModelTest { super(descriptor, getResource); } + @Override + public void open() { + // do nothing + } + @Override public Map<String, Object> getParameters() { return Map.of(); @@ -179,9 +184,10 @@ class AgentPlanDeclareChatModelTest { BaseChatModelSetup expectedChatModel = (BaseChatModelSetup) resolveResource("testChatModel", ResourceType.CHAT_MODEL); Assertions.assertEquals(expectedChatModel.getClass(), actualChatModel.getClass()); - Assertions.assertEquals(expectedChatModel.getConnection(), actualChatModel.getConnection()); + Assertions.assertEquals( + expectedChatModel.getConnectionName(), actualChatModel.getConnectionName()); Assertions.assertEquals(expectedChatModel.getModel(), actualChatModel.getModel()); Assertions.assertEquals(expectedChatModel.getPrompt(), actualChatModel.getPrompt()); - Assertions.assertEquals(expectedChatModel.getTools(), actualChatModel.getTools()); + Assertions.assertEquals(expectedChatModel.getToolNames(), actualChatModel.getToolNames()); } } diff --git a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java index f9101e48..62680316 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java @@ -75,7 +75,7 @@ class ChatModelActionRetryTest { sensoryMemory = createStatefulMemoryObject(); // Wire up ChatModel - when(mockChatModel.getConnection()).thenReturn("test-connection"); + when(mockChatModel.getConnectionName()).thenReturn("test-connection"); // Wire up RunnerContext when(mockCtx.getResource(anyString(), eq(ResourceType.CHAT_MODEL))) @@ -158,7 +158,7 @@ class ChatModelActionRetryTest { assertThat(elapsed).isGreaterThanOrEqualTo(1000L); // Verify metrics recorded under connection name - verify(mockActionMetricGroup).getSubGroup(mockChatModel.getConnection()); + verify(mockActionMetricGroup).getSubGroup(mockChatModel.getConnectionName()); verify(mockRetryCountCounter).inc(1); verify(mockRetryWaitSecCounter).inc(1); } diff --git a/python/flink_agents/api/chat_models/chat_model.py b/python/flink_agents/api/chat_models/chat_model.py index 2f63566c..b3d8c5c8 100644 --- a/python/flink_agents/api/chat_models/chat_model.py +++ b/python/flink_agents/api/chat_models/chat_model.py @@ -17,7 +17,7 @@ ################################################################################# import re from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, List, Sequence, Tuple +from typing import Any, ClassVar, Dict, List, Sequence, Tuple, cast from pydantic import Field from typing_extensions import override @@ -49,12 +49,16 @@ class BaseChatModelConnection(Resource, ABC): """Return resource type of class.""" return ResourceType.CHAT_MODEL_CONNECTION - DEFAULT_REASONING_PATTERNS: ClassVar[Tuple[re.Pattern[str],...]] = ( + DEFAULT_REASONING_PATTERNS: ClassVar[Tuple[re.Pattern[str], ...]] = ( re.compile(r"<think>(.*?)</think>", re.DOTALL | re.IGNORECASE), re.compile(r"<analysis>(.*?)</analysis>", re.DOTALL | re.IGNORECASE), re.compile(r"<reasoning>(.*?)</reasoning>", re.DOTALL | re.IGNORECASE), - re.compile(r"```(?:think|reasoning|thought)\s*\n(.*?)\n```", re.DOTALL | re.IGNORECASE), - re.compile(r"(?:^|\n)Reasoning:\s*(.*?)(?:\n{2,}|$)", re.DOTALL | re.IGNORECASE), + re.compile( + r"```(?:think|reasoning|thought)\s*\n(.*?)\n```", re.DOTALL | re.IGNORECASE + ), + re.compile( + r"(?:^|\n)Reasoning:\s*(.*?)(?:\n{2,}|$)", re.DOTALL | re.IGNORECASE + ), ) @staticmethod @@ -133,9 +137,11 @@ class BaseChatModelSetup(Resource): different chat configurations. """ - connection: str = Field(description="Name of the referenced connection.") + connection: str | BaseChatModelConnection = Field( + description="The referenced connection." + ) prompt: Prompt | str | None = None - tools: List[str] | None = None + tools: List[str] | List[Tool] = Field(default_factory=list) @property @abstractmethod @@ -148,6 +154,24 @@ class BaseChatModelSetup(Resource): """Return resource type of class.""" return ResourceType.CHAT_MODEL + @override + def open(self) -> None: + self.connection = cast( + "BaseChatModelConnection", + self.get_resource(self.connection, ResourceType.CHAT_MODEL_CONNECTION), + ) + if self.prompt is not None: + if isinstance(self.prompt, str): + # Get prompt resource if it's a string + self.prompt = cast( + "Prompt", self.get_resource(self.prompt, ResourceType.PROMPT) + ) + if self.tools is not None: + self.tools = [ + cast("Tool", self.get_resource(tool_name, ResourceType.TOOL)) + for tool_name in self.tools + ] + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: """Execute chat conversation. @@ -168,19 +192,8 @@ class BaseChatModelSetup(Resource): ChatMessage Model response message """ - # Get model connection - connection = self.get_resource( - self.connection, ResourceType.CHAT_MODEL_CONNECTION - ) - # Apply prompt template if self.prompt is not None: - if isinstance(self.prompt, str): - # Get prompt resource if it's a string - prompt = self.get_resource(self.prompt, ResourceType.PROMPT) - else: - prompt = self.prompt - input_variable = {} # fill the prompt template @@ -188,25 +201,20 @@ class BaseChatModelSetup(Resource): # Convert Any values to str to match format_messages signature str_extra_args = {k: str(v) for k, v in msg.extra_args.items()} input_variable.update(str_extra_args) - prompt_messages = prompt.format_messages(**input_variable) + prompt_messages = self._get_prompt().format_messages(**input_variable) # append meaningful messages for msg in messages: - if (msg.content is not None and msg.content != "") or msg.role == MessageRole.ASSISTANT: + if ( + msg.content is not None and msg.content != "" + ) or msg.role == MessageRole.ASSISTANT: prompt_messages.append(msg) messages = prompt_messages - # Bind tools - tools = None - if self.tools is not None: - tools = [ - self.get_resource(tool_name, ResourceType.TOOL) - for tool_name in self.tools - ] # Call chat model connection to execute chat merged_kwargs = self.model_kwargs.copy() merged_kwargs.update(kwargs) - return connection.chat(messages, tools=tools, **merged_kwargs) + return self._get_connection().chat(messages, tools=self._get_tools(), **merged_kwargs) def _record_token_metrics( self, model_name: str, prompt_tokens: int, completion_tokens: int @@ -229,3 +237,23 @@ class BaseChatModelSetup(Resource): model_group = metric_group.get_sub_group(model_name) model_group.get_counter("promptTokens").inc(prompt_tokens) model_group.get_counter("completionTokens").inc(completion_tokens) + + def _get_connection(self) -> BaseChatModelConnection: + if not isinstance(self.connection, BaseChatModelConnection): + err_msg = f"Expect BaseChatModelConnection, but is {self.connection.__class__.__name__}" + raise TypeError(err_msg) + return self.connection + + def _get_prompt(self) -> Prompt: + if not isinstance(self.prompt, Prompt): + err_msg = f"Expect Prompt, but is {self.prompt.__class__.__name__}" + raise TypeError(err_msg) + return self.prompt + + def _get_tools(self) -> List[Tool]: + for tool in self.tools: + if not isinstance(tool, Tool): + err_msg = f"Expect Tool, but is {tool.__class__.__name__}" + raise TypeError(err_msg) + return self.tools + diff --git a/python/flink_agents/api/embedding_models/embedding_model.py b/python/flink_agents/api/embedding_models/embedding_model.py index be06913f..b7088982 100644 --- a/python/flink_agents/api/embedding_models/embedding_model.py +++ b/python/flink_agents/api/embedding_models/embedding_model.py @@ -16,7 +16,7 @@ # limitations under the License. ################################################################################# from abc import ABC, abstractmethod -from typing import Any, Dict, Sequence +from typing import Any, Dict, Sequence, cast from pydantic import Field from typing_extensions import override @@ -45,7 +45,9 @@ class BaseEmbeddingModelConnection(Resource, ABC): return ResourceType.EMBEDDING_MODEL_CONNECTION @abstractmethod - def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] | list[list[float]]: + def embed( + self, text: str | Sequence[str], **kwargs: Any + ) -> list[float] | list[list[float]]: """Generate embedding vector for a single text input. Converts the input text into a high-dimensional vector representation @@ -71,7 +73,9 @@ class BaseEmbeddingModelSetup(Resource, ABC): Provides the basic embedding interface for generating embeddings from text inputs. """ - connection: str = Field(description="Name of the referenced connection.") + connection: str | BaseEmbeddingModelConnection = Field( + description="The referenced connection." + ) model: str = Field(description="Name of the embedding model to use.") @classmethod @@ -85,7 +89,22 @@ class BaseEmbeddingModelSetup(Resource, ABC): def model_kwargs(self) -> Dict[str, Any]: """Return embedding model settings.""" - def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] | list[list[float]]: + @override + def open(self) -> None: + self.connection = cast( + "BaseEmbeddingModelConnection", + self.get_resource(self.connection, ResourceType.EMBEDDING_MODEL_CONNECTION), + ) + + def _get_connection(self) -> BaseEmbeddingModelConnection: + if not isinstance(self.connection, BaseEmbeddingModelConnection): + err_msg = f"Expect BaseEmbeddingModelConnection, but is {self.connection.__class__.__name__}" + raise TypeError(err_msg) + return self.connection + + def embed( + self, text: str | Sequence[str], **kwargs: Any + ) -> list[float] | list[list[float]]: """Generate embedding vector for a single text query. Converts the input text into a high-dimensional vector representation @@ -99,9 +118,6 @@ class BaseEmbeddingModelSetup(Resource, ABC): A list of floating-point numbers representing the embedding vector. The dimension of the vector depends on the specific embedding model used. """ - connection = self.get_resource( - self.connection, ResourceType.EMBEDDING_MODEL_CONNECTION - ) merged_kwargs = self.model_kwargs.copy() merged_kwargs.update(kwargs) - return connection.embed(text, **merged_kwargs) + return self._get_connection().embed(text, **merged_kwargs) diff --git a/python/flink_agents/api/resource.py b/python/flink_agents/api/resource.py index a5b1dd86..57d50667 100644 --- a/python/flink_agents/api/resource.py +++ b/python/flink_agents/api/resource.py @@ -90,6 +90,9 @@ class Resource(BaseModel, ABC): """ return self._metric_group + def open(self) -> None: + """Open the resource.""" + def close(self) -> None: """Close the resource.""" @@ -120,14 +123,14 @@ class ResourceDescriptor(BaseModel): arguments: Dict[str, Any] def __init__( - self, - /, - *, - clazz: str | None = None, - target_module: str | None = None, - target_clazz: str | None = None, - arguments: Dict[str, Any] | None = None, - **kwargs: Any, + self, + /, + *, + clazz: str | None = None, + target_module: str | None = None, + target_clazz: str | None = None, + arguments: Dict[str, Any] | None = None, + **kwargs: Any, ) -> None: """Initialize ResourceDescriptor. @@ -182,9 +185,9 @@ class ResourceDescriptor(BaseModel): if not isinstance(other, ResourceDescriptor): return False return ( - self.target_module == other.target_module - and self.target_clazz == other.target_clazz - and self.arguments == other.arguments + self.target_module == other.target_module + and self.target_clazz == other.target_clazz + and self.arguments == other.arguments ) def __hash__(self) -> int: @@ -253,8 +256,12 @@ class ResourceName: TONGYI_SETUP = "flink_agents.integrations.chat_models.tongyi_chat_model.TongyiChatModelSetup" # Java Wrapper - JAVA_WRAPPER_CONNECTION = "flink_agents.api.chat_models.java_chat_model.JavaChatModelConnection" - JAVA_WRAPPER_SETUP = "flink_agents.api.chat_models.java_chat_model.JavaChatModelSetup" + JAVA_WRAPPER_CONNECTION = ( + "flink_agents.api.chat_models.java_chat_model.JavaChatModelConnection" + ) + JAVA_WRAPPER_SETUP = ( + "flink_agents.api.chat_models.java_chat_model.JavaChatModelSetup" + ) class Java: """Java implementations of ChatModel.""" @@ -307,7 +314,9 @@ class ResourceName: CHROMA_VECTOR_STORE = "flink_agents.integrations.vector_stores.chroma.chroma_vector_store.ChromaVectorStore" # Java Wrapper - JAVA_WRAPPER_VECTOR_STORE = "flink_agents.api.vector_stores.java_vector_store.JavaVectorStore" + JAVA_WRAPPER_VECTOR_STORE = ( + "flink_agents.api.vector_stores.java_vector_store.JavaVectorStore" + ) JAVA_WRAPPER_COLLECTION_MANAGEABLE_VECTOR_STORE = "flink_agents.api.vector_stores.java_vector_store.JavaCollectionManageableVectorStore" class Java: diff --git a/python/flink_agents/api/vector_stores/vector_store.py b/python/flink_agents/api/vector_stores/vector_store.py index edc778ca..a6fcc225 100644 --- a/python/flink_agents/api/vector_stores/vector_store.py +++ b/python/flink_agents/api/vector_stores/vector_store.py @@ -17,11 +17,12 @@ ################################################################################ from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, List +from typing import Any, Dict, List, cast from pydantic import BaseModel, Field from typing_extensions import override +from flink_agents.api.embedding_models.embedding_model import BaseEmbeddingModelSetup from flink_agents.api.resource import Resource, ResourceType @@ -141,8 +142,8 @@ class BaseVectorStore(Resource, ABC): embedding generation internally. """ - embedding_model: str = Field( - description="Name of the embedding model resource to use." + embedding_model: str | BaseEmbeddingModelSetup = Field( + description="The embedding model to use." ) @classmethod @@ -160,6 +161,19 @@ class BaseVectorStore(Resource, ABC): when performing vector search operations. """ + @override + def open(self) -> None: + self.embedding_model = cast( + "BaseEmbeddingModelSetup", + self.get_resource(self.embedding_model, ResourceType.EMBEDDING_MODEL), + ) + + def _get_embedding_model(self) -> BaseEmbeddingModelSetup: + if not isinstance(self.embedding_model, BaseEmbeddingModelSetup): + err_msg = f"Expect BaseEmbeddingModelSetup, but is {self.embedding_model.__class__.__name__}" + raise TypeError(err_msg) + return self.embedding_model + def add( self, documents: Document | List[Document], @@ -181,16 +195,10 @@ class BaseVectorStore(Resource, ABC): """ # Normalize to list documents = _maybe_cast_to_list(documents) - - # Generate embeddings for all documents - embedding_model = self.get_resource( - self.embedding_model, ResourceType.EMBEDDING_MODEL - ) - # Generate embeddings for each document for doc in documents: if doc.embedding is None: - doc.embedding = embedding_model.embed(doc.content) + doc.embedding = self._get_embedding_model().embed(doc.content) # Merge setup kwargs with add-specific args merged_kwargs = self.store_kwargs.copy() @@ -213,10 +221,7 @@ class BaseVectorStore(Resource, ABC): VectorStoreQueryResult containing the retrieved documents """ # Generate embedding from the query text - embedding_model = self.get_resource( - self.embedding_model, ResourceType.EMBEDDING_MODEL - ) - query_embedding = embedding_model.embed(query.query_text) + query_embedding = self._get_embedding_model().embed(query.query_text) # Merge setup kwargs with query-specific args merged_kwargs = self.store_kwargs.copy() diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py index 5ea09478..25f4555f 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py @@ -37,6 +37,9 @@ from flink_agents.api.tools.tool import ToolType class SlowMockChatModel(BaseChatModelSetup): """Mock ChatModel with slow connection.""" + def open(self) -> None: + """Do nothing.""" + @property def model_kwargs(self) -> Dict[str, Any]: # noqa: D102 return {} diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/long_term_memory_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/long_term_memory_test.py index 46ed4427..3dd49cf5 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/long_term_memory_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/long_term_memory_test.py @@ -23,6 +23,7 @@ from importlib import resources from pathlib import Path from typing import Any, List +import pytest from pydantic import BaseModel from pyflink.common import Encoder, Types, WatermarkStrategy from pyflink.datastream import ( @@ -212,6 +213,11 @@ class LongTermMemoryAgent(Agent): ctx.send_event(OutputEvent(output=record)) [email protected]( + "During compaction, VectorStoreLongTermMemory need getting chat model. This will cause exception because" + "flink-agent doesn't allow get resource in async thread. We will deprecate VectorStoreLongTermMemory in 0.3.0," + "so we will not fix this issue for now." +) def test_long_term_memory_async_execution_in_action(tmp_path: Path) -> None: # noqa: D103 env = StreamExecutionEnvironment.get_execution_environment() env.set_runtime_mode(RuntimeExecutionMode.STREAMING) diff --git a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py index 321b499c..5fb080a2 100644 --- a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py @@ -111,6 +111,9 @@ def test_ollama_chat_with_tools() -> None: # noqa :D103 tools=["add"], get_resource=get_resource, ) + + llm.open() + response = llm.chat( [ ChatMessage( @@ -175,6 +178,8 @@ def test_ollama_chat_with_extract_reasoning() -> None: get_resource=get_resource, ) + llm.open() + # Replace the real client with our mock client connection._OllamaChatModelConnection__client = mock_client diff --git a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py index 3278f1ed..39fb1578 100644 --- a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py @@ -88,6 +88,8 @@ def test_tongyi_chat_with_tools() -> None: get_resource=get_resource, ) + llm.open() + response = llm.chat( [ ChatMessage( @@ -151,6 +153,8 @@ def test_tongyi_chat_with_extract_reasoning(monkeypatch: pytest.MonkeyPatch) -> get_resource=get_resource, ) + llm.open() + response = llm.chat( [ChatMessage(role=MessageRole.USER, content="What's the meaning of life?")] ) diff --git a/python/flink_agents/integrations/vector_stores/chroma/tests/test_chroma_vector_store.py b/python/flink_agents/integrations/vector_stores/chroma/tests/test_chroma_vector_store.py index 7fd3b585..438d02ea 100644 --- a/python/flink_agents/integrations/vector_stores/chroma/tests/test_chroma_vector_store.py +++ b/python/flink_agents/integrations/vector_stores/chroma/tests/test_chroma_vector_store.py @@ -28,6 +28,7 @@ try: except ImportError: chromadb_available = False +from flink_agents.api.embedding_models.embedding_model import BaseEmbeddingModelSetup from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.vector_stores.vector_store import ( Document, @@ -42,12 +43,13 @@ tenant = os.environ.get("TEST_TENANT") database = os.environ.get("TEST_DATABASE") -class MockEmbeddingModel(Resource): # noqa: D101 +class MockEmbeddingModel(BaseEmbeddingModelSetup): # noqa: D101 name: str + connection: str = "mock" + model: str = "mock" - @classmethod - def resource_type(cls) -> ResourceType: # noqa: D102 - return ResourceType.EMBEDDING_MODEL + def open(self) -> None: # noqa: D102 + pass @property def model_kwargs(self) -> Dict[str, Any]: # noqa: D102 @@ -100,6 +102,8 @@ def test_local_chroma_vector_store() -> None: get_resource=get_resource, ) + vector_store.open() + _populate_test_data(vector_store, "test_collection") query = VectorStoreQuery(query_text="What is Flink Agent?", limit=1) @@ -129,6 +133,8 @@ def test_collection_management() -> None: get_resource=get_resource, ) + vector_store.open() + vector_store.get_or_create_collection( name="collection_management", metadata={"key1": "value1", "key2": "value2"} ) @@ -165,6 +171,8 @@ def test_document_management() -> None: get_resource=get_resource, ) + vector_store.open() + vector_store.get_or_create_collection( name="document_management", metadata={"key1": "value1", "key2": "value2"} ) @@ -221,6 +229,8 @@ def test_cloud_chroma_vector_store() -> None: get_resource=get_resource, ) + vector_store.open() + _populate_test_data(vector_store) query = VectorStoreQuery(query_text="What is Flink Agent?", limit=1) diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index 210289db..0a0c780e 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -41,6 +41,7 @@ from flink_agents.api.memory_object import MemoryObject from flink_agents.api.resource import ResourceType from flink_agents.api.runner_context import RunnerContext from flink_agents.plan.actions.action import Action +from flink_agents.plan.actions.utils import support_async from flink_agents.plan.function import PythonFunction if TYPE_CHECKING: @@ -224,11 +225,13 @@ async def chat( ) chat_async = ctx.config.get(AgentExecutionOptions.CHAT_ASYNC) - # java chat model doesn't support async execution, - # see https://github.com/apache/flink-agents/issues/448 for details. - chat_async = chat_async and not isinstance(chat_model, JavaChatModelSetup) - error_handling_strategy = ctx.config.get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY) + if isinstance(chat_model, JavaChatModelSetup) and not support_async(): + chat_async = False + + error_handling_strategy = ctx.config.get( + AgentExecutionOptions.ERROR_HANDLING_STRATEGY + ) num_retries = 0 retry_wait_interval_sec = 0 if error_handling_strategy == ErrorHandlingStrategy.RETRY: @@ -251,8 +254,16 @@ async def chat( else: response = ctx.durable_execute(chat_model.chat, messages) - if response.extra_args.get("model_name") and response.extra_args.get("promptTokens") and response.extra_args.get("completionTokens"): - chat_model._record_token_metrics(response.extra_args["model_name"], response.extra_args["promptTokens"], response.extra_args["completionTokens"]) + if ( + response.extra_args.get("model_name") + and response.extra_args.get("promptTokens") + and response.extra_args.get("completionTokens") + ): + chat_model._record_token_metrics( + response.extra_args["model_name"], + response.extra_args["promptTokens"], + response.extra_args["completionTokens"], + ) if output_schema is not None and len(response.tool_calls) == 0: response = _generate_structured_output(response, output_schema) break diff --git a/python/flink_agents/plan/actions/context_retrieval_action.py b/python/flink_agents/plan/actions/context_retrieval_action.py index d484e782..e56134dc 100644 --- a/python/flink_agents/plan/actions/context_retrieval_action.py +++ b/python/flink_agents/plan/actions/context_retrieval_action.py @@ -28,10 +28,12 @@ from flink_agents.api.runner_context import RunnerContext from flink_agents.api.vector_stores.java_vector_store import JavaVectorStore from flink_agents.api.vector_stores.vector_store import VectorStoreQuery from flink_agents.plan.actions.action import Action +from flink_agents.plan.actions.utils import support_async from flink_agents.plan.function import PythonFunction _logger = logging.getLogger(__name__) + async def process_context_retrieval_request(event: Event, ctx: RunnerContext) -> None: """Built-in action for processing context retrieval requests.""" if isinstance(event, ContextRetrievalRequestEvent): @@ -40,9 +42,10 @@ async def process_context_retrieval_request(event: Event, ctx: RunnerContext) -> query = VectorStoreQuery(query_text=event.query, limit=event.max_results) rag_async = ctx.config.get(AgentExecutionOptions.RAG_ASYNC) - # java vector store doesn't support async execution - # see https://github.com/apache/flink-agents/issues/448 for details. - rag_async = rag_async and not isinstance(vector_store, JavaVectorStore) + + if isinstance(vector_store, JavaVectorStore) and not support_async(): + rag_async = False + if rag_async: # To avoid https://github.com/alibaba/pemja/issues/88, # we log a message here. diff --git a/python/flink_agents/plan/actions/utils.py b/python/flink_agents/plan/actions/utils.py new file mode 100644 index 00000000..f8326044 --- /dev/null +++ b/python/flink_agents/plan/actions/utils.py @@ -0,0 +1,58 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import functools +import logging +from importlib.metadata import version +from typing import List, Tuple + +from packaging import version as pkg_version + +# For flink versions depend on pemja without fix +# pr: https://github.com/alibaba/pemja/pull/95. +# The async execution for cross language resource +# is not supported. +UNSUPPORTED_RANGES: List[Tuple[str, str]] = [ + ("1.0.0", "1.20.3"), + ("2.0.0", "2.0.1"), + ("2.1.0", "2.1.1"), + ("2.2.0", "2.2.0"), +] + + [email protected]_cache(maxsize=1) +def support_async() -> bool: + """Check whether the current Flink version supports the async execution for + cross-language resource. + + The async execution for java resource is supported only on flink with + the pemja 0.6.2 dependency. See https://github.com/apache/flink-agents/pull/571 + for details. + """ + try: + current = pkg_version.parse(version("apache-flink")) + + for min_ver, max_ver in UNSUPPORTED_RANGES: + if pkg_version.parse(min_ver) <= current <= pkg_version.parse(max_ver): + logging.debug( + f"Flink {current} doesn't support async execution for java resource, will fallback to sync execution." + ) + return False + except Exception: + return False + else: + return True diff --git a/python/flink_agents/plan/resource_provider.py b/python/flink_agents/plan/resource_provider.py index 4894b998..714a01c2 100644 --- a/python/flink_agents/plan/resource_provider.py +++ b/python/flink_agents/plan/resource_provider.py @@ -194,7 +194,6 @@ class JavaResourceProvider(ResourceProvider): if not self._j_resource_adapter: err_msg = "java resource adapter is not set" raise RuntimeError(err_msg) - j_resource = self._j_resource_adapter.getResource(self.name, self.type.value) class_path = JAVA_RESOURCE_MAPPING.get(self.type) diff --git a/python/flink_agents/plan/tests/test_agent_plan.py b/python/flink_agents/plan/tests/test_agent_plan.py index 2b022095..583940cb 100644 --- a/python/flink_agents/plan/tests/test_agent_plan.py +++ b/python/flink_agents/plan/tests/test_agent_plan.py @@ -93,6 +93,9 @@ class MockChatModelImpl(BaseChatModelSetup): # noqa: D101 host: str desc: str + def open(self) -> None: + """Do nothing.""" + @property def model_kwargs(self) -> Dict[str, Any]: # noqa: D102 return {} diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index 2c593c0f..631267fa 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -241,6 +241,7 @@ class FlinkRunnerContext(RunnerContext): @override def get_resource(self, name: str, type: ResourceType, metric_group: MetricGroup = None) -> Resource: + self._j_runner_context.checkMailboxThread() resource = self.__resource_cache.get_resource(name, type) # Bind metric group to the resource resource.set_metric_group(metric_group or self.action_metric_group) diff --git a/python/flink_agents/runtime/java/java_chat_model.py b/python/flink_agents/runtime/java/java_chat_model.py index 2263b68e..f22d079c 100644 --- a/python/flink_agents/runtime/java/java_chat_model.py +++ b/python/flink_agents/runtime/java/java_chat_model.py @@ -123,6 +123,11 @@ class JavaChatModelSetupImpl(JavaChatModelSetup): """ return {} + @override + def open(self) -> None: + """Open the java resource.""" + self._j_resource.open() + @override def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: """Execute chat conversation by delegating to Java implementation. diff --git a/python/flink_agents/runtime/java/java_embedding_model.py b/python/flink_agents/runtime/java/java_embedding_model.py index c53a2a85..a2dbc99a 100644 --- a/python/flink_agents/runtime/java/java_embedding_model.py +++ b/python/flink_agents/runtime/java/java_embedding_model.py @@ -17,6 +17,8 @@ ################################################################################# from typing import Any, Dict, Sequence +from typing_extensions import override + from flink_agents.api.embedding_models.java_embedding_model import ( JavaEmbeddingModelConnection, JavaEmbeddingModelSetup, @@ -96,6 +98,11 @@ class JavaEmbeddingModelSetupImpl(JavaEmbeddingModelSetup): """ return {} + @override + def open(self) -> None: + """Open the java resource.""" + self._j_resource.open() + def embed(self, text: str | Sequence[str], **kwargs: Any) -> list[float] | list[list[float]]: """Generate embedding vector for a single text query. Converts the input text into a high-dimensional vector representation diff --git a/python/flink_agents/runtime/java/java_vector_store.py b/python/flink_agents/runtime/java/java_vector_store.py index 77f3167f..1cb4abe9 100644 --- a/python/flink_agents/runtime/java/java_vector_store.py +++ b/python/flink_agents/runtime/java/java_vector_store.py @@ -67,6 +67,10 @@ class JavaVectorStoreImpl(JavaCollectionManageableVectorStore): def store_kwargs(self) -> Dict[str, Any]: return {} + @override + def open(self) -> None: + self._j_resource.open() + @override def add( self, diff --git a/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py b/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py index 37ca8b09..390d54e0 100644 --- a/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py +++ b/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py @@ -27,6 +27,7 @@ from chromadb.utils import embedding_functions from pydantic import ConfigDict from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.embedding_models.embedding_model import BaseEmbeddingModelSetup from flink_agents.api.memory.long_term_memory import ( CompactionConfig, DatetimeRange, @@ -58,13 +59,14 @@ if TYPE_CHECKING: current_dir = Path(__file__).parent -class MockEmbeddingModel(Resource): # noqa: D101 +class MockEmbeddingModel(BaseEmbeddingModelSetup): # noqa: D101 model_config = ConfigDict(arbitrary_types_allowed=True) ef: EmbeddingFunction = embedding_functions.DefaultEmbeddingFunction() + connection: str = "mock" + model: str = "mock" - @classmethod - def resource_type(cls) -> ResourceType: # noqa: D102 - return ResourceType.EMBEDDING_MODEL + def open(self) -> None: # noqa: D102 + pass @property def model_kwargs(self) -> Dict[str, Any]: # noqa: D102 @@ -84,18 +86,20 @@ def long_term_memory() -> VectorStoreLongTermMemory: # noqa: D103 def get_resource(name: str, type: ResourceType) -> Resource: if type == ResourceType.CHAT_MODEL: - return chat_model + resource = chat_model elif type == ResourceType.CHAT_MODEL_CONNECTION: - return chat_model_connection + resource = chat_model_connection elif type == ResourceType.EMBEDDING_MODEL: if use_ollama: - return embedding_model + resource = embedding_model else: - return MockEmbeddingModel() + resource = MockEmbeddingModel() elif type == ResourceType.EMBEDDING_MODEL_CONNECTION: - return embedding_model_connection + resource = embedding_model_connection else: - return vector_store + resource = vector_store + resource.open() + return resource chat_model = OllamaChatModelSetup( get_resource=get_resource, connection="chat_model_connection", model="qwen3:8b" diff --git a/python/flink_agents/runtime/resource_cache.py b/python/flink_agents/runtime/resource_cache.py index fa0ca617..cc4d32a7 100644 --- a/python/flink_agents/runtime/resource_cache.py +++ b/python/flink_agents/runtime/resource_cache.py @@ -79,6 +79,7 @@ class ResourceCache: resource = resource_provider.provide( get_resource=self.get_resource, config=self._config ) + resource.open() self._cache.setdefault(type, {})[name] = resource return resource diff --git a/python/flink_agents/runtime/tests/test_built_in_actions.py b/python/flink_agents/runtime/tests/test_built_in_actions.py index 618b0599..8959b12b 100644 --- a/python/flink_agents/runtime/tests/test_built_in_actions.py +++ b/python/flink_agents/runtime/tests/test_built_in_actions.py @@ -73,6 +73,9 @@ class MockChatModelConnection(BaseChatModelConnection): class MockChatModel(BaseChatModelSetup): """Mock ChatModel for testing integrating prompt and tool.""" + def open(self) -> None: + """Do nothing.""" + @property def model_kwargs(self) -> Dict[str, Any]: """Return model kwargs.""" diff --git a/python/flink_agents/runtime/tests/test_get_resource_in_action.py b/python/flink_agents/runtime/tests/test_get_resource_in_action.py index bd02ca35..f6028003 100644 --- a/python/flink_agents/runtime/tests/test_get_resource_in_action.py +++ b/python/flink_agents/runtime/tests/test_get_resource_in_action.py @@ -31,6 +31,9 @@ class MockChatModelImpl(BaseChatModelSetup): # noqa: D101 host: str desc: str + def open(self) -> None: + """Do nothing.""" + @property def model_kwargs(self) -> Dict[str, Any]: # noqa: D102 return {} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/ResourceCache.java b/runtime/src/main/java/org/apache/flink/agents/runtime/ResourceCache.java index 3537a5a7..604348b6 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/ResourceCache.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/ResourceCache.java @@ -94,7 +94,7 @@ public class ResourceCache implements AutoCloseable { throw new RuntimeException(e); } }); - + resource.open(); cache.computeIfAbsent(type, k -> new ConcurrentHashMap<>()).put(name, resource); return resource; } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index 8c7b9fc3..dcba1555 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -225,6 +225,7 @@ public class RunnerContextImpl implements RunnerContext { @Override public Resource getResource(String name, ResourceType type) throws Exception { + mailboxThreadChecker.run(); if (resourceCache == null) { throw new IllegalStateException("ResourceCache is not available in this context"); } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java index e9741baf..61aa8854 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java @@ -60,6 +60,11 @@ public class PythonRunnerContextImpl extends RunnerContextImpl { sendEvent(new PythonEvent(event, type, eventJsonStr)); } + public void checkMailboxThread() { + // this method will be invoked by PythonActionExecutor's python interpreter. + this.mailboxThreadChecker.run(); + } + public String getPythonAwaitableRef() { return pythonAwaitableRef; }
