This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 763e591b1e0a593e9cd5a2f530a23c3f94d4a0a7 Author: WenjinXie <[email protected]> AuthorDate: Fri Jan 16 22:02:23 2026 +0800 [plan][java] Built-in actions support async execution. --- .../agents/api/agents/AgentExecutionOptions.java | 10 ++++++++ .../integration/test/ChatModelIntegrationTest.java | 26 +++++++++++-------- .../resource/test/ChatModelCrossLanguageTest.java | 26 +++++++++++-------- .../flink/agents/plan/actions/ChatModelAction.java | 26 ++++++++++++++++++- .../plan/actions/ContextRetrievalAction.java | 28 ++++++++++++++++++++- .../flink/agents/plan/actions/ToolCallAction.java | 29 +++++++++++++++++++++- 6 files changed, 122 insertions(+), 23 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java index 64880a5a..26991b60 100644 --- a/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java +++ b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java @@ -29,4 +29,14 @@ public class AgentExecutionOptions { public static final ConfigOption<Integer> MAX_RETRIES = new ConfigOption<>("max-retries", Integer.class, 3); + + // Async execution is supported on jdk >= 21, so set default false here. + public static final ConfigOption<Boolean> CHAT_ASYNC = + new ConfigOption<>("chat.async", Boolean.class, true); + + public static final ConfigOption<Boolean> TOOL_CALL_ASYNC = + new ConfigOption<>("tool-call.async", Boolean.class, true); + + public static final ConfigOption<Boolean> RAG_ASYNC = + new ConfigOption<>("rag.async", Boolean.class, true); } diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java index cad885ac..c843b970 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java @@ -31,6 +31,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import static org.apache.flink.agents.integration.test.ChatModelIntegrationAgent.OLLAMA_MODEL; @@ -103,18 +104,23 @@ public class ChatModelIntegrationTest extends OllamaPreparationUtils { public void checkResult(CloseableIterator<Object> results) { List<String> expectedWords = List.of("77", "37", "89", "23", "68", "22", "26", "22", "23", ""); + List<String> responses = new ArrayList<>(); + while (results.hasNext()) { + responses.add((String) results.next()); + } + + Assertions.assertEquals( + expectedWords.size(), + responses.size(), + String.format( + "LLM response count is mismatch," + "the responses are %s", responses)); + + String text = String.join("\n", responses); for (String expected : expectedWords) { Assertions.assertTrue( - results.hasNext(), "Output messages count %s is less than expected."); - String res = (String) results.next(); - if (res.contains("error") || res.contains("parameters")) { - LOG.warn(res); - } else { - Assertions.assertTrue( - res.contains(expected), - String.format( - "Groud truth %s is not contained in answer {%s}", expected, res)); - } + text.contains(expected), + String.format( + "Groud truth %s is not contained in answer {%s}", expected, text)); } } } diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java index 7d7ec1d0..d62a9726 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java @@ -30,6 +30,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import static org.apache.flink.agents.resource.test.ChatModelCrossLanguageAgent.OLLAMA_MODEL; @@ -82,18 +83,23 @@ public class ChatModelCrossLanguageTest { public void checkResult(CloseableIterator<Object> results) { List<String> expectedWords = List.of("77", "22", ""); + List<String> responses = new ArrayList<>(); + while (results.hasNext()) { + responses.add((String) results.next()); + } + + Assertions.assertEquals( + expectedWords.size(), + responses.size(), + String.format( + "LLM response count is mismatch," + "the responses are %s", responses)); + + String text = String.join("\n", responses); for (String expected : expectedWords) { Assertions.assertTrue( - results.hasNext(), "Output messages count %s is less than expected."); - String res = (String) results.next(); - if (res.contains("error") || res.contains("parameters")) { - LOG.warn(res); - } else { - Assertions.assertTrue( - res.contains(expected), - String.format( - "Groud truth %s is not contained in answer {%s}", expected, res)); - } + text.contains(expected), + String.format( + "Groud truth %s is not contained in answer {%s}", expected, text)); } } } diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java index 7d5b34c9..9bde346a 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -26,6 +26,7 @@ import org.apache.flink.agents.api.agents.OutputSchema; import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.context.DurableCallable; import org.apache.flink.agents.api.context.MemoryObject; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.event.ChatRequestEvent; @@ -196,6 +197,8 @@ public class ChatModelAction { BaseChatModelSetup chatModel = (BaseChatModelSetup) ctx.getResource(model, ResourceType.CHAT_MODEL); + boolean chatAsync = ctx.getConfig().get(AgentExecutionOptions.CHAT_ASYNC); + Agent.ErrorHandlingStrategy strategy = ctx.getConfig().get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY); int numRetries = 0; @@ -210,7 +213,28 @@ public class ChatModelAction { for (int attempt = 0; attempt < numRetries + 1; attempt++) { try { - response = chatModel.chat(messages, Map.of()); + if (chatAsync) { + response = + ctx.durableExecuteAsync( + new DurableCallable<>() { + @Override + public String getId() { + return "chat-async"; + } + + @Override + public Class<ChatMessage> getResultClass() { + return ChatMessage.class; + } + + @Override + public ChatMessage call() throws Exception { + return chatModel.chat(messages, Map.of()); + } + }); + } else { + response = chatModel.chat(messages, Map.of()); + } // only generate structured output for final response. if (outputSchema != null && response.getToolCalls().isEmpty()) { response = generateStructuredOutput(response, outputSchema); diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java index 504011d3..72463010 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java @@ -19,6 +19,8 @@ package org.apache.flink.agents.plan.actions; import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.agents.AgentExecutionOptions; +import org.apache.flink.agents.api.context.DurableCallable; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.event.ContextRetrievalRequestEvent; import org.apache.flink.agents.api.event.ContextRetrievalResponseEvent; @@ -46,6 +48,8 @@ public class ContextRetrievalAction { public static void processContextRetrievalRequest(Event event, RunnerContext ctx) throws Exception { if (event instanceof ContextRetrievalRequestEvent) { + boolean ragAsync = ctx.getConfig().get(AgentExecutionOptions.RAG_ASYNC); + final ContextRetrievalRequestEvent contextRetrievalRequestEvent = (ContextRetrievalRequestEvent) event; @@ -60,7 +64,29 @@ public class ContextRetrievalAction { contextRetrievalRequestEvent.getQuery(), contextRetrievalRequestEvent.getMaxResults()); - final VectorStoreQueryResult result = vectorStore.query(vectorStoreQuery); + VectorStoreQueryResult result; + if (ragAsync) { + result = + ctx.durableExecuteAsync( + new DurableCallable<VectorStoreQueryResult>() { + @Override + public String getId() { + return "rag-async"; + } + + @Override + public Class<VectorStoreQueryResult> getResultClass() { + return VectorStoreQueryResult.class; + } + + @Override + public VectorStoreQueryResult call() throws Exception { + return vectorStore.query(vectorStoreQuery); + } + }); + } else { + result = vectorStore.query(vectorStoreQuery); + } ctx.sendEvent( new ContextRetrievalResponseEvent( diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java index 592c8214..f5fba2da 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java @@ -17,6 +17,8 @@ */ package org.apache.flink.agents.plan.actions; +import org.apache.flink.agents.api.agents.AgentExecutionOptions; +import org.apache.flink.agents.api.context.DurableCallable; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.event.ToolRequestEvent; import org.apache.flink.agents.api.event.ToolResponseEvent; @@ -44,6 +46,8 @@ public class ToolCallAction { @SuppressWarnings("unchecked") public static void processToolRequest(ToolRequestEvent event, RunnerContext ctx) { + boolean toolCallAsync = ctx.getConfig().get(AgentExecutionOptions.TOOL_CALL_ASYNC); + Map<String, Boolean> success = new HashMap<>(); Map<String, String> error = new HashMap<>(); Map<String, ToolResponse> responses = new HashMap<>(); @@ -70,7 +74,30 @@ public class ToolCallAction { if (tool != null) { try { - ToolResponse response = tool.call(new ToolParameters(arguments)); + ToolResponse response; + if (toolCallAsync) { + final Tool toolRef = tool; + response = + ctx.durableExecuteAsync( + new DurableCallable<>() { + @Override + public String getId() { + return "tool-call-async"; + } + + @Override + public Class<ToolResponse> getResultClass() { + return ToolResponse.class; + } + + @Override + public ToolResponse call() throws Exception { + return toolRef.call(new ToolParameters(arguments)); + } + }); + } else { + response = tool.call(new ToolParameters(arguments)); + } success.put(id, true); responses.put(id, response); } catch (Exception e) {
