This is an automated email from the ASF dual-hosted git repository.

sxnan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git


The following commit(s) were added to refs/heads/main by this push:
     new 281844ef [plan] Support retry interval for chat action (#602)
281844ef is described below

commit 281844ef84d481e6c6d958d416e4df15f7004d7a
Author: Eugene <[email protected]>
AuthorDate: Thu Apr 2 17:22:25 2026 +0800

    [plan] Support retry interval for chat action (#602)
---
 .../agents/api/agents/AgentExecutionOptions.java   |   3 +
 .../flink/agents/api/event/ChatResponseEvent.java  |  17 ++
 docs/content/docs/operations/configuration.md      |   1 +
 .../flink/agents/plan/actions/ChatModelAction.java |  83 +++++-
 .../plan/actions/ChatModelActionRetryTest.java     | 314 +++++++++++++++++++++
 python/flink_agents/api/core_options.py            |   6 +
 python/flink_agents/api/events/chat_event.py       |   6 +
 python/flink_agents/api/runner_context.py          |  10 +-
 .../flink_agents/plan/actions/chat_model_action.py |  79 +++++-
 python/flink_agents/plan/tests/actions/__init__.py |  17 ++
 .../tests/actions/test_chat_model_action_retry.py  | 250 ++++++++++++++++
 python/flink_agents/runtime/local_runner.py        |  10 +-
 12 files changed, 783 insertions(+), 13 deletions(-)

diff --git 
a/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
 
b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
index 69ba2a02..2b8751d4 100644
--- 
a/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
+++ 
b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
@@ -30,6 +30,9 @@ public class AgentExecutionOptions {
     public static final ConfigOption<Integer> MAX_RETRIES =
             new ConfigOption<>("max-retries", Integer.class, 3);
 
+    public static final ConfigOption<Integer> RETRY_WAIT_INTERVAL =
+            new ConfigOption<>("retry-wait-interval", Integer.class, 1);
+
     public static final ConfigOption<Integer> NUM_ASYNC_THREADS =
             new ConfigOption<>(
                     "num-async-threads",
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java 
b/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java
index 9041e5a4..89de3453 100644
--- a/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java
+++ b/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java
@@ -26,10 +26,19 @@ import java.util.UUID;
 public class ChatResponseEvent extends Event {
     private final UUID requestId;
     private final ChatMessage response;
+    private final int retryCount;
+    private final int totalRetryWaitSec;
 
     public ChatResponseEvent(UUID requestId, ChatMessage response) {
+        this(requestId, response, 0, 0);
+    }
+
+    public ChatResponseEvent(
+            UUID requestId, ChatMessage response, int retryCount, int 
totalRetryWaitSec) {
         this.requestId = requestId;
         this.response = response;
+        this.retryCount = retryCount;
+        this.totalRetryWaitSec = totalRetryWaitSec;
     }
 
     public UUID getRequestId() {
@@ -39,4 +48,12 @@ public class ChatResponseEvent extends Event {
     public ChatMessage getResponse() {
         return response;
     }
+
+    public int getRetryCount() {
+        return retryCount;
+    }
+
+    public int getTotalRetryWaitSec() {
+        return totalRetryWaitSec;
+    }
 }
diff --git a/docs/content/docs/operations/configuration.md 
b/docs/content/docs/operations/configuration.md
index 4460b548..67fed224 100644
--- a/docs/content/docs/operations/configuration.md
+++ b/docs/content/docs/operations/configuration.md
@@ -130,6 +130,7 @@ Here is the list of all built-in core configuration options.
 | `prettyPrint`             | false                      | boolean             
  | Whether to enable pretty-printed JSON format for event logs. When set to 
`true`, each event is written as formatted multi-line JSON instead of JSONL 
(JSON Lines) format. {{< hint info >}}Note: enabling this option makes the log 
file no longer valid JSONL format.  {{< /hint >}} |
 | `error-handling-strategy` | ErrorHandlingStrategy.FAIL | 
ErrorHandlingStrategy | Strategy for handling errors during model requests, 
include timeout and unexpected output schema. <br/>The option value could 
be:<br/> <ul><li>`ErrorHandlingStrategy.FAIL`</li> 
<li>`ErrorHandlingStrategy.RETRY`</li> <li>`ErrorHandlingStrategy.IGNORE`</li> |
 | `max-retries`             | 3                          | int                 
  | Number of retries when using `ErrorHandlingStrategy.RETRY`.                 
                                                                                
                                                                                
                    |
+| `retry-wait-interval`     | 1                          | int                 
  | Base wait interval in seconds between retries when using 
`ErrorHandlingStrategy.RETRY`. Uses exponential backoff: the actual wait time 
for the Nth retry is `retry-wait-interval * 2^(N-1)` seconds. For example, with 
default 1s, waits are 1s, 2s, 4s, etc. Retry count and total wait time are 
reported in `ChatResponseEvent` and recorded as metrics (`retryCount`, 
`retryWaitSec`) under the connection name. |
 | `chat.async`              | true                       | boolean             
  | Whether chat asynchronously for built-in chat action.                       
                                                                                
                                                                                
                    |
 | `tool-call.async`         | true                       | boolean             
  | Whether process tool call for built-in tool call action.                    
                                                                                
                                                                                
                    |
 | `rag.async`               | true                       | boolean             
  | Whether retrieve context asynchronously for built-in context retrieval 
action.                                                                         
                                                                                
                         |
diff --git 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
index cc73f2c6..c3dc74e1 100644
--- 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
+++ 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
@@ -34,6 +34,7 @@ import org.apache.flink.agents.api.event.ChatRequestEvent;
 import org.apache.flink.agents.api.event.ChatResponseEvent;
 import org.apache.flink.agents.api.event.ToolRequestEvent;
 import org.apache.flink.agents.api.event.ToolResponseEvent;
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
 import org.apache.flink.agents.api.resource.ResourceType;
 import org.apache.flink.agents.api.tools.ToolResponse;
 import org.apache.flink.agents.plan.JavaFunction;
@@ -57,6 +58,9 @@ public class ChatModelAction {
     private static final String INITIAL_REQUEST_ID = "initialRequestId";
     private static final String MODEL = "model";
     private static final String OUTPUT_SCHEMA = "outputSchema";
+    private static final String RETRY_STATS_CONTEXT = "_RETRY_STATS_CONTEXT";
+    private static final String TOTAL_RETRY_COUNT = "totalRetryCount";
+    private static final String TOTAL_RETRY_WAIT_SEC = "totalRetryWaitSec";
 
     private static final ObjectMapper mapper = new ObjectMapper();
 
@@ -129,6 +133,50 @@ public class ChatModelAction {
         return (Map<String, Object>) toolRequestEventContext.remove(requestId);
     }
 
+    @SuppressWarnings("unchecked")
+    private static void accumulateRetryStats(
+            MemoryObject sensoryMem, UUID initialRequestId, int retryCount, 
int retryWaitSec)
+            throws Exception {
+        Map<UUID, Map<String, Long>> retryStatsContext;
+        if (sensoryMem.isExist(RETRY_STATS_CONTEXT)) {
+            retryStatsContext =
+                    (Map<UUID, Map<String, Long>>) 
sensoryMem.get(RETRY_STATS_CONTEXT).getValue();
+        } else {
+            retryStatsContext = new HashMap<>();
+        }
+        Map<String, Long> stats = 
retryStatsContext.getOrDefault(initialRequestId, new HashMap<>());
+        stats.put(TOTAL_RETRY_COUNT, stats.getOrDefault(TOTAL_RETRY_COUNT, 0L) 
+ retryCount);
+        stats.put(
+                TOTAL_RETRY_WAIT_SEC, stats.getOrDefault(TOTAL_RETRY_WAIT_SEC, 
0L) + retryWaitSec);
+        retryStatsContext.put(initialRequestId, stats);
+        sensoryMem.set(RETRY_STATS_CONTEXT, retryStatsContext);
+    }
+
+    @SuppressWarnings("unchecked")
+    private static Map<String, Long> getRetryStats(MemoryObject sensoryMem, 
UUID initialRequestId)
+            throws Exception {
+        if (!sensoryMem.isExist(RETRY_STATS_CONTEXT)) {
+            return Map.of(TOTAL_RETRY_COUNT, 0L, TOTAL_RETRY_WAIT_SEC, 0L);
+        }
+        Map<UUID, Map<String, Long>> retryStatsContext =
+                (Map<UUID, Map<String, Long>>) 
sensoryMem.get(RETRY_STATS_CONTEXT).getValue();
+        return retryStatsContext.getOrDefault(
+                initialRequestId, Map.of(TOTAL_RETRY_COUNT, 0L, 
TOTAL_RETRY_WAIT_SEC, 0L));
+    }
+
+    private static void recordRetryMetrics(
+            RunnerContext ctx, String model, int retryCount, int 
totalRetryWaitSec) {
+        if (retryCount <= 0) {
+            return;
+        }
+        FlinkAgentsMetricGroup metricGroup = ctx.getActionMetricGroup();
+        if (metricGroup != null) {
+            FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(model);
+            modelGroup.getCounter("retryCount").inc(retryCount);
+            modelGroup.getCounter("retryWaitSec").inc(totalRetryWaitSec);
+        }
+    }
+
     private static void handleToolCalls(
             ChatMessage response,
             UUID initialRequestId,
@@ -206,14 +254,21 @@ public class ChatModelAction {
         Agent.ErrorHandlingStrategy strategy =
                 
ctx.getConfig().get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY);
         int numRetries = 0;
+        int retryWaitIntervalSec = 0;
         if (strategy == Agent.ErrorHandlingStrategy.RETRY) {
             numRetries =
                     ctx.getConfig().get(AgentExecutionOptions.MAX_RETRIES) > 0
                             ? 
ctx.getConfig().get(AgentExecutionOptions.MAX_RETRIES)
                             : 0;
+            retryWaitIntervalSec =
+                    
ctx.getConfig().get(AgentExecutionOptions.RETRY_WAIT_INTERVAL) > 0
+                            ? 
ctx.getConfig().get(AgentExecutionOptions.RETRY_WAIT_INTERVAL)
+                            : 0;
         }
 
         ChatMessage response = null;
+        int actualRetryCount = 0;
+        int totalWaitTimeSec = 0;
 
         DurableCallable<ChatMessage> callable =
                 new DurableCallable<>() {
@@ -253,12 +308,19 @@ public class ChatModelAction {
                     if (attempt == numRetries) {
                         throw e;
                     }
+                    actualRetryCount = attempt + 1;
+                    int currentWaitSec = retryWaitIntervalSec * (1 << 
(actualRetryCount - 1));
                     LOG.warn(
-                            "Chat request {} failed with error: {}, retrying 
{} / {}.",
+                            "Chat request {} failed with error: {}, retrying 
{} / {}, waiting {} s.",
                             initialRequestId,
                             e,
-                            attempt,
-                            numRetries);
+                            actualRetryCount,
+                            numRetries,
+                            currentWaitSec);
+                    if (currentWaitSec > 0) {
+                        Thread.sleep(currentWaitSec * 1000L);
+                        totalWaitTimeSec += currentWaitSec;
+                    }
                 } else {
                     LOG.debug(
                             "Chat request {} failed, the input chat messages 
are {}.",
@@ -269,10 +331,23 @@ public class ChatModelAction {
             }
         }
 
+        if (actualRetryCount > 0) {
+            accumulateRetryStats(
+                    ctx.getSensoryMemory(), initialRequestId, 
actualRetryCount, totalWaitTimeSec);
+        }
+
         if (!Objects.requireNonNull(response).getToolCalls().isEmpty()) {
             handleToolCalls(response, initialRequestId, model, messages, 
outputSchema, ctx);
         } else {
-            ctx.sendEvent(new ChatResponseEvent(initialRequestId, response));
+            Map<String, Long> retryStats = 
getRetryStats(ctx.getSensoryMemory(), initialRequestId);
+            int totalRetryCount = retryStats.get(TOTAL_RETRY_COUNT).intValue();
+            int totalRetryWaitSec = 
retryStats.get(TOTAL_RETRY_WAIT_SEC).intValue();
+
+            recordRetryMetrics(ctx, chatModel.getConnection(), 
totalRetryCount, totalRetryWaitSec);
+
+            ctx.sendEvent(
+                    new ChatResponseEvent(
+                            initialRequestId, response, totalRetryCount, 
totalRetryWaitSec));
         }
     }
 
diff --git 
a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java
 
b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java
new file mode 100644
index 00000000..f9101e48
--- /dev/null
+++ 
b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.plan.actions;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.api.agents.Agent;
+import org.apache.flink.agents.api.agents.AgentExecutionOptions;
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
+import org.apache.flink.agents.api.context.DurableCallable;
+import org.apache.flink.agents.api.context.MemoryObject;
+import org.apache.flink.agents.api.context.RunnerContext;
+import org.apache.flink.agents.api.event.ChatResponseEvent;
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.metrics.Counter;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.*;
+
+/** Tests for retry behavior in {@link ChatModelAction}. */
+class ChatModelActionRetryTest {
+
+    @Mock private RunnerContext mockCtx;
+
+    @Mock private BaseChatModelSetup mockChatModel;
+
+    @Mock private FlinkAgentsMetricGroup mockActionMetricGroup;
+
+    @Mock private FlinkAgentsMetricGroup mockModelMetricGroup;
+
+    @Mock private Counter mockRetryCountCounter;
+
+    @Mock private Counter mockRetryWaitSecCounter;
+
+    private MemoryObject sensoryMemory;
+    private List<Event> sentEvents;
+    private AutoCloseable mocks;
+
+    @BeforeEach
+    void setUp() throws Exception {
+        mocks = MockitoAnnotations.openMocks(this);
+        sentEvents = new ArrayList<>();
+        sensoryMemory = createStatefulMemoryObject();
+
+        // Wire up ChatModel
+        when(mockChatModel.getConnection()).thenReturn("test-connection");
+
+        // Wire up RunnerContext
+        when(mockCtx.getResource(anyString(), eq(ResourceType.CHAT_MODEL)))
+                .thenReturn(mockChatModel);
+        when(mockCtx.getSensoryMemory()).thenReturn(sensoryMemory);
+        when(mockCtx.getActionMetricGroup()).thenReturn(mockActionMetricGroup);
+        doAnswer(inv -> 
sentEvents.add(inv.getArgument(0))).when(mockCtx).sendEvent(any());
+        when(mockCtx.<ChatMessage>durableExecute(any()))
+                .thenAnswer(inv -> 
inv.<DurableCallable<ChatMessage>>getArgument(0).call());
+
+        // Wire up metric group chain
+        
when(mockActionMetricGroup.getSubGroup(anyString())).thenReturn(mockModelMetricGroup);
+        
when(mockModelMetricGroup.getCounter("retryCount")).thenReturn(mockRetryCountCounter);
+        
when(mockModelMetricGroup.getCounter("retryWaitSec")).thenReturn(mockRetryWaitSecCounter);
+    }
+
+    @AfterEach
+    void tearDown() throws Exception {
+        if (mocks != null) {
+            mocks.close();
+        }
+    }
+
+    @Test
+    void chatSucceedsWithoutRetry_retryCountIsZero() throws Exception {
+        configureRetryStrategy(3, 1);
+        when(mockChatModel.chat(any(), any()))
+                .thenReturn(new ChatMessage(MessageRole.ASSISTANT, "hello"));
+
+        UUID requestId = UUID.randomUUID();
+        ChatModelAction.chat(
+                requestId,
+                "test-model",
+                List.of(new ChatMessage(MessageRole.USER, "hi")),
+                null,
+                mockCtx);
+
+        assertThat(sentEvents).hasSize(1);
+        ChatResponseEvent responseEvent = (ChatResponseEvent) 
sentEvents.get(0);
+        assertThat(responseEvent.getRetryCount()).isEqualTo(0);
+        assertThat(responseEvent.getTotalRetryWaitSec()).isEqualTo(0);
+
+        // No retry metrics should be recorded
+        verify(mockActionMetricGroup, never()).getSubGroup(anyString());
+    }
+
+    @Test
+    void chatRetriesWithExponentialBackoff() throws Exception {
+        // 1 second base interval; fail once then succeed -> wait 1s (1 * 2^0)
+        configureRetryStrategy(3, 1);
+
+        AtomicInteger callCount = new AtomicInteger(0);
+        when(mockChatModel.chat(any(), any()))
+                .thenAnswer(
+                        inv -> {
+                            int count = callCount.incrementAndGet();
+                            if (count <= 1) {
+                                throw new RuntimeException("transient error");
+                            }
+                            return new ChatMessage(MessageRole.ASSISTANT, 
"success");
+                        });
+
+        UUID requestId = UUID.randomUUID();
+
+        long startTime = System.currentTimeMillis();
+        ChatModelAction.chat(
+                requestId,
+                "test-model",
+                List.of(new ChatMessage(MessageRole.USER, "hi")),
+                null,
+                mockCtx);
+        long elapsed = System.currentTimeMillis() - startTime;
+
+        assertThat(sentEvents).hasSize(1);
+        ChatResponseEvent responseEvent = (ChatResponseEvent) 
sentEvents.get(0);
+        assertThat(responseEvent.getRetryCount()).isEqualTo(1);
+        // Exponential backoff: 1000ms (1s * 2^0) total
+        // 1 retry with 1s interval = 1s total
+        assertThat(responseEvent.getTotalRetryWaitSec()).isEqualTo(1);
+        assertThat(elapsed).isGreaterThanOrEqualTo(1000L);
+
+        // Verify metrics recorded under connection name
+        
verify(mockActionMetricGroup).getSubGroup(mockChatModel.getConnection());
+        verify(mockRetryCountCounter).inc(1);
+        verify(mockRetryWaitSecCounter).inc(1);
+    }
+
+    @Test
+    void chatExhaustsRetriesAndThrows() {
+        configureRetryStrategy(2, 0);
+
+        when(mockChatModel.chat(any(), any())).thenThrow(new 
RuntimeException("persistent error"));
+
+        UUID requestId = UUID.randomUUID();
+
+        assertThatThrownBy(
+                        () ->
+                                ChatModelAction.chat(
+                                        requestId,
+                                        "test-model",
+                                        List.of(new 
ChatMessage(MessageRole.USER, "hi")),
+                                        null,
+                                        mockCtx))
+                .isInstanceOf(RuntimeException.class)
+                .hasMessage("persistent error");
+
+        assertThat(sentEvents).isEmpty();
+    }
+
+    @Test
+    void chatResponseEventDefaultConstructorHasZeroRetryInfo() {
+        UUID requestId = UUID.randomUUID();
+        ChatMessage msg = new ChatMessage(MessageRole.ASSISTANT, "test");
+        ChatResponseEvent event = new ChatResponseEvent(requestId, msg);
+
+        assertThat(event.getRetryCount()).isEqualTo(0);
+        assertThat(event.getTotalRetryWaitSec()).isEqualTo(0);
+        assertThat(event.getRequestId()).isEqualTo(requestId);
+    }
+
+    @Test
+    void chatResponseEventFullConstructorCarriesRetryInfo() {
+        UUID requestId = UUID.randomUUID();
+        ChatMessage msg = new ChatMessage(MessageRole.ASSISTANT, "test");
+        ChatResponseEvent event = new ChatResponseEvent(requestId, msg, 5, 31);
+
+        assertThat(event.getRetryCount()).isEqualTo(5);
+        assertThat(event.getTotalRetryWaitSec()).isEqualTo(31);
+    }
+
+    @Test
+    void retryWaitIntervalDefaultValue() {
+        
assertThat(AgentExecutionOptions.RETRY_WAIT_INTERVAL.getDefaultValue()).isEqualTo(1);
+    }
+
+    // --- Helper methods ---
+
+    private void configureRetryStrategy(int maxRetries, int waitIntervalSec) {
+        when(mockCtx.getConfig())
+                .thenAnswer(
+                        inv -> {
+                            // Return a mock ReadableConfiguration
+                            return new 
org.apache.flink.agents.api.configuration
+                                    .ReadableConfiguration() {
+                                @Override
+                                @SuppressWarnings("unchecked")
+                                public <T> T get(
+                                        
org.apache.flink.agents.api.configuration.ConfigOption<T>
+                                                option) {
+                                    if (option == 
AgentExecutionOptions.ERROR_HANDLING_STRATEGY) {
+                                        return (T) 
Agent.ErrorHandlingStrategy.RETRY;
+                                    }
+                                    if (option == 
AgentExecutionOptions.MAX_RETRIES) {
+                                        return (T) Integer.valueOf(maxRetries);
+                                    }
+                                    if (option == 
AgentExecutionOptions.RETRY_WAIT_INTERVAL) {
+                                        return (T) 
Integer.valueOf(waitIntervalSec);
+                                    }
+                                    if (option == 
AgentExecutionOptions.CHAT_ASYNC) {
+                                        return (T) Boolean.FALSE;
+                                    }
+                                    return option.getDefaultValue();
+                                }
+
+                                @Override
+                                public Integer getInt(String key, Integer 
defaultValue) {
+                                    return defaultValue;
+                                }
+
+                                @Override
+                                public Long getLong(String key, Long 
defaultValue) {
+                                    return defaultValue;
+                                }
+
+                                @Override
+                                public Float getFloat(String key, Float 
defaultValue) {
+                                    return defaultValue;
+                                }
+
+                                @Override
+                                public Double getDouble(String key, Double 
defaultValue) {
+                                    return defaultValue;
+                                }
+
+                                @Override
+                                public Boolean getBool(String key, Boolean 
defaultValue) {
+                                    return defaultValue;
+                                }
+
+                                @Override
+                                public String getStr(String key, String 
defaultValue) {
+                                    return defaultValue;
+                                }
+                            };
+                        });
+    }
+
+    /**
+     * Creates a stateful MemoryObject backed by a HashMap, supporting 
isExist/get/set operations
+     * needed by the retry stats accumulation logic.
+     */
+    private static MemoryObject createStatefulMemoryObject() {
+        Map<String, Object> store = new HashMap<>();
+
+        MemoryObject memoryObject = mock(MemoryObject.class);
+
+        when(memoryObject.isExist(anyString()))
+                .thenAnswer(inv -> 
store.containsKey(inv.<String>getArgument(0)));
+
+        try {
+            when(memoryObject.get(anyString()))
+                    .thenAnswer(
+                            inv -> {
+                                String path = inv.getArgument(0);
+                                Object value = store.get(path);
+                                if (value == null) {
+                                    throw new Exception("Path not found: " + 
path);
+                                }
+                                MemoryObject valueObj = 
mock(MemoryObject.class);
+                                when(valueObj.getValue()).thenReturn(value);
+                                return valueObj;
+                            });
+
+            when(memoryObject.set(anyString(), any()))
+                    .thenAnswer(
+                            inv -> {
+                                store.put(inv.getArgument(0), 
inv.getArgument(1));
+                                return null;
+                            });
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+
+        return memoryObject;
+    }
+}
diff --git a/python/flink_agents/api/core_options.py 
b/python/flink_agents/api/core_options.py
index eb503de2..9245e78c 100644
--- a/python/flink_agents/api/core_options.py
+++ b/python/flink_agents/api/core_options.py
@@ -107,6 +107,12 @@ class AgentExecutionOptions:
         default=3,
     )
 
+    RETRY_WAIT_INTERVAL = ConfigOption(
+        key="retry-wait-interval",
+        config_type=int,
+        default=1,
+    )
+
     NUM_ASYNC_THREADS = ConfigOption(
         key="num-async-threads",
         config_type=int,
diff --git a/python/flink_agents/api/events/chat_event.py 
b/python/flink_agents/api/events/chat_event.py
index 5e1237f5..063548d0 100644
--- a/python/flink_agents/api/events/chat_event.py
+++ b/python/flink_agents/api/events/chat_event.py
@@ -50,7 +50,13 @@ class ChatResponseEvent(Event):
         The id of the request event.
     response : ChatMessage
         The response from the chat model.
+    retry_count : int
+        The total number of retries across all tool call rounds.
+    total_retry_wait_sec : int
+        The total time spent waiting during retries in seconds.
     """
 
     request_id: UUID
     response: ChatMessage
+    retry_count: int = 0
+    total_retry_wait_sec: int = 0
diff --git a/python/flink_agents/api/runner_context.py 
b/python/flink_agents/api/runner_context.py
index 4e3bc303..7dfd5437 100644
--- a/python/flink_agents/api/runner_context.py
+++ b/python/flink_agents/api/runner_context.py
@@ -169,24 +169,26 @@ class RunnerContext(ABC):
 
     @property
     @abstractmethod
-    def agent_metric_group(self) -> MetricGroup:
+    def agent_metric_group(self) -> MetricGroup | None:
         """Get the metric group for flink agents.
 
         Returns:
         -------
-        MetricGroup
+        MetricGroup | None
             The metric group shared across all actions.
+            May return None when not running on Flink.
         """
 
     @property
     @abstractmethod
-    def action_metric_group(self) -> MetricGroup:
+    def action_metric_group(self) -> MetricGroup | None:
         """Get the individual metric group dedicated for each action.
 
         Returns:
         -------
-        MetricGroup
+        MetricGroup | None
             The individual metric group specific to the current action.
+            May return None when not running on Flink.
         """
 
     @abstractmethod
diff --git a/python/flink_agents/plan/actions/chat_model_action.py 
b/python/flink_agents/plan/actions/chat_model_action.py
index 02e92ed8..210289db 100644
--- a/python/flink_agents/plan/actions/chat_model_action.py
+++ b/python/flink_agents/plan/actions/chat_model_action.py
@@ -18,6 +18,7 @@
 import copy
 import json
 import logging
+import time
 from typing import TYPE_CHECKING, Dict, List, cast
 from uuid import UUID
 
@@ -47,6 +48,7 @@ if TYPE_CHECKING:
 
 _TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT"
 _TOOL_REQUEST_EVENT_CONTEXT = "_TOOL_REQUEST_EVENT_CONTEXT"
+_RETRY_STATS_CONTEXT = "_RETRY_STATS_CONTEXT"
 
 _logger = logging.getLogger(__name__)
 
@@ -111,6 +113,49 @@ def _get_tool_request_event_context(
     return removed_context
 
 
+def _accumulate_retry_stats(
+    sensory_memory: MemoryObject,
+    initial_request_id: UUID,
+    retry_count: int,
+    retry_wait_sec: int,
+) -> None:
+    """Accumulate retry stats for a given initial request across tool call 
rounds."""
+    retry_stats_context = sensory_memory.get(_RETRY_STATS_CONTEXT) or {}
+    stats = retry_stats_context.get(initial_request_id, {
+        "total_retry_count": 0,
+        "total_retry_wait_sec": 0,
+    })
+    stats["total_retry_count"] += retry_count
+    stats["total_retry_wait_sec"] += retry_wait_sec
+    retry_stats_context[initial_request_id] = stats
+    sensory_memory.set(_RETRY_STATS_CONTEXT, retry_stats_context)
+
+
+def _get_retry_stats(
+    sensory_memory: MemoryObject,
+    initial_request_id: UUID,
+) -> dict:
+    """Get accumulated retry stats for a given initial request."""
+    retry_stats_context = sensory_memory.get(_RETRY_STATS_CONTEXT) or {}
+    return retry_stats_context.get(initial_request_id, {
+        "total_retry_count": 0,
+        "total_retry_wait_sec": 0,
+    })
+
+
+def _record_retry_metrics(
+    ctx: RunnerContext, model: str, retry_count: int, total_retry_wait_sec: int
+) -> None:
+    """Record retry metrics under the connection name if retries occurred."""
+    if retry_count <= 0:
+        return
+    metric_group = ctx.action_metric_group
+    if metric_group is not None:
+        model_group = metric_group.get_sub_group(model)
+        model_group.get_counter("retryCount").inc(retry_count)
+        model_group.get_counter("retryWaitSec").inc(total_retry_wait_sec)
+
+
 def _handle_tool_calls(
     response: ChatMessage,
     initial_request_id: UUID,
@@ -185,10 +230,20 @@ async def chat(
 
     error_handling_strategy = 
ctx.config.get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY)
     num_retries = 0
+    retry_wait_interval_sec = 0
     if error_handling_strategy == ErrorHandlingStrategy.RETRY:
         num_retries = max(0, ctx.config.get(AgentExecutionOptions.MAX_RETRIES))
+        retry_wait_interval_config = ctx.config.get(
+            AgentExecutionOptions.RETRY_WAIT_INTERVAL
+        )
+        retry_wait_interval_sec = (
+            max(0, retry_wait_interval_config) if retry_wait_interval_config 
else 0
+        )
 
     response = None
+    actual_retry_count = 0
+    total_wait_time_sec = 0
+
     for attempt in range(num_retries + 1):
         try:
             if chat_async:
@@ -210,15 +265,29 @@ async def chat(
             elif error_handling_strategy == ErrorHandlingStrategy.RETRY:
                 if attempt == num_retries:
                     raise
+                actual_retry_count = attempt + 1
+                current_wait_sec = retry_wait_interval_sec * (
+                    1 << (actual_retry_count - 1)
+                )
                 _logger.warning(
-                    f"Chat request {initial_request_id} failed with error: 
{e}, retrying {attempt} / {num_retries}."
+                    f"Chat request {initial_request_id} failed with error: 
{e}, "
+                    f"retrying {actual_retry_count} / {num_retries}, "
+                    f"waiting {current_wait_sec} s."
                 )
+                if current_wait_sec > 0:
+                    time.sleep(current_wait_sec)
+                    total_wait_time_sec += current_wait_sec
             else:
                 _logger.debug(
                     f"Chat request {initial_request_id} failed, the input chat 
messages are {messages}."
                 )
                 raise
 
+    if actual_retry_count > 0:
+        _accumulate_retry_stats(
+            ctx.sensory_memory, initial_request_id, actual_retry_count, 
total_wait_time_sec
+        )
+
     if (
         len(response.tool_calls) > 0
     ):  # generate tool request event according tool calls in response
@@ -226,10 +295,18 @@ async def chat(
             response, initial_request_id, model, messages, output_schema, ctx
         )
     else:  # if there is no tool call generated, return chat response directly
+        retry_stats = _get_retry_stats(ctx.sensory_memory, initial_request_id)
+        total_retry_count = retry_stats["total_retry_count"]
+        total_retry_wait_sec = retry_stats["total_retry_wait_sec"]
+
+        _record_retry_metrics(ctx, chat_model.connection, total_retry_count, 
total_retry_wait_sec)
+
         ctx.send_event(
             ChatResponseEvent(
                 request_id=initial_request_id,
                 response=response,
+                retry_count=total_retry_count,
+                total_retry_wait_sec=total_retry_wait_sec,
             )
         )
 
diff --git a/python/flink_agents/plan/tests/actions/__init__.py 
b/python/flink_agents/plan/tests/actions/__init__.py
new file mode 100644
index 00000000..e154fadd
--- /dev/null
+++ b/python/flink_agents/plan/tests/actions/__init__.py
@@ -0,0 +1,17 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
diff --git 
a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py 
b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py
new file mode 100644
index 00000000..a6014323
--- /dev/null
+++ b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py
@@ -0,0 +1,250 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+"""Tests for retry behavior in chat_model_action."""
+
+import asyncio
+import time
+from typing import Any, Sequence
+from unittest.mock import MagicMock
+from uuid import uuid4
+
+import pytest
+
+from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.core_options import (
+    AgentExecutionOptions,
+    ErrorHandlingStrategy,
+)
+from flink_agents.api.events.chat_event import ChatResponseEvent
+from flink_agents.api.metric_group import Counter, MetricGroup
+from flink_agents.plan.actions.chat_model_action import chat
+
+# ============================================================================
+# Mock infrastructure
+# ============================================================================
+
+
+class _MockCounter(Counter):
+    """Mock counter that tracks inc calls."""
+
+    def __init__(self) -> None:
+        self._count = 0
+
+    def inc(self, n: int = 1) -> None:
+        self._count += n
+
+    def dec(self, n: int = 1) -> None:
+        self._count -= n
+
+    def get_count(self) -> int:
+        return self._count
+
+
+class _MockMetricGroup(MetricGroup):
+    """Mock metric group that tracks sub-groups and counters."""
+
+    def __init__(self) -> None:
+        self._sub_groups: dict[str, _MockMetricGroup] = {}
+        self._counters: dict[str, _MockCounter] = {}
+
+    def get_sub_group(self, name: str) -> "_MockMetricGroup":
+        if name not in self._sub_groups:
+            self._sub_groups[name] = _MockMetricGroup()
+        return self._sub_groups[name]
+
+    def get_counter(self, name: str) -> _MockCounter:
+        if name not in self._counters:
+            self._counters[name] = _MockCounter()
+        return self._counters[name]
+
+    def get_meter(self, name: str) -> Any:
+        return MagicMock()
+
+    def get_gauge(self, name: str) -> Any:
+        return MagicMock()
+
+    def get_histogram(self, name: str, window_size: int = 100) -> Any:
+        return MagicMock()
+
+
+class _MockMemoryObject:
+    """Simple dict-backed memory object for testing."""
+
+    def __init__(self) -> None:
+        self._store: dict[str, Any] = {}
+
+    def get(self, path: str) -> Any:
+        return self._store.get(path)
+
+    def set(self, path: str, value: Any) -> None:
+        self._store[path] = value
+
+
+def _create_mock_runner_context(
+    chat_model: Any,
+    max_retries: int = 3,
+    retry_wait_interval_sec: int = 1,
+) -> tuple[MagicMock, list, _MockMetricGroup, _MockMemoryObject]:
+    """Create a mock RunnerContext with configurable retry settings.
+
+    Returns (ctx, sent_events, action_metric_group, sensory_memory).
+    """
+    sent_events = []
+    metric_group = _MockMetricGroup()
+    sensory_memory = _MockMemoryObject()
+
+    config = MagicMock()
+    option_values = {
+        id(AgentExecutionOptions.ERROR_HANDLING_STRATEGY): 
ErrorHandlingStrategy.RETRY,
+        id(AgentExecutionOptions.MAX_RETRIES): max_retries,
+        id(AgentExecutionOptions.RETRY_WAIT_INTERVAL): retry_wait_interval_sec,
+        id(AgentExecutionOptions.CHAT_ASYNC): False,
+    }
+    config.get = MagicMock(
+        side_effect=lambda option: option_values.get(id(option), 
option.get_default_value())
+    )
+
+    ctx = MagicMock()
+    ctx.config = config
+    ctx.sensory_memory = sensory_memory
+    ctx.action_metric_group = metric_group
+    ctx.send_event = MagicMock(side_effect=lambda e: sent_events.append(e))
+    ctx.get_resource = MagicMock(return_value=chat_model)
+    ctx.durable_execute = MagicMock(side_effect=lambda fn, *args, **kwargs: 
fn(*args, **kwargs))
+
+    return ctx, sent_events, metric_group, sensory_memory
+
+
+# ============================================================================
+# Tests
+# ============================================================================
+
+
+class TestChatModelActionRetry:
+    """Tests for retry behavior in chat()."""
+
+    def test_chat_succeeds_without_retry(self) -> None:
+        """No retry needed: retry_count=0, total_retry_wait_sec=0, no 
metrics."""
+        chat_model = MagicMock()
+        chat_model.chat = MagicMock(
+            return_value=ChatMessage(role=MessageRole.ASSISTANT, 
content="hello")
+        )
+
+        ctx, sent_events, metric_group, _ = 
_create_mock_runner_context(chat_model)
+        request_id = uuid4()
+
+        asyncio.run(
+            chat(request_id, chat_model.connection, 
[ChatMessage(role=MessageRole.USER, content="hi")], None, ctx)
+        )
+
+        assert len(sent_events) == 1
+        event = sent_events[0]
+        assert isinstance(event, ChatResponseEvent)
+        assert event.retry_count == 0
+        assert event.total_retry_wait_sec == 0
+
+        # No retry metrics should be recorded
+        assert len(metric_group._sub_groups) == 0
+
+    def test_chat_retries_with_exponential_backoff(self) -> None:
+        """Fail once then succeed: 1s interval, 1 retry -> wait 1s (1 * 
2^0)."""
+        call_count = 0
+
+        def mock_chat(messages: Sequence[ChatMessage], **kwargs: Any) -> 
ChatMessage:
+            nonlocal call_count
+            call_count += 1
+            if call_count <= 1:
+                err_msg = "transient error"
+                raise RuntimeError(err_msg)
+            return ChatMessage(role=MessageRole.ASSISTANT, content="success")
+
+        chat_model = MagicMock()
+        chat_model.chat = mock_chat
+
+        ctx, sent_events, metric_group, _ = _create_mock_runner_context(
+            chat_model, max_retries=3, retry_wait_interval_sec=1
+        )
+        request_id = uuid4()
+
+        start = time.monotonic()
+        asyncio.run(
+            chat(request_id, "test-model", [ChatMessage(role=MessageRole.USER, 
content="hi")], None, ctx)
+        )
+        elapsed = time.monotonic() - start
+
+        assert len(sent_events) == 1
+        event = sent_events[0]
+        assert isinstance(event, ChatResponseEvent)
+        assert event.retry_count == 1
+        # 1s config. Exponential: 1s (2^0) = 1s total
+        assert event.total_retry_wait_sec == 1
+        assert elapsed >= 1.0
+
+        # Verify metrics recorded under connection name
+        model_group = metric_group.get_sub_group(chat_model.connection)
+        assert model_group.get_counter("retryCount").get_count() == 1
+        assert model_group.get_counter("retryWaitSec").get_count() == 1
+
+    def test_chat_exhausts_retries_and_raises(self) -> None:
+        """All retries exhausted: exception raised, no event sent."""
+        chat_model = MagicMock()
+        chat_model.chat = MagicMock(side_effect=RuntimeError("persistent 
error"))
+
+        ctx, sent_events, _, _ = _create_mock_runner_context(
+            chat_model, max_retries=2, retry_wait_interval_sec=0
+        )
+        request_id = uuid4()
+
+        with pytest.raises(RuntimeError, match="persistent error"):
+            asyncio.run(
+                chat(request_id, "test-model", 
[ChatMessage(role=MessageRole.USER, content="hi")], None, ctx)
+            )
+
+        assert len(sent_events) == 0
+
+
+class TestChatResponseEventRetryFields:
+    """Tests for ChatResponseEvent retry fields."""
+
+    def test_default_retry_fields(self) -> None:
+        """Default construction has retry_count=0, total_retry_wait_sec=0."""
+        event = ChatResponseEvent(
+            request_id=uuid4(),
+            response=ChatMessage(role=MessageRole.ASSISTANT, content="test"),
+        )
+        assert event.retry_count == 0
+        assert event.total_retry_wait_sec == 0
+
+    def test_with_retry_fields(self) -> None:
+        """Full construction carries retry info."""
+        event = ChatResponseEvent(
+            request_id=uuid4(),
+            response=ChatMessage(role=MessageRole.ASSISTANT, content="test"),
+            retry_count=5,
+            total_retry_wait_sec=31,
+        )
+        assert event.retry_count == 5
+        assert event.total_retry_wait_sec == 31
+
+
+class TestRetryWaitIntervalConfig:
+    """Tests for RETRY_WAIT_INTERVAL configuration."""
+
+    def test_default_value(self) -> None:
+        """Default value is 1 second."""
+        assert AgentExecutionOptions.RETRY_WAIT_INTERVAL.get_default_value() 
== 1
diff --git a/python/flink_agents/runtime/local_runner.py 
b/python/flink_agents/runtime/local_runner.py
index b8eaf3f0..e4124439 100644
--- a/python/flink_agents/runtime/local_runner.py
+++ b/python/flink_agents/runtime/local_runner.py
@@ -167,17 +167,19 @@ class LocalRunnerContext(RunnerContext):
 
     @property
     @override
-    def agent_metric_group(self) -> MetricGroup:
+    def agent_metric_group(self) -> None:
         # TODO: Support metric mechanism for local agent execution.
         err_msg = "Metric mechanism is not supported for local agent execution 
yet."
-        raise NotImplementedError(err_msg)
+        logger.warning(err_msg)
+        return
 
     @property
     @override
-    def action_metric_group(self) -> MetricGroup:
+    def action_metric_group(self) -> None:
         # TODO: Support metric mechanism for local agent execution.
         err_msg = "Metric mechanism is not supported for local agent execution 
yet."
-        raise NotImplementedError(err_msg)
+        logger.warning(err_msg)
+        return
 
     @override
     def durable_execute(


Reply via email to