xintongsong commented on code in PR #126:
URL: https://github.com/apache/flink-agents/pull/126#discussion_r2306516114
##########
python/flink_agents/plan/actions/chat_model_action.py:
##########
@@ -15,69 +15,125 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
+import copy
+from typing import List, Union, cast
+from uuid import UUID
from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.chat_models.chat_model import BaseChatModelSetup
from flink_agents.api.events.chat_event import ChatRequestEvent,
ChatResponseEvent
from flink_agents.api.events.event import Event
from flink_agents.api.events.tool_event import ToolRequestEvent,
ToolResponseEvent
+from flink_agents.api.memory_object import MemoryObject
from flink_agents.api.resource import ResourceType
from flink_agents.api.runner_context import RunnerContext
from flink_agents.plan.actions.action import Action
from flink_agents.plan.function import PythonFunction
+TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT"
+
+
+def chat(
+ request_id: UUID,
+ model: str,
+ chat_model: BaseChatModelSetup,
+ messages: List[ChatMessage],
+ short_term_memory: MemoryObject,
+) -> Union[ChatResponseEvent, ToolRequestEvent]:
+ """Chat with llm.
+
+ If there is no tool call generated, we return the chat response event
directly,
+ otherwise, we generate tool request event according to the tool calls in
chat model
+ response, and save the request and response messages in tool call context.
+ """
+ # TODO: support async execution of chat.
+ response = chat_model.chat(messages)
+
+ # generate tool request event according tool calls in response
+ if len(response.tool_calls) > 0:
+ # TODO: Because memory doesn't support remove currently, so we use
+ # dict to store tool context in memory and remove the specific
+ # tool context from dict after consuming. This will cause write and
+ # read amplification for we need get the whole dict and overwrite it
+ # to memory each time we update a specific tool context.
+ # After memory supports remove, we can use
"TOOL_CALL_CONTEXT/request_id"
+ # to store and remove the specific tool context directly.
+
+ # get tool call context
+ tool_call_context = short_term_memory.get(TOOL_CALL_CONTEXT)
+ if not tool_call_context:
+ tool_call_context = {}
+ if request_id not in tool_call_context:
+ tool_call_context[request_id] = copy.deepcopy(messages)
+ # append response to tool call context
+ tool_call_context[request_id].append(response)
+ # update tool call context
+ short_term_memory.set(TOOL_CALL_CONTEXT, tool_call_context)
+ return ToolRequestEvent(
+ id=request_id,
+ model=model,
+ tool_calls=response.tool_calls,
+ )
+ # if there is no tool call generated, return chat response directly
+ else:
+ # clear tool call context related to specific request id
+ tool_call_context = short_term_memory.get(TOOL_CALL_CONTEXT)
+ if tool_call_context and request_id in tool_call_context:
+ tool_call_context.pop(request_id)
+ short_term_memory.set(TOOL_CALL_CONTEXT, tool_call_context)
+ return ChatResponseEvent(
+ request=ChatRequestEvent(id=request_id, model=model,
messages=messages),
Review Comment:
We should try to avoid re-construct the `ChatRequestEvent`.
##########
python/flink_agents/runtime/remote_execution_environment.py:
##########
@@ -65,6 +71,14 @@ def apply(self, agent: Agent) -> "AgentBuilder":
if self.__agent_plan is not None:
err_msg = "RemoteAgentBuilder doesn't support apply multiple
agents yet."
raise RuntimeError(err_msg)
+
+ # inspect refer actions and resources from env to agent.
+ for type, names in agent._resource_names.items():
Review Comment:
What is this `_resource_names`?
##########
python/flink_agents/examples/react_agent_example.py:
##########
@@ -0,0 +1,149 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+import argparse
+import os
+
+from pydantic import BaseModel
+from pyflink.common.typeinfo import BasicTypeInfo, RowTypeInfo
+
+from flink_agents.api.agents.react_agent import ReActAgent
+from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.execution_environment import AgentsExecutionEnvironment
+from flink_agents.api.prompts.prompt import Prompt
+from flink_agents.integrations.chat_models.ollama_chat_model import (
+ OllamaChatModelConnection,
+ OllamaChatModelSetup,
+)
+
+model = os.environ.get("OLLAMA_CHAT_MODEL", "qwen2.5:7b")
+
+
+class InputData(BaseModel): # noqa: D101
+ a: int
+ b: int
+ c: int
+
+
+class OutputData(BaseModel): # noqa: D101
+ result: int
+
+
+def get_output_schema() -> type[OutputData]:
+ """Get the output schema."""
+ return OutputData
+
+
+def get_output_row_type() -> RowTypeInfo:
+ """Get the output row type."""
+ return RowTypeInfo(
+ [BasicTypeInfo.INT_TYPE_INFO()],
+ ["result"],
+ )
+
+
+def add(a: int, b: int) -> int:
+ """Calculate the sum of a and b.
+
+ Parameters
+ ----------
+ a : int
+ The first operand
+ b : int
+ The second operand
+
+ Returns:
+ -------
+ int:
+ The sum of a and b
+ """
+ return a + b
+
+
+def multiply(a: int, b: int) -> int:
+ """Useful function to multiply two numbers.
+
+ Parameters
+ ----------
+ a : int
+ The first operand
+ b : int
+ The second operand
+
+ Returns:
+ -------
+ int:
+ The product of a and b
+ """
+ return a * b
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--output_row",
+ type=bool,
+ default=False,
+ help="Whether the agent should output row.",
+ )
+ args = parser.parse_args()
+
+ env = AgentsExecutionEnvironment.get_execution_environment()
+
+ # register resource to execution environment
+ (
+ env.add_chat_model_connection(
+ name="ollama", connection=OllamaChatModelConnection, model=model
+ )
+ .add_tool("add", add)
+ .add_tool("multiply", multiply)
+ )
+
+ # prepare prompt
+ prompt = Prompt.from_messages(
+ name="prompt",
+ messages=[
+ ChatMessage(
+ role=MessageRole.SYSTEM,
+ content='An example of output is {"result": 30.32}.',
+ ),
+ ChatMessage(role=MessageRole.USER, content="What is ({a} + {b}) *
{c}"),
+ ],
+ )
+
+ # create ReAct agent.
+ agent = ReActAgent(
+ chat_model_setup=OllamaChatModelSetup,
+ connection="ollama",
+ prompt=prompt,
+ tools=["add", "multiply"],
+ output_schema_provider=get_output_row_type
Review Comment:
Why do we need a provider here, rather than just provide the schema?
##########
python/flink_agents/plan/actions/chat_model_action.py:
##########
@@ -15,69 +15,125 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
+import copy
+from typing import List, Union, cast
+from uuid import UUID
from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.chat_models.chat_model import BaseChatModelSetup
from flink_agents.api.events.chat_event import ChatRequestEvent,
ChatResponseEvent
from flink_agents.api.events.event import Event
from flink_agents.api.events.tool_event import ToolRequestEvent,
ToolResponseEvent
+from flink_agents.api.memory_object import MemoryObject
from flink_agents.api.resource import ResourceType
from flink_agents.api.runner_context import RunnerContext
from flink_agents.plan.actions.action import Action
from flink_agents.plan.function import PythonFunction
+TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT"
+
+
+def chat(
+ request_id: UUID,
+ model: str,
+ chat_model: BaseChatModelSetup,
+ messages: List[ChatMessage],
+ short_term_memory: MemoryObject,
+) -> Union[ChatResponseEvent, ToolRequestEvent]:
+ """Chat with llm.
+
+ If there is no tool call generated, we return the chat response event
directly,
+ otherwise, we generate tool request event according to the tool calls in
chat model
+ response, and save the request and response messages in tool call context.
+ """
+ # TODO: support async execution of chat.
+ response = chat_model.chat(messages)
+
+ # generate tool request event according tool calls in response
+ if len(response.tool_calls) > 0:
+ # TODO: Because memory doesn't support remove currently, so we use
+ # dict to store tool context in memory and remove the specific
+ # tool context from dict after consuming. This will cause write and
+ # read amplification for we need get the whole dict and overwrite it
+ # to memory each time we update a specific tool context.
+ # After memory supports remove, we can use
"TOOL_CALL_CONTEXT/request_id"
+ # to store and remove the specific tool context directly.
+
+ # get tool call context
+ tool_call_context = short_term_memory.get(TOOL_CALL_CONTEXT)
+ if not tool_call_context:
+ tool_call_context = {}
+ if request_id not in tool_call_context:
+ tool_call_context[request_id] = copy.deepcopy(messages)
+ # append response to tool call context
+ tool_call_context[request_id].append(response)
+ # update tool call context
+ short_term_memory.set(TOOL_CALL_CONTEXT, tool_call_context)
+ return ToolRequestEvent(
+ id=request_id,
+ model=model,
+ tool_calls=response.tool_calls,
+ )
+ # if there is no tool call generated, return chat response directly
+ else:
+ # clear tool call context related to specific request id
+ tool_call_context = short_term_memory.get(TOOL_CALL_CONTEXT)
+ if tool_call_context and request_id in tool_call_context:
+ tool_call_context.pop(request_id)
+ short_term_memory.set(TOOL_CALL_CONTEXT, tool_call_context)
+ return ChatResponseEvent(
+ request=ChatRequestEvent(id=request_id, model=model,
messages=messages),
+ response=response,
+ )
+
def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) ->
None:
- """Built-in action for processing a chat request or tool response."""
+ """Built-in action for processing a chat request or tool response.
+
+ Internally, this action will use short term memory to save the tool call
context,
+ which is a dict mapping request id to chat messages.
+ """
+ short_term_memory = ctx.get_short_term_memory()
if isinstance(event, ChatRequestEvent):
- chat_model = ctx.get_resource(event.model, ResourceType.CHAT_MODEL)
- # TODO: support async execution of chat.
- response = chat_model.chat(event.messages)
- # call tool
- if len(response.tool_calls) > 0:
- for tool_call in response.tool_calls:
- # store the tool call context in short term memory
- state = ctx.get_short_term_memory()
- # TODO: Because memory doesn't support remove currently, so we
use
- # dict to store tool context in memory and remove the specific
- # tool context from dict after consuming. This will cause some
- # overhead for we need get the whole dict and overwrite it to
memory
- # each time we update a specific tool context.
- # After memory supports remove, we can use
- # "__tool_context/tool_call_id" to store and remove the
specific tool
- # context directly.
- if not state.is_exist("__tool_context"):
- state.set("__tool_context", {})
- tool_context = state.get("__tool_context")
- tool_call_id = tool_call["id"]
- tool_context[tool_call_id] = event
- tool_context[tool_call_id].messages.append(response)
- state.set("__tool_context", tool_context)
- ctx.send_event(
- ToolRequestEvent(
- id=tool_call_id,
- tool=tool_call["function"]["name"],
- kwargs=tool_call["function"]["arguments"],
- )
- )
-
- # send response
- else:
- ctx.send_event(ChatResponseEvent(request=event, response=response))
+ chat_model = cast(
+ "BaseChatModelSetup", ctx.get_resource(event.model,
ResourceType.CHAT_MODEL)
+ )
+
+ event = chat(
+ request_id=event.id,
+ model=event.model,
+ chat_model=chat_model,
+ messages=event.messages,
+ short_term_memory=short_term_memory,
+ )
+
+ ctx.send_event(event)
Review Comment:
Why sending the event here outside `chat()`?
##########
python/flink_agents/api/agents/react_agent.py:
##########
@@ -0,0 +1,251 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+import json
+from collections.abc import Callable
+from typing import Any, List, Optional, cast
+
+from pydantic import BaseModel
+from pyflink.common import Row
+from pyflink.common.typeinfo import RowTypeInfo
+
+from flink_agents.api.agent import Agent
+from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.chat_models.chat_model import BaseChatModelSetup
+from flink_agents.api.decorators import action
+from flink_agents.api.events.chat_event import ChatRequestEvent,
ChatResponseEvent
+from flink_agents.api.events.event import InputEvent, OutputEvent
+from flink_agents.api.prompts.prompt import Prompt
+from flink_agents.api.resource import ResourceType
+from flink_agents.api.runner_context import RunnerContext
+
+DEFAULT_CHAT_MODEL = "_default_chat_model"
+DEFAULT_SCHEMA_PROMPT = "_default_schema_prompt"
+DEFAULT_USER_PROMPT = "_default_user_prompt"
+OUTPUT_PARSER = "_output_parser"
+OUTPUT_SCHEMA = "_output_schema"
+
+
+class ReActAgent(Agent):
+ """Built-in implementation of ReAct agent which is based on the function
+ call ability of llm.
+
+ This implementation is not based on the foundational ReAct paper which uses
+ prompt to force llm output contain <Thought>, <Action> and <Observation>
and
+ extract tool calls by text parsing. For a more robust and feature-rich
+ implementation we use the tool/function call ability of current llm, and
get
+ the tool calls from response directly.
+
+
+ Example:
+ ::
+
+ class OutputData(BaseModel):
+ result: int
+
+ def get_output_schema() -> type[OutputData]:
+ \"\"\"Get the output schema.\"\"\"
+ return OutputData
+
+ env = AgentsExecutionEnvironment.get_execution_environment()
+
+ # register resource to execution environment
+ (
+ env.add_chat_model_connection(
+ name="ollama", connection=OllamaChatModelConnection,
model=model
+ )
+ .add_tool("add", add)
+ .add_tool("multiply", multiply)
+ )
+
+ # prepare prompt
+ prompt = Prompt.from_messages(
+ name="prompt",
+ messages=[
+ ChatMessage(
+ role=MessageRole.SYSTEM,
+ content='An example of output is {"result": 30.32}.',
+ ),
+ ChatMessage(role=MessageRole.USER,
+ content="What is ({a} + {b}) * {c}"),
+ ]
+ )
+
+ # create ReAct agent.
+ agent = ReActAgent(
+ chat_model=OllamaChatModelSetup,
+ connection="ollama",
+ prompt=prompt,
+ tools=["add", "multiply"],
+ output_schema_provider=get_output_schema
+ )
+ """
+
+ def __init__(
+ self,
+ *,
+ chat_model_setup: type[BaseChatModelSetup],
+ connection: str,
+ prompt: Optional[Prompt] = None,
+ tools: Optional[List[str]] = None,
+ output_schema_provider: Optional[Callable] = None,
+ output_parser: Optional[Callable] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Init method of ReActAgent.
+
+ Parameters
+ ----------
+ chat_model_setup : BaseChatModelSetup
+ The type of the chat model setup used in this ReAct agent.
+ connection: str
+ The name of the chat model connection used in chat model setup. The
+ connection should be registered in environment.
+ prompt : Optional[Prompt] = None
+ Prompt to instruct the llm, could include input and output example,
+ task and so on.
+ tools : Optional[List[str]]
+ Tools names can be used in this ReAct agent. The tools should be
registered
+ in environment.
+ output_schema_provider : Optional[Callable]
+ A function to provide output schema, the schema should be
RowTypeInfo or
+ subclass of BaseModel. When user provide output schema, ReAct
agent will
+ add system prompt to instruct response format of llm, and add
output parser
+ according to the schema.
+ output_parser: Optional[Callable]
+ A function to parse response from llm. When user provide output
parser,
+ ReAct agent will use it rather than construct according to output
schema.
+ **kwargs: Any
+ The initialize arguments of chat_model_setup.
+ """
+ super().__init__()
+ settings = {
+ "name": DEFAULT_CHAT_MODEL,
+ "connection": connection,
+ "tools": tools,
+ }
+ settings.update(kwargs)
+ self._resources[ResourceType.CHAT_MODEL][DEFAULT_CHAT_MODEL] = (
+ chat_model_setup,
+ settings,
+ )
+
+ if output_schema_provider:
+ output_schema = output_schema_provider()
+ if issubclass(output_schema, BaseModel):
+ json_schema = output_schema.model_json_schema()
+ elif isinstance(output_schema, RowTypeInfo):
+ json_schema = str(output_schema)
+ else:
+ err_msg = f"Output schema {output_schema.__class__} is not
supported."
+ raise TypeError(err_msg)
+ schema_prompt = f"The final response should be json format, and
match the schema {json_schema}."
+ self._resources[ResourceType.PROMPT][DEFAULT_SCHEMA_PROMPT] = (
+ Prompt.from_text(name="output_schema", text=schema_prompt)
+ )
+
+ self._resources[ResourceType.TOOL][OUTPUT_SCHEMA] =
output_schema_provider
+ if prompt:
+ self._resources[ResourceType.PROMPT][DEFAULT_USER_PROMPT] = prompt
+ if output_parser:
+ self._resources[ResourceType.TOOL][OUTPUT_PARSER] = output_parser
Review Comment:
Why are `output_schema_provider` and `output_parser` tools?
##########
python/flink_agents/runtime/local_execution_environment.py:
##########
@@ -54,6 +54,10 @@ def apply(self, agent: Agent) -> AgentBuilder:
if self.__runner is not None:
err_msg = "LocalAgentBuilder doesn't support apply multiple
agents."
raise RuntimeError(err_msg)
+ # inspect resources from environment to agent instance.
+ registered_resources = self.__env.resources
+ for type, name_to_resource in registered_resources.items():
+ agent.resources[type] = name_to_resource | agent.resources[type]
Review Comment:
Should not do this in `apply()`. What if the resources are registered
afterwards?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]