This is an automated email from the ASF dual-hosted git repository.
sxnan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new 281844ef [plan] Support retry interval for chat action (#602)
281844ef is described below
commit 281844ef84d481e6c6d958d416e4df15f7004d7a
Author: Eugene <[email protected]>
AuthorDate: Thu Apr 2 17:22:25 2026 +0800
[plan] Support retry interval for chat action (#602)
---
.../agents/api/agents/AgentExecutionOptions.java | 3 +
.../flink/agents/api/event/ChatResponseEvent.java | 17 ++
docs/content/docs/operations/configuration.md | 1 +
.../flink/agents/plan/actions/ChatModelAction.java | 83 +++++-
.../plan/actions/ChatModelActionRetryTest.java | 314 +++++++++++++++++++++
python/flink_agents/api/core_options.py | 6 +
python/flink_agents/api/events/chat_event.py | 6 +
python/flink_agents/api/runner_context.py | 10 +-
.../flink_agents/plan/actions/chat_model_action.py | 79 +++++-
python/flink_agents/plan/tests/actions/__init__.py | 17 ++
.../tests/actions/test_chat_model_action_retry.py | 250 ++++++++++++++++
python/flink_agents/runtime/local_runner.py | 10 +-
12 files changed, 783 insertions(+), 13 deletions(-)
diff --git
a/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
index 69ba2a02..2b8751d4 100644
---
a/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
+++
b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
@@ -30,6 +30,9 @@ public class AgentExecutionOptions {
public static final ConfigOption<Integer> MAX_RETRIES =
new ConfigOption<>("max-retries", Integer.class, 3);
+ public static final ConfigOption<Integer> RETRY_WAIT_INTERVAL =
+ new ConfigOption<>("retry-wait-interval", Integer.class, 1);
+
public static final ConfigOption<Integer> NUM_ASYNC_THREADS =
new ConfigOption<>(
"num-async-threads",
diff --git
a/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java
b/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java
index 9041e5a4..89de3453 100644
--- a/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java
+++ b/api/src/main/java/org/apache/flink/agents/api/event/ChatResponseEvent.java
@@ -26,10 +26,19 @@ import java.util.UUID;
public class ChatResponseEvent extends Event {
private final UUID requestId;
private final ChatMessage response;
+ private final int retryCount;
+ private final int totalRetryWaitSec;
public ChatResponseEvent(UUID requestId, ChatMessage response) {
+ this(requestId, response, 0, 0);
+ }
+
+ public ChatResponseEvent(
+ UUID requestId, ChatMessage response, int retryCount, int
totalRetryWaitSec) {
this.requestId = requestId;
this.response = response;
+ this.retryCount = retryCount;
+ this.totalRetryWaitSec = totalRetryWaitSec;
}
public UUID getRequestId() {
@@ -39,4 +48,12 @@ public class ChatResponseEvent extends Event {
public ChatMessage getResponse() {
return response;
}
+
+ public int getRetryCount() {
+ return retryCount;
+ }
+
+ public int getTotalRetryWaitSec() {
+ return totalRetryWaitSec;
+ }
}
diff --git a/docs/content/docs/operations/configuration.md
b/docs/content/docs/operations/configuration.md
index 4460b548..67fed224 100644
--- a/docs/content/docs/operations/configuration.md
+++ b/docs/content/docs/operations/configuration.md
@@ -130,6 +130,7 @@ Here is the list of all built-in core configuration options.
| `prettyPrint` | false | boolean
| Whether to enable pretty-printed JSON format for event logs. When set to
`true`, each event is written as formatted multi-line JSON instead of JSONL
(JSON Lines) format. {{< hint info >}}Note: enabling this option makes the log
file no longer valid JSONL format. {{< /hint >}} |
| `error-handling-strategy` | ErrorHandlingStrategy.FAIL |
ErrorHandlingStrategy | Strategy for handling errors during model requests,
include timeout and unexpected output schema. <br/>The option value could
be:<br/> <ul><li>`ErrorHandlingStrategy.FAIL`</li>
<li>`ErrorHandlingStrategy.RETRY`</li> <li>`ErrorHandlingStrategy.IGNORE`</li> |
| `max-retries` | 3 | int
| Number of retries when using `ErrorHandlingStrategy.RETRY`.
|
+| `retry-wait-interval` | 1 | int
| Base wait interval in seconds between retries when using
`ErrorHandlingStrategy.RETRY`. Uses exponential backoff: the actual wait time
for the Nth retry is `retry-wait-interval * 2^(N-1)` seconds. For example, with
default 1s, waits are 1s, 2s, 4s, etc. Retry count and total wait time are
reported in `ChatResponseEvent` and recorded as metrics (`retryCount`,
`retryWaitSec`) under the connection name. |
| `chat.async` | true | boolean
| Whether chat asynchronously for built-in chat action.
|
| `tool-call.async` | true | boolean
| Whether process tool call for built-in tool call action.
|
| `rag.async` | true | boolean
| Whether retrieve context asynchronously for built-in context retrieval
action.
|
diff --git
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
index cc73f2c6..c3dc74e1 100644
---
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
+++
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
@@ -34,6 +34,7 @@ import org.apache.flink.agents.api.event.ChatRequestEvent;
import org.apache.flink.agents.api.event.ChatResponseEvent;
import org.apache.flink.agents.api.event.ToolRequestEvent;
import org.apache.flink.agents.api.event.ToolResponseEvent;
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.api.tools.ToolResponse;
import org.apache.flink.agents.plan.JavaFunction;
@@ -57,6 +58,9 @@ public class ChatModelAction {
private static final String INITIAL_REQUEST_ID = "initialRequestId";
private static final String MODEL = "model";
private static final String OUTPUT_SCHEMA = "outputSchema";
+ private static final String RETRY_STATS_CONTEXT = "_RETRY_STATS_CONTEXT";
+ private static final String TOTAL_RETRY_COUNT = "totalRetryCount";
+ private static final String TOTAL_RETRY_WAIT_SEC = "totalRetryWaitSec";
private static final ObjectMapper mapper = new ObjectMapper();
@@ -129,6 +133,50 @@ public class ChatModelAction {
return (Map<String, Object>) toolRequestEventContext.remove(requestId);
}
+ @SuppressWarnings("unchecked")
+ private static void accumulateRetryStats(
+ MemoryObject sensoryMem, UUID initialRequestId, int retryCount,
int retryWaitSec)
+ throws Exception {
+ Map<UUID, Map<String, Long>> retryStatsContext;
+ if (sensoryMem.isExist(RETRY_STATS_CONTEXT)) {
+ retryStatsContext =
+ (Map<UUID, Map<String, Long>>)
sensoryMem.get(RETRY_STATS_CONTEXT).getValue();
+ } else {
+ retryStatsContext = new HashMap<>();
+ }
+ Map<String, Long> stats =
retryStatsContext.getOrDefault(initialRequestId, new HashMap<>());
+ stats.put(TOTAL_RETRY_COUNT, stats.getOrDefault(TOTAL_RETRY_COUNT, 0L)
+ retryCount);
+ stats.put(
+ TOTAL_RETRY_WAIT_SEC, stats.getOrDefault(TOTAL_RETRY_WAIT_SEC,
0L) + retryWaitSec);
+ retryStatsContext.put(initialRequestId, stats);
+ sensoryMem.set(RETRY_STATS_CONTEXT, retryStatsContext);
+ }
+
+ @SuppressWarnings("unchecked")
+ private static Map<String, Long> getRetryStats(MemoryObject sensoryMem,
UUID initialRequestId)
+ throws Exception {
+ if (!sensoryMem.isExist(RETRY_STATS_CONTEXT)) {
+ return Map.of(TOTAL_RETRY_COUNT, 0L, TOTAL_RETRY_WAIT_SEC, 0L);
+ }
+ Map<UUID, Map<String, Long>> retryStatsContext =
+ (Map<UUID, Map<String, Long>>)
sensoryMem.get(RETRY_STATS_CONTEXT).getValue();
+ return retryStatsContext.getOrDefault(
+ initialRequestId, Map.of(TOTAL_RETRY_COUNT, 0L,
TOTAL_RETRY_WAIT_SEC, 0L));
+ }
+
+ private static void recordRetryMetrics(
+ RunnerContext ctx, String model, int retryCount, int
totalRetryWaitSec) {
+ if (retryCount <= 0) {
+ return;
+ }
+ FlinkAgentsMetricGroup metricGroup = ctx.getActionMetricGroup();
+ if (metricGroup != null) {
+ FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(model);
+ modelGroup.getCounter("retryCount").inc(retryCount);
+ modelGroup.getCounter("retryWaitSec").inc(totalRetryWaitSec);
+ }
+ }
+
private static void handleToolCalls(
ChatMessage response,
UUID initialRequestId,
@@ -206,14 +254,21 @@ public class ChatModelAction {
Agent.ErrorHandlingStrategy strategy =
ctx.getConfig().get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY);
int numRetries = 0;
+ int retryWaitIntervalSec = 0;
if (strategy == Agent.ErrorHandlingStrategy.RETRY) {
numRetries =
ctx.getConfig().get(AgentExecutionOptions.MAX_RETRIES) > 0
?
ctx.getConfig().get(AgentExecutionOptions.MAX_RETRIES)
: 0;
+ retryWaitIntervalSec =
+
ctx.getConfig().get(AgentExecutionOptions.RETRY_WAIT_INTERVAL) > 0
+ ?
ctx.getConfig().get(AgentExecutionOptions.RETRY_WAIT_INTERVAL)
+ : 0;
}
ChatMessage response = null;
+ int actualRetryCount = 0;
+ int totalWaitTimeSec = 0;
DurableCallable<ChatMessage> callable =
new DurableCallable<>() {
@@ -253,12 +308,19 @@ public class ChatModelAction {
if (attempt == numRetries) {
throw e;
}
+ actualRetryCount = attempt + 1;
+ int currentWaitSec = retryWaitIntervalSec * (1 <<
(actualRetryCount - 1));
LOG.warn(
- "Chat request {} failed with error: {}, retrying
{} / {}.",
+ "Chat request {} failed with error: {}, retrying
{} / {}, waiting {} s.",
initialRequestId,
e,
- attempt,
- numRetries);
+ actualRetryCount,
+ numRetries,
+ currentWaitSec);
+ if (currentWaitSec > 0) {
+ Thread.sleep(currentWaitSec * 1000L);
+ totalWaitTimeSec += currentWaitSec;
+ }
} else {
LOG.debug(
"Chat request {} failed, the input chat messages
are {}.",
@@ -269,10 +331,23 @@ public class ChatModelAction {
}
}
+ if (actualRetryCount > 0) {
+ accumulateRetryStats(
+ ctx.getSensoryMemory(), initialRequestId,
actualRetryCount, totalWaitTimeSec);
+ }
+
if (!Objects.requireNonNull(response).getToolCalls().isEmpty()) {
handleToolCalls(response, initialRequestId, model, messages,
outputSchema, ctx);
} else {
- ctx.sendEvent(new ChatResponseEvent(initialRequestId, response));
+ Map<String, Long> retryStats =
getRetryStats(ctx.getSensoryMemory(), initialRequestId);
+ int totalRetryCount = retryStats.get(TOTAL_RETRY_COUNT).intValue();
+ int totalRetryWaitSec =
retryStats.get(TOTAL_RETRY_WAIT_SEC).intValue();
+
+ recordRetryMetrics(ctx, chatModel.getConnection(),
totalRetryCount, totalRetryWaitSec);
+
+ ctx.sendEvent(
+ new ChatResponseEvent(
+ initialRequestId, response, totalRetryCount,
totalRetryWaitSec));
}
}
diff --git
a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java
b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java
new file mode 100644
index 00000000..f9101e48
--- /dev/null
+++
b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.plan.actions;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.api.agents.Agent;
+import org.apache.flink.agents.api.agents.AgentExecutionOptions;
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
+import org.apache.flink.agents.api.context.DurableCallable;
+import org.apache.flink.agents.api.context.MemoryObject;
+import org.apache.flink.agents.api.context.RunnerContext;
+import org.apache.flink.agents.api.event.ChatResponseEvent;
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.metrics.Counter;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.*;
+
+/** Tests for retry behavior in {@link ChatModelAction}. */
+class ChatModelActionRetryTest {
+
+ @Mock private RunnerContext mockCtx;
+
+ @Mock private BaseChatModelSetup mockChatModel;
+
+ @Mock private FlinkAgentsMetricGroup mockActionMetricGroup;
+
+ @Mock private FlinkAgentsMetricGroup mockModelMetricGroup;
+
+ @Mock private Counter mockRetryCountCounter;
+
+ @Mock private Counter mockRetryWaitSecCounter;
+
+ private MemoryObject sensoryMemory;
+ private List<Event> sentEvents;
+ private AutoCloseable mocks;
+
+ @BeforeEach
+ void setUp() throws Exception {
+ mocks = MockitoAnnotations.openMocks(this);
+ sentEvents = new ArrayList<>();
+ sensoryMemory = createStatefulMemoryObject();
+
+ // Wire up ChatModel
+ when(mockChatModel.getConnection()).thenReturn("test-connection");
+
+ // Wire up RunnerContext
+ when(mockCtx.getResource(anyString(), eq(ResourceType.CHAT_MODEL)))
+ .thenReturn(mockChatModel);
+ when(mockCtx.getSensoryMemory()).thenReturn(sensoryMemory);
+ when(mockCtx.getActionMetricGroup()).thenReturn(mockActionMetricGroup);
+ doAnswer(inv ->
sentEvents.add(inv.getArgument(0))).when(mockCtx).sendEvent(any());
+ when(mockCtx.<ChatMessage>durableExecute(any()))
+ .thenAnswer(inv ->
inv.<DurableCallable<ChatMessage>>getArgument(0).call());
+
+ // Wire up metric group chain
+
when(mockActionMetricGroup.getSubGroup(anyString())).thenReturn(mockModelMetricGroup);
+
when(mockModelMetricGroup.getCounter("retryCount")).thenReturn(mockRetryCountCounter);
+
when(mockModelMetricGroup.getCounter("retryWaitSec")).thenReturn(mockRetryWaitSecCounter);
+ }
+
+ @AfterEach
+ void tearDown() throws Exception {
+ if (mocks != null) {
+ mocks.close();
+ }
+ }
+
+ @Test
+ void chatSucceedsWithoutRetry_retryCountIsZero() throws Exception {
+ configureRetryStrategy(3, 1);
+ when(mockChatModel.chat(any(), any()))
+ .thenReturn(new ChatMessage(MessageRole.ASSISTANT, "hello"));
+
+ UUID requestId = UUID.randomUUID();
+ ChatModelAction.chat(
+ requestId,
+ "test-model",
+ List.of(new ChatMessage(MessageRole.USER, "hi")),
+ null,
+ mockCtx);
+
+ assertThat(sentEvents).hasSize(1);
+ ChatResponseEvent responseEvent = (ChatResponseEvent)
sentEvents.get(0);
+ assertThat(responseEvent.getRetryCount()).isEqualTo(0);
+ assertThat(responseEvent.getTotalRetryWaitSec()).isEqualTo(0);
+
+ // No retry metrics should be recorded
+ verify(mockActionMetricGroup, never()).getSubGroup(anyString());
+ }
+
+ @Test
+ void chatRetriesWithExponentialBackoff() throws Exception {
+ // 1 second base interval; fail once then succeed -> wait 1s (1 * 2^0)
+ configureRetryStrategy(3, 1);
+
+ AtomicInteger callCount = new AtomicInteger(0);
+ when(mockChatModel.chat(any(), any()))
+ .thenAnswer(
+ inv -> {
+ int count = callCount.incrementAndGet();
+ if (count <= 1) {
+ throw new RuntimeException("transient error");
+ }
+ return new ChatMessage(MessageRole.ASSISTANT,
"success");
+ });
+
+ UUID requestId = UUID.randomUUID();
+
+ long startTime = System.currentTimeMillis();
+ ChatModelAction.chat(
+ requestId,
+ "test-model",
+ List.of(new ChatMessage(MessageRole.USER, "hi")),
+ null,
+ mockCtx);
+ long elapsed = System.currentTimeMillis() - startTime;
+
+ assertThat(sentEvents).hasSize(1);
+ ChatResponseEvent responseEvent = (ChatResponseEvent)
sentEvents.get(0);
+ assertThat(responseEvent.getRetryCount()).isEqualTo(1);
+ // Exponential backoff: 1000ms (1s * 2^0) total
+ // 1 retry with 1s interval = 1s total
+ assertThat(responseEvent.getTotalRetryWaitSec()).isEqualTo(1);
+ assertThat(elapsed).isGreaterThanOrEqualTo(1000L);
+
+ // Verify metrics recorded under connection name
+
verify(mockActionMetricGroup).getSubGroup(mockChatModel.getConnection());
+ verify(mockRetryCountCounter).inc(1);
+ verify(mockRetryWaitSecCounter).inc(1);
+ }
+
+ @Test
+ void chatExhaustsRetriesAndThrows() {
+ configureRetryStrategy(2, 0);
+
+ when(mockChatModel.chat(any(), any())).thenThrow(new
RuntimeException("persistent error"));
+
+ UUID requestId = UUID.randomUUID();
+
+ assertThatThrownBy(
+ () ->
+ ChatModelAction.chat(
+ requestId,
+ "test-model",
+ List.of(new
ChatMessage(MessageRole.USER, "hi")),
+ null,
+ mockCtx))
+ .isInstanceOf(RuntimeException.class)
+ .hasMessage("persistent error");
+
+ assertThat(sentEvents).isEmpty();
+ }
+
+ @Test
+ void chatResponseEventDefaultConstructorHasZeroRetryInfo() {
+ UUID requestId = UUID.randomUUID();
+ ChatMessage msg = new ChatMessage(MessageRole.ASSISTANT, "test");
+ ChatResponseEvent event = new ChatResponseEvent(requestId, msg);
+
+ assertThat(event.getRetryCount()).isEqualTo(0);
+ assertThat(event.getTotalRetryWaitSec()).isEqualTo(0);
+ assertThat(event.getRequestId()).isEqualTo(requestId);
+ }
+
+ @Test
+ void chatResponseEventFullConstructorCarriesRetryInfo() {
+ UUID requestId = UUID.randomUUID();
+ ChatMessage msg = new ChatMessage(MessageRole.ASSISTANT, "test");
+ ChatResponseEvent event = new ChatResponseEvent(requestId, msg, 5, 31);
+
+ assertThat(event.getRetryCount()).isEqualTo(5);
+ assertThat(event.getTotalRetryWaitSec()).isEqualTo(31);
+ }
+
+ @Test
+ void retryWaitIntervalDefaultValue() {
+
assertThat(AgentExecutionOptions.RETRY_WAIT_INTERVAL.getDefaultValue()).isEqualTo(1);
+ }
+
+ // --- Helper methods ---
+
+ private void configureRetryStrategy(int maxRetries, int waitIntervalSec) {
+ when(mockCtx.getConfig())
+ .thenAnswer(
+ inv -> {
+ // Return a mock ReadableConfiguration
+ return new
org.apache.flink.agents.api.configuration
+ .ReadableConfiguration() {
+ @Override
+ @SuppressWarnings("unchecked")
+ public <T> T get(
+
org.apache.flink.agents.api.configuration.ConfigOption<T>
+ option) {
+ if (option ==
AgentExecutionOptions.ERROR_HANDLING_STRATEGY) {
+ return (T)
Agent.ErrorHandlingStrategy.RETRY;
+ }
+ if (option ==
AgentExecutionOptions.MAX_RETRIES) {
+ return (T) Integer.valueOf(maxRetries);
+ }
+ if (option ==
AgentExecutionOptions.RETRY_WAIT_INTERVAL) {
+ return (T)
Integer.valueOf(waitIntervalSec);
+ }
+ if (option ==
AgentExecutionOptions.CHAT_ASYNC) {
+ return (T) Boolean.FALSE;
+ }
+ return option.getDefaultValue();
+ }
+
+ @Override
+ public Integer getInt(String key, Integer
defaultValue) {
+ return defaultValue;
+ }
+
+ @Override
+ public Long getLong(String key, Long
defaultValue) {
+ return defaultValue;
+ }
+
+ @Override
+ public Float getFloat(String key, Float
defaultValue) {
+ return defaultValue;
+ }
+
+ @Override
+ public Double getDouble(String key, Double
defaultValue) {
+ return defaultValue;
+ }
+
+ @Override
+ public Boolean getBool(String key, Boolean
defaultValue) {
+ return defaultValue;
+ }
+
+ @Override
+ public String getStr(String key, String
defaultValue) {
+ return defaultValue;
+ }
+ };
+ });
+ }
+
+ /**
+ * Creates a stateful MemoryObject backed by a HashMap, supporting
isExist/get/set operations
+ * needed by the retry stats accumulation logic.
+ */
+ private static MemoryObject createStatefulMemoryObject() {
+ Map<String, Object> store = new HashMap<>();
+
+ MemoryObject memoryObject = mock(MemoryObject.class);
+
+ when(memoryObject.isExist(anyString()))
+ .thenAnswer(inv ->
store.containsKey(inv.<String>getArgument(0)));
+
+ try {
+ when(memoryObject.get(anyString()))
+ .thenAnswer(
+ inv -> {
+ String path = inv.getArgument(0);
+ Object value = store.get(path);
+ if (value == null) {
+ throw new Exception("Path not found: " +
path);
+ }
+ MemoryObject valueObj =
mock(MemoryObject.class);
+ when(valueObj.getValue()).thenReturn(value);
+ return valueObj;
+ });
+
+ when(memoryObject.set(anyString(), any()))
+ .thenAnswer(
+ inv -> {
+ store.put(inv.getArgument(0),
inv.getArgument(1));
+ return null;
+ });
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+
+ return memoryObject;
+ }
+}
diff --git a/python/flink_agents/api/core_options.py
b/python/flink_agents/api/core_options.py
index eb503de2..9245e78c 100644
--- a/python/flink_agents/api/core_options.py
+++ b/python/flink_agents/api/core_options.py
@@ -107,6 +107,12 @@ class AgentExecutionOptions:
default=3,
)
+ RETRY_WAIT_INTERVAL = ConfigOption(
+ key="retry-wait-interval",
+ config_type=int,
+ default=1,
+ )
+
NUM_ASYNC_THREADS = ConfigOption(
key="num-async-threads",
config_type=int,
diff --git a/python/flink_agents/api/events/chat_event.py
b/python/flink_agents/api/events/chat_event.py
index 5e1237f5..063548d0 100644
--- a/python/flink_agents/api/events/chat_event.py
+++ b/python/flink_agents/api/events/chat_event.py
@@ -50,7 +50,13 @@ class ChatResponseEvent(Event):
The id of the request event.
response : ChatMessage
The response from the chat model.
+ retry_count : int
+ The total number of retries across all tool call rounds.
+ total_retry_wait_sec : int
+ The total time spent waiting during retries in seconds.
"""
request_id: UUID
response: ChatMessage
+ retry_count: int = 0
+ total_retry_wait_sec: int = 0
diff --git a/python/flink_agents/api/runner_context.py
b/python/flink_agents/api/runner_context.py
index 4e3bc303..7dfd5437 100644
--- a/python/flink_agents/api/runner_context.py
+++ b/python/flink_agents/api/runner_context.py
@@ -169,24 +169,26 @@ class RunnerContext(ABC):
@property
@abstractmethod
- def agent_metric_group(self) -> MetricGroup:
+ def agent_metric_group(self) -> MetricGroup | None:
"""Get the metric group for flink agents.
Returns:
-------
- MetricGroup
+ MetricGroup | None
The metric group shared across all actions.
+ May return None when not running on Flink.
"""
@property
@abstractmethod
- def action_metric_group(self) -> MetricGroup:
+ def action_metric_group(self) -> MetricGroup | None:
"""Get the individual metric group dedicated for each action.
Returns:
-------
- MetricGroup
+ MetricGroup | None
The individual metric group specific to the current action.
+ May return None when not running on Flink.
"""
@abstractmethod
diff --git a/python/flink_agents/plan/actions/chat_model_action.py
b/python/flink_agents/plan/actions/chat_model_action.py
index 02e92ed8..210289db 100644
--- a/python/flink_agents/plan/actions/chat_model_action.py
+++ b/python/flink_agents/plan/actions/chat_model_action.py
@@ -18,6 +18,7 @@
import copy
import json
import logging
+import time
from typing import TYPE_CHECKING, Dict, List, cast
from uuid import UUID
@@ -47,6 +48,7 @@ if TYPE_CHECKING:
_TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT"
_TOOL_REQUEST_EVENT_CONTEXT = "_TOOL_REQUEST_EVENT_CONTEXT"
+_RETRY_STATS_CONTEXT = "_RETRY_STATS_CONTEXT"
_logger = logging.getLogger(__name__)
@@ -111,6 +113,49 @@ def _get_tool_request_event_context(
return removed_context
+def _accumulate_retry_stats(
+ sensory_memory: MemoryObject,
+ initial_request_id: UUID,
+ retry_count: int,
+ retry_wait_sec: int,
+) -> None:
+ """Accumulate retry stats for a given initial request across tool call
rounds."""
+ retry_stats_context = sensory_memory.get(_RETRY_STATS_CONTEXT) or {}
+ stats = retry_stats_context.get(initial_request_id, {
+ "total_retry_count": 0,
+ "total_retry_wait_sec": 0,
+ })
+ stats["total_retry_count"] += retry_count
+ stats["total_retry_wait_sec"] += retry_wait_sec
+ retry_stats_context[initial_request_id] = stats
+ sensory_memory.set(_RETRY_STATS_CONTEXT, retry_stats_context)
+
+
+def _get_retry_stats(
+ sensory_memory: MemoryObject,
+ initial_request_id: UUID,
+) -> dict:
+ """Get accumulated retry stats for a given initial request."""
+ retry_stats_context = sensory_memory.get(_RETRY_STATS_CONTEXT) or {}
+ return retry_stats_context.get(initial_request_id, {
+ "total_retry_count": 0,
+ "total_retry_wait_sec": 0,
+ })
+
+
+def _record_retry_metrics(
+ ctx: RunnerContext, model: str, retry_count: int, total_retry_wait_sec: int
+) -> None:
+ """Record retry metrics under the connection name if retries occurred."""
+ if retry_count <= 0:
+ return
+ metric_group = ctx.action_metric_group
+ if metric_group is not None:
+ model_group = metric_group.get_sub_group(model)
+ model_group.get_counter("retryCount").inc(retry_count)
+ model_group.get_counter("retryWaitSec").inc(total_retry_wait_sec)
+
+
def _handle_tool_calls(
response: ChatMessage,
initial_request_id: UUID,
@@ -185,10 +230,20 @@ async def chat(
error_handling_strategy =
ctx.config.get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY)
num_retries = 0
+ retry_wait_interval_sec = 0
if error_handling_strategy == ErrorHandlingStrategy.RETRY:
num_retries = max(0, ctx.config.get(AgentExecutionOptions.MAX_RETRIES))
+ retry_wait_interval_config = ctx.config.get(
+ AgentExecutionOptions.RETRY_WAIT_INTERVAL
+ )
+ retry_wait_interval_sec = (
+ max(0, retry_wait_interval_config) if retry_wait_interval_config
else 0
+ )
response = None
+ actual_retry_count = 0
+ total_wait_time_sec = 0
+
for attempt in range(num_retries + 1):
try:
if chat_async:
@@ -210,15 +265,29 @@ async def chat(
elif error_handling_strategy == ErrorHandlingStrategy.RETRY:
if attempt == num_retries:
raise
+ actual_retry_count = attempt + 1
+ current_wait_sec = retry_wait_interval_sec * (
+ 1 << (actual_retry_count - 1)
+ )
_logger.warning(
- f"Chat request {initial_request_id} failed with error:
{e}, retrying {attempt} / {num_retries}."
+ f"Chat request {initial_request_id} failed with error:
{e}, "
+ f"retrying {actual_retry_count} / {num_retries}, "
+ f"waiting {current_wait_sec} s."
)
+ if current_wait_sec > 0:
+ time.sleep(current_wait_sec)
+ total_wait_time_sec += current_wait_sec
else:
_logger.debug(
f"Chat request {initial_request_id} failed, the input chat
messages are {messages}."
)
raise
+ if actual_retry_count > 0:
+ _accumulate_retry_stats(
+ ctx.sensory_memory, initial_request_id, actual_retry_count,
total_wait_time_sec
+ )
+
if (
len(response.tool_calls) > 0
): # generate tool request event according tool calls in response
@@ -226,10 +295,18 @@ async def chat(
response, initial_request_id, model, messages, output_schema, ctx
)
else: # if there is no tool call generated, return chat response directly
+ retry_stats = _get_retry_stats(ctx.sensory_memory, initial_request_id)
+ total_retry_count = retry_stats["total_retry_count"]
+ total_retry_wait_sec = retry_stats["total_retry_wait_sec"]
+
+ _record_retry_metrics(ctx, chat_model.connection, total_retry_count,
total_retry_wait_sec)
+
ctx.send_event(
ChatResponseEvent(
request_id=initial_request_id,
response=response,
+ retry_count=total_retry_count,
+ total_retry_wait_sec=total_retry_wait_sec,
)
)
diff --git a/python/flink_agents/plan/tests/actions/__init__.py
b/python/flink_agents/plan/tests/actions/__init__.py
new file mode 100644
index 00000000..e154fadd
--- /dev/null
+++ b/python/flink_agents/plan/tests/actions/__init__.py
@@ -0,0 +1,17 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
diff --git
a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py
b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py
new file mode 100644
index 00000000..a6014323
--- /dev/null
+++ b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py
@@ -0,0 +1,250 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+"""Tests for retry behavior in chat_model_action."""
+
+import asyncio
+import time
+from typing import Any, Sequence
+from unittest.mock import MagicMock
+from uuid import uuid4
+
+import pytest
+
+from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.core_options import (
+ AgentExecutionOptions,
+ ErrorHandlingStrategy,
+)
+from flink_agents.api.events.chat_event import ChatResponseEvent
+from flink_agents.api.metric_group import Counter, MetricGroup
+from flink_agents.plan.actions.chat_model_action import chat
+
+# ============================================================================
+# Mock infrastructure
+# ============================================================================
+
+
+class _MockCounter(Counter):
+ """Mock counter that tracks inc calls."""
+
+ def __init__(self) -> None:
+ self._count = 0
+
+ def inc(self, n: int = 1) -> None:
+ self._count += n
+
+ def dec(self, n: int = 1) -> None:
+ self._count -= n
+
+ def get_count(self) -> int:
+ return self._count
+
+
+class _MockMetricGroup(MetricGroup):
+ """Mock metric group that tracks sub-groups and counters."""
+
+ def __init__(self) -> None:
+ self._sub_groups: dict[str, _MockMetricGroup] = {}
+ self._counters: dict[str, _MockCounter] = {}
+
+ def get_sub_group(self, name: str) -> "_MockMetricGroup":
+ if name not in self._sub_groups:
+ self._sub_groups[name] = _MockMetricGroup()
+ return self._sub_groups[name]
+
+ def get_counter(self, name: str) -> _MockCounter:
+ if name not in self._counters:
+ self._counters[name] = _MockCounter()
+ return self._counters[name]
+
+ def get_meter(self, name: str) -> Any:
+ return MagicMock()
+
+ def get_gauge(self, name: str) -> Any:
+ return MagicMock()
+
+ def get_histogram(self, name: str, window_size: int = 100) -> Any:
+ return MagicMock()
+
+
+class _MockMemoryObject:
+ """Simple dict-backed memory object for testing."""
+
+ def __init__(self) -> None:
+ self._store: dict[str, Any] = {}
+
+ def get(self, path: str) -> Any:
+ return self._store.get(path)
+
+ def set(self, path: str, value: Any) -> None:
+ self._store[path] = value
+
+
+def _create_mock_runner_context(
+ chat_model: Any,
+ max_retries: int = 3,
+ retry_wait_interval_sec: int = 1,
+) -> tuple[MagicMock, list, _MockMetricGroup, _MockMemoryObject]:
+ """Create a mock RunnerContext with configurable retry settings.
+
+ Returns (ctx, sent_events, action_metric_group, sensory_memory).
+ """
+ sent_events = []
+ metric_group = _MockMetricGroup()
+ sensory_memory = _MockMemoryObject()
+
+ config = MagicMock()
+ option_values = {
+ id(AgentExecutionOptions.ERROR_HANDLING_STRATEGY):
ErrorHandlingStrategy.RETRY,
+ id(AgentExecutionOptions.MAX_RETRIES): max_retries,
+ id(AgentExecutionOptions.RETRY_WAIT_INTERVAL): retry_wait_interval_sec,
+ id(AgentExecutionOptions.CHAT_ASYNC): False,
+ }
+ config.get = MagicMock(
+ side_effect=lambda option: option_values.get(id(option),
option.get_default_value())
+ )
+
+ ctx = MagicMock()
+ ctx.config = config
+ ctx.sensory_memory = sensory_memory
+ ctx.action_metric_group = metric_group
+ ctx.send_event = MagicMock(side_effect=lambda e: sent_events.append(e))
+ ctx.get_resource = MagicMock(return_value=chat_model)
+ ctx.durable_execute = MagicMock(side_effect=lambda fn, *args, **kwargs:
fn(*args, **kwargs))
+
+ return ctx, sent_events, metric_group, sensory_memory
+
+
+# ============================================================================
+# Tests
+# ============================================================================
+
+
+class TestChatModelActionRetry:
+ """Tests for retry behavior in chat()."""
+
+ def test_chat_succeeds_without_retry(self) -> None:
+ """No retry needed: retry_count=0, total_retry_wait_sec=0, no
metrics."""
+ chat_model = MagicMock()
+ chat_model.chat = MagicMock(
+ return_value=ChatMessage(role=MessageRole.ASSISTANT,
content="hello")
+ )
+
+ ctx, sent_events, metric_group, _ =
_create_mock_runner_context(chat_model)
+ request_id = uuid4()
+
+ asyncio.run(
+ chat(request_id, chat_model.connection,
[ChatMessage(role=MessageRole.USER, content="hi")], None, ctx)
+ )
+
+ assert len(sent_events) == 1
+ event = sent_events[0]
+ assert isinstance(event, ChatResponseEvent)
+ assert event.retry_count == 0
+ assert event.total_retry_wait_sec == 0
+
+ # No retry metrics should be recorded
+ assert len(metric_group._sub_groups) == 0
+
+ def test_chat_retries_with_exponential_backoff(self) -> None:
+ """Fail once then succeed: 1s interval, 1 retry -> wait 1s (1 *
2^0)."""
+ call_count = 0
+
+ def mock_chat(messages: Sequence[ChatMessage], **kwargs: Any) ->
ChatMessage:
+ nonlocal call_count
+ call_count += 1
+ if call_count <= 1:
+ err_msg = "transient error"
+ raise RuntimeError(err_msg)
+ return ChatMessage(role=MessageRole.ASSISTANT, content="success")
+
+ chat_model = MagicMock()
+ chat_model.chat = mock_chat
+
+ ctx, sent_events, metric_group, _ = _create_mock_runner_context(
+ chat_model, max_retries=3, retry_wait_interval_sec=1
+ )
+ request_id = uuid4()
+
+ start = time.monotonic()
+ asyncio.run(
+ chat(request_id, "test-model", [ChatMessage(role=MessageRole.USER,
content="hi")], None, ctx)
+ )
+ elapsed = time.monotonic() - start
+
+ assert len(sent_events) == 1
+ event = sent_events[0]
+ assert isinstance(event, ChatResponseEvent)
+ assert event.retry_count == 1
+ # 1s config. Exponential: 1s (2^0) = 1s total
+ assert event.total_retry_wait_sec == 1
+ assert elapsed >= 1.0
+
+ # Verify metrics recorded under connection name
+ model_group = metric_group.get_sub_group(chat_model.connection)
+ assert model_group.get_counter("retryCount").get_count() == 1
+ assert model_group.get_counter("retryWaitSec").get_count() == 1
+
+ def test_chat_exhausts_retries_and_raises(self) -> None:
+ """All retries exhausted: exception raised, no event sent."""
+ chat_model = MagicMock()
+ chat_model.chat = MagicMock(side_effect=RuntimeError("persistent
error"))
+
+ ctx, sent_events, _, _ = _create_mock_runner_context(
+ chat_model, max_retries=2, retry_wait_interval_sec=0
+ )
+ request_id = uuid4()
+
+ with pytest.raises(RuntimeError, match="persistent error"):
+ asyncio.run(
+ chat(request_id, "test-model",
[ChatMessage(role=MessageRole.USER, content="hi")], None, ctx)
+ )
+
+ assert len(sent_events) == 0
+
+
+class TestChatResponseEventRetryFields:
+ """Tests for ChatResponseEvent retry fields."""
+
+ def test_default_retry_fields(self) -> None:
+ """Default construction has retry_count=0, total_retry_wait_sec=0."""
+ event = ChatResponseEvent(
+ request_id=uuid4(),
+ response=ChatMessage(role=MessageRole.ASSISTANT, content="test"),
+ )
+ assert event.retry_count == 0
+ assert event.total_retry_wait_sec == 0
+
+ def test_with_retry_fields(self) -> None:
+ """Full construction carries retry info."""
+ event = ChatResponseEvent(
+ request_id=uuid4(),
+ response=ChatMessage(role=MessageRole.ASSISTANT, content="test"),
+ retry_count=5,
+ total_retry_wait_sec=31,
+ )
+ assert event.retry_count == 5
+ assert event.total_retry_wait_sec == 31
+
+
+class TestRetryWaitIntervalConfig:
+ """Tests for RETRY_WAIT_INTERVAL configuration."""
+
+ def test_default_value(self) -> None:
+ """Default value is 1 second."""
+ assert AgentExecutionOptions.RETRY_WAIT_INTERVAL.get_default_value()
== 1
diff --git a/python/flink_agents/runtime/local_runner.py
b/python/flink_agents/runtime/local_runner.py
index b8eaf3f0..e4124439 100644
--- a/python/flink_agents/runtime/local_runner.py
+++ b/python/flink_agents/runtime/local_runner.py
@@ -167,17 +167,19 @@ class LocalRunnerContext(RunnerContext):
@property
@override
- def agent_metric_group(self) -> MetricGroup:
+ def agent_metric_group(self) -> None:
# TODO: Support metric mechanism for local agent execution.
err_msg = "Metric mechanism is not supported for local agent execution
yet."
- raise NotImplementedError(err_msg)
+ logger.warning(err_msg)
+ return
@property
@override
- def action_metric_group(self) -> MetricGroup:
+ def action_metric_group(self) -> None:
# TODO: Support metric mechanism for local agent execution.
err_msg = "Metric mechanism is not supported for local agent execution
yet."
- raise NotImplementedError(err_msg)
+ logger.warning(err_msg)
+ return
@override
def durable_execute(