This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 56dfe770f2a1ab32c12229b02239486f7053316b Author: WenjinXie <[email protected]> AuthorDate: Fri Sep 19 14:58:12 2025 +0800 [plan] Add built-in chat model action and tool call action in java. add license. --- .../agents/api/chat/messages/ChatMessage.java | 30 ++-- .../flink/agents/api/event/ChatRequestEvent.java | 42 +++++ .../flink/agents/api/event/ChatResponseEvent.java | 42 +++++ .../flink/agents/api/event/ToolRequestEvent.java | 33 ++-- .../flink/agents/api/event/ToolResponseEvent.java | 93 +++++----- .../flink/agents/examples/AgentWithResource.java | 90 ++++++---- .../agents/examples/AgentWithResourceExample.java | 15 +- .../org/apache/flink/agents/plan/AgentPlan.java | 16 ++ .../flink/agents/plan/actions/ChatModelAction.java | 193 +++++++++++++++++++++ .../flink/agents/plan/actions/ToolCallAction.java | 86 +++++++++ .../apache/flink/agents/plan/AgentPlanTest.java | 8 +- .../flink_agents/plan/actions/chat_model_action.py | 4 - .../create_python_agent_plan_from_json.py | 4 +- 13 files changed, 530 insertions(+), 126 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java b/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java index b00d802..7ebc787 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java @@ -18,7 +18,6 @@ package org.apache.flink.agents.api.chat.messages; -import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnore; import java.util.ArrayList; @@ -43,19 +42,20 @@ public class ChatMessage { /** Default constructor with SYSTEM role */ public ChatMessage() { - this.role = MessageRole.SYSTEM; - this.content = ""; - this.toolCalls = new ArrayList<>(); - this.extraArgs = new HashMap<>(); + this(MessageRole.SYSTEM, null, null, null); } /** Constructor with role and content */ public ChatMessage(MessageRole role, String content) { - this.role = role != null ? role : MessageRole.SYSTEM; - this.content = content != null ? content : ""; - this.toolCalls = new ArrayList<>(); - this.extraArgs = new HashMap<>(); - this.extraArgs.put(MESSAGE_TYPE, this.role); + this(role, content, null, null); + } + + public ChatMessage(MessageRole role, String content, Map<String, Object> extraArgs) { + this(role, content, null, extraArgs); + } + + public ChatMessage(MessageRole role, String content, List<Map<String, Object>> toolCalls) { + this(role, content, toolCalls, null); } /** Full constructor */ @@ -71,16 +71,6 @@ public class ChatMessage { this.extraArgs.put(MESSAGE_TYPE, this.role); } - /** Constructor with resource */ - public ChatMessage(MessageRole role, Resource resource, Map<String, Object> extraArgs) { - if (resource == null) throw new IllegalArgumentException("resource must not be null"); - // TODO handle resource content properly - this.role = role != null ? role : MessageRole.SYSTEM; - this.toolCalls = new ArrayList<>(); - this.extraArgs = extraArgs != null ? new HashMap<>(extraArgs) : new HashMap<>(); - this.extraArgs.put(MESSAGE_TYPE, this.role); - } - public MessageRole getRole() { return role; } diff --git a/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java b/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java new file mode 100644 index 0000000..e5277d3 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.event; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.chat.messages.ChatMessage; + +import java.util.List; + +public class ChatRequestEvent extends Event { + private final String model; + private final List<ChatMessage> messages; + + public ChatRequestEvent(String model, List<ChatMessage> messages) { + this.model = model; + this.messages = messages; + } + + public String getModel() { + return model; + } + + public List<ChatMessage> getMessages() { + return messages; + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java b/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java new file mode 100644 index 0000000..9041e5a --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.event; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.chat.messages.ChatMessage; + +import java.util.UUID; + +public class ChatResponseEvent extends Event { + private final UUID requestId; + private final ChatMessage response; + + public ChatResponseEvent(UUID requestId, ChatMessage response) { + this.requestId = requestId; + this.response = response; + } + + public UUID getRequestId() { + return requestId; + } + + public ChatMessage getResponse() { + return response; + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/event/ToolRequestEvent.java b/api/src/main/java/org/apache/flink/agents/api/event/ToolRequestEvent.java index e39db7f..889c525 100644 --- a/api/src/main/java/org/apache/flink/agents/api/event/ToolRequestEvent.java +++ b/api/src/main/java/org/apache/flink/agents/api/event/ToolRequestEvent.java @@ -18,28 +18,29 @@ package org.apache.flink.agents.api.event; +import org.apache.flink.agents.api.Event; + +import java.util.List; import java.util.Map; -import java.util.Objects; /** Event representing a tool call request */ -public class ToolRequestEvent { - - private final String tool; - private final Map<String, Object> kwargs; +public class ToolRequestEvent extends Event { + private final String model; + private final List<Map<String, Object>> toolCalls; private final long timestamp; - public ToolRequestEvent(String tool, Map<String, Object> kwargs) { - this.tool = Objects.requireNonNull(tool, "tool name cannot be null"); - this.kwargs = Objects.requireNonNull(kwargs, "kwargs cannot be null"); + public ToolRequestEvent(String model, List<Map<String, Object>> toolCalls) { + this.model = model; + this.toolCalls = toolCalls; this.timestamp = System.currentTimeMillis(); } - public String getTool() { - return tool; + public String getModel() { + return model; } - public Map<String, Object> getKwargs() { - return kwargs; + public List<Map<String, Object>> getToolCalls() { + return toolCalls; } public long getTimestamp() { @@ -49,11 +50,11 @@ public class ToolRequestEvent { @Override public String toString() { return "ToolRequestEvent{" - + "tool='" - + tool + + "model='" + + model + '\'' - + ", kwargs=" - + kwargs + + ", toolCalls=" + + toolCalls + ", timestamp=" + timestamp + '}'; diff --git a/api/src/main/java/org/apache/flink/agents/api/event/ToolResponseEvent.java b/api/src/main/java/org/apache/flink/agents/api/event/ToolResponseEvent.java index 5115c69..c4d874d 100644 --- a/api/src/main/java/org/apache/flink/agents/api/event/ToolResponseEvent.java +++ b/api/src/main/java/org/apache/flink/agents/api/event/ToolResponseEvent.java @@ -18,46 +18,60 @@ package org.apache.flink.agents.api.event; -import java.util.Objects; +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.tools.ToolResponse; -/** Event representing a result from tool call */ -public class ToolResponseEvent { +import java.util.Map; +import java.util.UUID; - private final ToolRequestEvent request; - private final Object response; - private final boolean success; - private final String error; +/** Event representing a result from tool call */ +public class ToolResponseEvent extends Event { + private final UUID requestId; + private final Map<String, ToolResponse> responses; + private final Map<String, String> externalIds; + private final Map<String, Boolean> success; + private final Map<String, String> error; private final long timestamp; - public ToolResponseEvent(ToolRequestEvent request, Object response) { - this.request = Objects.requireNonNull(request, "request cannot be null"); - this.response = response; - this.success = true; - this.error = null; + public ToolResponseEvent( + UUID requestId, + Map<String, ToolResponse> responses, + Map<String, Boolean> success, + Map<String, String> error, + Map<String, String> externalIds) { + this.requestId = requestId; + this.responses = responses; + this.success = success; + this.error = error; + this.externalIds = externalIds; this.timestamp = System.currentTimeMillis(); } - public ToolResponseEvent(ToolRequestEvent request, String error) { - this.request = Objects.requireNonNull(request, "request cannot be null"); - this.response = null; - this.success = false; - this.error = Objects.requireNonNull(error, "error cannot be null"); - this.timestamp = System.currentTimeMillis(); + public ToolResponseEvent( + UUID requestId, + Map<String, ToolResponse> responses, + Map<String, Boolean> success, + Map<String, String> error) { + this(requestId, responses, success, error, Map.of()); + } + + public UUID getRequestId() { + return requestId; } - public ToolRequestEvent getRequest() { - return request; + public Map<String, ToolResponse> getResponses() { + return responses; } - public Object getResponse() { - return response; + public Map<String, String> getExternalIds() { + return externalIds; } - public boolean isSuccess() { + public Map<String, Boolean> getSuccess() { return success; } - public String getError() { + public Map<String, String> getError() { return error; } @@ -67,27 +81,14 @@ public class ToolResponseEvent { @Override public String toString() { - if (success) { - return "ToolResponseEvent{" - + "request=" - + request - + ", response=" - + response - + ", success=true" - + ", timestamp=" - + timestamp - + '}'; - } else { - return "ToolResponseEvent{" - + "request=" - + request - + ", error='" - + error - + '\'' - + ", success=false" - + ", timestamp=" - + timestamp - + '}'; - } + return "ToolResponseEvent{" + + "requestId=" + + requestId + + ", response=" + + responses + + ", success=true" + + ", timestamp=" + + timestamp + + '}'; } } diff --git a/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResource.java b/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResource.java index f6d5d97..27ee17b 100644 --- a/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResource.java +++ b/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResource.java @@ -29,18 +29,18 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.event.ChatRequestEvent; +import org.apache.flink.agents.api.event.ChatResponseEvent; import org.apache.flink.agents.api.prompt.Prompt; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; -import org.apache.flink.agents.api.tools.BaseTool; -import org.apache.flink.agents.api.tools.ToolParameters; -import org.apache.flink.agents.api.tools.ToolResponse; -import java.util.Collections; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.StringJoiner; import java.util.function.BiFunction; public class AgentWithResource extends Agent { @@ -69,30 +69,49 @@ public class AgentWithResource extends Agent { @Override public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> parameters) { - Prompt prompt = (Prompt) getResource.apply((String) this.prompt, ResourceType.PROMPT); - BaseTool tool = (BaseTool) getResource.apply(this.tools.get(0), ResourceType.TOOL); - Map<String, Object> params = new HashMap<>(); - params.put("a", 1); - params.put("b", 2); - params.put("operation", "add"); - ToolParameters toolParameters = new ToolParameters(params); - ToolResponse result = tool.call(toolParameters); - String output = - String.format( - "Prompt: %s, input: %s, endpoint: %s, topP: %s, topK: %s, tool call result: %s", - prompt.formatString(new HashMap<>()), - messages.get(0).getContent(), - endpoint, - topP, - topK, - result.getResult()); - return new ChatMessage(MessageRole.ASSISTANT, output); + if (messages.size() == 1) { + Map<String, Object> toolCall = new HashMap<>(); + toolCall.put("id", "1"); + toolCall.put( + "function", + new HashMap<String, Object>() { + { + put("name", tools.get(0)); + put("arguments", Map.of("a", 1, "b", 2, "operation", "add")); + } + }); + return new ChatMessage( + MessageRole.ASSISTANT, + String.format("I will call tool %s", tools.get(0)), + List.of(toolCall)); + } else { + StringJoiner content = new StringJoiner("\n"); + content.add( + String.format("endpoint: %s, topP: %s, topK: %s", endpoint, topP, topK)); + + Map<String, String> arguments = new HashMap<>(); + for (ChatMessage message : messages) { + for (Map.Entry<String, Object> entry : message.getExtraArgs().entrySet()) { + arguments.put(entry.getKey(), entry.getValue().toString()); + } + } + Prompt prompt = + (Prompt) getResource.apply((String) this.prompt, ResourceType.PROMPT); + List<ChatMessage> formatMessages = + prompt.formatMessages(MessageRole.USER, arguments); + content.add("Prompt: " + formatMessages.get(0).getContent()); + + for (ChatMessage message : messages) { + content.add(message.getContent()); + } + return new ChatMessage(MessageRole.ASSISTANT, content.toString()); + } } } @org.apache.flink.agents.api.annotation.Prompt public static Prompt myPrompt() { - return new Prompt("This is a test prompt"); + return new Prompt("What is {a} + {b}?"); } @ChatModelSetup @@ -128,13 +147,22 @@ public class AgentWithResource extends Agent { @Action(listenEvents = {InputEvent.class}) public static void process(InputEvent event, RunnerContext ctx) throws Exception { - BaseChatModelSetup chatModel = - (BaseChatModelSetup) ctx.getResource("myChatModel", ResourceType.CHAT_MODEL); - ChatMessage response = - chatModel.chat( - Collections.singletonList( - new ChatMessage(MessageRole.USER, (String) event.getInput())), - Collections.emptyMap()); - ctx.sendEvent(new OutputEvent(response.getContent())); + Map<String, Integer> input = (Map<String, Integer>) event.getInput(); + + ChatMessage message = + new ChatMessage( + MessageRole.USER, + String.format("What is %s + %s?", input.get("a"), input.get("b")), + Map.of("a", input.get("a"), "b", input.get("b"))); + + List<ChatMessage> messages = new ArrayList<>(); + messages.add(message); + ctx.sendEvent(new ChatRequestEvent("myChatModel", messages)); + } + + @Action(listenEvents = {ChatResponseEvent.class}) + public static void output(ChatResponseEvent event, RunnerContext ctx) throws Exception { + String output = event.getResponse().getContent(); + ctx.sendEvent(new OutputEvent(output)); } } diff --git a/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResourceExample.java b/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResourceExample.java index 90338bb..8ca3c75 100644 --- a/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResourceExample.java +++ b/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResourceExample.java @@ -20,8 +20,12 @@ package org.apache.flink.agents.examples; import org.apache.flink.agents.api.AgentsExecutionEnvironment; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.DataStreamSource; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import java.util.HashMap; +import java.util.Map; + /** * Example to test MemoryObject in a complete Java Flink execution environment. This job triggers * the {@link MemoryObjectAgent} to test storing and retrieving complex data structures. @@ -33,8 +37,10 @@ public class AgentWithResourceExample { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(1); - // Use two different keys (1 and 2) to show that memory is isolated per key. - DataStream<String> inputStream = env.fromElements("This is the test input"); + Map<String, Integer> element = new HashMap<>(); + element.put("a", 1); + element.put("b", 2); + DataStreamSource<Map<String, Integer>> inputStream = env.fromElements(element); // Create agents execution environment AgentsExecutionEnvironment agentsEnv = @@ -43,7 +49,10 @@ public class AgentWithResourceExample { // Apply agent to the DataStream and use the integer itself as the key DataStream<Object> outputStream = agentsEnv - .fromDataStream(inputStream, (KeySelector<String, String>) value -> value) + .fromDataStream( + inputStream, + (KeySelector<Map<String, Integer>, Integer>) + value -> value.get("a")) .apply(new AgentWithResource()) .toDataStream(); diff --git a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java index ade3347..035b089 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java @@ -30,6 +30,8 @@ import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.SerializableResource; import org.apache.flink.agents.api.tools.ToolMetadata; import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.plan.actions.ChatModelAction; +import org.apache.flink.agents.plan.actions.ToolCallAction; import org.apache.flink.agents.plan.resourceprovider.JavaResourceProvider; import org.apache.flink.agents.plan.resourceprovider.JavaSerializableResourceProvider; import org.apache.flink.agents.plan.resourceprovider.ResourceProvider; @@ -236,7 +238,21 @@ public class AgentPlan implements Serializable { } } + private void addBuiltAction(Action action) { + // Add to actions map + actions.put(action.getName(), action); + + // Add to actionsByEvent map + for (String eventTypeName : action.getListenEventTypes()) { + actionsByEvent.computeIfAbsent(eventTypeName, k -> new ArrayList<>()).add(action); + } + } + private void extractActionsFromAgent(Agent agent) throws Exception { + // Add built-in actions + addBuiltAction(ChatModelAction.getChatModelAction()); + addBuiltAction(ToolCallAction.getToolCallAction()); + // Scan the agent class for methods annotated with @Action Class<?> agentClass = agent.getClass(); for (Method method : agentClass.getDeclaredMethods()) { diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java new file mode 100644 index 0000000..f20fdc1 --- /dev/null +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.plan.actions; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.context.MemoryObject; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.event.ChatRequestEvent; +import org.apache.flink.agents.api.event.ChatResponseEvent; +import org.apache.flink.agents.api.event.ToolRequestEvent; +import org.apache.flink.agents.api.event.ToolResponseEvent; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.apache.flink.agents.plan.JavaFunction; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** Built-in action for processing chat request and tool call result. */ +public class ChatModelAction { + private static final String TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT"; + private static final String TOOL_REQUEST_EVENT_CONTEXT = "_TOOL_REQUEST_EVENT_CONTEXT"; + private static final String INITIAL_REQUEST_ID = "initialRequestId"; + private static final String MODEL = "model"; + + public static Action getChatModelAction() throws Exception { + return new Action( + "chat_model_action", + new JavaFunction( + ChatModelAction.class, + "processChatRequestOrToolResponse", + new Class[] {Event.class, RunnerContext.class}), + List.of(ChatRequestEvent.class.getName(), ToolResponseEvent.class.getName())); + } + + /** + * Chat with chat model. + * + * <p>If there is no tool calls in chat model response, send the chat response event. Otherwise, + * generate tool request event and save the tool call context in memory. + * + * @param initialRequestId The request id of the initial chat request event. + * @param messages The chat messages as llm input. + * @param ctx The runner context this function executed in. + */ + public static void chat( + UUID initialRequestId, String model, List<ChatMessage> messages, RunnerContext ctx) + throws Exception { + BaseChatModelSetup chatModel = + (BaseChatModelSetup) ctx.getResource(model, ResourceType.CHAT_MODEL); + + ChatMessage response = chatModel.chat(messages, Map.of()); + MemoryObject stm = ctx.getShortTermMemory(); + + if (!response.getToolCalls().isEmpty()) { + Map<UUID, Object> toolCallContext; + if (stm.isExist(TOOL_CALL_CONTEXT)) { + toolCallContext = (Map<UUID, Object>) stm.get(TOOL_CALL_CONTEXT).getValue(); + } else { + toolCallContext = new HashMap<>(); + } + if (!toolCallContext.containsKey(initialRequestId)) { + toolCallContext.put(initialRequestId, messages); + } + List<ChatMessage> messageContext = + (List<ChatMessage>) toolCallContext.get(initialRequestId); + messageContext.add(response); + stm.set(TOOL_CALL_CONTEXT, toolCallContext); + + ToolRequestEvent toolRequestEvent = + new ToolRequestEvent(model, response.getToolCalls()); + + Map<UUID, Object> toolRequestEventContext; + if (stm.isExist(TOOL_REQUEST_EVENT_CONTEXT)) { + toolRequestEventContext = + (Map<UUID, Object>) stm.get(TOOL_REQUEST_EVENT_CONTEXT).getValue(); + } else { + toolRequestEventContext = new HashMap<>(); + } + toolRequestEventContext.put( + toolRequestEvent.getId(), + Map.of(INITIAL_REQUEST_ID, initialRequestId, MODEL, model)); + stm.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext); + + ctx.sendEvent(toolRequestEvent); + } else { + // clean tool call context + if (stm.isExist(TOOL_CALL_CONTEXT)) { + Map<UUID, Object> toolCallContext = + (Map<UUID, Object>) stm.get(TOOL_CALL_CONTEXT).getValue(); + if (toolCallContext.containsKey(initialRequestId)) { + toolCallContext.remove(initialRequestId); + stm.set(TOOL_CALL_CONTEXT, toolCallContext); + } + } + + ctx.sendEvent(new ChatResponseEvent(initialRequestId, response)); + } + } + + /** + * Built-in action for processing chat request and tool call result. + * + * <p>This action will listen {@link ChatRequestEvent} and send {@link ChatResponseEvent}. If + * there are tool calls in chat model response, it will send {@link ToolRequestEvent} and + * feedback the correspond {@link ToolResponseEvent} to chat model. + * + * @param event Event this action listened, must be {@link ChatRequestEvent} or {@link + * ToolResponseEvent} + * @param ctx The runner context this action executed in. + */ + @SuppressWarnings("unchecked") + public static void processChatRequestOrToolResponse(Event event, RunnerContext ctx) + throws Exception { + MemoryObject stm = ctx.getShortTermMemory(); + if (event instanceof ChatRequestEvent) { + ChatRequestEvent chatRequestEvent = (ChatRequestEvent) event; + chat( + chatRequestEvent.getId(), + chatRequestEvent.getModel(), + chatRequestEvent.getMessages(), + ctx); + } else if (event instanceof ToolResponseEvent) { + ToolResponseEvent toolResponseEvent = (ToolResponseEvent) event; + UUID toolRequestId = toolResponseEvent.getRequestId(); + // get tool request context from memory + Map<UUID, Object> toolRequestEventContext = + (Map<UUID, Object>) stm.get(TOOL_REQUEST_EVENT_CONTEXT).getValue(); + Map<String, Object> context = + (Map<String, Object>) toolRequestEventContext.get(toolRequestId); + UUID initialRequestId = (UUID) context.get(INITIAL_REQUEST_ID); + String model = (String) context.get(MODEL); + toolRequestEventContext.remove(toolRequestId); + stm.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext); + Map<String, ToolResponse> responses = toolResponseEvent.getResponses(); + Map<String, Boolean> success = toolResponseEvent.getSuccess(); + + // get tool call context + Map<UUID, Object> toolCallContext = + (Map<UUID, Object>) stm.get(TOOL_CALL_CONTEXT).getValue(); + // update tool call context + List<ChatMessage> messages = (List<ChatMessage>) toolCallContext.get(initialRequestId); + for (Map.Entry<String, ToolResponse> entry : responses.entrySet()) { + Map<String, Object> extraArgs = new HashMap<>(); + String toolCallId = entry.getKey(); + if (toolResponseEvent.getExternalIds().containsKey(toolCallId)) { + extraArgs.put("externalId", toolResponseEvent.getExternalIds().get(toolCallId)); + } + + ToolResponse response = entry.getValue(); + if (success.get(toolCallId) && response.isSuccess()) { + messages.add( + new ChatMessage( + MessageRole.TOOL, + String.valueOf(response.getResult()), + extraArgs)); + } else { + messages.add( + new ChatMessage( + MessageRole.TOOL, + String.valueOf(response.getError()), + extraArgs)); + } + } + // overwrite tool call context + stm.set(TOOL_CALL_CONTEXT, toolCallContext); + + chat(initialRequestId, model, messages, ctx); + } else { + throw new RuntimeException(String.format("Unexpected type event %s", event)); + } + } +} diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java new file mode 100644 index 0000000..7eafd3c --- /dev/null +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.plan.actions; + +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.event.ToolRequestEvent; +import org.apache.flink.agents.api.event.ToolResponseEvent; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.tools.BaseTool; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.apache.flink.agents.plan.JavaFunction; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Built-in action for processing tool call. */ +public class ToolCallAction { + public static Action getToolCallAction() throws Exception { + return new Action( + "tool_call_action", + new JavaFunction( + ToolCallAction.class, + "processToolRequest", + new Class[] {ToolRequestEvent.class, RunnerContext.class}), + List.of(ToolRequestEvent.class.getName())); + } + + @SuppressWarnings("unchecked") + public static void processToolRequest(ToolRequestEvent event, RunnerContext ctx) { + Map<String, Boolean> success = new HashMap<>(); + Map<String, String> error = new HashMap<>(); + Map<String, ToolResponse> responses = new HashMap<>(); + Map<String, String> externalIds = new HashMap<>(); + for (Map<String, Object> toolCall : event.getToolCalls()) { + String id = String.valueOf(toolCall.get("id")); + Map<String, Object> function = (Map<String, Object>) toolCall.get("function"); + String name = (String) function.get("name"); + Map<String, Object> arguments = (Map<String, Object>) function.get("arguments"); + + if (toolCall.containsKey("original_id")) { + externalIds.put(id, (String) toolCall.get("original_id")); + } + + BaseTool tool = null; + try { + tool = (BaseTool) ctx.getResource(name, ResourceType.TOOL); + } catch (Exception e) { + success.put(id, false); + responses.put( + id, ToolResponse.error(String.format("Tool %s does not exist.", name))); + error.put(id, e.getMessage()); + } + + if (tool != null) { + try { + ToolResponse response = tool.call(new ToolParameters(arguments)); + success.put(id, true); + responses.put(id, response); + } catch (Exception e) { + success.put(id, false); + responses.put( + id, ToolResponse.error(String.format("Tool %s execute failed.", name))); + error.put(id, e.getMessage()); + } + } + } + ctx.sendEvent(new ToolResponseEvent(event.getId(), responses, success, error, externalIds)); + } +} diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java index 2a5d024..e61dbb9 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java @@ -137,7 +137,7 @@ public class AgentPlanTest { AgentPlan agentPlan = new AgentPlan(agent); // Verify that actions were collected correctly - assertThat(agentPlan.getActions().size()).isEqualTo(2); + assertThat(agentPlan.getActions().size()).isEqualTo(4); assertThat(agentPlan.getActions()).containsKey("handleInputEvent"); assertThat(agentPlan.getActions()).containsKey("handleMultipleEvents"); @@ -164,7 +164,7 @@ public class AgentPlanTest { assertThat(multiAction.getExec()).isInstanceOf(JavaFunction.class); // Verify actionsByEvent mapping - assertThat(agentPlan.getActionsByEvent().size()).isEqualTo(3); + assertThat(agentPlan.getActionsByEvent().size()).isEqualTo(6); // Check InputEvent mapping List<Action> inputEventActions = @@ -199,8 +199,8 @@ public class AgentPlanTest { AgentPlan agentPlan = new AgentPlan(emptyAgent); // Verify that no actions were collected - assertThat(agentPlan.getActions().size()).isEqualTo(0); - assertThat(agentPlan.getActionsByEvent().size()).isEqualTo(0); + assertThat(agentPlan.getActions().size()).isEqualTo(2); + assertThat(agentPlan.getActionsByEvent().size()).isEqualTo(3); } @Test diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index 656453d..967d0d9 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -115,10 +115,6 @@ def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) -> N """ short_term_memory = ctx.get_short_term_memory() if isinstance(event, ChatRequestEvent): - cast( - "BaseChatModelSetup", ctx.get_resource(event.model, ResourceType.CHAT_MODEL) - ) - chat( initial_request_id=event.id, model=event.model, diff --git a/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py b/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py index 34274c9..ff940d8 100644 --- a/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py +++ b/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py @@ -32,7 +32,7 @@ if __name__ == "__main__": agent_plan = AgentPlan.model_validate_json(java_plan_json) actions = agent_plan.actions - assert len(actions) == 2 + assert len(actions) == 4 event = "org.apache.flink.agents.api.Event" input_event = "org.apache.flink.agents.api.InputEvent" @@ -72,7 +72,7 @@ if __name__ == "__main__": # check actions_by_event actions_by_event = agent_plan.actions_by_event - assert len(actions_by_event) == 2 + assert len(actions_by_event) == 5 assert input_event in actions_by_event assert sorted(actions_by_event[input_event]) == ["firstAction", "secondAction"]
