This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit bcf8f0b42bbc0b4bbc90b65dcb5497d3a8bf14df Author: WenjinXie <[email protected]> AuthorDate: Thu Aug 28 12:01:38 2025 +0800 [runtime][python] Refactor built-in actions to execute tool calls from one response one time. --- python/flink_agents/api/events/chat_event.py | 7 +- python/flink_agents/api/events/tool_event.py | 35 ++-- .../flink_agents/plan/actions/chat_model_action.py | 181 +++++++++++++++------ .../flink_agents/plan/actions/tool_call_action.py | 25 ++- 4 files changed, 171 insertions(+), 77 deletions(-) diff --git a/python/flink_agents/api/events/chat_event.py b/python/flink_agents/api/events/chat_event.py index b1f53db..2f4dfb1 100644 --- a/python/flink_agents/api/events/chat_event.py +++ b/python/flink_agents/api/events/chat_event.py @@ -16,6 +16,7 @@ # limitations under the License. ################################################################################# from typing import List +from uuid import UUID from flink_agents.api.chat_message import ChatMessage from flink_agents.api.events.event import Event @@ -41,11 +42,11 @@ class ChatResponseEvent(Event): Attributes: ---------- - request : ChatRequestEvent - The correspond request of the response. + request_id : UUID + The id of the request event. response : ChatMessage The response from the chat model. """ - request: ChatRequestEvent + request_id: UUID response: ChatMessage diff --git a/python/flink_agents/api/events/tool_event.py b/python/flink_agents/api/events/tool_event.py index 6bda919..3468ba6 100644 --- a/python/flink_agents/api/events/tool_event.py +++ b/python/flink_agents/api/events/tool_event.py @@ -15,7 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import Any, Optional +from typing import Any, Dict, List, Optional +from uuid import UUID from flink_agents.api.events.event import Event @@ -25,18 +26,14 @@ class ToolRequestEvent(Event): Attributes: ---------- - tool : str - The name of the tool to be called. - kwargs : dict - The arguments passed to the tool. - external_id : Optional[str] - Optional identifier for storing original tool call IDs from external systems - (e.g., Anthropic tool_use_id). + model: str + name of the model that generated the tool request. + tool_calls : List[Dict[str, Any]] + tool calls that should be executed in batch. """ - tool: str - kwargs: dict - external_id: Optional[str] = None + model: str + tool_calls: List[Dict[str, Any]] class ToolResponseEvent(Event): @@ -44,11 +41,15 @@ class ToolResponseEvent(Event): Attributes: ---------- - request : ToolRequestEvent - The correspond request of the response. - response : Any - The response from the tool. + request_id : UUID + The id of the request event. + responses : Dict[UUID, Any] + The dict maps tool call id to result. + external_ids : Dict[UUID, str] + Optional identifier for storing original tool call IDs from external systems + (e.g., Anthropic tool_use_id). """ - request: ToolRequestEvent - response: Any + request_id: UUID + responses: Dict[UUID, Any] + external_ids: Dict[UUID, Optional[str]] diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index f51a0fb..656453d 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -15,6 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# +import copy +from typing import TYPE_CHECKING, List, cast +from uuid import UUID from flink_agents.api.chat_message import ChatMessage, MessageRole from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent @@ -25,64 +28,138 @@ from flink_agents.api.runner_context import RunnerContext from flink_agents.plan.actions.action import Action from flink_agents.plan.function import PythonFunction +if TYPE_CHECKING: + from flink_agents.api.chat_models.chat_model import BaseChatModelSetup + +_TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT" +_TOOL_REQUEST_EVENT_CONTEXT = "_TOOL_REQUEST_EVENT_CONTEXT" + + +def chat( + initial_request_id: UUID, + model: str, + messages: List[ChatMessage], + ctx: RunnerContext, +) -> None: + """Chat with llm. + + If there is no tool call generated, we return the chat response event directly, + otherwise, we generate tool request event according to the tool calls in chat model + response, and save the request and response messages in tool call context. + """ + chat_model = cast( + "BaseChatModelSetup", ctx.get_resource(model, ResourceType.CHAT_MODEL) + ) + + # TODO: support async execution of chat. + response = chat_model.chat(messages) + short_term_memory = ctx.get_short_term_memory() + + # generate tool request event according tool calls in response + if len(response.tool_calls) > 0: + # TODO: Because memory doesn't support remove currently, so we use + # dict to store tool context in memory and remove the specific + # tool context from dict after consuming. This will cause write and + # read amplification for we need get the whole dict and overwrite it + # to memory each time we update a specific tool context. + # After memory supports remove, we can use "TOOL_CALL_CONTEXT/request_id" + # to store and remove the specific tool context directly. + + # save tool call context + tool_call_context = short_term_memory.get(_TOOL_CALL_CONTEXT) + if not tool_call_context: + tool_call_context = {} + if initial_request_id not in tool_call_context: + tool_call_context[initial_request_id] = copy.deepcopy(messages) + # append response to tool call context + tool_call_context[initial_request_id].append(response) + # update tool call context + short_term_memory.set(_TOOL_CALL_CONTEXT, tool_call_context) + + tool_request_event = ToolRequestEvent( + model=model, + tool_calls=response.tool_calls, + ) + + # save tool request event context + tool_request_event_context = tool_call_context.get(_TOOL_REQUEST_EVENT_CONTEXT) + if not tool_request_event_context: + tool_request_event_context = {} + tool_request_event_context[tool_request_event.id] = { + "initial_request_id": initial_request_id, + "model": model, + } + short_term_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, tool_request_event_context) + + ctx.send_event(tool_request_event) + # if there is no tool call generated, return chat response directly + else: + # clear tool call context related to specific request id + tool_call_context = short_term_memory.get(_TOOL_CALL_CONTEXT) + if tool_call_context and initial_request_id in tool_call_context: + tool_call_context.pop(initial_request_id) + short_term_memory.set(_TOOL_CALL_CONTEXT, tool_call_context) + ctx.send_event( + ChatResponseEvent( + request_id=initial_request_id, + response=response, + ) + ) + def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) -> None: - """Built-in action for processing a chat request or tool response.""" + """Built-in action for processing a chat request or tool response. + + Internally, this action will use short term memory to save the tool call context, + which is a dict mapping request id to chat messages. + """ + short_term_memory = ctx.get_short_term_memory() if isinstance(event, ChatRequestEvent): - chat_model = ctx.get_resource(event.model, ResourceType.CHAT_MODEL) - # TODO: support async execution of chat. - response = chat_model.chat(event.messages) - # call tool - if len(response.tool_calls) > 0: - for tool_call in response.tool_calls: - # store the tool call context in short term memory - state = ctx.get_short_term_memory() - # TODO: Because memory doesn't support remove currently, so we use - # dict to store tool context in memory and remove the specific - # tool context from dict after consuming. This will cause some - # overhead for we need get the whole dict and overwrite it to memory - # each time we update a specific tool context. - # After memory supports remove, we can use - # "__tool_context/tool_call_id" to store and remove the specific tool - # context directly. - if not state.is_exist("__tool_context"): - state.set("__tool_context", {}) - tool_context = state.get("__tool_context") - tool_call_id = tool_call["id"] - tool_context[tool_call_id] = event - tool_context[tool_call_id].messages.append(response) - state.set("__tool_context", tool_context) - ctx.send_event( - ToolRequestEvent( - id=tool_call_id, - tool=tool_call["function"]["name"], - kwargs=tool_call["function"]["arguments"], - external_id=tool_call.get("original_id"), - ) - ) + cast( + "BaseChatModelSetup", ctx.get_resource(event.model, ResourceType.CHAT_MODEL) + ) + + chat( + initial_request_id=event.id, + model=event.model, + messages=event.messages, + ctx=ctx, + ) - # send response - else: - ctx.send_event(ChatResponseEvent(request=event, response=response)) elif isinstance(event, ToolResponseEvent): - state = ctx.get_short_term_memory() - - if state.is_exist("__tool_context"): - tool_context = state.get("__tool_context") - tool_call_id = event.request.id - if tool_context is not None and tool_call_id in tool_context: - # get the specific tool call context from short term memory - specific_tool_ctx = tool_context.pop(tool_call_id) - specific_tool_ctx.messages.append( - ChatMessage( - role=MessageRole.TOOL, - content=str(event.response), - extra_args={"external_id": event.request.external_id} if event.request.external_id else {} - ) + request_id = event.request_id + + # get correspond tool request event context + tool_request_event_context = short_term_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) + initial_request_id = tool_request_event_context[request_id][ + "initial_request_id" + ] + model = tool_request_event_context[request_id]["model"] + # clear tool request event context + tool_request_event_context.pop(request_id) + short_term_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, tool_request_event_context) + + responses = event.responses + # update tool call context + tool_call_context = short_term_memory.get(_TOOL_CALL_CONTEXT) + for id, response in responses.items(): + tool_call_context[initial_request_id].append( + ChatMessage( + role=MessageRole.TOOL, + content=str(response), + extra_args={"external_id": event.external_ids[id]} + if event.external_ids[id] + else {}, ) - ctx.send_event(specific_tool_ctx) - # update short term memory to remove the specific tool call context - state.set("__tool_context", tool_context) + ) + short_term_memory.set(_TOOL_CALL_CONTEXT, tool_call_context) + + chat( + initial_request_id=initial_request_id, + model=model, + messages=tool_call_context[initial_request_id], + ctx=ctx, + ) CHAT_MODEL_ACTION = Action( diff --git a/python/flink_agents/plan/actions/tool_call_action.py b/python/flink_agents/plan/actions/tool_call_action.py index 70788ed..4d56e36 100644 --- a/python/flink_agents/plan/actions/tool_call_action.py +++ b/python/flink_agents/plan/actions/tool_call_action.py @@ -23,11 +23,26 @@ from flink_agents.plan.function import PythonFunction def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None: - """Built-in action for processing a tool call request.""" - tool = ctx.get_resource(event.tool, ResourceType.TOOL) - # TODO: support async execution of tool call. - response = tool.call(**event.kwargs) - ctx.send_event(ToolResponseEvent(request=event, response=response)) + """Built-in action for processing tool call requests.""" + responses = {} + external_ids = {} + for tool_call in event.tool_calls: + id = tool_call["id"] + name = tool_call["function"]["name"] + kwargs = tool_call["function"]["arguments"] + tool = ctx.get_resource(name, ResourceType.TOOL) + external_id = tool_call.get("original_id") + if not tool: + response = f"Tool `{name}` does not exist." + else: + response = tool.call(**kwargs) + responses[id] = response + external_ids[id] = external_id + ctx.send_event( + ToolResponseEvent( + request_id=event.id, responses=responses, external_ids=external_ids + ) + ) TOOL_CALL_ACTION = Action(
