This is an automated email from the ASF dual-hosted git repository.

xintongsong pushed a commit to branch release-0.2
in repository https://gitbox.apache.org/repos/asf/flink-agents.git


The following commit(s) were added to refs/heads/release-0.2 by this push:
     new 6b3e3e77 [api][plan][integrations] Record built-in chat token metrics 
outside the async call boundary (#712) (#725)
6b3e3e77 is described below

commit 6b3e3e773d006ba02e0d70d523cb37ff11136334
Author: Xintong Song <[email protected]>
AuthorDate: Tue Jun 2 09:30:38 2026 +0800

    [api][plan][integrations] Record built-in chat token metrics outside the 
async call boundary (#712) (#725)
    
    Backport of #712 to release-0.2 with scope narrowed to the chat-model
    connections that exist on this branch. The Bedrock, AzureOpenAI and
    OpenAIResponses connection variants are not present on release-0.2 and
    are intentionally excluded.
    
    Move token-metric recording from the durable async callable (where it
    crossed the operator/mailbox thread boundary) to the action thread:
    
    - BaseChatModelSetup gains public recordTokenMetrics(String, long, long).
    - BaseChatModelConnection.recordTokenMetrics(...) and the
      connection.setMetricGroup(...) forwarding in BaseChatModelSetup are
      removed.
    - Each connection's chat() stashes model_name / promptTokens /
      completionTokens into ChatMessage.extraArgs (Ollama, Anthropic, AzureAI,
      OpenAI on release-0.2).
    - ChatModelAction records via the setup after durableExecute(Async)
      returns, before structured-output reassignment.
    - RunnerContext.getAgentMetricGroup/getActionMetricGroup javadoc notes
      that the returned group must only be accessed from the operator
      thread, not inside a durable callable.
    
    Emitted metric paths and counter names are unchanged. Records are
    gated identically to Python: non-empty model name and both token
    counts greater than zero; Integer/Long token values are accepted via
    Number#longValue().
    
    Tests:
    - BaseChatModelConnectionTokenMetricsTest renamed and rewritten to
      BaseChatModelSetupTokenMetricsTest (target moved from connection to
      setup).
    - New ChatModelActionTest covers recordChatTokenMetrics: records when
      all keys present and positive; Integer-typed values still recorded;
      skips on missing key, non-numeric value, zero token, or empty model
      name.
    
    Co-authored-by: Weiqing Yang <[email protected]>
---
 .../api/chat/model/BaseChatModelConnection.java    |  19 ----
 .../agents/api/chat/model/BaseChatModelSetup.java  |  21 +++-
 .../flink/agents/api/context/RunnerContext.java    |   8 ++
 ...ava => BaseChatModelSetupTokenMetricsTest.java} |  75 +++++---------
 .../anthropic/AnthropicChatModelConnection.java    |   7 +-
 .../azureai/AzureAIChatModelConnection.java        |   9 +-
 .../ollama/OllamaChatModelConnection.java          |   7 +-
 .../openai/OpenAIChatModelConnection.java          |  11 ++-
 .../flink/agents/plan/actions/ChatModelAction.java |  18 ++++
 .../agents/plan/actions/ChatModelActionTest.java   | 110 +++++++++++++++++++++
 10 files changed, 200 insertions(+), 85 deletions(-)

diff --git 
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
 
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
index 7b70af0c..7254e37f 100644
--- 
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
+++ 
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
@@ -19,7 +19,6 @@
 package org.apache.flink.agents.api.chat.model;
 
 import org.apache.flink.agents.api.chat.messages.ChatMessage;
-import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
 import org.apache.flink.agents.api.resource.Resource;
 import org.apache.flink.agents.api.resource.ResourceDescriptor;
 import org.apache.flink.agents.api.resource.ResourceType;
@@ -57,22 +56,4 @@ public abstract class BaseChatModelConnection extends 
Resource {
      */
     public abstract ChatMessage chat(
             List<ChatMessage> messages, List<Tool> tools, Map<String, Object> 
arguments);
-
-    /**
-     * Record token usage metrics for the given model.
-     *
-     * @param modelName the name of the model used
-     * @param promptTokens the number of prompt tokens
-     * @param completionTokens the number of completion tokens
-     */
-    protected void recordTokenMetrics(String modelName, long promptTokens, 
long completionTokens) {
-        FlinkAgentsMetricGroup metricGroup = getMetricGroup();
-        if (metricGroup == null) {
-            return;
-        }
-
-        FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName);
-        modelGroup.getCounter("promptTokens").inc(promptTokens);
-        modelGroup.getCounter("completionTokens").inc(completionTokens);
-    }
 }
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
 
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
index 15aa953e..343dd5a0 100644
--- 
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
+++ 
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
@@ -20,6 +20,7 @@ package org.apache.flink.agents.api.chat.model;
 
 import org.apache.flink.agents.api.chat.messages.ChatMessage;
 import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
 import org.apache.flink.agents.api.prompt.Prompt;
 import org.apache.flink.agents.api.resource.Resource;
 import org.apache.flink.agents.api.resource.ResourceDescriptor;
@@ -51,6 +52,23 @@ public abstract class BaseChatModelSetup extends Resource {
 
     public abstract Map<String, Object> getParameters();
 
+    /**
+     * Record token usage metrics for the given model on this setup's bound 
metric group.
+     *
+     * @param modelName the name of the model used
+     * @param promptTokens the number of prompt tokens
+     * @param completionTokens the number of completion tokens
+     */
+    public void recordTokenMetrics(String modelName, long promptTokens, long 
completionTokens) {
+        FlinkAgentsMetricGroup metricGroup = getMetricGroup();
+        if (metricGroup == null) {
+            return;
+        }
+        FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName);
+        modelGroup.getCounter("promptTokens").inc(promptTokens);
+        modelGroup.getCounter("completionTokens").inc(completionTokens);
+    }
+
     public ChatMessage chat(List<ChatMessage> messages) {
         return this.chat(messages, Collections.emptyMap());
     }
@@ -60,9 +78,6 @@ public abstract class BaseChatModelSetup extends Resource {
                 (BaseChatModelConnection)
                         this.getResource.apply(this.connection, 
ResourceType.CHAT_MODEL_CONNECTION);
 
-        // Pass metric group to connection for token usage tracking
-        connection.setMetricGroup(getMetricGroup());
-
         // Format input messages if set prompt.
         if (this.prompt != null) {
             if (this.prompt instanceof String) {
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java 
b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
index 0810752a..65f84b7a 100644
--- a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
+++ b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
@@ -67,6 +67,10 @@ public interface RunnerContext {
     /**
      * Gets the metric group for Flink Agents.
      *
+     * <p>The returned group must only be accessed from the operator/mailbox 
(action) thread, not
+     * from inside a {@link #durableExecute} or {@link #durableExecuteAsync} 
callable, which runs on
+     * a separate thread pool.
+     *
      * @return the metric group shared across all actions.
      */
     FlinkAgentsMetricGroup getAgentMetricGroup();
@@ -74,6 +78,10 @@ public interface RunnerContext {
     /**
      * Gets the individual metric group dedicated for each action.
      *
+     * <p>The returned group must only be accessed from the operator/mailbox 
(action) thread, not
+     * from inside a {@link #durableExecute} or {@link #durableExecuteAsync} 
callable, which runs on
+     * a separate thread pool.
+     *
      * @return the individual metric group specific to the current action.
      */
     FlinkAgentsMetricGroup getActionMetricGroup();
diff --git 
a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
 
b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java
similarity index 61%
rename from 
api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
rename to 
api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java
index 53c9bc6c..8d98febf 100644
--- 
a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
+++ 
b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java
@@ -18,73 +18,63 @@
 
 package org.apache.flink.agents.api.chat.model;
 
-import org.apache.flink.agents.api.chat.messages.ChatMessage;
-import org.apache.flink.agents.api.chat.messages.MessageRole;
 import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
 import org.apache.flink.agents.api.resource.Resource;
 import org.apache.flink.agents.api.resource.ResourceDescriptor;
 import org.apache.flink.agents.api.resource.ResourceType;
-import org.apache.flink.agents.api.tools.Tool;
 import org.apache.flink.metrics.Counter;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.DisplayName;
 import org.junit.jupiter.api.Test;
 
 import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import java.util.function.BiFunction;
 
-import static org.junit.jupiter.api.Assertions.*;
-import static org.mockito.Mockito.*;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
 
-/** Test cases for BaseChatModelConnection token metrics functionality. */
-class BaseChatModelConnectionTokenMetricsTest {
+/** Test cases for BaseChatModelSetup token metrics functionality. */
+class BaseChatModelSetupTokenMetricsTest {
 
-    private TestChatModelConnection connection;
+    private TestChatModelSetup setup;
     private FlinkAgentsMetricGroup mockMetricGroup;
     private FlinkAgentsMetricGroup mockModelGroup;
     private Counter mockPromptTokensCounter;
     private Counter mockCompletionTokensCounter;
 
-    /** Test implementation of BaseChatModelConnection for testing purposes. */
-    private static class TestChatModelConnection extends 
BaseChatModelConnection {
+    /** Test implementation of BaseChatModelSetup for testing purposes. */
+    private static class TestChatModelSetup extends BaseChatModelSetup {
 
-        public TestChatModelConnection(
+        public TestChatModelSetup(
                 ResourceDescriptor descriptor,
                 BiFunction<String, ResourceType, Resource> getResource) {
             super(descriptor, getResource);
         }
 
         @Override
-        public ChatMessage chat(
-                List<ChatMessage> messages, List<Tool> tools, Map<String, 
Object> arguments) {
-            // Simple test implementation
-            return new ChatMessage(MessageRole.ASSISTANT, "Test response");
-        }
-
-        // Expose protected method for testing
-        public void testRecordTokenMetrics(
-                String modelName, long promptTokens, long completionTokens) {
-            recordTokenMetrics(modelName, promptTokens, completionTokens);
+        public Map<String, Object> getParameters() {
+            return Collections.emptyMap();
         }
     }
 
     @BeforeEach
     void setUp() {
-        connection =
-                new TestChatModelConnection(
+        setup =
+                new TestChatModelSetup(
                         new ResourceDescriptor(
-                                TestChatModelConnection.class.getName(), 
Collections.emptyMap()),
+                                TestChatModelSetup.class.getName(), 
Collections.emptyMap()),
                         null);
 
-        // Create mock objects
         mockMetricGroup = mock(FlinkAgentsMetricGroup.class);
         mockModelGroup = mock(FlinkAgentsMetricGroup.class);
         mockPromptTokensCounter = mock(Counter.class);
         mockCompletionTokensCounter = mock(Counter.class);
 
-        // Set up mock behavior
         when(mockMetricGroup.getSubGroup("gpt-4")).thenReturn(mockModelGroup);
         
when(mockModelGroup.getCounter("promptTokens")).thenReturn(mockPromptTokensCounter);
         
when(mockModelGroup.getCounter("completionTokens")).thenReturn(mockCompletionTokensCounter);
@@ -93,13 +83,10 @@ class BaseChatModelConnectionTokenMetricsTest {
     @Test
     @DisplayName("Test token metrics are recorded when metric group is set")
     void testRecordTokenMetricsWithMetricGroup() {
-        // Set the metric group
-        connection.setMetricGroup(mockMetricGroup);
+        setup.setMetricGroup(mockMetricGroup);
 
-        // Record token metrics
-        connection.testRecordTokenMetrics("gpt-4", 100, 50);
+        setup.recordTokenMetrics("gpt-4", 100, 50);
 
-        // Verify the metrics were recorded
         verify(mockMetricGroup).getSubGroup("gpt-4");
         verify(mockModelGroup).getCounter("promptTokens");
         verify(mockModelGroup).getCounter("completionTokens");
@@ -110,22 +97,16 @@ class BaseChatModelConnectionTokenMetricsTest {
     @Test
     @DisplayName("Test token metrics are not recorded when metric group is 
null")
     void testRecordTokenMetricsWithoutMetricGroup() {
-        // Do not set metric group (should be null by default)
+        assertDoesNotThrow(() -> setup.recordTokenMetrics("gpt-4", 100, 50));
 
-        // Record token metrics - should not throw
-        assertDoesNotThrow(() -> connection.testRecordTokenMetrics("gpt-4", 
100, 50));
-
-        // No metrics should be recorded
         verifyNoInteractions(mockMetricGroup);
     }
 
     @Test
-    @DisplayName("Test token metrics hierarchy: actionMetricGroup -> modelName 
-> counters")
+    @DisplayName("Test token metrics hierarchy: metricGroup -> modelName -> 
counters")
     void testTokenMetricsHierarchy() {
-        // Set the metric group
-        connection.setMetricGroup(mockMetricGroup);
+        setup.setMetricGroup(mockMetricGroup);
 
-        // Record token metrics for different models
         FlinkAgentsMetricGroup mockGpt35Group = 
mock(FlinkAgentsMetricGroup.class);
         Counter mockGpt35PromptCounter = mock(Counter.class);
         Counter mockGpt35CompletionCounter = mock(Counter.class);
@@ -134,13 +115,9 @@ class BaseChatModelConnectionTokenMetricsTest {
         
when(mockGpt35Group.getCounter("promptTokens")).thenReturn(mockGpt35PromptCounter);
         
when(mockGpt35Group.getCounter("completionTokens")).thenReturn(mockGpt35CompletionCounter);
 
-        // Record for gpt-4
-        connection.testRecordTokenMetrics("gpt-4", 100, 50);
-
-        // Record for gpt-3.5-turbo
-        connection.testRecordTokenMetrics("gpt-3.5-turbo", 200, 100);
+        setup.recordTokenMetrics("gpt-4", 100, 50);
+        setup.recordTokenMetrics("gpt-3.5-turbo", 200, 100);
 
-        // Verify each model has its own counters
         verify(mockMetricGroup).getSubGroup("gpt-4");
         verify(mockMetricGroup).getSubGroup("gpt-3.5-turbo");
         verify(mockPromptTokensCounter).inc(100);
@@ -150,8 +127,8 @@ class BaseChatModelConnectionTokenMetricsTest {
     }
 
     @Test
-    @DisplayName("Test resource type is CHAT_MODEL_CONNECTION")
+    @DisplayName("Test resource type is CHAT_MODEL")
     void testResourceType() {
-        assertEquals(ResourceType.CHAT_MODEL_CONNECTION, 
connection.getResourceType());
+        assertEquals(ResourceType.CHAT_MODEL, setup.getResourceType());
     }
 }
diff --git 
a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
 
b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
index 6dded957..f713654e 100644
--- 
a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
+++ 
b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
@@ -135,7 +135,7 @@ public class AnthropicChatModelConnection extends 
BaseChatModelConnection {
             Message response = client.messages().create(params);
             ChatMessage result = convertResponse(response, jsonPrefillApplied);
 
-            // Record token metrics
+            // Stash token usage
             String modelName = null;
             if (arguments != null && arguments.get("model") != null) {
                 modelName = arguments.get("model").toString();
@@ -144,8 +144,9 @@ public class AnthropicChatModelConnection extends 
BaseChatModelConnection {
                 modelName = this.defaultModel;
             }
             if (modelName != null && !modelName.isBlank()) {
-                recordTokenMetrics(
-                        modelName, response.usage().inputTokens(), 
response.usage().outputTokens());
+                result.getExtraArgs().put("model_name", modelName);
+                result.getExtraArgs().put("promptTokens", 
response.usage().inputTokens());
+                result.getExtraArgs().put("completionTokens", 
response.usage().outputTokens());
             }
 
             return result;
diff --git 
a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
 
b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
index f223ee55..7afa6805 100644
--- 
a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
+++ 
b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
@@ -189,12 +189,15 @@ public class AzureAIChatModelConnection extends 
BaseChatModelConnection {
                 chatMessage.setToolCalls(convertedToolCalls);
             }
 
-            // Record token metrics if model name is available
+            // Stash token usage if model name is available
             if (modelName != null && !modelName.isBlank()) {
                 CompletionsUsage usage = completions.getUsage();
                 if (usage != null) {
-                    recordTokenMetrics(
-                            modelName, usage.getPromptTokens(), 
usage.getCompletionTokens());
+                    chatMessage.getExtraArgs().put("model_name", modelName);
+                    chatMessage.getExtraArgs().put("promptTokens", (long) 
usage.getPromptTokens());
+                    chatMessage
+                            .getExtraArgs()
+                            .put("completionTokens", (long) 
usage.getCompletionTokens());
                 }
             }
 
diff --git 
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
 
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
index 773069f4..4071b51c 100644
--- 
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
+++ 
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
@@ -227,13 +227,14 @@ public class OllamaChatModelConnection extends 
BaseChatModelConnection {
                 chatMessage.setToolCalls(toolCalls);
             }
 
-            // Record token metrics if model name is available
+            // Stash token usage if model name is available
             if (modelName != null && !modelName.isBlank()) {
                 Integer promptTokens = ollamaChatResponse.getPromptEvalCount();
                 Integer completionTokens = ollamaChatResponse.getEvalCount();
                 if (promptTokens != null && completionTokens != null) {
-                    recordTokenMetrics(
-                            modelName, promptTokens.longValue(), 
completionTokens.longValue());
+                    extraArgs.put("model_name", modelName);
+                    extraArgs.put("promptTokens", promptTokens.longValue());
+                    extraArgs.put("completionTokens", 
completionTokens.longValue());
                 }
             }
 
diff --git 
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
 
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
index b04cd2b2..039963d3 100644
--- 
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
+++ 
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
@@ -144,17 +144,18 @@ public class OpenAIChatModelConnection extends 
BaseChatModelConnection {
             ChatCompletion completion = 
client.chat().completions().create(params);
             ChatMessage response = convertResponse(completion);
 
-            // Record token metrics
+            // Stash token usage
             if (completion.usage().isPresent()) {
                 String modelName = arguments != null ? (String) 
arguments.get("model") : null;
                 if (modelName == null || modelName.isBlank()) {
                     modelName = this.defaultModel;
                 }
                 if (modelName != null && !modelName.isBlank()) {
-                    recordTokenMetrics(
-                            modelName,
-                            completion.usage().get().promptTokens(),
-                            completion.usage().get().completionTokens());
+                    response.getExtraArgs().put("model_name", modelName);
+                    response.getExtraArgs()
+                            .put("promptTokens", 
completion.usage().get().promptTokens());
+                    response.getExtraArgs()
+                            .put("completionTokens", 
completion.usage().get().completionTokens());
                 }
             }
 
diff --git 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
index cc73f2c6..276c232f 100644
--- 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
+++ 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
@@ -129,6 +129,23 @@ public class ChatModelAction {
         return (Map<String, Object>) toolRequestEventContext.remove(requestId);
     }
 
+    static void recordChatTokenMetrics(BaseChatModelSetup chatModel, 
ChatMessage response) {
+        Map<String, Object> extraArgs = response.getExtraArgs();
+        Object modelName = extraArgs.get("model_name");
+        Object promptTokens = extraArgs.get("promptTokens");
+        Object completionTokens = extraArgs.get("completionTokens");
+        if (modelName != null
+                && !modelName.toString().isEmpty()
+                && promptTokens instanceof Number
+                && completionTokens instanceof Number) {
+            long prompt = ((Number) promptTokens).longValue();
+            long completion = ((Number) completionTokens).longValue();
+            if (prompt > 0 && completion > 0) {
+                chatModel.recordTokenMetrics(modelName.toString(), prompt, 
completion);
+            }
+        }
+    }
+
     private static void handleToolCalls(
             ChatMessage response,
             UUID initialRequestId,
@@ -239,6 +256,7 @@ public class ChatModelAction {
                         chatAsync
                                 ? ctx.durableExecuteAsync(callable)
                                 : ctx.durableExecute(callable);
+                recordChatTokenMetrics(chatModel, response);
                 // only generate structured output for final response.
                 if (outputSchema != null && response.getToolCalls().isEmpty()) 
{
                     response = generateStructuredOutput(response, 
outputSchema);
diff --git 
a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
 
b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
new file mode 100644
index 00000000..b4c7c23e
--- /dev/null
+++ 
b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.plan.actions;
+
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
+import org.junit.jupiter.api.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+
+/** Tests for {@link ChatModelAction}. */
+class ChatModelActionTest {
+
+    private static ChatMessage responseWith(Map<String, Object> extraArgs) {
+        return new ChatMessage(MessageRole.ASSISTANT, "response", extraArgs);
+    }
+
+    @Test
+    void testRecordChatTokenMetricsRecordsWhenAllKeysPresent() {
+        BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+        Map<String, Object> extraArgs = new HashMap<>();
+        extraArgs.put("model_name", "m");
+        extraArgs.put("promptTokens", 100L);
+        extraArgs.put("completionTokens", 50L);
+
+        ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+        verify(setup).recordTokenMetrics("m", 100L, 50L);
+    }
+
+    @Test
+    void testRecordChatTokenMetricsHandlesIntegerTokenValues() {
+        BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+        Map<String, Object> extraArgs = new HashMap<>();
+        extraArgs.put("model_name", "m");
+        extraArgs.put("promptTokens", 100);
+        extraArgs.put("completionTokens", 50);
+
+        ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+        verify(setup).recordTokenMetrics("m", 100L, 50L);
+    }
+
+    @Test
+    void testRecordChatTokenMetricsSkipsWhenTokenValueNonNumeric() {
+        BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+        Map<String, Object> extraArgs = new HashMap<>();
+        extraArgs.put("model_name", "m");
+        extraArgs.put("promptTokens", "100");
+        extraArgs.put("completionTokens", 50L);
+
+        ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+        verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), 
anyLong());
+    }
+
+    @Test
+    void testRecordChatTokenMetricsSkipsWhenKeyMissing() {
+        BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+        Map<String, Object> extraArgs = new HashMap<>();
+        extraArgs.put("model_name", "m");
+        extraArgs.put("completionTokens", 50L);
+
+        ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+        verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), 
anyLong());
+    }
+
+    @Test
+    void testRecordChatTokenMetricsSkipsZeroTokensOrEmptyModel() {
+        BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+
+        Map<String, Object> zeroPrompt = new HashMap<>();
+        zeroPrompt.put("model_name", "m");
+        zeroPrompt.put("promptTokens", 0L);
+        zeroPrompt.put("completionTokens", 50L);
+        ChatModelAction.recordChatTokenMetrics(setup, 
responseWith(zeroPrompt));
+
+        Map<String, Object> emptyModel = new HashMap<>();
+        emptyModel.put("model_name", "");
+        emptyModel.put("promptTokens", 100L);
+        emptyModel.put("completionTokens", 50L);
+        ChatModelAction.recordChatTokenMetrics(setup, 
responseWith(emptyModel));
+
+        verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), 
anyLong());
+    }
+}

Reply via email to