This is an automated email from the ASF dual-hosted git repository.
xintongsong pushed a commit to branch release-0.2
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/release-0.2 by this push:
new 6b3e3e77 [api][plan][integrations] Record built-in chat token metrics
outside the async call boundary (#712) (#725)
6b3e3e77 is described below
commit 6b3e3e773d006ba02e0d70d523cb37ff11136334
Author: Xintong Song <[email protected]>
AuthorDate: Tue Jun 2 09:30:38 2026 +0800
[api][plan][integrations] Record built-in chat token metrics outside the
async call boundary (#712) (#725)
Backport of #712 to release-0.2 with scope narrowed to the chat-model
connections that exist on this branch. The Bedrock, AzureOpenAI and
OpenAIResponses connection variants are not present on release-0.2 and
are intentionally excluded.
Move token-metric recording from the durable async callable (where it
crossed the operator/mailbox thread boundary) to the action thread:
- BaseChatModelSetup gains public recordTokenMetrics(String, long, long).
- BaseChatModelConnection.recordTokenMetrics(...) and the
connection.setMetricGroup(...) forwarding in BaseChatModelSetup are
removed.
- Each connection's chat() stashes model_name / promptTokens /
completionTokens into ChatMessage.extraArgs (Ollama, Anthropic, AzureAI,
OpenAI on release-0.2).
- ChatModelAction records via the setup after durableExecute(Async)
returns, before structured-output reassignment.
- RunnerContext.getAgentMetricGroup/getActionMetricGroup javadoc notes
that the returned group must only be accessed from the operator
thread, not inside a durable callable.
Emitted metric paths and counter names are unchanged. Records are
gated identically to Python: non-empty model name and both token
counts greater than zero; Integer/Long token values are accepted via
Number#longValue().
Tests:
- BaseChatModelConnectionTokenMetricsTest renamed and rewritten to
BaseChatModelSetupTokenMetricsTest (target moved from connection to
setup).
- New ChatModelActionTest covers recordChatTokenMetrics: records when
all keys present and positive; Integer-typed values still recorded;
skips on missing key, non-numeric value, zero token, or empty model
name.
Co-authored-by: Weiqing Yang <[email protected]>
---
.../api/chat/model/BaseChatModelConnection.java | 19 ----
.../agents/api/chat/model/BaseChatModelSetup.java | 21 +++-
.../flink/agents/api/context/RunnerContext.java | 8 ++
...ava => BaseChatModelSetupTokenMetricsTest.java} | 75 +++++---------
.../anthropic/AnthropicChatModelConnection.java | 7 +-
.../azureai/AzureAIChatModelConnection.java | 9 +-
.../ollama/OllamaChatModelConnection.java | 7 +-
.../openai/OpenAIChatModelConnection.java | 11 ++-
.../flink/agents/plan/actions/ChatModelAction.java | 18 ++++
.../agents/plan/actions/ChatModelActionTest.java | 110 +++++++++++++++++++++
10 files changed, 200 insertions(+), 85 deletions(-)
diff --git
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
index 7b70af0c..7254e37f 100644
---
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
+++
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
@@ -19,7 +19,6 @@
package org.apache.flink.agents.api.chat.model;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
-import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
@@ -57,22 +56,4 @@ public abstract class BaseChatModelConnection extends
Resource {
*/
public abstract ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object>
arguments);
-
- /**
- * Record token usage metrics for the given model.
- *
- * @param modelName the name of the model used
- * @param promptTokens the number of prompt tokens
- * @param completionTokens the number of completion tokens
- */
- protected void recordTokenMetrics(String modelName, long promptTokens,
long completionTokens) {
- FlinkAgentsMetricGroup metricGroup = getMetricGroup();
- if (metricGroup == null) {
- return;
- }
-
- FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName);
- modelGroup.getCounter("promptTokens").inc(promptTokens);
- modelGroup.getCounter("completionTokens").inc(completionTokens);
- }
}
diff --git
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
index 15aa953e..343dd5a0 100644
---
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
+++
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
@@ -20,6 +20,7 @@ package org.apache.flink.agents.api.chat.model;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.prompt.Prompt;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
@@ -51,6 +52,23 @@ public abstract class BaseChatModelSetup extends Resource {
public abstract Map<String, Object> getParameters();
+ /**
+ * Record token usage metrics for the given model on this setup's bound
metric group.
+ *
+ * @param modelName the name of the model used
+ * @param promptTokens the number of prompt tokens
+ * @param completionTokens the number of completion tokens
+ */
+ public void recordTokenMetrics(String modelName, long promptTokens, long
completionTokens) {
+ FlinkAgentsMetricGroup metricGroup = getMetricGroup();
+ if (metricGroup == null) {
+ return;
+ }
+ FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName);
+ modelGroup.getCounter("promptTokens").inc(promptTokens);
+ modelGroup.getCounter("completionTokens").inc(completionTokens);
+ }
+
public ChatMessage chat(List<ChatMessage> messages) {
return this.chat(messages, Collections.emptyMap());
}
@@ -60,9 +78,6 @@ public abstract class BaseChatModelSetup extends Resource {
(BaseChatModelConnection)
this.getResource.apply(this.connection,
ResourceType.CHAT_MODEL_CONNECTION);
- // Pass metric group to connection for token usage tracking
- connection.setMetricGroup(getMetricGroup());
-
// Format input messages if set prompt.
if (this.prompt != null) {
if (this.prompt instanceof String) {
diff --git
a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
index 0810752a..65f84b7a 100644
--- a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
+++ b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
@@ -67,6 +67,10 @@ public interface RunnerContext {
/**
* Gets the metric group for Flink Agents.
*
+ * <p>The returned group must only be accessed from the operator/mailbox
(action) thread, not
+ * from inside a {@link #durableExecute} or {@link #durableExecuteAsync}
callable, which runs on
+ * a separate thread pool.
+ *
* @return the metric group shared across all actions.
*/
FlinkAgentsMetricGroup getAgentMetricGroup();
@@ -74,6 +78,10 @@ public interface RunnerContext {
/**
* Gets the individual metric group dedicated for each action.
*
+ * <p>The returned group must only be accessed from the operator/mailbox
(action) thread, not
+ * from inside a {@link #durableExecute} or {@link #durableExecuteAsync}
callable, which runs on
+ * a separate thread pool.
+ *
* @return the individual metric group specific to the current action.
*/
FlinkAgentsMetricGroup getActionMetricGroup();
diff --git
a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java
similarity index 61%
rename from
api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
rename to
api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java
index 53c9bc6c..8d98febf 100644
---
a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
+++
b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java
@@ -18,73 +18,63 @@
package org.apache.flink.agents.api.chat.model;
-import org.apache.flink.agents.api.chat.messages.ChatMessage;
-import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
-import org.apache.flink.agents.api.tools.Tool;
import org.apache.flink.metrics.Counter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import java.util.Collections;
-import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
-import static org.junit.jupiter.api.Assertions.*;
-import static org.mockito.Mockito.*;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
-/** Test cases for BaseChatModelConnection token metrics functionality. */
-class BaseChatModelConnectionTokenMetricsTest {
+/** Test cases for BaseChatModelSetup token metrics functionality. */
+class BaseChatModelSetupTokenMetricsTest {
- private TestChatModelConnection connection;
+ private TestChatModelSetup setup;
private FlinkAgentsMetricGroup mockMetricGroup;
private FlinkAgentsMetricGroup mockModelGroup;
private Counter mockPromptTokensCounter;
private Counter mockCompletionTokensCounter;
- /** Test implementation of BaseChatModelConnection for testing purposes. */
- private static class TestChatModelConnection extends
BaseChatModelConnection {
+ /** Test implementation of BaseChatModelSetup for testing purposes. */
+ private static class TestChatModelSetup extends BaseChatModelSetup {
- public TestChatModelConnection(
+ public TestChatModelSetup(
ResourceDescriptor descriptor,
BiFunction<String, ResourceType, Resource> getResource) {
super(descriptor, getResource);
}
@Override
- public ChatMessage chat(
- List<ChatMessage> messages, List<Tool> tools, Map<String,
Object> arguments) {
- // Simple test implementation
- return new ChatMessage(MessageRole.ASSISTANT, "Test response");
- }
-
- // Expose protected method for testing
- public void testRecordTokenMetrics(
- String modelName, long promptTokens, long completionTokens) {
- recordTokenMetrics(modelName, promptTokens, completionTokens);
+ public Map<String, Object> getParameters() {
+ return Collections.emptyMap();
}
}
@BeforeEach
void setUp() {
- connection =
- new TestChatModelConnection(
+ setup =
+ new TestChatModelSetup(
new ResourceDescriptor(
- TestChatModelConnection.class.getName(),
Collections.emptyMap()),
+ TestChatModelSetup.class.getName(),
Collections.emptyMap()),
null);
- // Create mock objects
mockMetricGroup = mock(FlinkAgentsMetricGroup.class);
mockModelGroup = mock(FlinkAgentsMetricGroup.class);
mockPromptTokensCounter = mock(Counter.class);
mockCompletionTokensCounter = mock(Counter.class);
- // Set up mock behavior
when(mockMetricGroup.getSubGroup("gpt-4")).thenReturn(mockModelGroup);
when(mockModelGroup.getCounter("promptTokens")).thenReturn(mockPromptTokensCounter);
when(mockModelGroup.getCounter("completionTokens")).thenReturn(mockCompletionTokensCounter);
@@ -93,13 +83,10 @@ class BaseChatModelConnectionTokenMetricsTest {
@Test
@DisplayName("Test token metrics are recorded when metric group is set")
void testRecordTokenMetricsWithMetricGroup() {
- // Set the metric group
- connection.setMetricGroup(mockMetricGroup);
+ setup.setMetricGroup(mockMetricGroup);
- // Record token metrics
- connection.testRecordTokenMetrics("gpt-4", 100, 50);
+ setup.recordTokenMetrics("gpt-4", 100, 50);
- // Verify the metrics were recorded
verify(mockMetricGroup).getSubGroup("gpt-4");
verify(mockModelGroup).getCounter("promptTokens");
verify(mockModelGroup).getCounter("completionTokens");
@@ -110,22 +97,16 @@ class BaseChatModelConnectionTokenMetricsTest {
@Test
@DisplayName("Test token metrics are not recorded when metric group is
null")
void testRecordTokenMetricsWithoutMetricGroup() {
- // Do not set metric group (should be null by default)
+ assertDoesNotThrow(() -> setup.recordTokenMetrics("gpt-4", 100, 50));
- // Record token metrics - should not throw
- assertDoesNotThrow(() -> connection.testRecordTokenMetrics("gpt-4",
100, 50));
-
- // No metrics should be recorded
verifyNoInteractions(mockMetricGroup);
}
@Test
- @DisplayName("Test token metrics hierarchy: actionMetricGroup -> modelName
-> counters")
+ @DisplayName("Test token metrics hierarchy: metricGroup -> modelName ->
counters")
void testTokenMetricsHierarchy() {
- // Set the metric group
- connection.setMetricGroup(mockMetricGroup);
+ setup.setMetricGroup(mockMetricGroup);
- // Record token metrics for different models
FlinkAgentsMetricGroup mockGpt35Group =
mock(FlinkAgentsMetricGroup.class);
Counter mockGpt35PromptCounter = mock(Counter.class);
Counter mockGpt35CompletionCounter = mock(Counter.class);
@@ -134,13 +115,9 @@ class BaseChatModelConnectionTokenMetricsTest {
when(mockGpt35Group.getCounter("promptTokens")).thenReturn(mockGpt35PromptCounter);
when(mockGpt35Group.getCounter("completionTokens")).thenReturn(mockGpt35CompletionCounter);
- // Record for gpt-4
- connection.testRecordTokenMetrics("gpt-4", 100, 50);
-
- // Record for gpt-3.5-turbo
- connection.testRecordTokenMetrics("gpt-3.5-turbo", 200, 100);
+ setup.recordTokenMetrics("gpt-4", 100, 50);
+ setup.recordTokenMetrics("gpt-3.5-turbo", 200, 100);
- // Verify each model has its own counters
verify(mockMetricGroup).getSubGroup("gpt-4");
verify(mockMetricGroup).getSubGroup("gpt-3.5-turbo");
verify(mockPromptTokensCounter).inc(100);
@@ -150,8 +127,8 @@ class BaseChatModelConnectionTokenMetricsTest {
}
@Test
- @DisplayName("Test resource type is CHAT_MODEL_CONNECTION")
+ @DisplayName("Test resource type is CHAT_MODEL")
void testResourceType() {
- assertEquals(ResourceType.CHAT_MODEL_CONNECTION,
connection.getResourceType());
+ assertEquals(ResourceType.CHAT_MODEL, setup.getResourceType());
}
}
diff --git
a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
index 6dded957..f713654e 100644
---
a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
+++
b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
@@ -135,7 +135,7 @@ public class AnthropicChatModelConnection extends
BaseChatModelConnection {
Message response = client.messages().create(params);
ChatMessage result = convertResponse(response, jsonPrefillApplied);
- // Record token metrics
+ // Stash token usage
String modelName = null;
if (arguments != null && arguments.get("model") != null) {
modelName = arguments.get("model").toString();
@@ -144,8 +144,9 @@ public class AnthropicChatModelConnection extends
BaseChatModelConnection {
modelName = this.defaultModel;
}
if (modelName != null && !modelName.isBlank()) {
- recordTokenMetrics(
- modelName, response.usage().inputTokens(),
response.usage().outputTokens());
+ result.getExtraArgs().put("model_name", modelName);
+ result.getExtraArgs().put("promptTokens",
response.usage().inputTokens());
+ result.getExtraArgs().put("completionTokens",
response.usage().outputTokens());
}
return result;
diff --git
a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
index f223ee55..7afa6805 100644
---
a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
+++
b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
@@ -189,12 +189,15 @@ public class AzureAIChatModelConnection extends
BaseChatModelConnection {
chatMessage.setToolCalls(convertedToolCalls);
}
- // Record token metrics if model name is available
+ // Stash token usage if model name is available
if (modelName != null && !modelName.isBlank()) {
CompletionsUsage usage = completions.getUsage();
if (usage != null) {
- recordTokenMetrics(
- modelName, usage.getPromptTokens(),
usage.getCompletionTokens());
+ chatMessage.getExtraArgs().put("model_name", modelName);
+ chatMessage.getExtraArgs().put("promptTokens", (long)
usage.getPromptTokens());
+ chatMessage
+ .getExtraArgs()
+ .put("completionTokens", (long)
usage.getCompletionTokens());
}
}
diff --git
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
index 773069f4..4071b51c 100644
---
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
+++
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
@@ -227,13 +227,14 @@ public class OllamaChatModelConnection extends
BaseChatModelConnection {
chatMessage.setToolCalls(toolCalls);
}
- // Record token metrics if model name is available
+ // Stash token usage if model name is available
if (modelName != null && !modelName.isBlank()) {
Integer promptTokens = ollamaChatResponse.getPromptEvalCount();
Integer completionTokens = ollamaChatResponse.getEvalCount();
if (promptTokens != null && completionTokens != null) {
- recordTokenMetrics(
- modelName, promptTokens.longValue(),
completionTokens.longValue());
+ extraArgs.put("model_name", modelName);
+ extraArgs.put("promptTokens", promptTokens.longValue());
+ extraArgs.put("completionTokens",
completionTokens.longValue());
}
}
diff --git
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
index b04cd2b2..039963d3 100644
---
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
+++
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
@@ -144,17 +144,18 @@ public class OpenAIChatModelConnection extends
BaseChatModelConnection {
ChatCompletion completion =
client.chat().completions().create(params);
ChatMessage response = convertResponse(completion);
- // Record token metrics
+ // Stash token usage
if (completion.usage().isPresent()) {
String modelName = arguments != null ? (String)
arguments.get("model") : null;
if (modelName == null || modelName.isBlank()) {
modelName = this.defaultModel;
}
if (modelName != null && !modelName.isBlank()) {
- recordTokenMetrics(
- modelName,
- completion.usage().get().promptTokens(),
- completion.usage().get().completionTokens());
+ response.getExtraArgs().put("model_name", modelName);
+ response.getExtraArgs()
+ .put("promptTokens",
completion.usage().get().promptTokens());
+ response.getExtraArgs()
+ .put("completionTokens",
completion.usage().get().completionTokens());
}
}
diff --git
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
index cc73f2c6..276c232f 100644
---
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
+++
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
@@ -129,6 +129,23 @@ public class ChatModelAction {
return (Map<String, Object>) toolRequestEventContext.remove(requestId);
}
+ static void recordChatTokenMetrics(BaseChatModelSetup chatModel,
ChatMessage response) {
+ Map<String, Object> extraArgs = response.getExtraArgs();
+ Object modelName = extraArgs.get("model_name");
+ Object promptTokens = extraArgs.get("promptTokens");
+ Object completionTokens = extraArgs.get("completionTokens");
+ if (modelName != null
+ && !modelName.toString().isEmpty()
+ && promptTokens instanceof Number
+ && completionTokens instanceof Number) {
+ long prompt = ((Number) promptTokens).longValue();
+ long completion = ((Number) completionTokens).longValue();
+ if (prompt > 0 && completion > 0) {
+ chatModel.recordTokenMetrics(modelName.toString(), prompt,
completion);
+ }
+ }
+ }
+
private static void handleToolCalls(
ChatMessage response,
UUID initialRequestId,
@@ -239,6 +256,7 @@ public class ChatModelAction {
chatAsync
? ctx.durableExecuteAsync(callable)
: ctx.durableExecute(callable);
+ recordChatTokenMetrics(chatModel, response);
// only generate structured output for final response.
if (outputSchema != null && response.getToolCalls().isEmpty())
{
response = generateStructuredOutput(response,
outputSchema);
diff --git
a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
new file mode 100644
index 00000000..b4c7c23e
--- /dev/null
+++
b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.plan.actions;
+
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
+import org.junit.jupiter.api.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+
+/** Tests for {@link ChatModelAction}. */
+class ChatModelActionTest {
+
+ private static ChatMessage responseWith(Map<String, Object> extraArgs) {
+ return new ChatMessage(MessageRole.ASSISTANT, "response", extraArgs);
+ }
+
+ @Test
+ void testRecordChatTokenMetricsRecordsWhenAllKeysPresent() {
+ BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+ Map<String, Object> extraArgs = new HashMap<>();
+ extraArgs.put("model_name", "m");
+ extraArgs.put("promptTokens", 100L);
+ extraArgs.put("completionTokens", 50L);
+
+ ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+ verify(setup).recordTokenMetrics("m", 100L, 50L);
+ }
+
+ @Test
+ void testRecordChatTokenMetricsHandlesIntegerTokenValues() {
+ BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+ Map<String, Object> extraArgs = new HashMap<>();
+ extraArgs.put("model_name", "m");
+ extraArgs.put("promptTokens", 100);
+ extraArgs.put("completionTokens", 50);
+
+ ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+ verify(setup).recordTokenMetrics("m", 100L, 50L);
+ }
+
+ @Test
+ void testRecordChatTokenMetricsSkipsWhenTokenValueNonNumeric() {
+ BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+ Map<String, Object> extraArgs = new HashMap<>();
+ extraArgs.put("model_name", "m");
+ extraArgs.put("promptTokens", "100");
+ extraArgs.put("completionTokens", 50L);
+
+ ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+ verify(setup, never()).recordTokenMetrics(anyString(), anyLong(),
anyLong());
+ }
+
+ @Test
+ void testRecordChatTokenMetricsSkipsWhenKeyMissing() {
+ BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+ Map<String, Object> extraArgs = new HashMap<>();
+ extraArgs.put("model_name", "m");
+ extraArgs.put("completionTokens", 50L);
+
+ ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+ verify(setup, never()).recordTokenMetrics(anyString(), anyLong(),
anyLong());
+ }
+
+ @Test
+ void testRecordChatTokenMetricsSkipsZeroTokensOrEmptyModel() {
+ BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+
+ Map<String, Object> zeroPrompt = new HashMap<>();
+ zeroPrompt.put("model_name", "m");
+ zeroPrompt.put("promptTokens", 0L);
+ zeroPrompt.put("completionTokens", 50L);
+ ChatModelAction.recordChatTokenMetrics(setup,
responseWith(zeroPrompt));
+
+ Map<String, Object> emptyModel = new HashMap<>();
+ emptyModel.put("model_name", "");
+ emptyModel.put("promptTokens", 100L);
+ emptyModel.put("completionTokens", 50L);
+ ChatModelAction.recordChatTokenMetrics(setup,
responseWith(emptyModel));
+
+ verify(setup, never()).recordTokenMetrics(anyString(), anyLong(),
anyLong());
+ }
+}