This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 56dfe770f2a1ab32c12229b02239486f7053316b
Author: WenjinXie <[email protected]>
AuthorDate: Fri Sep 19 14:58:12 2025 +0800

    [plan] Add built-in chat model action and tool call action in java.
    
    add license.
---
 .../agents/api/chat/messages/ChatMessage.java      |  30 ++--
 .../flink/agents/api/event/ChatRequestEvent.java   |  42 +++++
 .../flink/agents/api/event/ChatResponseEvent.java  |  42 +++++
 .../flink/agents/api/event/ToolRequestEvent.java   |  33 ++--
 .../flink/agents/api/event/ToolResponseEvent.java  |  93 +++++-----
 .../flink/agents/examples/AgentWithResource.java   |  90 ++++++----
 .../agents/examples/AgentWithResourceExample.java  |  15 +-
 .../org/apache/flink/agents/plan/AgentPlan.java    |  16 ++
 .../flink/agents/plan/actions/ChatModelAction.java | 193 +++++++++++++++++++++
 .../flink/agents/plan/actions/ToolCallAction.java  |  86 +++++++++
 .../apache/flink/agents/plan/AgentPlanTest.java    |   8 +-
 .../flink_agents/plan/actions/chat_model_action.py |   4 -
 .../create_python_agent_plan_from_json.py          |   4 +-
 13 files changed, 530 insertions(+), 126 deletions(-)

diff --git 
a/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java 
b/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java
index b00d802..7ebc787 100644
--- 
a/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java
+++ 
b/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.agents.api.chat.messages;
 
-import org.apache.flink.agents.api.resource.Resource;
 import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnore;
 
 import java.util.ArrayList;
@@ -43,19 +42,20 @@ public class ChatMessage {
 
     /** Default constructor with SYSTEM role */
     public ChatMessage() {
-        this.role = MessageRole.SYSTEM;
-        this.content = "";
-        this.toolCalls = new ArrayList<>();
-        this.extraArgs = new HashMap<>();
+        this(MessageRole.SYSTEM, null, null, null);
     }
 
     /** Constructor with role and content */
     public ChatMessage(MessageRole role, String content) {
-        this.role = role != null ? role : MessageRole.SYSTEM;
-        this.content = content != null ? content : "";
-        this.toolCalls = new ArrayList<>();
-        this.extraArgs = new HashMap<>();
-        this.extraArgs.put(MESSAGE_TYPE, this.role);
+        this(role, content, null, null);
+    }
+
+    public ChatMessage(MessageRole role, String content, Map<String, Object> 
extraArgs) {
+        this(role, content, null, extraArgs);
+    }
+
+    public ChatMessage(MessageRole role, String content, List<Map<String, 
Object>> toolCalls) {
+        this(role, content, toolCalls, null);
     }
 
     /** Full constructor */
@@ -71,16 +71,6 @@ public class ChatMessage {
         this.extraArgs.put(MESSAGE_TYPE, this.role);
     }
 
-    /** Constructor with resource */
-    public ChatMessage(MessageRole role, Resource resource, Map<String, 
Object> extraArgs) {
-        if (resource == null) throw new IllegalArgumentException("resource 
must not be null");
-        // TODO handle resource content properly
-        this.role = role != null ? role : MessageRole.SYSTEM;
-        this.toolCalls = new ArrayList<>();
-        this.extraArgs = extraArgs != null ? new HashMap<>(extraArgs) : new 
HashMap<>();
-        this.extraArgs.put(MESSAGE_TYPE, this.role);
-    }
-
     public MessageRole getRole() {
         return role;
     }
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java 
b/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java
new file mode 100644
index 0000000..e5277d3
--- /dev/null
+++ b/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api.event;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+
+import java.util.List;
+
+public class ChatRequestEvent extends Event {
+    private final String model;
+    private final List<ChatMessage> messages;
+
+    public ChatRequestEvent(String model, List<ChatMessage> messages) {
+        this.model = model;
+        this.messages = messages;
+    }
+
+    public String getModel() {
+        return model;
+    }
+
+    public List<ChatMessage> getMessages() {
+        return messages;
+    }
+}
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java 
b/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java
new file mode 100644
index 0000000..9041e5a
--- /dev/null
+++ b/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api.event;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+
+import java.util.UUID;
+
+public class ChatResponseEvent extends Event {
+    private final UUID requestId;
+    private final ChatMessage response;
+
+    public ChatResponseEvent(UUID requestId, ChatMessage response) {
+        this.requestId = requestId;
+        this.response = response;
+    }
+
+    public UUID getRequestId() {
+        return requestId;
+    }
+
+    public ChatMessage getResponse() {
+        return response;
+    }
+}
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/event/ToolRequestEvent.java 
b/api/src/main/java/org/apache/flink/agents/api/event/ToolRequestEvent.java
index e39db7f..889c525 100644
--- a/api/src/main/java/org/apache/flink/agents/api/event/ToolRequestEvent.java
+++ b/api/src/main/java/org/apache/flink/agents/api/event/ToolRequestEvent.java
@@ -18,28 +18,29 @@
 
 package org.apache.flink.agents.api.event;
 
+import org.apache.flink.agents.api.Event;
+
+import java.util.List;
 import java.util.Map;
-import java.util.Objects;
 
 /** Event representing a tool call request */
-public class ToolRequestEvent {
-
-    private final String tool;
-    private final Map<String, Object> kwargs;
+public class ToolRequestEvent extends Event {
+    private final String model;
+    private final List<Map<String, Object>> toolCalls;
     private final long timestamp;
 
-    public ToolRequestEvent(String tool, Map<String, Object> kwargs) {
-        this.tool = Objects.requireNonNull(tool, "tool name cannot be null");
-        this.kwargs = Objects.requireNonNull(kwargs, "kwargs cannot be null");
+    public ToolRequestEvent(String model, List<Map<String, Object>> toolCalls) 
{
+        this.model = model;
+        this.toolCalls = toolCalls;
         this.timestamp = System.currentTimeMillis();
     }
 
-    public String getTool() {
-        return tool;
+    public String getModel() {
+        return model;
     }
 
-    public Map<String, Object> getKwargs() {
-        return kwargs;
+    public List<Map<String, Object>> getToolCalls() {
+        return toolCalls;
     }
 
     public long getTimestamp() {
@@ -49,11 +50,11 @@ public class ToolRequestEvent {
     @Override
     public String toString() {
         return "ToolRequestEvent{"
-                + "tool='"
-                + tool
+                + "model='"
+                + model
                 + '\''
-                + ", kwargs="
-                + kwargs
+                + ", toolCalls="
+                + toolCalls
                 + ", timestamp="
                 + timestamp
                 + '}';
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/event/ToolResponseEvent.java 
b/api/src/main/java/org/apache/flink/agents/api/event/ToolResponseEvent.java
index 5115c69..c4d874d 100644
--- a/api/src/main/java/org/apache/flink/agents/api/event/ToolResponseEvent.java
+++ b/api/src/main/java/org/apache/flink/agents/api/event/ToolResponseEvent.java
@@ -18,46 +18,60 @@
 
 package org.apache.flink.agents.api.event;
 
-import java.util.Objects;
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.api.tools.ToolResponse;
 
-/** Event representing a result from tool call */
-public class ToolResponseEvent {
+import java.util.Map;
+import java.util.UUID;
 
-    private final ToolRequestEvent request;
-    private final Object response;
-    private final boolean success;
-    private final String error;
+/** Event representing a result from tool call */
+public class ToolResponseEvent extends Event {
+    private final UUID requestId;
+    private final Map<String, ToolResponse> responses;
+    private final Map<String, String> externalIds;
+    private final Map<String, Boolean> success;
+    private final Map<String, String> error;
     private final long timestamp;
 
-    public ToolResponseEvent(ToolRequestEvent request, Object response) {
-        this.request = Objects.requireNonNull(request, "request cannot be 
null");
-        this.response = response;
-        this.success = true;
-        this.error = null;
+    public ToolResponseEvent(
+            UUID requestId,
+            Map<String, ToolResponse> responses,
+            Map<String, Boolean> success,
+            Map<String, String> error,
+            Map<String, String> externalIds) {
+        this.requestId = requestId;
+        this.responses = responses;
+        this.success = success;
+        this.error = error;
+        this.externalIds = externalIds;
         this.timestamp = System.currentTimeMillis();
     }
 
-    public ToolResponseEvent(ToolRequestEvent request, String error) {
-        this.request = Objects.requireNonNull(request, "request cannot be 
null");
-        this.response = null;
-        this.success = false;
-        this.error = Objects.requireNonNull(error, "error cannot be null");
-        this.timestamp = System.currentTimeMillis();
+    public ToolResponseEvent(
+            UUID requestId,
+            Map<String, ToolResponse> responses,
+            Map<String, Boolean> success,
+            Map<String, String> error) {
+        this(requestId, responses, success, error, Map.of());
+    }
+
+    public UUID getRequestId() {
+        return requestId;
     }
 
-    public ToolRequestEvent getRequest() {
-        return request;
+    public Map<String, ToolResponse> getResponses() {
+        return responses;
     }
 
-    public Object getResponse() {
-        return response;
+    public Map<String, String> getExternalIds() {
+        return externalIds;
     }
 
-    public boolean isSuccess() {
+    public Map<String, Boolean> getSuccess() {
         return success;
     }
 
-    public String getError() {
+    public Map<String, String> getError() {
         return error;
     }
 
@@ -67,27 +81,14 @@ public class ToolResponseEvent {
 
     @Override
     public String toString() {
-        if (success) {
-            return "ToolResponseEvent{"
-                    + "request="
-                    + request
-                    + ", response="
-                    + response
-                    + ", success=true"
-                    + ", timestamp="
-                    + timestamp
-                    + '}';
-        } else {
-            return "ToolResponseEvent{"
-                    + "request="
-                    + request
-                    + ", error='"
-                    + error
-                    + '\''
-                    + ", success=false"
-                    + ", timestamp="
-                    + timestamp
-                    + '}';
-        }
+        return "ToolResponseEvent{"
+                + "requestId="
+                + requestId
+                + ", response="
+                + responses
+                + ", success=true"
+                + ", timestamp="
+                + timestamp
+                + '}';
     }
 }
diff --git 
a/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResource.java
 
b/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResource.java
index f6d5d97..27ee17b 100644
--- 
a/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResource.java
+++ 
b/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResource.java
@@ -29,18 +29,18 @@ import 
org.apache.flink.agents.api.chat.messages.ChatMessage;
 import org.apache.flink.agents.api.chat.messages.MessageRole;
 import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
 import org.apache.flink.agents.api.context.RunnerContext;
+import org.apache.flink.agents.api.event.ChatRequestEvent;
+import org.apache.flink.agents.api.event.ChatResponseEvent;
 import org.apache.flink.agents.api.prompt.Prompt;
 import org.apache.flink.agents.api.resource.Resource;
 import org.apache.flink.agents.api.resource.ResourceDescriptor;
 import org.apache.flink.agents.api.resource.ResourceType;
-import org.apache.flink.agents.api.tools.BaseTool;
-import org.apache.flink.agents.api.tools.ToolParameters;
-import org.apache.flink.agents.api.tools.ToolResponse;
 
-import java.util.Collections;
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.StringJoiner;
 import java.util.function.BiFunction;
 
 public class AgentWithResource extends Agent {
@@ -69,30 +69,49 @@ public class AgentWithResource extends Agent {
 
         @Override
         public ChatMessage chat(List<ChatMessage> messages, Map<String, 
Object> parameters) {
-            Prompt prompt = (Prompt) getResource.apply((String) this.prompt, 
ResourceType.PROMPT);
-            BaseTool tool = (BaseTool) getResource.apply(this.tools.get(0), 
ResourceType.TOOL);
-            Map<String, Object> params = new HashMap<>();
-            params.put("a", 1);
-            params.put("b", 2);
-            params.put("operation", "add");
-            ToolParameters toolParameters = new ToolParameters(params);
-            ToolResponse result = tool.call(toolParameters);
-            String output =
-                    String.format(
-                            "Prompt: %s, input: %s, endpoint: %s, topP: %s, 
topK: %s, tool call result: %s",
-                            prompt.formatString(new HashMap<>()),
-                            messages.get(0).getContent(),
-                            endpoint,
-                            topP,
-                            topK,
-                            result.getResult());
-            return new ChatMessage(MessageRole.ASSISTANT, output);
+            if (messages.size() == 1) {
+                Map<String, Object> toolCall = new HashMap<>();
+                toolCall.put("id", "1");
+                toolCall.put(
+                        "function",
+                        new HashMap<String, Object>() {
+                            {
+                                put("name", tools.get(0));
+                                put("arguments", Map.of("a", 1, "b", 2, 
"operation", "add"));
+                            }
+                        });
+                return new ChatMessage(
+                        MessageRole.ASSISTANT,
+                        String.format("I will call tool %s", tools.get(0)),
+                        List.of(toolCall));
+            } else {
+                StringJoiner content = new StringJoiner("\n");
+                content.add(
+                        String.format("endpoint: %s, topP: %s, topK: %s", 
endpoint, topP, topK));
+
+                Map<String, String> arguments = new HashMap<>();
+                for (ChatMessage message : messages) {
+                    for (Map.Entry<String, Object> entry : 
message.getExtraArgs().entrySet()) {
+                        arguments.put(entry.getKey(), 
entry.getValue().toString());
+                    }
+                }
+                Prompt prompt =
+                        (Prompt) getResource.apply((String) this.prompt, 
ResourceType.PROMPT);
+                List<ChatMessage> formatMessages =
+                        prompt.formatMessages(MessageRole.USER, arguments);
+                content.add("Prompt: " + formatMessages.get(0).getContent());
+
+                for (ChatMessage message : messages) {
+                    content.add(message.getContent());
+                }
+                return new ChatMessage(MessageRole.ASSISTANT, 
content.toString());
+            }
         }
     }
 
     @org.apache.flink.agents.api.annotation.Prompt
     public static Prompt myPrompt() {
-        return new Prompt("This is a test prompt");
+        return new Prompt("What is {a} + {b}?");
     }
 
     @ChatModelSetup
@@ -128,13 +147,22 @@ public class AgentWithResource extends Agent {
 
     @Action(listenEvents = {InputEvent.class})
     public static void process(InputEvent event, RunnerContext ctx) throws 
Exception {
-        BaseChatModelSetup chatModel =
-                (BaseChatModelSetup) ctx.getResource("myChatModel", 
ResourceType.CHAT_MODEL);
-        ChatMessage response =
-                chatModel.chat(
-                        Collections.singletonList(
-                                new ChatMessage(MessageRole.USER, (String) 
event.getInput())),
-                        Collections.emptyMap());
-        ctx.sendEvent(new OutputEvent(response.getContent()));
+        Map<String, Integer> input = (Map<String, Integer>) event.getInput();
+
+        ChatMessage message =
+                new ChatMessage(
+                        MessageRole.USER,
+                        String.format("What is %s + %s?", input.get("a"), 
input.get("b")),
+                        Map.of("a", input.get("a"), "b", input.get("b")));
+
+        List<ChatMessage> messages = new ArrayList<>();
+        messages.add(message);
+        ctx.sendEvent(new ChatRequestEvent("myChatModel", messages));
+    }
+
+    @Action(listenEvents = {ChatResponseEvent.class})
+    public static void output(ChatResponseEvent event, RunnerContext ctx) 
throws Exception {
+        String output = event.getResponse().getContent();
+        ctx.sendEvent(new OutputEvent(output));
     }
 }
diff --git 
a/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResourceExample.java
 
b/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResourceExample.java
index 90338bb..8ca3c75 100644
--- 
a/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResourceExample.java
+++ 
b/examples/src/main/java/org/apache/flink/agents/examples/AgentWithResourceExample.java
@@ -20,8 +20,12 @@ package org.apache.flink.agents.examples;
 import org.apache.flink.agents.api.AgentsExecutionEnvironment;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamSource;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 
+import java.util.HashMap;
+import java.util.Map;
+
 /**
  * Example to test MemoryObject in a complete Java Flink execution 
environment. This job triggers
  * the {@link MemoryObjectAgent} to test storing and retrieving complex data 
structures.
@@ -33,8 +37,10 @@ public class AgentWithResourceExample {
         StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
         env.setParallelism(1);
 
-        // Use two different keys (1 and 2) to show that memory is isolated 
per key.
-        DataStream<String> inputStream = env.fromElements("This is the test 
input");
+        Map<String, Integer> element = new HashMap<>();
+        element.put("a", 1);
+        element.put("b", 2);
+        DataStreamSource<Map<String, Integer>> inputStream = 
env.fromElements(element);
 
         // Create agents execution environment
         AgentsExecutionEnvironment agentsEnv =
@@ -43,7 +49,10 @@ public class AgentWithResourceExample {
         // Apply agent to the DataStream and use the integer itself as the key
         DataStream<Object> outputStream =
                 agentsEnv
-                        .fromDataStream(inputStream, (KeySelector<String, 
String>) value -> value)
+                        .fromDataStream(
+                                inputStream,
+                                (KeySelector<Map<String, Integer>, Integer>)
+                                        value -> value.get("a"))
                         .apply(new AgentWithResource())
                         .toDataStream();
 
diff --git a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java 
b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java
index ade3347..035b089 100644
--- a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java
+++ b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java
@@ -30,6 +30,8 @@ import org.apache.flink.agents.api.resource.ResourceType;
 import org.apache.flink.agents.api.resource.SerializableResource;
 import org.apache.flink.agents.api.tools.ToolMetadata;
 import org.apache.flink.agents.plan.actions.Action;
+import org.apache.flink.agents.plan.actions.ChatModelAction;
+import org.apache.flink.agents.plan.actions.ToolCallAction;
 import org.apache.flink.agents.plan.resourceprovider.JavaResourceProvider;
 import 
org.apache.flink.agents.plan.resourceprovider.JavaSerializableResourceProvider;
 import org.apache.flink.agents.plan.resourceprovider.ResourceProvider;
@@ -236,7 +238,21 @@ public class AgentPlan implements Serializable {
         }
     }
 
+    private void addBuiltAction(Action action) {
+        // Add to actions map
+        actions.put(action.getName(), action);
+
+        // Add to actionsByEvent map
+        for (String eventTypeName : action.getListenEventTypes()) {
+            actionsByEvent.computeIfAbsent(eventTypeName, k -> new 
ArrayList<>()).add(action);
+        }
+    }
+
     private void extractActionsFromAgent(Agent agent) throws Exception {
+        // Add built-in actions
+        addBuiltAction(ChatModelAction.getChatModelAction());
+        addBuiltAction(ToolCallAction.getToolCallAction());
+
         // Scan the agent class for methods annotated with @Action
         Class<?> agentClass = agent.getClass();
         for (Method method : agentClass.getDeclaredMethods()) {
diff --git 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
new file mode 100644
index 0000000..f20fdc1
--- /dev/null
+++ 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.plan.actions;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
+import org.apache.flink.agents.api.context.MemoryObject;
+import org.apache.flink.agents.api.context.RunnerContext;
+import org.apache.flink.agents.api.event.ChatRequestEvent;
+import org.apache.flink.agents.api.event.ChatResponseEvent;
+import org.apache.flink.agents.api.event.ToolRequestEvent;
+import org.apache.flink.agents.api.event.ToolResponseEvent;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.agents.api.tools.ToolResponse;
+import org.apache.flink.agents.plan.JavaFunction;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+
+/** Built-in action for processing chat request and tool call result. */
+public class ChatModelAction {
+    private static final String TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT";
+    private static final String TOOL_REQUEST_EVENT_CONTEXT = 
"_TOOL_REQUEST_EVENT_CONTEXT";
+    private static final String INITIAL_REQUEST_ID = "initialRequestId";
+    private static final String MODEL = "model";
+
+    public static Action getChatModelAction() throws Exception {
+        return new Action(
+                "chat_model_action",
+                new JavaFunction(
+                        ChatModelAction.class,
+                        "processChatRequestOrToolResponse",
+                        new Class[] {Event.class, RunnerContext.class}),
+                List.of(ChatRequestEvent.class.getName(), 
ToolResponseEvent.class.getName()));
+    }
+
+    /**
+     * Chat with chat model.
+     *
+     * <p>If there is no tool calls in chat model response, send the chat 
response event. Otherwise,
+     * generate tool request event and save the tool call context in memory.
+     *
+     * @param initialRequestId The request id of the initial chat request 
event.
+     * @param messages The chat messages as llm input.
+     * @param ctx The runner context this function executed in.
+     */
+    public static void chat(
+            UUID initialRequestId, String model, List<ChatMessage> messages, 
RunnerContext ctx)
+            throws Exception {
+        BaseChatModelSetup chatModel =
+                (BaseChatModelSetup) ctx.getResource(model, 
ResourceType.CHAT_MODEL);
+
+        ChatMessage response = chatModel.chat(messages, Map.of());
+        MemoryObject stm = ctx.getShortTermMemory();
+
+        if (!response.getToolCalls().isEmpty()) {
+            Map<UUID, Object> toolCallContext;
+            if (stm.isExist(TOOL_CALL_CONTEXT)) {
+                toolCallContext = (Map<UUID, Object>) 
stm.get(TOOL_CALL_CONTEXT).getValue();
+            } else {
+                toolCallContext = new HashMap<>();
+            }
+            if (!toolCallContext.containsKey(initialRequestId)) {
+                toolCallContext.put(initialRequestId, messages);
+            }
+            List<ChatMessage> messageContext =
+                    (List<ChatMessage>) toolCallContext.get(initialRequestId);
+            messageContext.add(response);
+            stm.set(TOOL_CALL_CONTEXT, toolCallContext);
+
+            ToolRequestEvent toolRequestEvent =
+                    new ToolRequestEvent(model, response.getToolCalls());
+
+            Map<UUID, Object> toolRequestEventContext;
+            if (stm.isExist(TOOL_REQUEST_EVENT_CONTEXT)) {
+                toolRequestEventContext =
+                        (Map<UUID, Object>) 
stm.get(TOOL_REQUEST_EVENT_CONTEXT).getValue();
+            } else {
+                toolRequestEventContext = new HashMap<>();
+            }
+            toolRequestEventContext.put(
+                    toolRequestEvent.getId(),
+                    Map.of(INITIAL_REQUEST_ID, initialRequestId, MODEL, 
model));
+            stm.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext);
+
+            ctx.sendEvent(toolRequestEvent);
+        } else {
+            // clean tool call context
+            if (stm.isExist(TOOL_CALL_CONTEXT)) {
+                Map<UUID, Object> toolCallContext =
+                        (Map<UUID, Object>) 
stm.get(TOOL_CALL_CONTEXT).getValue();
+                if (toolCallContext.containsKey(initialRequestId)) {
+                    toolCallContext.remove(initialRequestId);
+                    stm.set(TOOL_CALL_CONTEXT, toolCallContext);
+                }
+            }
+
+            ctx.sendEvent(new ChatResponseEvent(initialRequestId, response));
+        }
+    }
+
+    /**
+     * Built-in action for processing chat request and tool call result.
+     *
+     * <p>This action will listen {@link ChatRequestEvent} and send {@link 
ChatResponseEvent}. If
+     * there are tool calls in chat model response, it will send {@link 
ToolRequestEvent} and
+     * feedback the correspond {@link ToolResponseEvent} to chat model.
+     *
+     * @param event Event this action listened, must be {@link 
ChatRequestEvent} or {@link
+     *     ToolResponseEvent}
+     * @param ctx The runner context this action executed in.
+     */
+    @SuppressWarnings("unchecked")
+    public static void processChatRequestOrToolResponse(Event event, 
RunnerContext ctx)
+            throws Exception {
+        MemoryObject stm = ctx.getShortTermMemory();
+        if (event instanceof ChatRequestEvent) {
+            ChatRequestEvent chatRequestEvent = (ChatRequestEvent) event;
+            chat(
+                    chatRequestEvent.getId(),
+                    chatRequestEvent.getModel(),
+                    chatRequestEvent.getMessages(),
+                    ctx);
+        } else if (event instanceof ToolResponseEvent) {
+            ToolResponseEvent toolResponseEvent = (ToolResponseEvent) event;
+            UUID toolRequestId = toolResponseEvent.getRequestId();
+            // get tool request context from memory
+            Map<UUID, Object> toolRequestEventContext =
+                    (Map<UUID, Object>) 
stm.get(TOOL_REQUEST_EVENT_CONTEXT).getValue();
+            Map<String, Object> context =
+                    (Map<String, Object>) 
toolRequestEventContext.get(toolRequestId);
+            UUID initialRequestId = (UUID) context.get(INITIAL_REQUEST_ID);
+            String model = (String) context.get(MODEL);
+            toolRequestEventContext.remove(toolRequestId);
+            stm.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext);
+            Map<String, ToolResponse> responses = 
toolResponseEvent.getResponses();
+            Map<String, Boolean> success = toolResponseEvent.getSuccess();
+
+            // get tool call context
+            Map<UUID, Object> toolCallContext =
+                    (Map<UUID, Object>) stm.get(TOOL_CALL_CONTEXT).getValue();
+            // update tool call context
+            List<ChatMessage> messages = (List<ChatMessage>) 
toolCallContext.get(initialRequestId);
+            for (Map.Entry<String, ToolResponse> entry : responses.entrySet()) 
{
+                Map<String, Object> extraArgs = new HashMap<>();
+                String toolCallId = entry.getKey();
+                if 
(toolResponseEvent.getExternalIds().containsKey(toolCallId)) {
+                    extraArgs.put("externalId", 
toolResponseEvent.getExternalIds().get(toolCallId));
+                }
+
+                ToolResponse response = entry.getValue();
+                if (success.get(toolCallId) && response.isSuccess()) {
+                    messages.add(
+                            new ChatMessage(
+                                    MessageRole.TOOL,
+                                    String.valueOf(response.getResult()),
+                                    extraArgs));
+                } else {
+                    messages.add(
+                            new ChatMessage(
+                                    MessageRole.TOOL,
+                                    String.valueOf(response.getError()),
+                                    extraArgs));
+                }
+            }
+            // overwrite tool call context
+            stm.set(TOOL_CALL_CONTEXT, toolCallContext);
+
+            chat(initialRequestId, model, messages, ctx);
+        } else {
+            throw new RuntimeException(String.format("Unexpected type event 
%s", event));
+        }
+    }
+}
diff --git 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java
new file mode 100644
index 0000000..7eafd3c
--- /dev/null
+++ 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.plan.actions;
+
+import org.apache.flink.agents.api.context.RunnerContext;
+import org.apache.flink.agents.api.event.ToolRequestEvent;
+import org.apache.flink.agents.api.event.ToolResponseEvent;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.agents.api.tools.BaseTool;
+import org.apache.flink.agents.api.tools.ToolParameters;
+import org.apache.flink.agents.api.tools.ToolResponse;
+import org.apache.flink.agents.plan.JavaFunction;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** Built-in action for processing tool call. */
+public class ToolCallAction {
+    public static Action getToolCallAction() throws Exception {
+        return new Action(
+                "tool_call_action",
+                new JavaFunction(
+                        ToolCallAction.class,
+                        "processToolRequest",
+                        new Class[] {ToolRequestEvent.class, 
RunnerContext.class}),
+                List.of(ToolRequestEvent.class.getName()));
+    }
+
+    @SuppressWarnings("unchecked")
+    public static void processToolRequest(ToolRequestEvent event, 
RunnerContext ctx) {
+        Map<String, Boolean> success = new HashMap<>();
+        Map<String, String> error = new HashMap<>();
+        Map<String, ToolResponse> responses = new HashMap<>();
+        Map<String, String> externalIds = new HashMap<>();
+        for (Map<String, Object> toolCall : event.getToolCalls()) {
+            String id = String.valueOf(toolCall.get("id"));
+            Map<String, Object> function = (Map<String, Object>) 
toolCall.get("function");
+            String name = (String) function.get("name");
+            Map<String, Object> arguments = (Map<String, Object>) 
function.get("arguments");
+
+            if (toolCall.containsKey("original_id")) {
+                externalIds.put(id, (String) toolCall.get("original_id"));
+            }
+
+            BaseTool tool = null;
+            try {
+                tool = (BaseTool) ctx.getResource(name, ResourceType.TOOL);
+            } catch (Exception e) {
+                success.put(id, false);
+                responses.put(
+                        id, ToolResponse.error(String.format("Tool %s does not 
exist.", name)));
+                error.put(id, e.getMessage());
+            }
+
+            if (tool != null) {
+                try {
+                    ToolResponse response = tool.call(new 
ToolParameters(arguments));
+                    success.put(id, true);
+                    responses.put(id, response);
+                } catch (Exception e) {
+                    success.put(id, false);
+                    responses.put(
+                            id, ToolResponse.error(String.format("Tool %s 
execute failed.", name)));
+                    error.put(id, e.getMessage());
+                }
+            }
+        }
+        ctx.sendEvent(new ToolResponseEvent(event.getId(), responses, success, 
error, externalIds));
+    }
+}
diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java 
b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java
index 2a5d024..e61dbb9 100644
--- a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java
+++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java
@@ -137,7 +137,7 @@ public class AgentPlanTest {
         AgentPlan agentPlan = new AgentPlan(agent);
 
         // Verify that actions were collected correctly
-        assertThat(agentPlan.getActions().size()).isEqualTo(2);
+        assertThat(agentPlan.getActions().size()).isEqualTo(4);
         assertThat(agentPlan.getActions()).containsKey("handleInputEvent");
         assertThat(agentPlan.getActions()).containsKey("handleMultipleEvents");
 
@@ -164,7 +164,7 @@ public class AgentPlanTest {
         assertThat(multiAction.getExec()).isInstanceOf(JavaFunction.class);
 
         // Verify actionsByEvent mapping
-        assertThat(agentPlan.getActionsByEvent().size()).isEqualTo(3);
+        assertThat(agentPlan.getActionsByEvent().size()).isEqualTo(6);
 
         // Check InputEvent mapping
         List<Action> inputEventActions =
@@ -199,8 +199,8 @@ public class AgentPlanTest {
         AgentPlan agentPlan = new AgentPlan(emptyAgent);
 
         // Verify that no actions were collected
-        assertThat(agentPlan.getActions().size()).isEqualTo(0);
-        assertThat(agentPlan.getActionsByEvent().size()).isEqualTo(0);
+        assertThat(agentPlan.getActions().size()).isEqualTo(2);
+        assertThat(agentPlan.getActionsByEvent().size()).isEqualTo(3);
     }
 
     @Test
diff --git a/python/flink_agents/plan/actions/chat_model_action.py 
b/python/flink_agents/plan/actions/chat_model_action.py
index 656453d..967d0d9 100644
--- a/python/flink_agents/plan/actions/chat_model_action.py
+++ b/python/flink_agents/plan/actions/chat_model_action.py
@@ -115,10 +115,6 @@ def process_chat_request_or_tool_response(event: Event, 
ctx: RunnerContext) -> N
     """
     short_term_memory = ctx.get_short_term_memory()
     if isinstance(event, ChatRequestEvent):
-        cast(
-            "BaseChatModelSetup", ctx.get_resource(event.model, 
ResourceType.CHAT_MODEL)
-        )
-
         chat(
             initial_request_id=event.id,
             model=event.model,
diff --git 
a/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py
 
b/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py
index 34274c9..ff940d8 100644
--- 
a/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py
+++ 
b/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py
@@ -32,7 +32,7 @@ if __name__ == "__main__":
     agent_plan = AgentPlan.model_validate_json(java_plan_json)
     actions = agent_plan.actions
 
-    assert len(actions) == 2
+    assert len(actions) == 4
 
     event = "org.apache.flink.agents.api.Event"
     input_event = "org.apache.flink.agents.api.InputEvent"
@@ -72,7 +72,7 @@ if __name__ == "__main__":
 
     # check actions_by_event
     actions_by_event = agent_plan.actions_by_event
-    assert len(actions_by_event) == 2
+    assert len(actions_by_event) == 5
 
     assert input_event in actions_by_event
     assert sorted(actions_by_event[input_event]) == ["firstAction", 
"secondAction"]


Reply via email to