This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 004cb46f1f44d84255f75c672f3c783220faa63c Author: WenjinXie <[email protected]> AuthorDate: Thu Jan 15 19:35:24 2026 +0800 [plan][python] Built-in actions support async execution. Co-authored-by: Shekharrajak <[email protected]> fix fix --- python/flink_agents/api/core_options.py | 22 ++++ .../built_in_action_async_execution_test.py | 133 +++++++++++++++++++++ .../flink_agents/plan/actions/chat_model_action.py | 39 ++++-- .../plan/actions/context_retrieval_action.py | 38 +++--- .../flink_agents/plan/actions/tool_call_action.py | 17 ++- 5 files changed, 221 insertions(+), 28 deletions(-) diff --git a/python/flink_agents/api/core_options.py b/python/flink_agents/api/core_options.py index 9d8f7eaf..42343017 100644 --- a/python/flink_agents/api/core_options.py +++ b/python/flink_agents/api/core_options.py @@ -101,3 +101,25 @@ class AgentConfigOptions(metaclass=AgentConfigOptionsMeta): config_type=int, default=3, ) + + +class AgentExecutionOptions(metaclass=AgentConfigOptionsMeta): + """Execution options for Flink Agents.""" + + CHAT_ASYNC = ConfigOption( + key="chat.async", + config_type=bool, + default=True, + ) + + TOOL_CALL_ASYNC = ConfigOption( + key="tool-call.async", + config_type=bool, + default=True, + ) + + RAG_ASYNC = ConfigOption( + key="rag.async", + config_type=bool, + default=True, + ) diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py new file mode 100644 index 00000000..5ea09478 --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py @@ -0,0 +1,133 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import time +import uuid +from typing import Any, Dict, Sequence + +from pyflink.datastream import StreamExecutionEnvironment +from typing_extensions import override + +from flink_agents.api.agents.agent import Agent +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.chat_models.chat_model import BaseChatModelSetup +from flink_agents.api.decorators import action, chat_model_setup, tool +from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent +from flink_agents.api.events.event import InputEvent, OutputEvent +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.api.resource import ResourceDescriptor +from flink_agents.api.runner_context import RunnerContext +from flink_agents.api.tools.tool import ToolType + + +class SlowMockChatModel(BaseChatModelSetup): + """Mock ChatModel with slow connection.""" + + @property + def model_kwargs(self) -> Dict[str, Any]: # noqa: D102 + return {} + + @override + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: + time.sleep(5) # Simulate network delay + if "sum" in messages[-1].content: + input = messages[-1].content + function = {"name": "add", "arguments": {"a": 1, "b": 2}} + tool_call = { + "id": uuid.uuid4(), + "type": ToolType.FUNCTION, + "function": function, + } + return ChatMessage( + role=MessageRole.ASSISTANT, content=input, tool_calls=[tool_call] + ) + else: + content = "\n".join([message.content for message in messages]) + return ChatMessage(role=MessageRole.ASSISTANT, content=content) + + +class AsyncTestAgent(Agent): + """Agent for testing async execution.""" + + @chat_model_setup + @staticmethod + def slow_chat_model() -> ResourceDescriptor: # noqa: D102 + return ResourceDescriptor( + clazz=f"{SlowMockChatModel.__module__}.{SlowMockChatModel.__name__}", + connection="placement", + tools=["add"], + ) + + @tool + @staticmethod + def add(a: int, b: int) -> int: + """Calculate the sum of a and b.""" + time.sleep(5) # Simulate slow tool execution + return a + b + + @action(InputEvent) + @staticmethod + def process_input(event: InputEvent, ctx: RunnerContext) -> None: # noqa: D102 + input = event.input + ctx.send_event( + ChatRequestEvent( + model="slow_chat_model", + messages=[ + ChatMessage( + role=MessageRole.USER, content=input, extra_args={"task": input} + ) + ], + ) + ) + + @action(ChatResponseEvent) + @staticmethod + def process_chat_response(event: ChatResponseEvent, ctx: RunnerContext) -> None: # noqa: D102 + input = event.response + ctx.send_event(OutputEvent(output=input.content)) + + +def test_built_in_actions_async_execution() -> None: + """Test that built-in actions use async execution correctly. + + This test verifies that chat_model_action and tool_call_action work + correctly with async execution, ensuring backward compatibility. + """ + env = StreamExecutionEnvironment.get_execution_environment() + env.set_parallelism(1) + + input_stream = env.from_collection( + ["calculate the sum of 1 and 2" for _ in range(10)], + ) + + agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) + output_datastream = ( + agents_env.from_datastream( + input=input_stream, key_selector=lambda x: uuid.uuid4() + ) + .apply(AsyncTestAgent()) + .to_datastream() + ) + + output_datastream.print() + + # Measure execution time to verify async doesn't block + start_time = time.time() + agents_env.execute() + execution_time = time.time() - start_time + + assert execution_time < 50 diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index d97acaf1..88806a2e 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -28,7 +28,12 @@ from pyflink.common.typeinfo import RowTypeInfo from flink_agents.api.agents.agent import STRUCTURED_OUTPUT from flink_agents.api.agents.react_agent import OutputSchema from flink_agents.api.chat_message import ChatMessage, MessageRole -from flink_agents.api.core_options import AgentConfigOptions, ErrorHandlingStrategy +from flink_agents.api.chat_models.java_chat_model import JavaChatModelSetup +from flink_agents.api.core_options import ( + AgentConfigOptions, + AgentExecutionOptions, + ErrorHandlingStrategy, +) from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent from flink_agents.api.events.event import Event from flink_agents.api.events.tool_event import ToolRequestEvent, ToolResponseEvent @@ -80,6 +85,7 @@ def _update_tool_call_context( sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context) return tool_call_context[initial_request_id] + def _save_tool_request_event_context( sensory_memory: MemoryObject, tool_request_event_id: UUID, @@ -156,7 +162,7 @@ def _generate_structured_output( return response -def chat( +async def chat( initial_request_id: UUID, model: str, messages: List[ChatMessage], @@ -173,16 +179,23 @@ def chat( "BaseChatModelSetup", ctx.get_resource(model, ResourceType.CHAT_MODEL) ) + chat_async = ctx.config.get(AgentExecutionOptions.CHAT_ASYNC) + # java chat model doesn't support async execution. + chat_async = chat_async and not isinstance(chat_model, JavaChatModelSetup) + error_handling_strategy = ctx.config.get(AgentConfigOptions.ERROR_HANDLING_STRATEGY) num_retries = 0 if error_handling_strategy == ErrorHandlingStrategy.RETRY: num_retries = max(0, ctx.config.get(AgentConfigOptions.MAX_RETRIES)) - # TODO: support async execution of chat. response = None for attempt in range(num_retries + 1): try: - response = chat_model.chat(messages) + if chat_async: + response = await ctx.durable_execute_async(chat_model.chat, messages) + else: + response = chat_model.chat(messages) + if response.extra_args.get("model_name") and response.extra_args.get("promptTokens") and response.extra_args.get("completionTokens"): chat_model._record_token_metrics(response.extra_args["model_name"], response.extra_args["promptTokens"], response.extra_args["completionTokens"]) if output_schema is not None and len(response.tool_calls) == 0: @@ -221,9 +234,9 @@ def chat( ) -def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> None: +async def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> None: """Process chat request event.""" - chat( + await chat( initial_request_id=event.id, model=event.model, messages=event.messages, @@ -232,7 +245,7 @@ def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> None: ) -def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None: +async def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None: """Organize the tool call context and return it to the LLM.""" sensory_memory = ctx.sensory_memory request_id = event.request_id @@ -260,7 +273,7 @@ def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None ], ) - chat( + await chat( initial_request_id=initial_request_id, model=tool_request_event_context["model"], messages=messages, @@ -269,17 +282,21 @@ def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None ) -def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) -> None: +async def process_chat_request_or_tool_response( + event: Event, ctx: RunnerContext +) -> None: """Built-in action for processing a chat request or tool response. This action listens to ChatRequestEvent and ToolResponseEvent, and handles the complete chat flow including tool calls. It uses sensory memory to save the tool call context, which is a dict mapping request id to chat messages. """ + # To avoid https://github.com/alibaba/pemja/issues/88, we log a message here. + logging.debug("Processing chat request asynchronously.") if isinstance(event, ChatRequestEvent): - _process_chat_request(event, ctx) + await _process_chat_request(event, ctx) elif isinstance(event, ToolResponseEvent): - _process_tool_response(event, ctx) + await _process_tool_response(event, ctx) CHAT_MODEL_ACTION = Action( diff --git a/python/flink_agents/plan/actions/context_retrieval_action.py b/python/flink_agents/plan/actions/context_retrieval_action.py index f0a87157..c862996e 100644 --- a/python/flink_agents/plan/actions/context_retrieval_action.py +++ b/python/flink_agents/plan/actions/context_retrieval_action.py @@ -15,6 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# +import logging + +from flink_agents.api.core_options import AgentExecutionOptions from flink_agents.api.events.context_retrieval_event import ( ContextRetrievalRequestEvent, ContextRetrievalResponseEvent, @@ -22,31 +25,36 @@ from flink_agents.api.events.context_retrieval_event import ( from flink_agents.api.events.event import Event from flink_agents.api.resource import ResourceType from flink_agents.api.runner_context import RunnerContext +from flink_agents.api.vector_stores.java_vector_store import JavaVectorStore from flink_agents.api.vector_stores.vector_store import VectorStoreQuery from flink_agents.plan.actions.action import Action from flink_agents.plan.function import PythonFunction +_logger = logging.getLogger(__name__) -def process_context_retrieval_request(event: Event, ctx: RunnerContext) -> None: +async def process_context_retrieval_request(event: Event, ctx: RunnerContext) -> None: """Built-in action for processing context retrieval requests.""" if isinstance(event, ContextRetrievalRequestEvent): - vector_store = ctx.get_resource( - event.vector_store, - ResourceType.VECTOR_STORE - ) + vector_store = ctx.get_resource(event.vector_store, ResourceType.VECTOR_STORE) - query = VectorStoreQuery( - query_text=event.query, - limit=event.max_results - ) + query = VectorStoreQuery(query_text=event.query, limit=event.max_results) - result = vector_store.query(query) + rag_async = ctx.config.get(AgentExecutionOptions.RAG_ASYNC) + # java vector store doesn't support async execution. + rag_async = rag_async and not isinstance(vector_store, JavaVectorStore) + if rag_async: + # To avoid https://github.com/alibaba/pemja/issues/88, + # we log a message here. + _logger.debug("Processing context retrieval asynchronously.") + result = await ctx.durable_execute_async(vector_store.query, query) + else: + result = vector_store.query(query) - ctx.send_event(ContextRetrievalResponseEvent( - request_id=event.id, - query=event.query, - documents=result.documents - )) + ctx.send_event( + ContextRetrievalResponseEvent( + request_id=event.id, query=event.query, documents=result.documents + ) + ) CONTEXT_RETRIEVAL_ACTION = Action( diff --git a/python/flink_agents/plan/actions/tool_call_action.py b/python/flink_agents/plan/actions/tool_call_action.py index 4d56e36a..fe69f3cd 100644 --- a/python/flink_agents/plan/actions/tool_call_action.py +++ b/python/flink_agents/plan/actions/tool_call_action.py @@ -15,15 +15,25 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# +import logging + +from flink_agents.api.core_options import AgentExecutionOptions from flink_agents.api.events.tool_event import ToolRequestEvent, ToolResponseEvent from flink_agents.api.resource import ResourceType from flink_agents.api.runner_context import RunnerContext from flink_agents.plan.actions.action import Action from flink_agents.plan.function import PythonFunction +_logger = logging.getLogger(__name__) -def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None: +async def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None: """Built-in action for processing tool call requests.""" + tool_call_async = ctx.config.get(AgentExecutionOptions.TOOL_CALL_ASYNC) + + if tool_call_async: + # To avoid https://github.com/alibaba/pemja/issues/88, we log a message here. + _logger.debug("Processing tool call asynchronously.") + responses = {} external_ids = {} for tool_call in event.tool_calls: @@ -35,7 +45,10 @@ def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None: if not tool: response = f"Tool `{name}` does not exist." else: - response = tool.call(**kwargs) + if tool_call_async: + response = await ctx.durable_execute_async(tool.call, **kwargs) + else: + response = tool.call(**kwargs) responses[id] = response external_ids[id] = external_id ctx.send_event(
