This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 551994711be8ddc663aeeb50fc68ffff3ab209f7
Author: WenjinXie <[email protected]>
AuthorDate: Sat Jan 17 02:10:43 2026 +0800

    [runtime] Report token usage metric out of compaction for long-term memory.
---
 python/flink_agents/api/memory/long_term_memory.py |  4 ++
 python/flink_agents/api/runner_context.py          |  5 +-
 .../flink_agents/runtime/flink_runner_context.py   | 47 +++++++++------
 python/flink_agents/runtime/local_runner.py        |  2 +-
 .../runtime/memory/compaction_functions.py         | 16 +++--
 .../memory/internal_base_long_term_memory.py       | 36 ++++++++++++
 .../tests/test_vector_store_long_term_memory.py    |  4 ++
 .../memory/vector_store_long_term_memory.py        | 68 +++++++++++++++++++---
 .../runtime/python/utils/PythonActionExecutor.java |  8 +--
 9 files changed, 151 insertions(+), 39 deletions(-)

diff --git a/python/flink_agents/api/memory/long_term_memory.py 
b/python/flink_agents/api/memory/long_term_memory.py
index da673c81..bedc5741 100644
--- a/python/flink_agents/api/memory/long_term_memory.py
+++ b/python/flink_agents/api/memory/long_term_memory.py
@@ -340,3 +340,7 @@ class BaseLongTermMemory(ABC, BaseModel):
         Returns:
             Related memory items retrieved.
         """
+
+    @abstractmethod
+    def close(self) -> None:
+        """Logic executed when job close."""
diff --git a/python/flink_agents/api/runner_context.py 
b/python/flink_agents/api/runner_context.py
index 8b44efeb..4e3bc303 100644
--- a/python/flink_agents/api/runner_context.py
+++ b/python/flink_agents/api/runner_context.py
@@ -90,7 +90,7 @@ class RunnerContext(ABC):
         """
 
     @abstractmethod
-    def get_resource(self, name: str, type: ResourceType) -> Resource:
+    def get_resource(self, name: str, type: ResourceType, metric_group: 
MetricGroup = None) -> Resource:
         """Get resource from context.
 
         Parameters
@@ -99,6 +99,9 @@ class RunnerContext(ABC):
             The name of the resource.
         type : ResourceType
             The type of the resource.
+        metric_group: MetricGroup
+            The metric group used for reporting the metric. If not provided,
+            will use the action metric group.
         """
 
     @property
diff --git a/python/flink_agents/runtime/flink_runner_context.py 
b/python/flink_agents/runtime/flink_runner_context.py
index dffdaf73..994f095e 100644
--- a/python/flink_agents/runtime/flink_runner_context.py
+++ b/python/flink_agents/runtime/flink_runner_context.py
@@ -32,11 +32,15 @@ from flink_agents.api.memory.long_term_memory import (
     LongTermMemoryOptions,
 )
 from flink_agents.api.memory_object import MemoryType
+from flink_agents.api.metric_group import MetricGroup
 from flink_agents.api.resource import Resource, ResourceType
 from flink_agents.api.runner_context import AsyncExecutionResult, RunnerContext
 from flink_agents.plan.agent_plan import AgentPlan
 from flink_agents.runtime.flink_memory_object import FlinkMemoryObject
 from flink_agents.runtime.flink_metric_group import FlinkMetricGroup
+from flink_agents.runtime.memory.internal_base_long_term_memory import (
+    InternalBaseLongTermMemory,
+)
 from flink_agents.runtime.memory.vector_store_long_term_memory import (
     VectorStoreLongTermMemory,
 )
@@ -174,7 +178,7 @@ class FlinkRunnerContext(RunnerContext):
     """
 
     __agent_plan: AgentPlan | None
-    __ltm: BaseLongTermMemory = None
+    __ltm: InternalBaseLongTermMemory = None
 
     def __init__(
         self,
@@ -195,7 +199,7 @@ class FlinkRunnerContext(RunnerContext):
         self.__agent_plan.set_java_resource_adapter(j_resource_adapter)
         self.executor = executor
 
-    def set_long_term_memory(self, ltm: BaseLongTermMemory) -> None:
+    def set_long_term_memory(self, ltm: InternalBaseLongTermMemory) -> None:
         """Set long term memory instance to this context.
 
         Parameters
@@ -224,10 +228,10 @@ class FlinkRunnerContext(RunnerContext):
             raise RuntimeError(err_msg) from e
 
     @override
-    def get_resource(self, name: str, type: ResourceType) -> Resource:
+    def get_resource(self, name: str, type: ResourceType, metric_group: 
MetricGroup = None) -> Resource:
         resource = self.__agent_plan.get_resource(name, type)
-        # Bind current action's metric group to the resource
-        resource.set_metric_group(self.action_metric_group)
+        # Bind metric group to the resource
+        resource.set_metric_group(metric_group or self.action_metric_group)
         return resource
 
     @property
@@ -488,6 +492,9 @@ class FlinkRunnerContext(RunnerContext):
 
     @override
     def close(self) -> None:
+        if self.long_term_memory is not None:
+            self.long_term_memory.close()
+
         if self.__agent_plan is not None:
             try:
                 self.__agent_plan.close()
@@ -500,23 +507,13 @@ def create_flink_runner_context(
     agent_plan_json: str,
     executor: ThreadPoolExecutor,
     j_resource_adapter: Any,
+    job_identifier: str,
 ) -> FlinkRunnerContext:
     """Used to create a FlinkRunnerContext Python object in Pemja 
environment."""
-    return FlinkRunnerContext(
+    ctx = FlinkRunnerContext(
         j_runner_context, agent_plan_json, executor, j_resource_adapter
     )
 
-
-def flink_runner_context_switch_action_context(
-    ctx: FlinkRunnerContext,
-    job_identifier: str,
-    key: int,
-) -> None:
-    """Switch the context of the flink runner context.
-
-    The ctx is reused across keyed partitions, the context related to
-    specific key should be switched when process new action.
-    """
     backend = ctx.config.get(LongTermMemoryOptions.BACKEND)
     # use external vector store based long term memory
     if backend == LongTermMemoryBackend.EXTERNAL_VECTOR_STORE:
@@ -528,10 +525,24 @@ def flink_runner_context_switch_action_context(
                 ctx=ctx,
                 vector_store=vector_store_name,
                 job_id=job_identifier,
-                key=str(key),
             )
         )
 
+    return ctx
+
+
+def flink_runner_context_switch_action_context(
+    ctx: FlinkRunnerContext,
+    key: int,
+) -> None:
+    """Switch the context of the flink runner context.
+
+    The ctx is reused across keyed partitions, the context related to
+    specific key should be switched when process new action.
+    """
+    if ctx.long_term_memory is not None:
+        ctx.long_term_memory.switch_context(str(key))
+
 def close_flink_runner_context(
     ctx: FlinkRunnerContext,
 ) -> None:
diff --git a/python/flink_agents/runtime/local_runner.py 
b/python/flink_agents/runtime/local_runner.py
index 64ef4cc6..b8eaf3f0 100644
--- a/python/flink_agents/runtime/local_runner.py
+++ b/python/flink_agents/runtime/local_runner.py
@@ -119,7 +119,7 @@ class LocalRunnerContext(RunnerContext):
         self.events.append(event)
 
     @override
-    def get_resource(self, name: str, type: ResourceType) -> Resource:
+    def get_resource(self, name: str, type: ResourceType, metric_group: 
MetricGroup = None) -> Resource:
         return self.__agent_plan.get_resource(name, type)
 
     @property
diff --git a/python/flink_agents/runtime/memory/compaction_functions.py 
b/python/flink_agents/runtime/memory/compaction_functions.py
index 2bfeed12..669df4d5 100644
--- a/python/flink_agents/runtime/memory/compaction_functions.py
+++ b/python/flink_agents/runtime/memory/compaction_functions.py
@@ -17,7 +17,7 @@
 
#################################################################################
 import json
 import logging
-from typing import TYPE_CHECKING, List, Type, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Type, cast
 
 from flink_agents.api.chat_message import ChatMessage, MessageRole
 from flink_agents.api.memory.long_term_memory import (
@@ -26,6 +26,7 @@ from flink_agents.api.memory.long_term_memory import (
     MemorySetItem,
     SummarizationStrategy,
 )
+from flink_agents.api.metric_group import MetricGroup
 from flink_agents.api.prompts.prompt import Prompt
 from flink_agents.api.resource import ResourceType
 from flink_agents.api.runner_context import RunnerContext
@@ -62,8 +63,9 @@ def summarize(
     ltm: BaseLongTermMemory,
     memory_set: MemorySet,
     ctx: RunnerContext,
+    metric_group: MetricGroup,
     ids: List[str] | None = None,
-) -> None:
+) -> Dict[str, Any]:
     """Generate summarization of the items in the memory set.
 
     Will add the summarization to memory set, and delete original items 
involved
@@ -73,6 +75,7 @@ def summarize(
         ltm: The long term memory the memory set belongs to.
         memory_set: The memory set to be summarized.
         ctx: The runner context used to retrieve needed resources.
+        metric_group: Metric group used to report metrics.
         ids: The ids of items to be summarized. If not provided, all items 
will be
         involved in summarization. Optional
     """
@@ -84,7 +87,7 @@ def summarize(
     items: List[MemorySetItem] = ltm.get(memory_set=memory_set, ids=ids)
 
     response: ChatMessage = _generate_summarization(
-        items, memory_set.item_type, strategy, ctx
+        items, memory_set.item_type, strategy, ctx, metric_group
     )
 
     logging.debug(f"Items to be summarized: {items}\nSummarization: 
{response.content}")
@@ -131,6 +134,8 @@ def summarize(
             },
         )
 
+    return response.extra_args
+
 
 # TODO: Currently, we feed all items to the LLM at once, which may exceed the 
LLM's
 # context window. We need to support batched summary generation.
@@ -139,6 +144,7 @@ def _generate_summarization(
     item_type: Type,
     strategy: SummarizationStrategy,
     ctx: RunnerContext,
+    metric_group: MetricGroup
 ) -> ChatMessage:
     """Generate summarization of the items by llm."""
     # get arguments
@@ -157,7 +163,7 @@ def _generate_summarization(
     # generate summary
     model: BaseChatModelSetup = cast(
         "BaseChatModelSetup",
-        ctx.get_resource(name=model_name, type=ResourceType.CHAT_MODEL),
+        ctx.get_resource(name=model_name, type=ResourceType.CHAT_MODEL, 
metric_group=metric_group),
     )
     input_variable = {}
     for msg in msgs:
@@ -167,7 +173,7 @@ def _generate_summarization(
         if isinstance(prompt, str):
             prompt: Prompt = cast(
                 "Prompt",
-                ctx.get_resource(prompt, ResourceType.PROMPT),
+                ctx.get_resource(prompt, ResourceType.PROMPT, 
metric_group=metric_group),
             )
         prompt_messages = prompt.format_messages(
             role=MessageRole.USER, **input_variable
diff --git 
a/python/flink_agents/runtime/memory/internal_base_long_term_memory.py 
b/python/flink_agents/runtime/memory/internal_base_long_term_memory.py
new file mode 100644
index 00000000..96ff4558
--- /dev/null
+++ b/python/flink_agents/runtime/memory/internal_base_long_term_memory.py
@@ -0,0 +1,36 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+from abc import ABC, abstractmethod
+
+from flink_agents.api.memory.long_term_memory import BaseLongTermMemory
+
+
+class InternalBaseLongTermMemory(BaseLongTermMemory, ABC):
+    """Internal interface extends BaseLongTermMemory for hiding some interface
+    to user.
+    """
+
+    @abstractmethod
+    def switch_context(self, key: str) -> None:
+        """Switches the context for the memory operations. This allows
+        the same memory instance to be used for different key by isolating
+        data based on the provided key.
+
+        Args:
+            key: The context key.
+        """
diff --git 
a/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py
 
b/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py
index 9eb3c8b7..8465c2b5 100644
--- 
a/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py
+++ 
b/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py
@@ -33,6 +33,7 @@ from flink_agents.api.memory.long_term_memory import (
     MemorySet,
     SummarizationStrategy,
 )
+from flink_agents.api.metric_group import MetricGroup
 from flink_agents.api.resource import Resource, ResourceType
 from flink_agents.api.runner_context import RunnerContext
 from flink_agents.integrations.chat_models.ollama_chat_model import (
@@ -118,6 +119,9 @@ def long_term_memory() -> VectorStoreLongTermMemory:  # 
noqa: D103
 
     mock_runner_context = create_autospec(RunnerContext, instance=True)
     mock_runner_context.get_resource = get_resource
+    mock_runner_context.agent_metric_group.get_sub_group.return_value = 
create_autospec(
+        MetricGroup, instance=True
+    )
 
     return VectorStoreLongTermMemory(
         ctx=mock_runner_context,
diff --git 
a/python/flink_agents/runtime/memory/vector_store_long_term_memory.py 
b/python/flink_agents/runtime/memory/vector_store_long_term_memory.py
index a4cf203d..d43c9b50 100644
--- a/python/flink_agents/runtime/memory/vector_store_long_term_memory.py
+++ b/python/flink_agents/runtime/memory/vector_store_long_term_memory.py
@@ -16,6 +16,7 @@
 # limitations under the License.
 
#################################################################################
 import functools
+import queue
 import uuid
 from concurrent.futures import Future
 from datetime import datetime, timezone
@@ -26,7 +27,6 @@ from typing_extensions import override
 
 from flink_agents.api.chat_message import ChatMessage
 from flink_agents.api.memory.long_term_memory import (
-    BaseLongTermMemory,
     CompactionStrategy,
     CompactionStrategyType,
     DatetimeRange,
@@ -35,6 +35,7 @@ from flink_agents.api.memory.long_term_memory import (
     MemorySet,
     MemorySetItem,
 )
+from flink_agents.api.metric_group import MetricGroup
 from flink_agents.api.resource import ResourceType
 from flink_agents.api.runner_context import RunnerContext
 from flink_agents.api.vector_stores.vector_store import (
@@ -44,10 +45,13 @@ from flink_agents.api.vector_stores.vector_store import (
     _maybe_cast_to_list,
 )
 from flink_agents.runtime.memory.compaction_functions import summarize
+from flink_agents.runtime.memory.internal_base_long_term_memory import (
+    InternalBaseLongTermMemory,
+)
 
 
 # TODO: support async execution for operations and compaction
-class VectorStoreLongTermMemory(BaseLongTermMemory):
+class VectorStoreLongTermMemory(InternalBaseLongTermMemory):
     """Long-Term Memory based on ChromaDB."""
 
     model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -62,19 +66,27 @@ class VectorStoreLongTermMemory(BaseLongTermMemory):
 
     job_id: str = Field(description="Unique identifier for the job.")
 
-    key: str = Field(description="Unique identifier for the keyed partition.")
+    key: str = Field(
+        default=None, description="Unique identifier for the keyed partition."
+    )
 
     async_compaction: bool = Field(
         default=False, description="Whether to execute compact asynchronously."
     )
 
+    metric_group: MetricGroup = Field(
+        default=None, description="Metric group for reporting long-term memory 
metrics."
+    )
+    metric_records: queue.Queue = Field(
+        default=queue.Queue(), description="A thread safe queue for record 
metrics."
+    )
+
     def __init__(
         self,
         *,
         ctx: RunnerContext,
         vector_store: str,
         job_id: str,
-        key: str,
         **kwargs: Any,
     ) -> None:
         """Init method."""
@@ -82,11 +94,15 @@ class VectorStoreLongTermMemory(BaseLongTermMemory):
             ctx=ctx,
             vector_store=vector_store,
             job_id=job_id,
-            key=key,
             
async_compaction=ctx.config.get(LongTermMemoryOptions.ASYNC_COMPACTION),
+            
metric_group=ctx.agent_metric_group.get_sub_group("long-term-memory"),
             **kwargs,
         )
 
+    @override
+    def switch_context(self, key: str) -> None:
+        self.key = key
+
     @property
     def store(self) -> CollectionManageableVectorStore:
         """Get backend vector store.
@@ -186,15 +202,23 @@ class VectorStoreLongTermMemory(BaseLongTermMemory):
         if memory_set.size >= memory_set.capacity:
             # trigger compaction
             if self.async_compaction:
-                future = self.ctx.executor.submit(self._compact, 
memory_set=memory_set)
+                future = self.ctx.executor.submit(
+                    self._compact,
+                    memory_set=memory_set,
+                    metric_group=self.metric_group,
+                )
                 future.add_done_callback(
                     functools.partial(
                         self._handle_exception, self.job_id, self.key, 
memory_set
                     )
                 )
             else:
-                self._compact(memory_set=memory_set)
+                self._compact(
+                    memory_set=memory_set,
+                    metric_group=self.metric_group,
+                )
 
+        self._report_token_metrics()
         return ids
 
     @override
@@ -224,16 +248,42 @@ class VectorStoreLongTermMemory(BaseLongTermMemory):
 
         return self._convert_to_items(memory_set=memory_set, 
documents=result.documents)
 
+    @override
+    def close(self) -> None:
+        # report possible token usage metrics
+        self._report_token_metrics()
+
+    def _report_token_metrics(self) -> None:
+        """Report token usage metrics."""
+        if not self.metric_records.empty():
+            if self.metric_group is None:
+                return
+            while not self.metric_records.empty():
+                metric = self.metric_records.get()
+                if (
+                    metric.get("model_name")
+                    and metric.get("promptTokens")
+                    and metric.get("completionTokens")
+                ):
+                    model_group = 
self.metric_group.get_sub_group(metric["model_name"])
+                    
model_group.get_counter("promptTokens").inc(metric["promptTokens"])
+                    model_group.get_counter("completionTokens").inc(
+                        metric["completionTokens"]
+                    )
+
     def _name_mangling(self, name: str) -> str:
         """Mangle memory set name to actually name in vector store."""
         return f"{self.job_id}-{self.key}-{name}"
 
-    def _compact(self, memory_set: MemorySet) -> None:
+    def _compact(self, memory_set: MemorySet, metric_group: MetricGroup) -> 
Any | None:
         """Compact memory set to manage storge."""
         compaction_strategy: CompactionStrategy = 
memory_set.compaction_strategy
         if compaction_strategy.type == CompactionStrategyType.SUMMARIZATION:
             # currently, only support summarize all the items.
-            summarize(ltm=self, memory_set=memory_set, ctx=self.ctx)
+            extra_args = summarize(
+                ltm=self, memory_set=memory_set, ctx=self.ctx, 
metric_group=metric_group
+            )
+            self.metric_records.put(extra_args)
         else:
             msg = f"Unknown compaction strategy: {compaction_strategy.type}"
             raise RuntimeError(msg)
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
index 3f9ad702..55d38187 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
@@ -97,7 +97,8 @@ public class PythonActionExecutor {
                                 runnerContext,
                                 agentPlanJson,
                                 pythonAsyncThreadPool,
-                                javaResourceAdapter);
+                                javaResourceAdapter,
+                                jobIdentifier);
     }
 
     /**
@@ -116,10 +117,7 @@ public class PythonActionExecutor {
         function.setInterpreter(interpreter);
 
         interpreter.invoke(
-                FLINK_RUNNER_CONTEXT_SWITCH_ACTION_CONTEXT,
-                pythonRunnerContext,
-                jobIdentifier,
-                hashOfKey);
+                FLINK_RUNNER_CONTEXT_SWITCH_ACTION_CONTEXT, 
pythonRunnerContext, hashOfKey);
 
         Object pythonEventObject = 
interpreter.invoke(CONVERT_TO_PYTHON_OBJECT, event.getEvent());
 

Reply via email to