This is an automated email from the ASF dual-hosted git repository.
xintongsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new e98af528 [api][plan][integrations] Record built-in chat token metrics
outside the async call boundary (#712)
e98af528 is described below
commit e98af52851a580459da0f52c1123b942779c9182
Author: Weiqing Yang <[email protected]>
AuthorDate: Sun May 31 01:16:27 2026 -0700
[api][plan][integrations] Record built-in chat token metrics outside the
async call boundary (#712)
---
.../api/chat/model/BaseChatModelConnection.java | 19 -----
.../agents/api/chat/model/BaseChatModelSetup.java | 20 ++++-
.../flink/agents/api/context/RunnerContext.java | 8 ++
...ava => BaseChatModelSetupTokenMetricsTest.java} | 76 +++++++------------
.../anthropic/AnthropicChatModelConnection.java | 7 +-
.../azureai/AzureAIChatModelConnection.java | 9 ++-
.../bedrock/BedrockChatModelConnection.java | 10 ++-
.../ollama/OllamaChatModelConnection.java | 7 +-
.../openai/AzureOpenAIChatModelConnection.java | 9 ++-
.../openai/OpenAICompletionsConnection.java | 11 +--
.../openai/OpenAIResponsesModelConnection.java | 8 +-
.../flink/agents/plan/actions/ChatModelAction.java | 18 +++++
.../agents/plan/actions/ChatModelActionTest.java | 85 ++++++++++++++++++++++
13 files changed, 190 insertions(+), 97 deletions(-)
diff --git
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
index a6dccc44..7ce69b6d 100644
---
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
+++
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
@@ -19,7 +19,6 @@
package org.apache.flink.agents.api.chat.model;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
-import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceContext;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
@@ -56,22 +55,4 @@ public abstract class BaseChatModelConnection extends
Resource {
*/
public abstract ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object>
arguments);
-
- /**
- * Record token usage metrics for the given model.
- *
- * @param modelName the name of the model used
- * @param promptTokens the number of prompt tokens
- * @param completionTokens the number of completion tokens
- */
- protected void recordTokenMetrics(String modelName, long promptTokens,
long completionTokens) {
- FlinkAgentsMetricGroup metricGroup = getMetricGroup();
- if (metricGroup == null) {
- return;
- }
-
- FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName);
- modelGroup.getCounter("promptTokens").inc(promptTokens);
- modelGroup.getCounter("completionTokens").inc(completionTokens);
- }
}
diff --git
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
index 858cd069..3a9c7b2d 100644
---
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
+++
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
@@ -20,6 +20,7 @@ package org.apache.flink.agents.api.chat.model;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.prompt.Prompt;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceContext;
@@ -107,6 +108,23 @@ public abstract class BaseChatModelSetup extends Resource {
public abstract Map<String, Object> getParameters();
+ /**
+ * Record token usage metrics for the given model on this setup's bound
metric group.
+ *
+ * @param modelName the name of the model used
+ * @param promptTokens the number of prompt tokens
+ * @param completionTokens the number of completion tokens
+ */
+ public void recordTokenMetrics(String modelName, long promptTokens, long
completionTokens) {
+ FlinkAgentsMetricGroup metricGroup = getMetricGroup();
+ if (metricGroup == null) {
+ return;
+ }
+ FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName);
+ modelGroup.getCounter("promptTokens").inc(promptTokens);
+ modelGroup.getCounter("completionTokens").inc(completionTokens);
+ }
+
public ChatMessage chat(List<ChatMessage> messages) {
return this.chat(messages, Collections.emptyMap(),
Collections.emptyMap());
}
@@ -118,8 +136,6 @@ public abstract class BaseChatModelSetup extends Resource {
Preconditions.checkNotNull(
connection,
"Connection is not initialized. Ensure open() is called before
chat().");
- // Pass metric group to connection for token usage tracking
- connection.setMetricGroup(getMetricGroup());
// Format input messages if set prompt.
if (this.prompt != null) {
diff --git
a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
index 06cd2b38..c3e5d19b 100644
--- a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
+++ b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
@@ -67,6 +67,10 @@ public interface RunnerContext {
/**
* Gets the metric group for Flink Agents.
*
+ * <p>The returned group must only be accessed from the operator/mailbox
(action) thread, not
+ * from inside a {@link #durableExecute} or {@link #durableExecuteAsync}
callable, which runs on
+ * a separate thread pool.
+ *
* @return the metric group shared across all actions.
*/
FlinkAgentsMetricGroup getAgentMetricGroup();
@@ -74,6 +78,10 @@ public interface RunnerContext {
/**
* Gets the individual metric group dedicated for each action.
*
+ * <p>The returned group must only be accessed from the operator/mailbox
(action) thread, not
+ * from inside a {@link #durableExecute} or {@link #durableExecuteAsync}
callable, which runs on
+ * a separate thread pool.
+ *
* @return the individual metric group specific to the current action.
*/
FlinkAgentsMetricGroup getActionMetricGroup();
diff --git
a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java
similarity index 59%
rename from
api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
rename to
api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java
index 43654944..cde9f683 100644
---
a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
+++
b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java
@@ -18,71 +18,60 @@
package org.apache.flink.agents.api.chat.model;
-import org.apache.flink.agents.api.chat.messages.ChatMessage;
-import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.resource.ResourceContext;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
-import org.apache.flink.agents.api.tools.Tool;
import org.apache.flink.metrics.Counter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import java.util.Collections;
-import java.util.List;
import java.util.Map;
-import static org.junit.jupiter.api.Assertions.*;
-import static org.mockito.Mockito.*;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
-/** Test cases for BaseChatModelConnection token metrics functionality. */
-class BaseChatModelConnectionTokenMetricsTest {
+/** Test cases for BaseChatModelSetup token metrics functionality. */
+class BaseChatModelSetupTokenMetricsTest {
- private TestChatModelConnection connection;
+ private TestChatModelSetup setup;
private FlinkAgentsMetricGroup mockMetricGroup;
private FlinkAgentsMetricGroup mockModelGroup;
private Counter mockPromptTokensCounter;
private Counter mockCompletionTokensCounter;
- /** Test implementation of BaseChatModelConnection for testing purposes. */
- private static class TestChatModelConnection extends
BaseChatModelConnection {
+ /** Test implementation of BaseChatModelSetup for testing purposes. */
+ private static class TestChatModelSetup extends BaseChatModelSetup {
- public TestChatModelConnection(
- ResourceDescriptor descriptor, ResourceContext
resourceContext) {
+ public TestChatModelSetup(ResourceDescriptor descriptor,
ResourceContext resourceContext) {
super(descriptor, resourceContext);
}
@Override
- public ChatMessage chat(
- List<ChatMessage> messages, List<Tool> tools, Map<String,
Object> arguments) {
- // Simple test implementation
- return new ChatMessage(MessageRole.ASSISTANT, "Test response");
- }
-
- // Expose protected method for testing
- public void testRecordTokenMetrics(
- String modelName, long promptTokens, long completionTokens) {
- recordTokenMetrics(modelName, promptTokens, completionTokens);
+ public Map<String, Object> getParameters() {
+ return Collections.emptyMap();
}
}
@BeforeEach
void setUp() {
- connection =
- new TestChatModelConnection(
+ setup =
+ new TestChatModelSetup(
new ResourceDescriptor(
- TestChatModelConnection.class.getName(),
Collections.emptyMap()),
+ TestChatModelSetup.class.getName(),
Collections.emptyMap()),
null);
- // Create mock objects
mockMetricGroup = mock(FlinkAgentsMetricGroup.class);
mockModelGroup = mock(FlinkAgentsMetricGroup.class);
mockPromptTokensCounter = mock(Counter.class);
mockCompletionTokensCounter = mock(Counter.class);
- // Set up mock behavior
when(mockMetricGroup.getSubGroup("gpt-4")).thenReturn(mockModelGroup);
when(mockModelGroup.getCounter("promptTokens")).thenReturn(mockPromptTokensCounter);
when(mockModelGroup.getCounter("completionTokens")).thenReturn(mockCompletionTokensCounter);
@@ -91,13 +80,10 @@ class BaseChatModelConnectionTokenMetricsTest {
@Test
@DisplayName("Test token metrics are recorded when metric group is set")
void testRecordTokenMetricsWithMetricGroup() {
- // Set the metric group
- connection.setMetricGroup(mockMetricGroup);
+ setup.setMetricGroup(mockMetricGroup);
- // Record token metrics
- connection.testRecordTokenMetrics("gpt-4", 100, 50);
+ setup.recordTokenMetrics("gpt-4", 100, 50);
- // Verify the metrics were recorded
verify(mockMetricGroup).getSubGroup("gpt-4");
verify(mockModelGroup).getCounter("promptTokens");
verify(mockModelGroup).getCounter("completionTokens");
@@ -108,22 +94,16 @@ class BaseChatModelConnectionTokenMetricsTest {
@Test
@DisplayName("Test token metrics are not recorded when metric group is
null")
void testRecordTokenMetricsWithoutMetricGroup() {
- // Do not set metric group (should be null by default)
+ assertDoesNotThrow(() -> setup.recordTokenMetrics("gpt-4", 100, 50));
- // Record token metrics - should not throw
- assertDoesNotThrow(() -> connection.testRecordTokenMetrics("gpt-4",
100, 50));
-
- // No metrics should be recorded
verifyNoInteractions(mockMetricGroup);
}
@Test
- @DisplayName("Test token metrics hierarchy: actionMetricGroup -> modelName
-> counters")
+ @DisplayName("Test token metrics hierarchy: metricGroup -> modelName ->
counters")
void testTokenMetricsHierarchy() {
- // Set the metric group
- connection.setMetricGroup(mockMetricGroup);
+ setup.setMetricGroup(mockMetricGroup);
- // Record token metrics for different models
FlinkAgentsMetricGroup mockGpt35Group =
mock(FlinkAgentsMetricGroup.class);
Counter mockGpt35PromptCounter = mock(Counter.class);
Counter mockGpt35CompletionCounter = mock(Counter.class);
@@ -132,13 +112,9 @@ class BaseChatModelConnectionTokenMetricsTest {
when(mockGpt35Group.getCounter("promptTokens")).thenReturn(mockGpt35PromptCounter);
when(mockGpt35Group.getCounter("completionTokens")).thenReturn(mockGpt35CompletionCounter);
- // Record for gpt-4
- connection.testRecordTokenMetrics("gpt-4", 100, 50);
-
- // Record for gpt-3.5-turbo
- connection.testRecordTokenMetrics("gpt-3.5-turbo", 200, 100);
+ setup.recordTokenMetrics("gpt-4", 100, 50);
+ setup.recordTokenMetrics("gpt-3.5-turbo", 200, 100);
- // Verify each model has its own counters
verify(mockMetricGroup).getSubGroup("gpt-4");
verify(mockMetricGroup).getSubGroup("gpt-3.5-turbo");
verify(mockPromptTokensCounter).inc(100);
@@ -148,8 +124,8 @@ class BaseChatModelConnectionTokenMetricsTest {
}
@Test
- @DisplayName("Test resource type is CHAT_MODEL_CONNECTION")
+ @DisplayName("Test resource type is CHAT_MODEL")
void testResourceType() {
- assertEquals(ResourceType.CHAT_MODEL_CONNECTION,
connection.getResourceType());
+ assertEquals(ResourceType.CHAT_MODEL, setup.getResourceType());
}
}
diff --git
a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
index 248b464e..93691d3f 100644
---
a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
+++
b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
@@ -133,7 +133,7 @@ public class AnthropicChatModelConnection extends
BaseChatModelConnection {
Message response = client.messages().create(params);
ChatMessage result = convertResponse(response, jsonPrefillApplied);
- // Record token metrics
+ // Stash token usage
String modelName = null;
if (arguments != null && arguments.get("model") != null) {
modelName = arguments.get("model").toString();
@@ -142,8 +142,9 @@ public class AnthropicChatModelConnection extends
BaseChatModelConnection {
modelName = this.defaultModel;
}
if (modelName != null && !modelName.isBlank()) {
- recordTokenMetrics(
- modelName, response.usage().inputTokens(),
response.usage().outputTokens());
+ result.getExtraArgs().put("model_name", modelName);
+ result.getExtraArgs().put("promptTokens",
response.usage().inputTokens());
+ result.getExtraArgs().put("completionTokens",
response.usage().outputTokens());
}
return result;
diff --git
a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
index 3051ecf4..318b5457 100644
---
a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
+++
b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
@@ -187,12 +187,15 @@ public class AzureAIChatModelConnection extends
BaseChatModelConnection {
chatMessage.setToolCalls(convertedToolCalls);
}
- // Record token metrics if model name is available
+ // Stash token usage if model name is available
if (modelName != null && !modelName.isBlank()) {
CompletionsUsage usage = completions.getUsage();
if (usage != null) {
- recordTokenMetrics(
- modelName, usage.getPromptTokens(),
usage.getCompletionTokens());
+ chatMessage.getExtraArgs().put("model_name", modelName);
+ chatMessage.getExtraArgs().put("promptTokens", (long)
usage.getPromptTokens());
+ chatMessage
+ .getExtraArgs()
+ .put("completionTokens", (long)
usage.getCompletionTokens());
}
}
diff --git
a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java
b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java
index 8327795a..58d23508 100644
---
a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java
+++
b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java
@@ -178,12 +178,14 @@ public class BedrockChatModelConnection extends
BaseChatModelConnection {
ConverseResponse response =
retryExecutor.execute(() -> client.converse(request),
"BedrockConverse");
+ ChatMessage result = convertResponse(response);
if (response.usage() != null) {
- recordTokenMetrics(
- modelId, response.usage().inputTokens(),
response.usage().outputTokens());
+ result.getExtraArgs().put("model_name", modelId);
+ result.getExtraArgs().put("promptTokens",
response.usage().inputTokens().longValue());
+ result.getExtraArgs()
+ .put("completionTokens",
response.usage().outputTokens().longValue());
}
-
- return convertResponse(response);
+ return result;
}
private static boolean isRetryable(Exception e) {
diff --git
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
index 2cda1ea4..4c617455 100644
---
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
+++
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
@@ -224,13 +224,14 @@ public class OllamaChatModelConnection extends
BaseChatModelConnection {
chatMessage.setToolCalls(toolCalls);
}
- // Record token metrics if model name is available
+ // Stash token usage if model name is available
if (modelName != null && !modelName.isBlank()) {
Integer promptTokens = ollamaChatResponse.getPromptEvalCount();
Integer completionTokens = ollamaChatResponse.getEvalCount();
if (promptTokens != null && completionTokens != null) {
- recordTokenMetrics(
- modelName, promptTokens.longValue(),
completionTokens.longValue());
+ extraArgs.put("model_name", modelName);
+ extraArgs.put("promptTokens", promptTokens.longValue());
+ extraArgs.put("completionTokens",
completionTokens.longValue());
}
}
diff --git
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java
index 7d6b5c2c..6567bd2b 100644
---
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java
+++
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java
@@ -216,10 +216,11 @@ public class AzureOpenAIChatModelConnection extends
BaseChatModelConnection {
if (modelOfAzureDeployment != null
&& !modelOfAzureDeployment.isBlank()
&& completion.usage().isPresent()) {
- recordTokenMetrics(
- modelOfAzureDeployment,
- completion.usage().get().promptTokens(),
- completion.usage().get().completionTokens());
+ response.getExtraArgs().put("model_name",
modelOfAzureDeployment);
+ response.getExtraArgs()
+ .put("promptTokens",
completion.usage().get().promptTokens());
+ response.getExtraArgs()
+ .put("completionTokens",
completion.usage().get().completionTokens());
}
return response;
diff --git
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java
index e4947e8f..2a0b78fe 100644
---
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java
+++
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java
@@ -129,17 +129,18 @@ public class OpenAICompletionsConnection extends
BaseChatModelConnection {
OpenAIChatCompletionsUtils.convertFromOpenAIMessage(
completion.choices().get(0).message());
- // Record token metrics
+ // Stash token usage
if (completion.usage().isPresent()) {
String modelName = arguments != null ? (String)
arguments.get("model") : null;
if (modelName == null || modelName.isBlank()) {
modelName = this.defaultModel;
}
if (modelName != null && !modelName.isBlank()) {
- recordTokenMetrics(
- modelName,
- completion.usage().get().promptTokens(),
- completion.usage().get().completionTokens());
+ response.getExtraArgs().put("model_name", modelName);
+ response.getExtraArgs()
+ .put("promptTokens",
completion.usage().get().promptTokens());
+ response.getExtraArgs()
+ .put("completionTokens",
completion.usage().get().completionTokens());
}
}
diff --git
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java
index 9b0d143e..00b5f9b6 100644
---
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java
+++
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java
@@ -140,10 +140,10 @@ public class OpenAIResponsesModelConnection extends
BaseChatModelConnection {
modelName = this.defaultModel;
}
if (modelName != null && !modelName.isBlank()) {
- recordTokenMetrics(
- modelName,
- response.usage().get().inputTokens(),
- response.usage().get().outputTokens());
+ result.getExtraArgs().put("model_name", modelName);
+ result.getExtraArgs().put("promptTokens",
response.usage().get().inputTokens());
+ result.getExtraArgs()
+ .put("completionTokens",
response.usage().get().outputTokens());
}
}
diff --git
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
index 5805eceb..504b4fb9 100644
---
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
+++
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
@@ -182,6 +182,23 @@ public class ChatModelAction {
}
}
+ static void recordChatTokenMetrics(BaseChatModelSetup chatModel,
ChatMessage response) {
+ Map<String, Object> extraArgs = response.getExtraArgs();
+ Object modelName = extraArgs.get("model_name");
+ Object promptTokens = extraArgs.get("promptTokens");
+ Object completionTokens = extraArgs.get("completionTokens");
+ if (modelName != null
+ && !modelName.toString().isEmpty()
+ && promptTokens instanceof Number
+ && completionTokens instanceof Number) {
+ long prompt = ((Number) promptTokens).longValue();
+ long completion = ((Number) completionTokens).longValue();
+ if (prompt > 0 && completion > 0) {
+ chatModel.recordTokenMetrics(modelName.toString(), prompt,
completion);
+ }
+ }
+ }
+
private static void handleToolCalls(
ChatMessage response,
UUID initialRequestId,
@@ -355,6 +372,7 @@ public class ChatModelAction {
chatAsync
? ctx.durableExecuteAsync(callable)
: ctx.durableExecute(callable);
+ recordChatTokenMetrics(chatModel, response);
// only generate structured output for final response.
if (outputSchema != null && response.getToolCalls().isEmpty())
{
response = generateStructuredOutput(response,
outputSchema);
diff --git
a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
index d7f11785..85c263a6 100644
---
a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
+++
b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
@@ -17,13 +17,98 @@
*/
package org.apache.flink.agents.plan.actions;
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
import org.junit.jupiter.api.Test;
+import java.util.HashMap;
+import java.util.Map;
+
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
/** Tests for {@link ChatModelAction}. */
class ChatModelActionTest {
+ private static ChatMessage responseWith(Map<String, Object> extraArgs) {
+ return new ChatMessage(MessageRole.ASSISTANT, "response", extraArgs);
+ }
+
+ @Test
+ void testRecordChatTokenMetricsRecordsWhenAllKeysPresent() {
+ BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+ Map<String, Object> extraArgs = new HashMap<>();
+ extraArgs.put("model_name", "m");
+ extraArgs.put("promptTokens", 100L);
+ extraArgs.put("completionTokens", 50L);
+
+ ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+ verify(setup).recordTokenMetrics("m", 100L, 50L);
+ }
+
+ @Test
+ void testRecordChatTokenMetricsHandlesIntegerTokenValues() {
+ BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+ Map<String, Object> extraArgs = new HashMap<>();
+ extraArgs.put("model_name", "m");
+ extraArgs.put("promptTokens", 100);
+ extraArgs.put("completionTokens", 50);
+
+ ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+ verify(setup).recordTokenMetrics("m", 100L, 50L);
+ }
+
+ @Test
+ void testRecordChatTokenMetricsSkipsWhenTokenValueNonNumeric() {
+ BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+ Map<String, Object> extraArgs = new HashMap<>();
+ extraArgs.put("model_name", "m");
+ extraArgs.put("promptTokens", "100");
+ extraArgs.put("completionTokens", 50L);
+
+ ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+ verify(setup, never()).recordTokenMetrics(anyString(), anyLong(),
anyLong());
+ }
+
+ @Test
+ void testRecordChatTokenMetricsSkipsWhenKeyMissing() {
+ BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+ Map<String, Object> extraArgs = new HashMap<>();
+ extraArgs.put("model_name", "m");
+ extraArgs.put("completionTokens", 50L);
+
+ ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+ verify(setup, never()).recordTokenMetrics(anyString(), anyLong(),
anyLong());
+ }
+
+ @Test
+ void testRecordChatTokenMetricsSkipsZeroTokensOrEmptyModel() {
+ BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+
+ Map<String, Object> zeroPrompt = new HashMap<>();
+ zeroPrompt.put("model_name", "m");
+ zeroPrompt.put("promptTokens", 0L);
+ zeroPrompt.put("completionTokens", 50L);
+ ChatModelAction.recordChatTokenMetrics(setup,
responseWith(zeroPrompt));
+
+ Map<String, Object> emptyModel = new HashMap<>();
+ emptyModel.put("model_name", "");
+ emptyModel.put("promptTokens", 100L);
+ emptyModel.put("completionTokens", 50L);
+ ChatModelAction.recordChatTokenMetrics(setup,
responseWith(emptyModel));
+
+ verify(setup, never()).recordTokenMetrics(anyString(), anyLong(),
anyLong());
+ }
+
@Test
void testCleanLlmResponseWithJsonBlock() {
String input = "```json\n{\"key\": \"value\"}\n```";