This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 9474216d328be9949855dd70384a1f7974bd1385 Author: WenjinXie <[email protected]> AuthorDate: Fri Jan 16 15:31:33 2026 +0800 [metrics] Report metrics in chat action to avoid crash caused by reporting in async thread. fix --- python/flink_agents/api/chat_models/chat_model.py | 48 +++++----- .../api/chat_models/tests/test_token_metrics.py | 67 +++++++------- .../chat_models/anthropic/anthropic_chat_model.py | 103 +++++++++++++-------- .../integrations/chat_models/ollama_chat_model.py | 6 +- .../chat_models/openai/openai_chat_model.py | 11 +-- .../chat_models/openai/openai_utils.py | 12 ++- .../integrations/chat_models/tongyi_chat_model.py | 12 ++- .../flink_agents/plan/actions/chat_model_action.py | 2 + 8 files changed, 143 insertions(+), 118 deletions(-) diff --git a/python/flink_agents/api/chat_models/chat_model.py b/python/flink_agents/api/chat_models/chat_model.py index fbde4786..2f63566c 100644 --- a/python/flink_agents/api/chat_models/chat_model.py +++ b/python/flink_agents/api/chat_models/chat_model.py @@ -92,28 +92,6 @@ class BaseChatModelConnection(Resource, ABC): cleaned = cleaned.strip() return cleaned, reasoning - def _record_token_metrics( - self, model_name: str, prompt_tokens: int, completion_tokens: int - ) -> None: - """Record token usage metrics for the given model. - - Parameters - ---------- - model_name : str - The name of the model used - prompt_tokens : int - The number of prompt tokens - completion_tokens : int - The number of completion tokens - """ - metric_group = self.metric_group - if metric_group is None: - return - - model_group = metric_group.get_sub_group(model_name) - model_group.get_counter("promptTokens").inc(prompt_tokens) - model_group.get_counter("completionTokens").inc(completion_tokens) - @abstractmethod def chat( self, @@ -195,10 +173,6 @@ class BaseChatModelSetup(Resource): self.connection, ResourceType.CHAT_MODEL_CONNECTION ) - # Pass metric group to connection for token usage tracking - if self.metric_group is not None: - connection.set_metric_group(self.metric_group) - # Apply prompt template if self.prompt is not None: if isinstance(self.prompt, str): @@ -233,3 +207,25 @@ class BaseChatModelSetup(Resource): merged_kwargs = self.model_kwargs.copy() merged_kwargs.update(kwargs) return connection.chat(messages, tools=tools, **merged_kwargs) + + def _record_token_metrics( + self, model_name: str, prompt_tokens: int, completion_tokens: int + ) -> None: + """Record token usage metrics for the given model. + + Parameters + ---------- + model_name : str + The name of the model used + prompt_tokens : int + The number of prompt tokens + completion_tokens : int + The number of completion tokens + """ + metric_group = self.metric_group + if metric_group is None: + return + + model_group = metric_group.get_sub_group(model_name) + model_group.get_counter("promptTokens").inc(prompt_tokens) + model_group.get_counter("completionTokens").inc(completion_tokens) diff --git a/python/flink_agents/api/chat_models/tests/test_token_metrics.py b/python/flink_agents/api/chat_models/tests/test_token_metrics.py index c51c1e5c..982565ab 100644 --- a/python/flink_agents/api/chat_models/tests/test_token_metrics.py +++ b/python/flink_agents/api/chat_models/tests/test_token_metrics.py @@ -16,30 +16,32 @@ # limitations under the License. ################################################################################# """Test cases for BaseChatModelConnection token metrics functionality.""" -from typing import Any, List, Sequence + +from typing import Any, Dict, Sequence from unittest.mock import MagicMock from flink_agents.api.chat_message import ChatMessage, MessageRole -from flink_agents.api.chat_models.chat_model import BaseChatModelConnection +from flink_agents.api.chat_models.chat_model import ( + BaseChatModelSetup, +) from flink_agents.api.metric_group import Counter, MetricGroup from flink_agents.api.resource import ResourceType -from flink_agents.api.tools.tool import Tool -class TestChatModelConnection(BaseChatModelConnection): +class TestChatModelSetup(BaseChatModelSetup): """Test implementation of BaseChatModelConnection for testing purposes.""" + @property + def model_kwargs(self) -> Dict[str, Any]: + """Return model kwargs.""" + return {} + @classmethod def resource_type(cls) -> ResourceType: """Return resource type of class.""" - return ResourceType.CHAT_MODEL_CONNECTION - - def chat( - self, - messages: Sequence[ChatMessage], - tools: List[Tool] | None = None, - **kwargs: Any, - ) -> ChatMessage: + return ResourceType.CHAT_MODEL + + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: """Simple test implementation.""" return ChatMessage(role=MessageRole.ASSISTANT, content="Test response") @@ -93,19 +95,19 @@ class _MockMetricGroup(MetricGroup): return MagicMock() -class TestBaseChatModelConnectionTokenMetrics: +class TestBaseChatModelTokenMetrics: """Test cases for BaseChatModelConnection token metrics functionality.""" def test_record_token_metrics_with_metric_group(self) -> None: """Test token metrics are recorded when metric group is set.""" - connection = TestChatModelConnection() + chat_model = TestChatModelSetup(connection="mock") mock_metric_group = _MockMetricGroup() # Set the metric group - connection.set_metric_group(mock_metric_group) + chat_model.set_metric_group(mock_metric_group) # Record token metrics - connection.test_record_token_metrics("gpt-4", 100, 50) + chat_model.test_record_token_metrics("gpt-4", 100, 50) # Verify the metrics were recorded model_group = mock_metric_group.get_sub_group("gpt-4") @@ -114,26 +116,26 @@ class TestBaseChatModelConnectionTokenMetrics: def test_record_token_metrics_without_metric_group(self) -> None: """Test token metrics are not recorded when metric group is null.""" - connection = TestChatModelConnection() + chat_model = TestChatModelSetup(connection="mock") # Do not set metric group (should be None by default) # Record token metrics - should not throw - connection.test_record_token_metrics("gpt-4", 100, 50) + chat_model.test_record_token_metrics("gpt-4", 100, 50) # No exception should be raised def test_token_metrics_hierarchy(self) -> None: """Test token metrics hierarchy: actionMetricGroup -> modelName -> counters.""" - connection = TestChatModelConnection() + chat_model = TestChatModelSetup(connection="mock") mock_metric_group = _MockMetricGroup() # Set the metric group - connection.set_metric_group(mock_metric_group) + chat_model.set_metric_group(mock_metric_group) # Record for gpt-4 - connection.test_record_token_metrics("gpt-4", 100, 50) + chat_model.test_record_token_metrics("gpt-4", 100, 50) # Record for gpt-3.5-turbo - connection.test_record_token_metrics("gpt-3.5-turbo", 200, 100) + chat_model.test_record_token_metrics("gpt-3.5-turbo", 200, 100) # Verify each model has its own counters gpt4_group = mock_metric_group.get_sub_group("gpt-4") @@ -146,15 +148,15 @@ class TestBaseChatModelConnectionTokenMetrics: def test_token_metrics_accumulation(self) -> None: """Test that token metrics accumulate across multiple calls.""" - connection = TestChatModelConnection() + chat_model = TestChatModelSetup(connection="mock") mock_metric_group = _MockMetricGroup() # Set the metric group - connection.set_metric_group(mock_metric_group) + chat_model.set_metric_group(mock_metric_group) # Record multiple times for the same model - connection.test_record_token_metrics("gpt-4", 100, 50) - connection.test_record_token_metrics("gpt-4", 150, 75) + chat_model.test_record_token_metrics("gpt-4", 100, 50) + chat_model.test_record_token_metrics("gpt-4", 150, 75) # Verify the metrics accumulated model_group = mock_metric_group.get_sub_group("gpt-4") @@ -163,20 +165,19 @@ class TestBaseChatModelConnectionTokenMetrics: def test_resource_type(self) -> None: """Test resource type is CHAT_MODEL_CONNECTION.""" - connection = TestChatModelConnection() - assert connection.resource_type() == ResourceType.CHAT_MODEL_CONNECTION + chat_model = TestChatModelSetup(connection="mock") + assert chat_model.resource_type() == ResourceType.CHAT_MODEL def test_bound_metric_group_property(self) -> None: """Test bound_metric_group property.""" - connection = TestChatModelConnection() + chat_model = TestChatModelSetup(connection="mock") # Initially should be None - assert connection.metric_group is None + assert chat_model.metric_group is None # Set metric group mock_metric_group = _MockMetricGroup() - connection.set_metric_group(mock_metric_group) + chat_model.set_metric_group(mock_metric_group) # Now should return the set metric group - assert connection.metric_group is mock_metric_group - + assert chat_model.metric_group is mock_metric_group diff --git a/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py b/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py index e0d1233a..171fba24 100644 --- a/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py +++ b/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py @@ -32,7 +32,9 @@ from flink_agents.api.chat_models.chat_model import ( from flink_agents.api.tools.tool import Tool, ToolMetadata -def to_anthropic_tool(*, metadata: ToolMetadata, skip_length_check: bool = False) -> ToolParam: +def to_anthropic_tool( + *, metadata: ToolMetadata, skip_length_check: bool = False +) -> ToolParam: """Convert to Anthropic tool: https://docs.anthropic.com/en/api/messages#body-tools.""" if not skip_length_check and len(metadata.description) > 1024: msg = ( @@ -43,7 +45,7 @@ def to_anthropic_tool(*, metadata: ToolMetadata, skip_length_check: bool = False return { "name": metadata.name, "description": metadata.description, - "input_schema": metadata.get_parameters_dict() + "input_schema": metadata.get_parameters_dict(), } @@ -63,7 +65,11 @@ def convert_to_anthropic_message(message: ChatMessage) -> MessageParam: elif message.role == MessageRole.ASSISTANT: # Use original Anthropic content blocks if available for context anthropic_content_blocks = message.extra_args.get("anthropic_content_blocks") - content = anthropic_content_blocks if anthropic_content_blocks is not None else message.content + content = ( + anthropic_content_blocks + if anthropic_content_blocks is not None + else message.content + ) return { "role": message.role.value, "content": content, # type: ignore @@ -75,27 +81,32 @@ def convert_to_anthropic_message(message: ChatMessage) -> MessageParam: } -def convert_to_anthropic_messages(messages: Sequence[ChatMessage]) -> List[MessageParam]: +def convert_to_anthropic_messages( + messages: Sequence[ChatMessage], +) -> List[MessageParam]: """Convert user/assistant messages to Anthropic input messages. See: https://docs.anthropic.com/en/api/messages#body-messages """ - return [convert_to_anthropic_message(message) for message in messages if - message.role in [MessageRole.USER, MessageRole.ASSISTANT, MessageRole.TOOL]] + return [ + convert_to_anthropic_message(message) + for message in messages + if message.role in [MessageRole.USER, MessageRole.ASSISTANT, MessageRole.TOOL] + ] -def convert_to_anthropic_system_prompts(messages: Sequence[ChatMessage]) -> List[TextBlockParam]: +def convert_to_anthropic_system_prompts( + messages: Sequence[ChatMessage], +) -> List[TextBlockParam]: """Convert system messages to Anthropic system prompts. See: https://docs.anthropic.com/en/api/messages#body-system """ - system_messages = [message for message in messages if message.role == MessageRole.SYSTEM] + system_messages = [ + message for message in messages if message.role == MessageRole.SYSTEM + ] return [ - TextBlockParam( - type="text", - text=message.content - ) - for message in system_messages + TextBlockParam(type="text", text=message.content) for message in system_messages ] @@ -128,11 +139,11 @@ class AnthropicChatModelConnection(BaseChatModelConnection): ) def __init__( - self, - api_key: str | None = None, - max_retries: int = 3, - timeout: float = 60.0, - **kwargs: Any, + self, + api_key: str | None = None, + max_retries: int = 3, + timeout: float = 60.0, + **kwargs: Any, ) -> None: """Initialize the Anthropic chat model connection.""" super().__init__( @@ -148,15 +159,23 @@ class AnthropicChatModelConnection(BaseChatModelConnection): def client(self) -> Anthropic: """Get or create the Anthropic client instance.""" if self._client is None: - self._client = Anthropic(api_key=self.api_key, max_retries=self.max_retries, timeout=self.timeout) + self._client = Anthropic( + api_key=self.api_key, max_retries=self.max_retries, timeout=self.timeout + ) return self._client - def chat(self, messages: Sequence[ChatMessage], tools: List[Tool] | None = None, - **kwargs: Any) -> ChatMessage: + def chat( + self, + messages: Sequence[ChatMessage], + tools: List[Tool] | None = None, + **kwargs: Any, + ) -> ChatMessage: """Direct communication with Anthropic model service for chat conversation.""" anthropic_tools = None if tools is not None: - anthropic_tools = [to_anthropic_tool(metadata=tool.metadata) for tool in tools] + anthropic_tools = [ + to_anthropic_tool(metadata=tool.metadata) for tool in tools + ] anthropic_system = convert_to_anthropic_system_prompts(messages) anthropic_messages = convert_to_anthropic_messages(messages) @@ -168,14 +187,13 @@ class AnthropicChatModelConnection(BaseChatModelConnection): **kwargs, ) + extra_args = {} # Record token metrics if model name and usage are available model_name = kwargs.get("model") if model_name and message.usage: - self._record_token_metrics( - model_name, - message.usage.input_tokens, - message.usage.output_tokens, - ) + extra_args["model_name"] = model_name + extra_args["promptTokens"] = message.usage.input_tokens + extra_args["completionTokens"] = message.usage.output_tokens if message.stop_reason == "tool_use": tool_calls = [ @@ -189,17 +207,15 @@ class AnthropicChatModelConnection(BaseChatModelConnection): "original_id": content_block.id, } for content_block in message.content - if content_block.type == 'tool_use' + if content_block.type == "tool_use" ] + extra_args["anthropic_content_blocks"] = message.content return ChatMessage( role=MessageRole(message.role), content=message.content[0].text, tool_calls=tool_calls, - extra_args={ - "anthropic_content_blocks": message.content - } - + extra_args=extra_args, ) else: # TODO: handle other stop_reason values according to Anthropic API: @@ -243,8 +259,9 @@ class AnthropicChatModelSetup(BaseChatModelSetup): """ model: str = Field( - default=DEFAULT_ANTHROPIC_MODEL, description="Specifies the Anthropic model to use. Defaults to " - "claude-sonnet-4-20250514." + default=DEFAULT_ANTHROPIC_MODEL, + description="Specifies the Anthropic model to use. Defaults to " + "claude-sonnet-4-20250514.", ) max_tokens: int = Field( default=DEFAULT_MAX_TOKENS, @@ -259,12 +276,12 @@ class AnthropicChatModelSetup(BaseChatModelSetup): ) def __init__( - self, - connection: str, - model: str = DEFAULT_ANTHROPIC_MODEL, - max_tokens: int = DEFAULT_MAX_TOKENS, - temperature: float = DEFAULT_TEMPERATURE, - **kwargs: Any, + self, + connection: str, + model: str = DEFAULT_ANTHROPIC_MODEL, + max_tokens: int = DEFAULT_MAX_TOKENS, + temperature: float = DEFAULT_TEMPERATURE, + **kwargs: Any, ) -> None: """Init method.""" super().__init__( @@ -278,4 +295,8 @@ class AnthropicChatModelSetup(BaseChatModelSetup): @property def model_kwargs(self) -> Dict[str, Any]: """Get model-specific keyword arguments.""" - return {"model": self.model, "max_tokens": self.max_tokens, "temperature": self.temperature} + return { + "model": self.model, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + } diff --git a/python/flink_agents/integrations/chat_models/ollama_chat_model.py b/python/flink_agents/integrations/chat_models/ollama_chat_model.py index 3eeff8ee..4bbc74b6 100644 --- a/python/flink_agents/integrations/chat_models/ollama_chat_model.py +++ b/python/flink_agents/integrations/chat_models/ollama_chat_model.py @@ -135,9 +135,9 @@ class OllamaChatModelConnection(BaseChatModelConnection): and response.prompt_eval_count is not None and response.eval_count is not None ): - self._record_token_metrics( - model_name, response.prompt_eval_count, response.eval_count - ) + extra_args["model_name"] = model_name + extra_args["promptTokens"] = response.prompt_eval_count + extra_args["completionTokens"] = response.eval_count return ChatMessage( role=MessageRole(response.message.role), diff --git a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py index b5e3297a..95edb298 100644 --- a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py @@ -172,18 +172,17 @@ class OpenAIChatModelConnection(BaseChatModelConnection): **kwargs, ) + extra_args = {} # Record token metrics if model name and usage are available model_name = kwargs.get("model") if model_name and response.usage: - self._record_token_metrics( - model_name, - response.usage.prompt_tokens, - response.usage.completion_tokens, - ) + extra_args["model_name"] = model_name + extra_args["promptTokens"] = response.usage.prompt_tokens + extra_args["completionTokens"] = response.usage.completion_tokens message = response.choices[0].message - return convert_from_openai_message(message) + return convert_from_openai_message(message, extra_args) @override def close(self) -> None: diff --git a/python/flink_agents/integrations/chat_models/openai/openai_utils.py b/python/flink_agents/integrations/chat_models/openai/openai_utils.py index 712787ab..16f49e84 100644 --- a/python/flink_agents/integrations/chat_models/openai/openai_utils.py +++ b/python/flink_agents/integrations/chat_models/openai/openai_utils.py @@ -18,7 +18,7 @@ import json import os import uuid -from typing import TYPE_CHECKING, List, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple import openai from openai.types.chat import ( @@ -167,7 +167,8 @@ def convert_to_openai_message(message: ChatMessage) -> ChatCompletionMessagePara } if message.tool_calls: openai_tool_calls = [ - _convert_to_openai_tool_call(tool_call) for tool_call in message.tool_calls + _convert_to_openai_tool_call(tool_call) + for tool_call in message.tool_calls ] assistant_message["tool_calls"] = openai_tool_calls @@ -192,7 +193,9 @@ def convert_to_openai_message(message: ChatMessage) -> ChatCompletionMessagePara raise ValueError(msg) -def convert_from_openai_message(message: ChatCompletionMessage) -> ChatMessage: +def convert_from_openai_message( + message: ChatCompletionMessage, extra_args: Dict[str, Any] +) -> ChatMessage: """Convert an OpenAI message to a chat message.""" tool_calls = [] if message.tool_calls: @@ -207,7 +210,7 @@ def convert_from_openai_message(message: ChatCompletionMessage) -> ChatMessage: "name": tool_call.function.name, "arguments": json.loads(tool_call.function.arguments), }, - "original_id": tool_call.id + "original_id": tool_call.id, } for tool_call in message.tool_calls ] @@ -215,4 +218,5 @@ def convert_from_openai_message(message: ChatCompletionMessage) -> ChatMessage: role=MessageRole(message.role), content=message.content or "", tool_calls=tool_calls, + extra_args=extra_args, ) diff --git a/python/flink_agents/integrations/chat_models/tongyi_chat_model.py b/python/flink_agents/integrations/chat_models/tongyi_chat_model.py index 164eb8ad..7d37f3fc 100644 --- a/python/flink_agents/integrations/chat_models/tongyi_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tongyi_chat_model.py @@ -36,7 +36,8 @@ DEFAULT_MODEL = "qwen-plus" def to_dashscope_tool( - metadata: ToolMetadata, skip_length_check: bool = False # noqa:FBT001 + metadata: ToolMetadata, + skip_length_check: bool = False, # noqa:FBT001 ) -> Dict[str, Any]: """To DashScope tool.""" if not skip_length_check and len(metadata.description) > 1024: @@ -128,11 +129,13 @@ class TongyiChatModelConnection(BaseChatModelConnection): msg = f"DashScope call failed: {response.message}" raise RuntimeError(msg) + extra_args: Dict[str, Any] = {} + # Record token metrics if model name and usage are available if model_name and response.usage: - self._record_token_metrics( - model_name, response.usage.input_tokens, response.usage.output_tokens - ) + extra_args["model_name"] = model_name + extra_args["promptTokens"] = response.usage.input_tokens + extra_args["completionTokens"] = response.usage.output_tokens choice = response.output["choices"][0] response_message: Dict[str, Any] = choice["message"] @@ -156,7 +159,6 @@ class TongyiChatModelConnection(BaseChatModelConnection): tool_calls.append(tool_call_dict) content = response_message.get("content") or "" - extra_args: Dict[str, Any] = {} reasoning_content = response_message.get("reasoning_content") or "" if extract_reasoning and reasoning_content: diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index 7413fa21..d97acaf1 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -183,6 +183,8 @@ def chat( for attempt in range(num_retries + 1): try: response = chat_model.chat(messages) + if response.extra_args.get("model_name") and response.extra_args.get("promptTokens") and response.extra_args.get("completionTokens"): + chat_model._record_token_metrics(response.extra_args["model_name"], response.extra_args["promptTokens"], response.extra_args["completionTokens"]) if output_schema is not None and len(response.tool_calls) == 0: response = _generate_structured_output(response, output_schema) break
