This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 4012835456310317273ce4c4c61fdc49af213ca9
Author: WenjinXie <[email protected]>
AuthorDate: Fri Dec 26 11:12:45 2025 +0800

    [plan] Support error handling strategy for chat action.
    
    fix
    
    [plan] Support error handling strategy for chat action in java.
    
    modify default max retries
    
    not manually clear sensory memory.
    
    fix
    
    fix
---
 .../org/apache/flink/agents/api/agents/Agent.java  |  18 +++
 ...nfigOptions.java => AgentExecutionOptions.java} |  13 +-
 .../flink/agents/api/agents/OutputSchema.java      | 134 ++++++++++++++++
 .../apache/flink/agents/api/agents/ReActAgent.java | 169 ++-------------------
 .../flink/agents/api/event/ChatRequestEvent.java   |  17 ++-
 .../flink/agents/api/agents/ReActAgentTest.java    |   5 +-
 .../agents/integration/test/ReActAgentTest.java    |  10 +-
 plan/pom.xml                                       |  16 ++
 .../flink/agents/plan/actions/ChatModelAction.java | 148 ++++++++++++++----
 python/flink_agents/api/agents/agent.py            |   2 +
 python/flink_agents/api/agents/react_agent.py      | 121 ++-------------
 python/flink_agents/api/agents/types.py            |  67 ++++++++
 python/flink_agents/api/core_options.py            |  24 +++
 python/flink_agents/api/events/chat_event.py       |   4 +
 .../e2e_tests_integration/react_agent_test.py      |  10 +-
 .../flink_agents/plan/actions/chat_model_action.py |  94 +++++++++---
 .../flink_agents/plan/tests/resources/action.json  |   2 +-
 17 files changed, 519 insertions(+), 335 deletions(-)

diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/Agent.java 
b/api/src/main/java/org/apache/flink/agents/api/agents/Agent.java
index ae046b0..b05dc09 100644
--- a/api/src/main/java/org/apache/flink/agents/api/agents/Agent.java
+++ b/api/src/main/java/org/apache/flink/agents/api/agents/Agent.java
@@ -109,4 +109,22 @@ public class Agent {
         }
         return this;
     }
+
+    public enum ErrorHandlingStrategy {
+        FAIL("fail"),
+        RETRY("retry"),
+        IGNORE("ignore");
+
+        private final String value;
+
+        ErrorHandlingStrategy(String value) {
+            this.value = value;
+        }
+
+        public String getValue() {
+            return value;
+        }
+    }
+
+    public static String STRUCTURED_OUTPUT = "structured_output";
 }
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgentConfigOptions.java
 
b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
similarity index 72%
rename from 
api/src/main/java/org/apache/flink/agents/api/agents/ReActAgentConfigOptions.java
rename to 
api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
index d3edee7..64880a5 100644
--- 
a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgentConfigOptions.java
+++ 
b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
@@ -20,12 +20,13 @@ package org.apache.flink.agents.api.agents;
 
 import org.apache.flink.agents.api.configuration.ConfigOption;
 
-/** Config Options for {@link ReActAgent}. */
-public class ReActAgentConfigOptions {
-    /** The option specifies the error handling strategy for react agent. */
-    public static final ConfigOption<ReActAgent.ErrorHandlingStrategy> 
ERROR_HANDLING_STRATEGY =
+public class AgentExecutionOptions {
+    public static final ConfigOption<Agent.ErrorHandlingStrategy> 
ERROR_HANDLING_STRATEGY =
             new ConfigOption<>(
                     "error-handling-strategy",
-                    ReActAgent.ErrorHandlingStrategy.class,
-                    ReActAgent.ErrorHandlingStrategy.FAIL);
+                    Agent.ErrorHandlingStrategy.class,
+                    Agent.ErrorHandlingStrategy.FAIL);
+
+    public static final ConfigOption<Integer> MAX_RETRIES =
+            new ConfigOption<>("max-retries", Integer.class, 3);
 }
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/agents/OutputSchema.java 
b/api/src/main/java/org/apache/flink/agents/api/agents/OutputSchema.java
new file mode 100644
index 0000000..54fbcc3
--- /dev/null
+++ b/api/src/main/java/org/apache/flink/agents/api/agents/OutputSchema.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api.agents;
+
+import com.fasterxml.jackson.core.JacksonException;
+import com.fasterxml.jackson.core.JsonGenerator;
+import com.fasterxml.jackson.core.JsonParser;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.DeserializationContext;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.SerializerProvider;
+import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
+import com.fasterxml.jackson.databind.annotation.JsonSerialize;
+import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
+import com.fasterxml.jackson.databind.ser.std.StdSerializer;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Helper class for {@link RowTypeInfo} serialization.
+ *
+ * <p>Currently, only support row contains basic type.
+ */
+@VisibleForTesting
+@JsonSerialize(using = OutputSchema.OutputSchemaJsonSerializer.class)
+@JsonDeserialize(using = OutputSchema.OutputSchemaJsonDeserializer.class)
+public class OutputSchema {
+    private final RowTypeInfo schema;
+
+    public OutputSchema(RowTypeInfo schema) {
+        this.schema = schema;
+        for (TypeInformation<?> info : schema.getFieldTypes()) {
+            if (!info.isBasicType()) {
+                throw new IllegalArgumentException(
+                        "Currently, output schema only support row contains 
basic type.");
+            }
+        }
+    }
+
+    public RowTypeInfo getSchema() {
+        return schema;
+    }
+
+    public static class OutputSchemaJsonSerializer extends 
StdSerializer<OutputSchema> {
+
+        protected OutputSchemaJsonSerializer() {
+            super(OutputSchema.class);
+        }
+
+        @Override
+        public void serialize(
+                OutputSchema schema,
+                JsonGenerator jsonGenerator,
+                SerializerProvider serializerProvider)
+                throws IOException {
+            RowTypeInfo typeInfo = schema.getSchema();
+            jsonGenerator.writeStartObject();
+
+            jsonGenerator.writeFieldName("fieldNames");
+            jsonGenerator.writeStartArray();
+            for (String name : typeInfo.getFieldNames()) {
+                jsonGenerator.writeString(name);
+            }
+            jsonGenerator.writeEndArray();
+
+            // TODO: support type information which is not basic.
+            jsonGenerator.writeFieldName("types");
+            jsonGenerator.writeStartArray();
+            for (TypeInformation<?> info : typeInfo.getFieldTypes()) {
+                jsonGenerator.writeObject(info.getTypeClass());
+            }
+            jsonGenerator.writeEndArray();
+
+            jsonGenerator.writeEndObject();
+        }
+    }
+
+    public static class OutputSchemaJsonDeserializer extends 
StdDeserializer<OutputSchema> {
+        private static final ObjectMapper mapper = new ObjectMapper();
+
+        protected OutputSchemaJsonDeserializer() {
+            super(OutputSchema.class);
+        }
+
+        @Override
+        public OutputSchema deserialize(
+                JsonParser jsonParser, DeserializationContext 
deserializationContext)
+                throws IOException, JacksonException {
+            JsonNode node = jsonParser.getCodec().readTree(jsonParser);
+            List<String> fieldNames = new ArrayList<>();
+            node.get("fieldNames").forEach(fieldNameNode -> 
fieldNames.add(fieldNameNode.asText()));
+            List<TypeInformation<?>> types = new ArrayList<>();
+            node.get("types")
+                    .forEach(
+                            typeNode -> {
+                                try {
+                                    types.add(
+                                            BasicTypeInfo.getInfoFor(
+                                                    
mapper.treeToValue(typeNode, Class.class)));
+                                } catch (JsonProcessingException e) {
+                                    throw new RuntimeException(e);
+                                }
+                            });
+
+            return new OutputSchema(
+                    new RowTypeInfo(
+                            types.toArray(new TypeInformation[0]),
+                            fieldNames.toArray(new String[0])));
+        }
+    }
+}
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java 
b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java
index c073f54..278e356 100644
--- a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java
+++ b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java
@@ -18,19 +18,9 @@
 
 package org.apache.flink.agents.api.agents;
 
-import com.fasterxml.jackson.core.JacksonException;
-import com.fasterxml.jackson.core.JsonGenerator;
-import com.fasterxml.jackson.core.JsonParser;
 import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.DeserializationContext;
 import com.fasterxml.jackson.databind.JsonMappingException;
-import com.fasterxml.jackson.databind.JsonNode;
 import com.fasterxml.jackson.databind.ObjectMapper;
-import com.fasterxml.jackson.databind.SerializerProvider;
-import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
-import com.fasterxml.jackson.databind.annotation.JsonSerialize;
-import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
-import com.fasterxml.jackson.databind.ser.std.StdSerializer;
 import org.apache.commons.lang3.ClassUtils;
 import org.apache.flink.agents.api.InputEvent;
 import org.apache.flink.agents.api.OutputEvent;
@@ -43,9 +33,6 @@ import org.apache.flink.agents.api.event.ChatResponseEvent;
 import org.apache.flink.agents.api.prompt.Prompt;
 import org.apache.flink.agents.api.resource.ResourceDescriptor;
 import org.apache.flink.agents.api.resource.ResourceType;
-import org.apache.flink.annotation.VisibleForTesting;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.types.Row;
 import org.slf4j.Logger;
@@ -53,7 +40,6 @@ import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nullable;
 
-import java.io.IOException;
 import java.lang.reflect.Method;
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -106,16 +92,14 @@ public class ReActAgent extends Agent {
 
         try {
             Method method =
-                    this.getClass()
-                            .getMethod("stopAction", ChatResponseEvent.class, 
RunnerContext.class);
-            this.addAction(new Class[] {ChatResponseEvent.class}, method, 
actionConfig);
+                    this.getClass().getMethod("startAction", InputEvent.class, 
RunnerContext.class);
+            this.addAction(new Class[] {InputEvent.class}, method, 
actionConfig);
         } catch (NoSuchMethodException e) {
             throw new IllegalStateException(
                     "Can't find the method stopAction, this must be a bug.");
         }
     }
 
-    @Action(listenEvents = {InputEvent.class})
     public static void startAction(InputEvent event, RunnerContext ctx) {
         Object input = event.getInput();
 
@@ -177,151 +161,22 @@ public class ReActAgent extends Agent {
             inputMessages.addAll(0, instruct);
         }
 
-        ctx.sendEvent(new ChatRequestEvent(DEFAULT_CHAT_MODEL, inputMessages));
-    }
-
-    public static void stopAction(ChatResponseEvent event, RunnerContext ctx)
-            throws JsonProcessingException {
-        Object output = String.valueOf(event.getResponse().getContent());
-
         Object outputSchema = ctx.getActionConfigValue("output_schema");
 
-        if (outputSchema != null) {
-            ErrorHandlingStrategy strategy =
-                    
ctx.getConfig().get(ReActAgentConfigOptions.ERROR_HANDLING_STRATEGY);
-            try {
-                if (outputSchema instanceof Class) {
-                    output = mapper.readValue(String.valueOf(output), 
(Class<?>) outputSchema);
-                } else if (outputSchema instanceof OutputSchema) {
-                    RowTypeInfo info = ((OutputSchema) 
outputSchema).getSchema();
-                    Map<String, Object> fields =
-                            mapper.readValue(String.valueOf(output), 
Map.class);
-                    output = Row.withNames();
-                    for (String name : info.getFieldNames()) {
-                        ((Row) output).setField(name, fields.get(name));
-                    }
-                }
-            } catch (Exception e) {
-                if (strategy == ErrorHandlingStrategy.FAIL) {
-                    throw e;
-                } else if (strategy == ErrorHandlingStrategy.IGNORE) {
-                    LOG.warn(
-                            "The response of llm {} doesn't match schema 
constraint, ignoring.",
-                            output);
-                    return;
-                }
-            }
-        }
-
-        ctx.sendEvent(new OutputEvent(output));
-    }
-
-    /**
-     * Helper class for {@link RowTypeInfo} serialization.
-     *
-     * <p>Currently, only support row contains basic type.
-     */
-    @VisibleForTesting
-    @JsonSerialize(using = OutputSchemaJsonSerializer.class)
-    @JsonDeserialize(using = OutputSchemaJsonDeserializer.class)
-    public static class OutputSchema {
-        private final RowTypeInfo schema;
-
-        public OutputSchema(RowTypeInfo schema) {
-            this.schema = schema;
-            for (TypeInformation<?> info : schema.getFieldTypes()) {
-                if (!info.isBasicType()) {
-                    throw new IllegalArgumentException(
-                            "Currently, output schema only support row 
contains basic type.");
-                }
-            }
-        }
-
-        public RowTypeInfo getSchema() {
-            return schema;
-        }
-    }
-
-    public static class OutputSchemaJsonSerializer extends 
StdSerializer<OutputSchema> {
-
-        protected OutputSchemaJsonSerializer() {
-            super(OutputSchema.class);
-        }
-
-        @Override
-        public void serialize(
-                OutputSchema schema,
-                JsonGenerator jsonGenerator,
-                SerializerProvider serializerProvider)
-                throws IOException {
-            RowTypeInfo typeInfo = schema.getSchema();
-            jsonGenerator.writeStartObject();
-
-            jsonGenerator.writeFieldName("fieldNames");
-            jsonGenerator.writeStartArray();
-            for (String name : typeInfo.getFieldNames()) {
-                jsonGenerator.writeString(name);
-            }
-            jsonGenerator.writeEndArray();
-
-            // TODO: support type information which is not basic.
-            jsonGenerator.writeFieldName("types");
-            jsonGenerator.writeStartArray();
-            for (TypeInformation<?> info : typeInfo.getFieldTypes()) {
-                jsonGenerator.writeObject(info.getTypeClass());
-            }
-            jsonGenerator.writeEndArray();
-
-            jsonGenerator.writeEndObject();
-        }
-    }
-
-    public static class OutputSchemaJsonDeserializer extends 
StdDeserializer<OutputSchema> {
-        private static final ObjectMapper mapper = new ObjectMapper();
-
-        protected OutputSchemaJsonDeserializer() {
-            super(OutputSchema.class);
-        }
-
-        @Override
-        public OutputSchema deserialize(
-                JsonParser jsonParser, DeserializationContext 
deserializationContext)
-                throws IOException, JacksonException {
-            JsonNode node = jsonParser.getCodec().readTree(jsonParser);
-            List<String> fieldNames = new ArrayList<>();
-            node.get("fieldNames").forEach(fieldNameNode -> 
fieldNames.add(fieldNameNode.asText()));
-            List<TypeInformation<?>> types = new ArrayList<>();
-            node.get("types")
-                    .forEach(
-                            typeNode -> {
-                                try {
-                                    types.add(
-                                            BasicTypeInfo.getInfoFor(
-                                                    
mapper.treeToValue(typeNode, Class.class)));
-                                } catch (JsonProcessingException e) {
-                                    throw new RuntimeException(e);
-                                }
-                            });
-
-            return new OutputSchema(
-                    new RowTypeInfo(
-                            types.toArray(new TypeInformation[0]),
-                            fieldNames.toArray(new String[0])));
-        }
+        ctx.sendEvent(new ChatRequestEvent(DEFAULT_CHAT_MODEL, inputMessages, 
outputSchema));
     }
 
-    public enum ErrorHandlingStrategy {
-        FAIL("fail"),
-        IGNORE("ignore");
-
-        private final String value;
+    @Action(listenEvents = {ChatResponseEvent.class})
+    public static void stopAction(ChatResponseEvent event, RunnerContext ctx) {
+        ChatMessage response = event.getResponse();
 
-        ErrorHandlingStrategy(String value) {
-            this.value = value;
+        Object output;
+        if (response.getExtraArgs().containsKey(STRUCTURED_OUTPUT)) {
+            output = response.getExtraArgs().get(STRUCTURED_OUTPUT);
+        } else {
+            output = String.valueOf(response.getContent());
         }
 
-        public String getValue() {
-            return value;
-        }
+        ctx.sendEvent(new OutputEvent(output));
     }
 }
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java 
b/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java
index e5277d3..cba7b52 100644
--- a/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java
+++ b/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java
@@ -21,15 +21,25 @@ package org.apache.flink.agents.api.event;
 import org.apache.flink.agents.api.Event;
 import org.apache.flink.agents.api.chat.messages.ChatMessage;
 
+import javax.annotation.Nullable;
+
 import java.util.List;
 
+/** Event representing a request for chat. */
 public class ChatRequestEvent extends Event {
     private final String model;
     private final List<ChatMessage> messages;
+    private final @Nullable Object outputSchema;
 
-    public ChatRequestEvent(String model, List<ChatMessage> messages) {
+    public ChatRequestEvent(
+            String model, List<ChatMessage> messages, @Nullable Object 
outputSchema) {
         this.model = model;
         this.messages = messages;
+        this.outputSchema = outputSchema;
+    }
+
+    public ChatRequestEvent(String model, List<ChatMessage> messages) {
+        this(model, messages, null);
     }
 
     public String getModel() {
@@ -39,4 +49,9 @@ public class ChatRequestEvent extends Event {
     public List<ChatMessage> getMessages() {
         return messages;
     }
+
+    @Nullable
+    public Object getOutputSchema() {
+        return outputSchema;
+    }
 }
diff --git 
a/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java 
b/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java
index f5851e1..db237e7 100644
--- a/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java
+++ b/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java
@@ -36,10 +36,9 @@ public class ReActAgentTest {
                             BasicTypeInfo.INT_TYPE_INFO, 
BasicTypeInfo.STRING_TYPE_INFO
                         },
                         new String[] {"a", "b"});
-        ReActAgent.OutputSchema schema = new ReActAgent.OutputSchema(typeInfo);
+        OutputSchema schema = new OutputSchema(typeInfo);
         String json = mapper.writeValueAsString(schema);
-        ReActAgent.OutputSchema deserialized =
-                mapper.readValue(json, ReActAgent.OutputSchema.class);
+        OutputSchema deserialized = mapper.readValue(json, OutputSchema.class);
         Assertions.assertEquals(typeInfo, deserialized.getSchema());
     }
 }
diff --git 
a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java
 
b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java
index 648b1b1..b4adee8 100644
--- 
a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java
+++ 
b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java
@@ -21,7 +21,6 @@ package org.apache.flink.agents.integration.test;
 import org.apache.flink.agents.api.AgentsExecutionEnvironment;
 import org.apache.flink.agents.api.agents.Agent;
 import org.apache.flink.agents.api.agents.ReActAgent;
-import org.apache.flink.agents.api.agents.ReActAgentConfigOptions;
 import org.apache.flink.agents.api.annotation.ToolParam;
 import org.apache.flink.agents.api.chat.messages.ChatMessage;
 import org.apache.flink.agents.api.chat.messages.MessageRole;
@@ -50,6 +49,8 @@ import org.junit.jupiter.api.Test;
 import java.io.IOException;
 import java.util.List;
 
+import static 
org.apache.flink.agents.api.agents.AgentExecutionOptions.ERROR_HANDLING_STRATEGY;
+import static 
org.apache.flink.agents.api.agents.AgentExecutionOptions.MAX_RETRIES;
 import static 
org.apache.flink.agents.integration.test.OllamaPreparationUtils.pullModel;
 
 public class ReActAgentTest {
@@ -110,11 +111,8 @@ public class ReActAgentTest {
                                 ReActAgentTest.class.getMethod(
                                         "multiply", Double.class, 
Double.class)));
 
-        agentsEnv
-                .getConfig()
-                .set(
-                        ReActAgentConfigOptions.ERROR_HANDLING_STRATEGY,
-                        ReActAgent.ErrorHandlingStrategy.IGNORE);
+        agentsEnv.getConfig().set(ERROR_HANDLING_STRATEGY, 
ReActAgent.ErrorHandlingStrategy.RETRY);
+        agentsEnv.getConfig().set(MAX_RETRIES, 3);
 
         // Declare the ReAct agent.
         Agent agent = getAgent();
diff --git a/plan/pom.xml b/plan/pom.xml
index 2b77646..f6b5df5 100644
--- a/plan/pom.xml
+++ b/plan/pom.xml
@@ -68,6 +68,22 @@ under the License.
             <version>2.0.2</version>
             <scope>test</scope>
         </dependency>
+        <!-- LOG -->
+        <dependency>
+            <groupId>org.slf4j</groupId>
+            <artifactId>slf4j-api</artifactId>
+            <version>${slf4j.version}</version>
+        </dependency>
+        <dependency>
+            <groupId>org.apache.logging.log4j</groupId>
+            <artifactId>log4j-core</artifactId>
+            <version>${log4j2.version}</version>
+        </dependency>
+        <dependency>
+            <groupId>org.apache.logging.log4j</groupId>
+            <artifactId>log4j-slf4j-impl</artifactId>
+            <version>${log4j2.version}</version>
+        </dependency>
     </dependencies>
 
     <profiles>
diff --git 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
index 0fb96e3..aa51fa4 100644
--- 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
+++ 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
@@ -17,7 +17,12 @@
  */
 package org.apache.flink.agents.plan.actions;
 
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
 import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.api.agents.Agent;
+import org.apache.flink.agents.api.agents.AgentExecutionOptions;
+import org.apache.flink.agents.api.agents.OutputSchema;
 import org.apache.flink.agents.api.chat.messages.ChatMessage;
 import org.apache.flink.agents.api.chat.messages.MessageRole;
 import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
@@ -30,15 +35,28 @@ import org.apache.flink.agents.api.event.ToolResponseEvent;
 import org.apache.flink.agents.api.resource.ResourceType;
 import org.apache.flink.agents.api.tools.ToolResponse;
 import org.apache.flink.agents.plan.JavaFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.types.Row;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
 
 import java.util.*;
 
+import static org.apache.flink.agents.api.agents.Agent.STRUCTURED_OUTPUT;
+
 /** Built-in action for processing chat request and tool call result. */
 public class ChatModelAction {
+    private static final Logger LOG = 
LoggerFactory.getLogger(ChatModelAction.class);
+
     private static final String TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT";
     private static final String TOOL_REQUEST_EVENT_CONTEXT = 
"_TOOL_REQUEST_EVENT_CONTEXT";
     private static final String INITIAL_REQUEST_ID = "initialRequestId";
     private static final String MODEL = "model";
+    private static final String OUTPUT_SCHEMA = "outputSchema";
+
+    private static final ObjectMapper mapper = new ObjectMapper();
 
     public static Action getChatModelAction() throws Exception {
         return new Action(
@@ -76,22 +94,13 @@ public class ChatModelAction {
         return messageContext;
     }
 
-    @SuppressWarnings("unchecked")
-    private static void clearToolCallContext(MemoryObject sensoryMem, UUID 
initialRequestId)
-            throws Exception {
-        if (sensoryMem.isExist(TOOL_CALL_CONTEXT)) {
-            Map<UUID, Object> toolCallContext =
-                    (Map<UUID, Object>) 
sensoryMem.get(TOOL_CALL_CONTEXT).getValue();
-            if (toolCallContext.containsKey(initialRequestId)) {
-                toolCallContext.remove(initialRequestId);
-                sensoryMem.set(TOOL_CALL_CONTEXT, toolCallContext);
-            }
-        }
-    }
-
     @SuppressWarnings("unchecked")
     private static void saveToolRequestEventContext(
-            MemoryObject sensoryMem, UUID toolRequestEventId, UUID 
initialRequestId, String model)
+            MemoryObject sensoryMem,
+            UUID toolRequestEventId,
+            UUID initialRequestId,
+            String model,
+            Object outputSchema)
             throws Exception {
         Map<UUID, Object> toolRequestEventContext;
         if (sensoryMem.isExist(TOOL_REQUEST_EVENT_CONTEXT)) {
@@ -100,20 +109,22 @@ public class ChatModelAction {
         } else {
             toolRequestEventContext = new HashMap<>();
         }
-        toolRequestEventContext.put(
-                toolRequestEventId, Map.of(INITIAL_REQUEST_ID, 
initialRequestId, MODEL, model));
+        Map<String, Object> context = new HashMap<>();
+        context.put(INITIAL_REQUEST_ID, initialRequestId);
+        context.put(MODEL, model);
+        if (outputSchema != null) {
+            context.put(OUTPUT_SCHEMA, outputSchema);
+        }
+        toolRequestEventContext.put(toolRequestEventId, context);
         sensoryMem.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext);
     }
 
     @SuppressWarnings("unchecked")
-    private static Map<String, Object> removeToolRequestEventContext(
+    private static Map<String, Object> getToolRequestEventContext(
             MemoryObject sensoryMem, UUID requestId) throws Exception {
         Map<UUID, Object> toolRequestEventContext =
                 (Map<UUID, Object>) 
sensoryMem.get(TOOL_REQUEST_EVENT_CONTEXT).getValue();
-        Map<String, Object> context =
-                (Map<String, Object>) 
toolRequestEventContext.remove(requestId);
-        sensoryMem.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext);
-        return context;
+        return (Map<String, Object>) toolRequestEventContext.remove(requestId);
     }
 
     private static void handleToolCalls(
@@ -121,6 +132,7 @@ public class ChatModelAction {
             UUID initialRequestId,
             String model,
             List<ChatMessage> messages,
+            Object outputSchema,
             RunnerContext ctx)
             throws Exception {
         updateToolCallContext(
@@ -132,11 +144,38 @@ public class ChatModelAction {
         ToolRequestEvent toolRequestEvent = new ToolRequestEvent(model, 
response.getToolCalls());
 
         saveToolRequestEventContext(
-                ctx.getSensoryMemory(), toolRequestEvent.getId(), 
initialRequestId, model);
+                ctx.getSensoryMemory(),
+                toolRequestEvent.getId(),
+                initialRequestId,
+                model,
+                outputSchema);
 
         ctx.sendEvent(toolRequestEvent);
     }
 
+    @SuppressWarnings("unchecked")
+    private static ChatMessage generateStructuredOutput(ChatMessage response, 
Object outputSchema)
+            throws JsonProcessingException {
+        String output = response.getContent();
+        Object structuredOutput;
+        if (outputSchema instanceof Class) {
+            structuredOutput = mapper.readValue(String.valueOf(output), 
(Class<?>) outputSchema);
+        } else if (outputSchema instanceof OutputSchema) {
+            RowTypeInfo info = ((OutputSchema) outputSchema).getSchema();
+            Map<String, Object> fields = 
mapper.readValue(String.valueOf(output), Map.class);
+            structuredOutput = Row.withNames();
+            for (String name : info.getFieldNames()) {
+                ((Row) structuredOutput).setField(name, fields.get(name));
+            }
+        } else {
+            throw new RuntimeException(
+                    String.format("Unsupported output schema %s.", 
outputSchema));
+        }
+        Map<String, Object> extraArgs = new HashMap<>();
+        extraArgs.put(STRUCTURED_OUTPUT, structuredOutput);
+        return new ChatMessage(response.getRole(), output, extraArgs);
+    }
+
     /**
      * Chat with chat model.
      *
@@ -148,26 +187,69 @@ public class ChatModelAction {
      * @param ctx The runner context this function executed in.
      */
     public static void chat(
-            UUID initialRequestId, String model, List<ChatMessage> messages, 
RunnerContext ctx)
+            UUID initialRequestId,
+            String model,
+            List<ChatMessage> messages,
+            @Nullable Object outputSchema,
+            RunnerContext ctx)
             throws Exception {
         BaseChatModelSetup chatModel =
                 (BaseChatModelSetup) ctx.getResource(model, 
ResourceType.CHAT_MODEL);
 
-        ChatMessage response = chatModel.chat(messages, Map.of());
+        Agent.ErrorHandlingStrategy strategy =
+                
ctx.getConfig().get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY);
+        int numRetries = 0;
+        if (strategy == Agent.ErrorHandlingStrategy.RETRY) {
+            numRetries =
+                    ctx.getConfig().get(AgentExecutionOptions.MAX_RETRIES) > 0
+                            ? 
ctx.getConfig().get(AgentExecutionOptions.MAX_RETRIES)
+                            : 0;
+        }
 
-        if (!response.getToolCalls().isEmpty()) {
-            handleToolCalls(response, initialRequestId, model, messages, ctx);
-        } else {
-            // clean tool call context
-            clearToolCallContext(ctx.getSensoryMemory(), initialRequestId);
+        ChatMessage response = null;
 
+        for (int attempt = 0; attempt < numRetries + 1; attempt++) {
+            try {
+                response = chatModel.chat(messages, Map.of());
+                // only generate structured output for final response.
+                if (outputSchema != null && response.getToolCalls().isEmpty()) 
{
+                    response = generateStructuredOutput(response, 
outputSchema);
+                }
+            } catch (Exception e) {
+                if (strategy == Agent.ErrorHandlingStrategy.IGNORE) {
+                    LOG.warn(
+                            "Chat request {} failed with error: {}, ignored.", 
initialRequestId, e);
+                    return;
+                } else if (strategy == Agent.ErrorHandlingStrategy.RETRY) {
+                    if (attempt == numRetries) {
+                        throw e;
+                    }
+                    LOG.warn(
+                            "Chat request {} failed with error: {}, retrying 
{} / {}.",
+                            initialRequestId,
+                            e,
+                            attempt,
+                            numRetries);
+                } else {
+                    LOG.debug(
+                            "Chat request {} failed, the input chat messages 
are {}.",
+                            initialRequestId,
+                            messages);
+                    throw e;
+                }
+            }
+        }
+
+        if (!Objects.requireNonNull(response).getToolCalls().isEmpty()) {
+            handleToolCalls(response, initialRequestId, model, messages, 
outputSchema, ctx);
+        } else {
             ctx.sendEvent(new ChatResponseEvent(initialRequestId, response));
         }
     }
 
     private static void processChatRequest(ChatRequestEvent event, 
RunnerContext ctx)
             throws Exception {
-        chat(event.getId(), event.getModel(), event.getMessages(), ctx);
+        chat(event.getId(), event.getModel(), event.getMessages(), 
event.getOutputSchema(), ctx);
     }
 
     private static void processToolResponse(ToolResponseEvent event, 
RunnerContext ctx)
@@ -175,11 +257,11 @@ public class ChatModelAction {
         MemoryObject sensoryMem = ctx.getSensoryMemory();
 
         // get tool request context from memory
-        Map<String, Object> context =
-                removeToolRequestEventContext(sensoryMem, 
event.getRequestId());
+        Map<String, Object> context = getToolRequestEventContext(sensoryMem, 
event.getRequestId());
 
         UUID initialRequestId = (UUID) context.get(INITIAL_REQUEST_ID);
         String model = (String) context.get(MODEL);
+        Object outputSchema = context.get(OUTPUT_SCHEMA);
 
         Map<String, ToolResponse> responses = event.getResponses();
         Map<String, Boolean> success = event.getSuccess();
@@ -212,7 +294,7 @@ public class ChatModelAction {
                         Collections.emptyList(),
                         toolResponseMessages);
 
-        chat(initialRequestId, model, messages, ctx);
+        chat(initialRequestId, model, messages, outputSchema, ctx);
     }
 
     /**
diff --git a/python/flink_agents/api/agents/agent.py 
b/python/flink_agents/api/agents/agent.py
index 0324dac..3619338 100644
--- a/python/flink_agents/api/agents/agent.py
+++ b/python/flink_agents/api/agents/agent.py
@@ -26,6 +26,8 @@ from flink_agents.api.resource import (
 )
 from flink_agents.api.tools.mcp import MCPServer
 
+STRUCTURED_OUTPUT = "structured_output"
+
 
 class Agent(ABC):
     """Base class for defining agent logic.
diff --git a/python/flink_agents/api/agents/react_agent.py 
b/python/flink_agents/api/agents/react_agent.py
index 63d319e..f540953 100644
--- a/python/flink_agents/api/agents/react_agent.py
+++ b/python/flink_agents/api/agents/react_agent.py
@@ -15,24 +15,17 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
#################################################################################
-import importlib
-import json
-import logging
-from enum import Enum
-from typing import Any, cast
+from typing import cast
 
 from pydantic import (
     BaseModel,
-    ConfigDict,
-    model_serializer,
-    model_validator,
 )
 from pyflink.common import Row
-from pyflink.common.typeinfo import BasicType, BasicTypeInfo, RowTypeInfo
+from pyflink.common.typeinfo import RowTypeInfo
 
-from flink_agents.api.agents.agent import Agent
+from flink_agents.api.agents.agent import STRUCTURED_OUTPUT, Agent
+from flink_agents.api.agents.types import OutputSchema
 from flink_agents.api.chat_message import ChatMessage, MessageRole
-from flink_agents.api.configuration import ConfigOption
 from flink_agents.api.decorators import action
 from flink_agents.api.events.chat_event import ChatRequestEvent, 
ChatResponseEvent
 from flink_agents.api.events.event import InputEvent, OutputEvent
@@ -46,68 +39,6 @@ _DEFAULT_USER_PROMPT = "_default_user_prompt"
 _OUTPUT_SCHEMA = "_output_schema"
 
 
-class ErrorHandlingStrategy(Enum):
-    """Error handling strategy for ReActAgent."""
-
-    FAIL = "fail"
-    IGNORE = "ignore"
-
-
-class ReActAgentOptions:
-    """Config options for ReActAgent."""
-
-    ERROR_HANDLING_STRATEGY = ConfigOption(
-        key="error-handling-strategy",
-        config_type=ErrorHandlingStrategy,
-        default=ErrorHandlingStrategy.FAIL,
-    )
-
-
-class OutputSchema(BaseModel):
-    """Util class to help serialize and deserialize output schema json."""
-
-    model_config = ConfigDict(arbitrary_types_allowed=True)
-    output_schema: type[BaseModel] | RowTypeInfo
-
-    @model_serializer
-    def __custom_serializer(self) -> dict[str, Any]:
-        if isinstance(self.output_schema, RowTypeInfo):
-            data = {
-                "output_schema": {
-                    "names": self.output_schema.get_field_names(),
-                    "types": [
-                        type._basic_type.value
-                        for type in self.output_schema.get_field_types()
-                    ],
-                },
-            }
-        else:
-            data = {
-                "output_schema": {
-                    "module": self.output_schema.__module__,
-                    "class": self.output_schema.__name__,
-                }
-            }
-        return data
-
-    @model_validator(mode="before")
-    def __custom_deserialize(self) -> "OutputSchema":
-        output_schema = self["output_schema"]
-        if isinstance(output_schema, dict):
-            if "names" in output_schema:
-                self["output_schema"] = RowTypeInfo(
-                    field_types=[
-                        BasicTypeInfo(BasicType(type))
-                        for type in output_schema["types"]
-                    ],
-                    field_names=output_schema["names"],
-                )
-            else:
-                module = importlib.import_module(output_schema["module"])
-                self["output_schema"] = getattr(module, output_schema["class"])
-        return self
-
-
 class ReActAgent(Agent):
     """Built-in implementation of ReAct agent which is based on the function
     call ability of llm.
@@ -204,13 +135,12 @@ class ReActAgent(Agent):
             self._resources[ResourceType.PROMPT][_DEFAULT_USER_PROMPT] = prompt
 
         self.add_action(
-            name="stop_action",
-            events=[ChatResponseEvent],
-            func=self.stop_action,
+            name="start_action",
+            events=[InputEvent],
+            func=self.start_action,
             output_schema=OutputSchema(output_schema=output_schema),
         )
 
-    @action(InputEvent)
     @staticmethod
     def start_action(event: InputEvent, ctx: RunnerContext) -> None:
         """Start action to format user input and send chat request event."""
@@ -257,44 +187,25 @@ class ReActAgent(Agent):
             instruct = schema_prompt.format_messages()
             usr_msgs = instruct + usr_msgs
 
+        output_schema = ctx.get_action_config_value(key="output_schema")
+
         ctx.send_event(
             ChatRequestEvent(
                 model=_DEFAULT_CHAT_MODEL,
                 messages=usr_msgs,
+                output_schema=output_schema,
             )
         )
 
+    @action(ChatResponseEvent)
     @staticmethod
     def stop_action(event: ChatResponseEvent, ctx: RunnerContext) -> None:
         """Stop action to output result."""
-        output = event.response.content
-        # parse llm response to target schema.
-        output_schema = ctx.get_action_config_value(key="output_schema")
+        response = event.response
 
-        error_handling_strategy = ctx.config.get(
-            ReActAgentOptions.ERROR_HANDLING_STRATEGY
-        )
-        try:
-            if output_schema:
-                output_schema = output_schema.output_schema
-                output = json.loads(output.strip())
-                if isinstance(output_schema, type) and issubclass(
-                    output_schema, BaseModel
-                ):
-                    output = output_schema.model_validate(output)
-                elif isinstance(output_schema, RowTypeInfo):
-                    field_names = output_schema.get_field_names()
-                    values = {}
-                    for field_name in field_names:
-                        values[field_name] = output[field_name]
-                    output = Row(**values)
-        except Exception:
-            if error_handling_strategy == ErrorHandlingStrategy.IGNORE:
-                logging.warning(
-                    f"The response of llm {output} doesn't match schema 
constraint, ignoring."
-                )
-                return
-            elif error_handling_strategy == ErrorHandlingStrategy.FAIL:
-                raise
+        if STRUCTURED_OUTPUT in response.extra_args:
+            output = response.extra_args[STRUCTURED_OUTPUT]
+        else:
+            output = response.content
 
         ctx.send_event(OutputEvent(output=output))
diff --git a/python/flink_agents/api/agents/types.py 
b/python/flink_agents/api/agents/types.py
new file mode 100644
index 0000000..2fd10b8
--- /dev/null
+++ b/python/flink_agents/api/agents/types.py
@@ -0,0 +1,67 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+import importlib
+from typing import Any
+
+from pydantic import BaseModel, ConfigDict, model_serializer, model_validator
+from pyflink.common.typeinfo import BasicType, BasicTypeInfo, RowTypeInfo
+
+
+class OutputSchema(BaseModel):
+    """Util class to help serialize and deserialize output schema json."""
+
+    model_config = ConfigDict(arbitrary_types_allowed=True)
+    output_schema: type[BaseModel] | RowTypeInfo
+
+    @model_serializer
+    def __custom_serializer(self) -> dict[str, Any]:
+        if isinstance(self.output_schema, RowTypeInfo):
+            data = {
+                "output_schema": {
+                    "names": self.output_schema.get_field_names(),
+                    "types": [
+                        type._basic_type.value
+                        for type in self.output_schema.get_field_types()
+                    ],
+                },
+            }
+        else:
+            data = {
+                "output_schema": {
+                    "module": self.output_schema.__module__,
+                    "class": self.output_schema.__name__,
+                }
+            }
+        return data
+
+    @model_validator(mode="before")
+    def __custom_deserialize(self) -> "OutputSchema":
+        output_schema = self["output_schema"]
+        if isinstance(output_schema, dict):
+            if "names" in output_schema:
+                self["output_schema"] = RowTypeInfo(
+                    field_types=[
+                        BasicTypeInfo(BasicType(type))
+                        for type in output_schema["types"]
+                    ],
+                    field_names=output_schema["names"],
+                )
+            else:
+                module = importlib.import_module(output_schema["module"])
+                self["output_schema"] = getattr(module, output_schema["class"])
+        return self
diff --git a/python/flink_agents/api/core_options.py 
b/python/flink_agents/api/core_options.py
index d9ee456..9d8f7ea 100644
--- a/python/flink_agents/api/core_options.py
+++ b/python/flink_agents/api/core_options.py
@@ -15,6 +15,7 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
#################################################################################
+from enum import Enum
 from typing import Any
 
 from pyflink.java_gateway import get_gateway
@@ -69,6 +70,17 @@ class AgentConfigOptionsMeta(type):
         return python_option
 
 
+class ErrorHandlingStrategy(Enum):
+    """Error handling strategy for Agent.
+
+    Currently, only works for chat action.
+    """
+
+    RETRY = "retry"
+    FAIL = "fail"
+    IGNORE = "ignore"
+
+
 class AgentConfigOptions(metaclass=AgentConfigOptionsMeta):
     """CoreOptions to manage core configuration parameters for Flink Agents."""
 
@@ -77,3 +89,15 @@ class AgentConfigOptions(metaclass=AgentConfigOptionsMeta):
         config_type=str,
         default=None,
     )
+
+    ERROR_HANDLING_STRATEGY = ConfigOption(
+        key="error-handling-strategy",
+        config_type=ErrorHandlingStrategy,
+        default=ErrorHandlingStrategy.FAIL,
+    )
+
+    MAX_RETRIES = ConfigOption(
+        key="max-retries",
+        config_type=int,
+        default=3,
+    )
diff --git a/python/flink_agents/api/events/chat_event.py 
b/python/flink_agents/api/events/chat_event.py
index 2f4dfb1..2fb4266 100644
--- a/python/flink_agents/api/events/chat_event.py
+++ b/python/flink_agents/api/events/chat_event.py
@@ -18,6 +18,7 @@
 from typing import List
 from uuid import UUID
 
+from flink_agents.api.agents.react_agent import OutputSchema
 from flink_agents.api.chat_message import ChatMessage
 from flink_agents.api.events.event import Event
 
@@ -31,10 +32,13 @@ class ChatRequestEvent(Event):
         The name of the chat model to be chatted with.
     messages : List[ChatMessage]
         The input to the chat model.
+    output_schema: OutputSchema | None
+        The expected output schema of the chat model final response. Optional.
     """
 
     model: str
     messages: List[ChatMessage]
+    output_schema: OutputSchema | None = None
 
 
 class ChatResponseEvent(Event):
diff --git 
a/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py 
b/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py
index 630a7f3..7e688b5 100644
--- a/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py
+++ b/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py
@@ -28,11 +28,10 @@ from pyflink.datastream import KeySelector, 
StreamExecutionEnvironment
 from pyflink.table import DataTypes, Schema, StreamTableEnvironment, 
TableDescriptor
 
 from flink_agents.api.agents.react_agent import (
-    ErrorHandlingStrategy,
     ReActAgent,
-    ReActAgentOptions,
 )
 from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.core_options import AgentConfigOptions, 
ErrorHandlingStrategy
 from flink_agents.api.execution_environment import AgentsExecutionEnvironment
 from flink_agents.api.prompts.prompt import Prompt
 from flink_agents.api.resource import ResourceDescriptor
@@ -79,8 +78,9 @@ client = pull_model(OLLAMA_MODEL)
 def test_react_agent_on_local_runner() -> None:  # noqa: D103
     env = AgentsExecutionEnvironment.get_execution_environment()
     env.get_config().set(
-        ReActAgentOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.IGNORE
+        AgentConfigOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.RETRY
     )
+    env.get_config().set(AgentConfigOptions.MAX_RETRIES, 3)
 
     # register resource to execution environment
     (
@@ -155,9 +155,11 @@ def test_react_agent_on_remote_runner(tmp_path: Path) -> 
None:  # noqa: D103
     )
 
     env.get_config().set(
-        ReActAgentOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.IGNORE
+        AgentConfigOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.RETRY
     )
 
+    env.get_config().set(AgentConfigOptions.MAX_RETRIES, 3)
+
     # register resource to execution environment
     (
         env.add_resource(
diff --git a/python/flink_agents/plan/actions/chat_model_action.py 
b/python/flink_agents/plan/actions/chat_model_action.py
index 7299f13..0db7e5a 100644
--- a/python/flink_agents/plan/actions/chat_model_action.py
+++ b/python/flink_agents/plan/actions/chat_model_action.py
@@ -16,10 +16,19 @@
 # limitations under the License.
 
#################################################################################
 import copy
+import json
+import logging
 from typing import TYPE_CHECKING, Dict, List, cast
 from uuid import UUID
 
+from pydantic import BaseModel
+from pyflink.common import Row
+from pyflink.common.typeinfo import RowTypeInfo
+
+from flink_agents.api.agents.agent import STRUCTURED_OUTPUT
+from flink_agents.api.agents.react_agent import OutputSchema
 from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.core_options import AgentConfigOptions, 
ErrorHandlingStrategy
 from flink_agents.api.events.chat_event import ChatRequestEvent, 
ChatResponseEvent
 from flink_agents.api.events.event import Event
 from flink_agents.api.events.tool_event import ToolRequestEvent, 
ToolResponseEvent
@@ -35,6 +44,8 @@ if TYPE_CHECKING:
 _TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT"
 _TOOL_REQUEST_EVENT_CONTEXT = "_TOOL_REQUEST_EVENT_CONTEXT"
 
+_logger = logging.getLogger(__name__)
+
 
 # ============================================================================
 # Helper Functions for Tool Call Context Management
@@ -69,39 +80,29 @@ def _update_tool_call_context(
     sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context)
     return tool_call_context[initial_request_id]
 
-
-def _clear_tool_call_context(
-    sensory_memory: MemoryObject, initial_request_id: UUID
-) -> None:
-    """Clear tool call context for a specific request ID."""
-    context = sensory_memory.get(_TOOL_CALL_CONTEXT) or {}
-    if initial_request_id in context:
-        context.pop(initial_request_id)
-        sensory_memory.set(_TOOL_CALL_CONTEXT, context)
-
-
 def _save_tool_request_event_context(
     sensory_memory: MemoryObject,
     tool_request_event_id: UUID,
     initial_request_id: UUID,
     model: str,
+    output_schema: OutputSchema | None,
 ) -> None:
     """Save the context for a specific tool request event."""
     context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {}
     context[tool_request_event_id] = {
         "initial_request_id": initial_request_id,
         "model": model,
+        "output_schema": output_schema,
     }
     sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, context)
 
 
-def _remove_tool_request_event_context(
+def _get_tool_request_event_context(
     sensory_memory: MemoryObject, request_id: UUID
 ) -> Dict:
     """Get and remove the context for a specific tool request event."""
     context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {}
     removed_context = context.pop(request_id, {})
-    sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, removed_context)
     return removed_context
 
 
@@ -110,6 +111,7 @@ def _handle_tool_calls(
     initial_request_id: UUID,
     model: str,
     messages: List[ChatMessage],
+    output_schema: OutputSchema | None,
     ctx: RunnerContext,
 ) -> None:
     """Handle tool calls in chat response."""
@@ -124,16 +126,41 @@ def _handle_tool_calls(
 
     # save tool request event context
     _save_tool_request_event_context(
-        ctx.sensory_memory, tool_request_event.id, initial_request_id, model
+        ctx.sensory_memory,
+        tool_request_event.id,
+        initial_request_id,
+        model,
+        output_schema,
     )
 
     ctx.send_event(tool_request_event)
 
 
+def _generate_structured_output(
+    response: ChatMessage, output_schema: OutputSchema
+) -> ChatMessage:
+    """Deserialize output to expected output schema."""
+    output_schema = output_schema.output_schema
+    output = json.loads(response.content.strip())
+
+    if isinstance(output_schema, type) and issubclass(output_schema, 
BaseModel):
+        output = output_schema.model_validate(output)
+    elif isinstance(output_schema, RowTypeInfo):
+        field_names = output_schema.get_field_names()
+        values = {}
+        for field_name in field_names:
+            values[field_name] = output[field_name]
+        output = Row(**values)
+    response.extra_args[STRUCTURED_OUTPUT] = output
+
+    return response
+
+
 def chat(
     initial_request_id: UUID,
     model: str,
     messages: List[ChatMessage],
+    output_schema: OutputSchema | None,
     ctx: RunnerContext,
 ) -> None:
     """Chat with llm.
@@ -146,16 +173,43 @@ def chat(
         "BaseChatModelSetup", ctx.get_resource(model, ResourceType.CHAT_MODEL)
     )
 
+    error_handling_strategy = 
ctx.config.get(AgentConfigOptions.ERROR_HANDLING_STRATEGY)
+    num_retries = 0
+    if error_handling_strategy == ErrorHandlingStrategy.RETRY:
+        num_retries = max(0, ctx.config.get(AgentConfigOptions.MAX_RETRIES))
+
     # TODO: support async execution of chat.
-    response = chat_model.chat(messages)
+    response = None
+    for attempt in range(num_retries + 1):
+        try:
+            response = chat_model.chat(messages)
+            if output_schema is not None and len(response.tool_calls) == 0:
+                response = _generate_structured_output(response, output_schema)
+        except Exception as e:  # noqa: PERF203
+            if error_handling_strategy == ErrorHandlingStrategy.IGNORE:
+                _logger.warning(
+                    f"Chat request {initial_request_id} failed with error: 
{e}, ignored."
+                )
+                return
+            elif error_handling_strategy == ErrorHandlingStrategy.RETRY:
+                if attempt == num_retries:
+                    raise
+                _logger.warning(
+                    f"Chat request {initial_request_id} failed with error: 
{e}, retrying {attempt} / {num_retries}."
+                )
+            else:
+                _logger.debug(
+                    f"Chat request {initial_request_id} failed, the input chat 
messages are {messages}."
+                )
+                raise
 
     if (
         len(response.tool_calls) > 0
     ):  # generate tool request event according tool calls in response
-        _handle_tool_calls(response, initial_request_id, model, messages, ctx)
+        _handle_tool_calls(
+            response, initial_request_id, model, messages, output_schema, ctx
+        )
     else:  # if there is no tool call generated, return chat response directly
-        _clear_tool_call_context(ctx.sensory_memory, initial_request_id)
-
         ctx.send_event(
             ChatResponseEvent(
                 request_id=initial_request_id,
@@ -170,6 +224,7 @@ def _process_chat_request(event: ChatRequestEvent, ctx: 
RunnerContext) -> None:
         initial_request_id=event.id,
         model=event.model,
         messages=event.messages,
+        output_schema=event.output_schema,
         ctx=ctx,
     )
 
@@ -180,7 +235,7 @@ def _process_tool_response(event: ToolResponseEvent, ctx: 
RunnerContext) -> None
     request_id = event.request_id
 
     # get correspond tool request event context
-    tool_request_event_context = _remove_tool_request_event_context(
+    tool_request_event_context = _get_tool_request_event_context(
         sensory_memory, request_id
     )
     initial_request_id = tool_request_event_context["initial_request_id"]
@@ -206,6 +261,7 @@ def _process_tool_response(event: ToolResponseEvent, ctx: 
RunnerContext) -> None
         initial_request_id=initial_request_id,
         model=tool_request_event_context["model"],
         messages=messages,
+        output_schema=tool_request_event_context["output_schema"],
         ctx=ctx,
     )
 
diff --git a/python/flink_agents/plan/tests/resources/action.json 
b/python/flink_agents/plan/tests/resources/action.json
index 2c52d83..e6de190 100644
--- a/python/flink_agents/plan/tests/resources/action.json
+++ b/python/flink_agents/plan/tests/resources/action.json
@@ -11,7 +11,7 @@
     "config": {
         "__config_type__": "python",
         "output_schema": [
-            "flink_agents.api.agents.react_agent",
+            "flink_agents.api.agents.types",
             "OutputSchema",
             {
                 "output_schema": {

Reply via email to