This is an automated email from the ASF dual-hosted git repository.
sxnan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new 6bc5f58 [metrics] Implement token usage metrics tracking for chat
models (#394)
6bc5f58 is described below
commit 6bc5f58063e30c564be01b1349cda0b97fcb794d
Author: Xuannan <[email protected]>
AuthorDate: Fri Dec 26 09:22:43 2025 +0800
[metrics] Implement token usage metrics tracking for chat models (#394)
---
.../api/chat/model/BaseChatModelConnection.java | 19 +++
.../agents/api/chat/model/BaseChatModelSetup.java | 3 +
.../apache/flink/agents/api/resource/Resource.java | 23 +++
.../BaseChatModelConnectionTokenMetricsTest.java | 157 ++++++++++++++++++
docs/content/docs/operations/monitoring.md | 13 +-
.../anthropic/AnthropicChatModelConnection.java | 17 +-
.../azureai/AzureAIChatModelConnection.java | 12 +-
.../ollama/OllamaChatModelConnection.java | 13 +-
.../openai/OpenAIChatModelConnection.java | 18 +-
python/flink_agents/api/chat_models/chat_model.py | 26 +++
.../flink_agents/api/chat_models/tests/__init__.py | 17 ++
.../api/chat_models/tests/test_token_metrics.py | 182 +++++++++++++++++++++
python/flink_agents/api/resource.py | 31 +++-
.../chat_models/anthropic/anthropic_chat_model.py | 9 +
.../integrations/chat_models/ollama_chat_model.py | 13 +-
.../chat_models/openai/openai_chat_model.py | 13 +-
.../chat_models/tests/test_tongyi_chat_model.py | 1 +
.../integrations/chat_models/tongyi_chat_model.py | 13 +-
.../flink_agents/runtime/flink_runner_context.py | 5 +-
.../agents/runtime/context/RunnerContextImpl.java | 5 +-
20 files changed, 574 insertions(+), 16 deletions(-)
diff --git
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
index 7254e37..7b70af0 100644
---
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
+++
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
@@ -19,6 +19,7 @@
package org.apache.flink.agents.api.chat.model;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
@@ -56,4 +57,22 @@ public abstract class BaseChatModelConnection extends
Resource {
*/
public abstract ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object>
arguments);
+
+ /**
+ * Record token usage metrics for the given model.
+ *
+ * @param modelName the name of the model used
+ * @param promptTokens the number of prompt tokens
+ * @param completionTokens the number of completion tokens
+ */
+ protected void recordTokenMetrics(String modelName, long promptTokens,
long completionTokens) {
+ FlinkAgentsMetricGroup metricGroup = getMetricGroup();
+ if (metricGroup == null) {
+ return;
+ }
+
+ FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName);
+ modelGroup.getCounter("promptTokens").inc(promptTokens);
+ modelGroup.getCounter("completionTokens").inc(completionTokens);
+ }
}
diff --git
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
index 9e16c07..15aa953 100644
---
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
+++
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
@@ -60,6 +60,9 @@ public abstract class BaseChatModelSetup extends Resource {
(BaseChatModelConnection)
this.getResource.apply(this.connection,
ResourceType.CHAT_MODEL_CONNECTION);
+ // Pass metric group to connection for token usage tracking
+ connection.setMetricGroup(getMetricGroup());
+
// Format input messages if set prompt.
if (this.prompt != null) {
if (this.prompt instanceof String) {
diff --git
a/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java
b/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java
index 0f22a5f..78b15ad 100644
--- a/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java
+++ b/api/src/main/java/org/apache/flink/agents/api/resource/Resource.java
@@ -18,6 +18,8 @@
package org.apache.flink.agents.api.resource;
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
+
import java.util.function.BiFunction;
/**
@@ -28,6 +30,9 @@ import java.util.function.BiFunction;
public abstract class Resource {
protected BiFunction<String, ResourceType, Resource> getResource;
+ /** The metric group bound to this resource, injected by
RunnerContext.getResource(). */
+ private transient FlinkAgentsMetricGroup metricGroup;
+
protected Resource(
ResourceDescriptor descriptor, BiFunction<String, ResourceType,
Resource> getResource) {
this.getResource = getResource;
@@ -41,4 +46,22 @@ public abstract class Resource {
* @return the resource type
*/
public abstract ResourceType getResourceType();
+
+ /**
+ * Set the metric group for this resource.
+ *
+ * @param metricGroup the metric group to bind
+ */
+ public void setMetricGroup(FlinkAgentsMetricGroup metricGroup) {
+ this.metricGroup = metricGroup;
+ }
+
+ /**
+ * Get the bound metric group.
+ *
+ * @return the bound metric group, or null if not set
+ */
+ protected FlinkAgentsMetricGroup getMetricGroup() {
+ return metricGroup;
+ }
}
diff --git
a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
new file mode 100644
index 0000000..53c9bc6
--- /dev/null
+++
b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api.chat.model;
+
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
+import org.apache.flink.agents.api.resource.Resource;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.agents.api.tools.Tool;
+import org.apache.flink.metrics.Counter;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiFunction;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.Mockito.*;
+
+/** Test cases for BaseChatModelConnection token metrics functionality. */
+class BaseChatModelConnectionTokenMetricsTest {
+
+ private TestChatModelConnection connection;
+ private FlinkAgentsMetricGroup mockMetricGroup;
+ private FlinkAgentsMetricGroup mockModelGroup;
+ private Counter mockPromptTokensCounter;
+ private Counter mockCompletionTokensCounter;
+
+ /** Test implementation of BaseChatModelConnection for testing purposes. */
+ private static class TestChatModelConnection extends
BaseChatModelConnection {
+
+ public TestChatModelConnection(
+ ResourceDescriptor descriptor,
+ BiFunction<String, ResourceType, Resource> getResource) {
+ super(descriptor, getResource);
+ }
+
+ @Override
+ public ChatMessage chat(
+ List<ChatMessage> messages, List<Tool> tools, Map<String,
Object> arguments) {
+ // Simple test implementation
+ return new ChatMessage(MessageRole.ASSISTANT, "Test response");
+ }
+
+ // Expose protected method for testing
+ public void testRecordTokenMetrics(
+ String modelName, long promptTokens, long completionTokens) {
+ recordTokenMetrics(modelName, promptTokens, completionTokens);
+ }
+ }
+
+ @BeforeEach
+ void setUp() {
+ connection =
+ new TestChatModelConnection(
+ new ResourceDescriptor(
+ TestChatModelConnection.class.getName(),
Collections.emptyMap()),
+ null);
+
+ // Create mock objects
+ mockMetricGroup = mock(FlinkAgentsMetricGroup.class);
+ mockModelGroup = mock(FlinkAgentsMetricGroup.class);
+ mockPromptTokensCounter = mock(Counter.class);
+ mockCompletionTokensCounter = mock(Counter.class);
+
+ // Set up mock behavior
+ when(mockMetricGroup.getSubGroup("gpt-4")).thenReturn(mockModelGroup);
+
when(mockModelGroup.getCounter("promptTokens")).thenReturn(mockPromptTokensCounter);
+
when(mockModelGroup.getCounter("completionTokens")).thenReturn(mockCompletionTokensCounter);
+ }
+
+ @Test
+ @DisplayName("Test token metrics are recorded when metric group is set")
+ void testRecordTokenMetricsWithMetricGroup() {
+ // Set the metric group
+ connection.setMetricGroup(mockMetricGroup);
+
+ // Record token metrics
+ connection.testRecordTokenMetrics("gpt-4", 100, 50);
+
+ // Verify the metrics were recorded
+ verify(mockMetricGroup).getSubGroup("gpt-4");
+ verify(mockModelGroup).getCounter("promptTokens");
+ verify(mockModelGroup).getCounter("completionTokens");
+ verify(mockPromptTokensCounter).inc(100);
+ verify(mockCompletionTokensCounter).inc(50);
+ }
+
+ @Test
+ @DisplayName("Test token metrics are not recorded when metric group is
null")
+ void testRecordTokenMetricsWithoutMetricGroup() {
+ // Do not set metric group (should be null by default)
+
+ // Record token metrics - should not throw
+ assertDoesNotThrow(() -> connection.testRecordTokenMetrics("gpt-4",
100, 50));
+
+ // No metrics should be recorded
+ verifyNoInteractions(mockMetricGroup);
+ }
+
+ @Test
+ @DisplayName("Test token metrics hierarchy: actionMetricGroup -> modelName
-> counters")
+ void testTokenMetricsHierarchy() {
+ // Set the metric group
+ connection.setMetricGroup(mockMetricGroup);
+
+ // Record token metrics for different models
+ FlinkAgentsMetricGroup mockGpt35Group =
mock(FlinkAgentsMetricGroup.class);
+ Counter mockGpt35PromptCounter = mock(Counter.class);
+ Counter mockGpt35CompletionCounter = mock(Counter.class);
+
+
when(mockMetricGroup.getSubGroup("gpt-3.5-turbo")).thenReturn(mockGpt35Group);
+
when(mockGpt35Group.getCounter("promptTokens")).thenReturn(mockGpt35PromptCounter);
+
when(mockGpt35Group.getCounter("completionTokens")).thenReturn(mockGpt35CompletionCounter);
+
+ // Record for gpt-4
+ connection.testRecordTokenMetrics("gpt-4", 100, 50);
+
+ // Record for gpt-3.5-turbo
+ connection.testRecordTokenMetrics("gpt-3.5-turbo", 200, 100);
+
+ // Verify each model has its own counters
+ verify(mockMetricGroup).getSubGroup("gpt-4");
+ verify(mockMetricGroup).getSubGroup("gpt-3.5-turbo");
+ verify(mockPromptTokensCounter).inc(100);
+ verify(mockCompletionTokensCounter).inc(50);
+ verify(mockGpt35PromptCounter).inc(200);
+ verify(mockGpt35CompletionCounter).inc(100);
+ }
+
+ @Test
+ @DisplayName("Test resource type is CHAT_MODEL_CONNECTION")
+ void testResourceType() {
+ assertEquals(ResourceType.CHAT_MODEL_CONNECTION,
connection.getResourceType());
+ }
+}
diff --git a/docs/content/docs/operations/monitoring.md
b/docs/content/docs/operations/monitoring.md
index db1c44a..462a7ba 100644
--- a/docs/content/docs/operations/monitoring.md
+++ b/docs/content/docs/operations/monitoring.md
@@ -26,7 +26,9 @@ under the License.
### Built-in Metrics
-We offer data monitoring for built-in metrics, which includes events and
actions.
+We offer data monitoring for built-in metrics, which includes events, actions,
and token usage.
+
+#### Event and Action Metrics
| Scope | Metrics | Description
| Type |
|-------------|--------------------------------------------------|----------------------------------------------------------------------------------|-------|
@@ -37,7 +39,14 @@ We offer data monitoring for built-in metrics, which
includes events and actions
| **Action** | <action_name>.numOfActionsExecuted | The total number of
actions this operator has executed for a specific action name. | Count |
| **Action** | <action_name>.numOfActionsExecutedPerSec | The number of
actions this operator has executed per second for a specific action name. |
Meter |
-####
+#### Token Usage Metrics
+
+Token usage metrics are automatically recorded when chat models are invoked
through `ChatModelConnection`. These metrics help track LLM API usage and costs.
+
+| Scope | Metrics | Description
| Type |
+|-----------|---------------------------------------------|--------------------------------------------------------------------------------|-------|
+| **Model** | <action_name>.<model_name>.promptTokens | The total number
of prompt tokens consumed by the model within an action. | Count |
+| **Model** | <action_name>.<model_name>.completionTokens | The total number
of completion tokens generated by the model within an action. | Count |
### How to add custom metrics
diff --git
a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
index ab30c97..49fbef3 100644
---
a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
+++
b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
@@ -128,7 +128,22 @@ public class AnthropicChatModelConnection extends
BaseChatModelConnection {
MessageCreateParams params = buildRequest(messages, tools,
arguments);
Message response = client.messages().create(params);
- return convertResponse(response, jsonPrefillApplied);
+ ChatMessage result = convertResponse(response, jsonPrefillApplied);
+
+ // Record token metrics
+ String modelName = null;
+ if (arguments != null && arguments.get("model") != null) {
+ modelName = arguments.get("model").toString();
+ }
+ if (modelName == null || modelName.isBlank()) {
+ modelName = this.defaultModel;
+ }
+ if (modelName != null && !modelName.isBlank()) {
+ recordTokenMetrics(
+ modelName, response.usage().inputTokens(),
response.usage().outputTokens());
+ }
+
+ return result;
} catch (Exception e) {
throw new RuntimeException("Failed to call Anthropic messages
API.", e);
}
diff --git
a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
index 182c4c1..f223ee5 100644
---
a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
+++
b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
@@ -171,9 +171,10 @@ public class AzureAIChatModelConnection extends
BaseChatModelConnection {
.map(this::convertToChatRequestMessage)
.collect(Collectors.toList());
+ final String modelName = (String) arguments.get("model");
ChatCompletionsOptions options =
new ChatCompletionsOptions(chatMessages)
- .setModel((String) arguments.get("model"))
+ .setModel(modelName)
.setTools(azureTools);
ChatCompletions completions = client.complete(options);
@@ -188,6 +189,15 @@ public class AzureAIChatModelConnection extends
BaseChatModelConnection {
chatMessage.setToolCalls(convertedToolCalls);
}
+ // Record token metrics if model name is available
+ if (modelName != null && !modelName.isBlank()) {
+ CompletionsUsage usage = completions.getUsage();
+ if (usage != null) {
+ recordTokenMetrics(
+ modelName, usage.getPromptTokens(),
usage.getCompletionTokens());
+ }
+ }
+
return chatMessage;
} catch (Exception e) {
throw new RuntimeException(e);
diff --git
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
index 5faa3ed..b591e0e 100644
---
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
+++
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
@@ -189,10 +189,11 @@ public class OllamaChatModelConnection extends
BaseChatModelConnection {
.map(this::convertToOllamaChatMessages)
.collect(Collectors.toList());
+ final String modelName = (String) arguments.get("model");
final OllamaChatRequest chatRequest =
OllamaChatRequest.builder()
.withMessages(ollamaChatMessages)
- .withModel((String) arguments.get("model"))
+ .withModel(modelName)
.withThinking(extractReasoning ? ThinkMode.ENABLED
: ThinkMode.DISABLED)
.withUseTools(false)
.build();
@@ -216,6 +217,16 @@ public class OllamaChatModelConnection extends
BaseChatModelConnection {
chatMessage.setToolCalls(toolCalls);
}
+ // Record token metrics if model name is available
+ if (modelName != null && !modelName.isBlank()) {
+ Integer promptTokens = ollamaChatResponse.getPromptEvalCount();
+ Integer completionTokens = ollamaChatResponse.getEvalCount();
+ if (promptTokens != null && completionTokens != null) {
+ recordTokenMetrics(
+ modelName, promptTokens.longValue(),
completionTokens.longValue());
+ }
+ }
+
return chatMessage;
} catch (Exception e) {
throw new RuntimeException(e);
diff --git
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
index ff15cf6..2675d42 100644
---
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
+++
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
@@ -142,7 +142,23 @@ public class OpenAIChatModelConnection extends
BaseChatModelConnection {
try {
ChatCompletionCreateParams params = buildRequest(messages, tools,
arguments);
ChatCompletion completion =
client.chat().completions().create(params);
- return convertResponse(completion);
+ ChatMessage response = convertResponse(completion);
+
+ // Record token metrics
+ if (completion.usage().isPresent()) {
+ String modelName = arguments != null ? (String)
arguments.get("model") : null;
+ if (modelName == null || modelName.isBlank()) {
+ modelName = this.defaultModel;
+ }
+ if (modelName != null && !modelName.isBlank()) {
+ recordTokenMetrics(
+ modelName,
+ completion.usage().get().promptTokens(),
+ completion.usage().get().completionTokens());
+ }
+ }
+
+ return response;
} catch (Exception e) {
throw new RuntimeException("Failed to call OpenAI chat completions
API.", e);
}
diff --git a/python/flink_agents/api/chat_models/chat_model.py
b/python/flink_agents/api/chat_models/chat_model.py
index 9614edd..fbde478 100644
--- a/python/flink_agents/api/chat_models/chat_model.py
+++ b/python/flink_agents/api/chat_models/chat_model.py
@@ -92,6 +92,28 @@ class BaseChatModelConnection(Resource, ABC):
cleaned = cleaned.strip()
return cleaned, reasoning
+ def _record_token_metrics(
+ self, model_name: str, prompt_tokens: int, completion_tokens: int
+ ) -> None:
+ """Record token usage metrics for the given model.
+
+ Parameters
+ ----------
+ model_name : str
+ The name of the model used
+ prompt_tokens : int
+ The number of prompt tokens
+ completion_tokens : int
+ The number of completion tokens
+ """
+ metric_group = self.metric_group
+ if metric_group is None:
+ return
+
+ model_group = metric_group.get_sub_group(model_name)
+ model_group.get_counter("promptTokens").inc(prompt_tokens)
+ model_group.get_counter("completionTokens").inc(completion_tokens)
+
@abstractmethod
def chat(
self,
@@ -173,6 +195,10 @@ class BaseChatModelSetup(Resource):
self.connection, ResourceType.CHAT_MODEL_CONNECTION
)
+ # Pass metric group to connection for token usage tracking
+ if self.metric_group is not None:
+ connection.set_metric_group(self.metric_group)
+
# Apply prompt template
if self.prompt is not None:
if isinstance(self.prompt, str):
diff --git a/python/flink_agents/api/chat_models/tests/__init__.py
b/python/flink_agents/api/chat_models/tests/__init__.py
new file mode 100644
index 0000000..e154fad
--- /dev/null
+++ b/python/flink_agents/api/chat_models/tests/__init__.py
@@ -0,0 +1,17 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
diff --git a/python/flink_agents/api/chat_models/tests/test_token_metrics.py
b/python/flink_agents/api/chat_models/tests/test_token_metrics.py
new file mode 100644
index 0000000..c51c1e5
--- /dev/null
+++ b/python/flink_agents/api/chat_models/tests/test_token_metrics.py
@@ -0,0 +1,182 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+"""Test cases for BaseChatModelConnection token metrics functionality."""
+from typing import Any, List, Sequence
+from unittest.mock import MagicMock
+
+from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.chat_models.chat_model import BaseChatModelConnection
+from flink_agents.api.metric_group import Counter, MetricGroup
+from flink_agents.api.resource import ResourceType
+from flink_agents.api.tools.tool import Tool
+
+
+class TestChatModelConnection(BaseChatModelConnection):
+ """Test implementation of BaseChatModelConnection for testing purposes."""
+
+ @classmethod
+ def resource_type(cls) -> ResourceType:
+ """Return resource type of class."""
+ return ResourceType.CHAT_MODEL_CONNECTION
+
+ def chat(
+ self,
+ messages: Sequence[ChatMessage],
+ tools: List[Tool] | None = None,
+ **kwargs: Any,
+ ) -> ChatMessage:
+ """Simple test implementation."""
+ return ChatMessage(role=MessageRole.ASSISTANT, content="Test response")
+
+ def test_record_token_metrics(
+ self, model_name: str, prompt_tokens: int, completion_tokens: int
+ ) -> None:
+ """Expose protected method for testing."""
+ self._record_token_metrics(model_name, prompt_tokens,
completion_tokens)
+
+
+class _MockCounter(Counter):
+ """Mock implementation of Counter for testing."""
+
+ def __init__(self) -> None:
+ self._count = 0
+
+ def inc(self, n: int = 1) -> None:
+ self._count += n
+
+ def dec(self, n: int = 1) -> None:
+ self._count -= n
+
+ def get_count(self) -> int:
+ return self._count
+
+
+class _MockMetricGroup(MetricGroup):
+ """Mock implementation of MetricGroup for testing."""
+
+ def __init__(self) -> None:
+ self._sub_groups: dict[str, _MockMetricGroup] = {}
+ self._counters: dict[str, _MockCounter] = {}
+
+ def get_sub_group(self, name: str) -> "_MockMetricGroup":
+ if name not in self._sub_groups:
+ self._sub_groups[name] = _MockMetricGroup()
+ return self._sub_groups[name]
+
+ def get_counter(self, name: str) -> _MockCounter:
+ if name not in self._counters:
+ self._counters[name] = _MockCounter()
+ return self._counters[name]
+
+ def get_meter(self, name: str) -> Any:
+ return MagicMock()
+
+ def get_gauge(self, name: str) -> Any:
+ return MagicMock()
+
+ def get_histogram(self, name: str, window_size: int = 100) -> Any:
+ return MagicMock()
+
+
+class TestBaseChatModelConnectionTokenMetrics:
+ """Test cases for BaseChatModelConnection token metrics functionality."""
+
+ def test_record_token_metrics_with_metric_group(self) -> None:
+ """Test token metrics are recorded when metric group is set."""
+ connection = TestChatModelConnection()
+ mock_metric_group = _MockMetricGroup()
+
+ # Set the metric group
+ connection.set_metric_group(mock_metric_group)
+
+ # Record token metrics
+ connection.test_record_token_metrics("gpt-4", 100, 50)
+
+ # Verify the metrics were recorded
+ model_group = mock_metric_group.get_sub_group("gpt-4")
+ assert model_group.get_counter("promptTokens").get_count() == 100
+ assert model_group.get_counter("completionTokens").get_count() == 50
+
+ def test_record_token_metrics_without_metric_group(self) -> None:
+ """Test token metrics are not recorded when metric group is null."""
+ connection = TestChatModelConnection()
+
+ # Do not set metric group (should be None by default)
+ # Record token metrics - should not throw
+ connection.test_record_token_metrics("gpt-4", 100, 50)
+ # No exception should be raised
+
+ def test_token_metrics_hierarchy(self) -> None:
+ """Test token metrics hierarchy: actionMetricGroup -> modelName ->
counters."""
+ connection = TestChatModelConnection()
+ mock_metric_group = _MockMetricGroup()
+
+ # Set the metric group
+ connection.set_metric_group(mock_metric_group)
+
+ # Record for gpt-4
+ connection.test_record_token_metrics("gpt-4", 100, 50)
+
+ # Record for gpt-3.5-turbo
+ connection.test_record_token_metrics("gpt-3.5-turbo", 200, 100)
+
+ # Verify each model has its own counters
+ gpt4_group = mock_metric_group.get_sub_group("gpt-4")
+ gpt35_group = mock_metric_group.get_sub_group("gpt-3.5-turbo")
+
+ assert gpt4_group.get_counter("promptTokens").get_count() == 100
+ assert gpt4_group.get_counter("completionTokens").get_count() == 50
+ assert gpt35_group.get_counter("promptTokens").get_count() == 200
+ assert gpt35_group.get_counter("completionTokens").get_count() == 100
+
+ def test_token_metrics_accumulation(self) -> None:
+ """Test that token metrics accumulate across multiple calls."""
+ connection = TestChatModelConnection()
+ mock_metric_group = _MockMetricGroup()
+
+ # Set the metric group
+ connection.set_metric_group(mock_metric_group)
+
+ # Record multiple times for the same model
+ connection.test_record_token_metrics("gpt-4", 100, 50)
+ connection.test_record_token_metrics("gpt-4", 150, 75)
+
+ # Verify the metrics accumulated
+ model_group = mock_metric_group.get_sub_group("gpt-4")
+ assert model_group.get_counter("promptTokens").get_count() == 250
+ assert model_group.get_counter("completionTokens").get_count() == 125
+
+ def test_resource_type(self) -> None:
+ """Test resource type is CHAT_MODEL_CONNECTION."""
+ connection = TestChatModelConnection()
+ assert connection.resource_type() == ResourceType.CHAT_MODEL_CONNECTION
+
+ def test_bound_metric_group_property(self) -> None:
+ """Test bound_metric_group property."""
+ connection = TestChatModelConnection()
+
+ # Initially should be None
+ assert connection.metric_group is None
+
+ # Set metric group
+ mock_metric_group = _MockMetricGroup()
+ connection.set_metric_group(mock_metric_group)
+
+ # Now should return the set metric group
+ assert connection.metric_group is mock_metric_group
+
diff --git a/python/flink_agents/api/resource.py
b/python/flink_agents/api/resource.py
index 056a850..090690e 100644
--- a/python/flink_agents/api/resource.py
+++ b/python/flink_agents/api/resource.py
@@ -18,9 +18,12 @@
import importlib
from abc import ABC, abstractmethod
from enum import Enum
-from typing import Any, Callable, Dict, Type
+from typing import TYPE_CHECKING, Any, Callable, Dict, Type
-from pydantic import BaseModel, Field, model_serializer, model_validator
+from pydantic import BaseModel, Field, PrivateAttr, model_serializer,
model_validator
+
+if TYPE_CHECKING:
+ from flink_agents.api.metric_group import MetricGroup
class ResourceType(Enum):
@@ -58,11 +61,35 @@ class Resource(BaseModel, ABC):
exclude=True, default=None
)
+ # The metric group bound to this resource, injected in
RunnerContext#get_resource
+ _metric_group: "MetricGroup | None" = PrivateAttr(default=None)
+
@classmethod
@abstractmethod
def resource_type(cls) -> ResourceType:
"""Return resource type of class."""
+ def set_metric_group(self, metric_group: "MetricGroup") -> None:
+ """Set the metric group for this resource.
+
+ Parameters
+ ----------
+ metric_group : MetricGroup
+ The metric group to bind.
+ """
+ self._metric_group = metric_group
+
+ @property
+ def metric_group(self) -> "MetricGroup | None":
+ """Get the bound metric group.
+
+ Returns:
+ -------
+ MetricGroup | None
+ The bound metric group, or None if not set.
+ """
+ return self._metric_group
+
class SerializableResource(Resource, ABC):
"""Resource which is serializable."""
diff --git
a/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py
b/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py
index bef1043..297fb5e 100644
---
a/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py
+++
b/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py
@@ -167,6 +167,15 @@ class
AnthropicChatModelConnection(BaseChatModelConnection):
**kwargs,
)
+ # Record token metrics if model name and usage are available
+ model_name = kwargs.get("model")
+ if model_name and message.usage:
+ self._record_token_metrics(
+ model_name,
+ message.usage.input_tokens,
+ message.usage.output_tokens,
+ )
+
if message.stop_reason == "tool_use":
tool_calls = [
{
diff --git a/python/flink_agents/integrations/chat_models/ollama_chat_model.py
b/python/flink_agents/integrations/chat_models/ollama_chat_model.py
index 8549ff3..3eeff8e 100644
--- a/python/flink_agents/integrations/chat_models/ollama_chat_model.py
+++ b/python/flink_agents/integrations/chat_models/ollama_chat_model.py
@@ -95,8 +95,9 @@ class OllamaChatModelConnection(BaseChatModelConnection):
if tools is not None:
ollama_tools = [to_openai_tool(metadata=tool.metadata) for tool in
tools]
+ model_name = kwargs.pop("model")
response = self.client.chat(
- model=kwargs.pop("model"),
+ model=model_name,
messages=ollama_messages,
stream=False,
tools=ollama_tools,
@@ -128,6 +129,16 @@ class OllamaChatModelConnection(BaseChatModelConnection):
if reasoning:
extra_args["reasoning"] = reasoning
+ # Record token metrics if model name and usage are available
+ if (
+ model_name
+ and response.prompt_eval_count is not None
+ and response.eval_count is not None
+ ):
+ self._record_token_metrics(
+ model_name, response.prompt_eval_count, response.eval_count
+ )
+
return ChatMessage(
role=MessageRole(response.message.role),
content=content,
diff --git
a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py
b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py
index f12dd39..a00bbae 100644
--- a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py
+++ b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py
@@ -171,9 +171,18 @@ class OpenAIChatModelConnection(BaseChatModelConnection):
**kwargs,
)
- response = response.choices[0].message
+ # Record token metrics if model name and usage are available
+ model_name = kwargs.get("model")
+ if model_name and response.usage:
+ self._record_token_metrics(
+ model_name,
+ response.usage.prompt_tokens,
+ response.usage.completion_tokens,
+ )
- return convert_from_openai_message(response)
+ message = response.choices[0].message
+
+ return convert_from_openai_message(message)
DEFAULT_TEMPERATURE = 0.1
diff --git
a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py
b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py
index 3c3f13d..3278f1e 100644
---
a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py
+++
b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py
@@ -125,6 +125,7 @@ def test_tongyi_chat_with_extract_reasoning(monkeypatch:
pytest.MonkeyPatch) ->
}
]
},
+ usage=SimpleNamespace(input_tokens=100, output_tokens=50),
)
mock_call = MagicMock(return_value=mocked_response)
diff --git a/python/flink_agents/integrations/chat_models/tongyi_chat_model.py
b/python/flink_agents/integrations/chat_models/tongyi_chat_model.py
index 0148603..164eb8a 100644
--- a/python/flink_agents/integrations/chat_models/tongyi_chat_model.py
+++ b/python/flink_agents/integrations/chat_models/tongyi_chat_model.py
@@ -113,8 +113,9 @@ class TongyiChatModelConnection(BaseChatModelConnection):
req_api_key = kwargs.pop("api_key", self.api_key)
+ model_name = kwargs.pop("model", DEFAULT_MODEL)
response = Generation.call(
- model=kwargs.pop("model", DEFAULT_MODEL),
+ model=model_name,
messages=tongyi_messages,
tools=tongyi_tools,
result_format="message",
@@ -123,10 +124,16 @@ class TongyiChatModelConnection(BaseChatModelConnection):
**kwargs,
)
- if getattr(response, "status_code", 200) != 200:
- msg = f"DashScope call failed: {getattr(response, 'message',
'unknown error')}"
+ if response.status_code != 200:
+ msg = f"DashScope call failed: {response.message}"
raise RuntimeError(msg)
+ # Record token metrics if model name and usage are available
+ if model_name and response.usage:
+ self._record_token_metrics(
+ model_name, response.usage.input_tokens,
response.usage.output_tokens
+ )
+
choice = response.output["choices"][0]
response_message: Dict[str, Any] = choice["message"]
diff --git a/python/flink_agents/runtime/flink_runner_context.py
b/python/flink_agents/runtime/flink_runner_context.py
index 75ba9c2..c63fc7c 100644
--- a/python/flink_agents/runtime/flink_runner_context.py
+++ b/python/flink_agents/runtime/flink_runner_context.py
@@ -98,7 +98,10 @@ class FlinkRunnerContext(RunnerContext):
@override
def get_resource(self, name: str, type: ResourceType) -> Resource:
- return self.__agent_plan.get_resource(name, type)
+ resource = self.__agent_plan.get_resource(name, type)
+ # Bind current action's metric group to the resource
+ resource.set_metric_group(self.action_metric_group)
+ return resource
@property
@override
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
index 743b4aa..4d534bd 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
@@ -181,7 +181,10 @@ public class RunnerContextImpl implements RunnerContext {
if (agentPlan == null) {
throw new IllegalStateException("AgentPlan is not available in
this context");
}
- return agentPlan.getResource(name, type);
+ Resource resource = agentPlan.getResource(name, type);
+ // Set current action's metric group to the resource
+ resource.setMetricGroup(getActionMetricGroup());
+ return resource;
}
@Override