This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit dd65ab07254c9b35000d00dc3c84ae5705db873f Author: WenjinXie <[email protected]> AuthorDate: Fri Aug 29 17:12:26 2025 +0800 [api][plan][runtime] Support pass additional parameters to action. fix support primitive types --- .../java/org/apache/flink/agents/plan/Action.java | 16 +++++++- .../plan/serializer/ActionJsonDeserializer.java | 29 ++++++++++++++- .../plan/serializer/ActionJsonSerializer.java | 20 ++++++++++ python/flink_agents/api/agent.py | 10 +++-- python/flink_agents/api/runner_context.py | 26 +++++++++++++ python/flink_agents/plan/actions/action.py | 43 ++++++++++++++++++++-- python/flink_agents/plan/agent_plan.py | 37 ++++++++++++++++++- .../flink_agents/plan/tests/resources/action.json | 23 ++++++++++-- .../plan/tests/resources/agent_plan.json | 12 ++++-- python/flink_agents/plan/tests/test_action.py | 12 +++++- .../flink_agents/runtime/flink_runner_context.py | 14 +++++++ python/flink_agents/runtime/local_runner.py | 18 ++++++++- .../agents/runtime/context/RunnerContextImpl.java | 4 ++ 13 files changed, 243 insertions(+), 21 deletions(-) diff --git a/plan/src/main/java/org/apache/flink/agents/plan/Action.java b/plan/src/main/java/org/apache/flink/agents/plan/Action.java index 6f00b2e..63c7923 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/Action.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/Action.java @@ -26,6 +26,7 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.annotatio import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.annotation.JsonSerialize; import java.util.List; +import java.util.Map; /** * Representation of an agent action with event listening and function execution. @@ -40,13 +41,22 @@ public class Action { private final Function exec; private final List<String> listenEventTypes; - public Action(String name, Function exec, List<String> listenEventTypes) throws Exception { + private final Map<String, Object> config; + + public Action( + String name, Function exec, List<String> listenEventTypes, Map<String, Object> config) + throws Exception { this.name = name; this.exec = exec; this.listenEventTypes = listenEventTypes; + this.config = config; exec.checkSignature(new Class[] {Event.class, RunnerContext.class}); } + public Action(String name, Function exec, List<String> listenEventTypes) throws Exception { + this(name, exec, listenEventTypes, null); + } + public String getName() { return name; } @@ -58,4 +68,8 @@ public class Action { public List<String> getListenEventTypes() { return listenEventTypes; } + + public Map<String, Object> getConfig() { + return config; + } } diff --git a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ActionJsonDeserializer.java b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ActionJsonDeserializer.java index 28a8112..3a3ca19 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ActionJsonDeserializer.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ActionJsonDeserializer.java @@ -29,7 +29,9 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.deser.std import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** * Custom deserializer for {@link Action} that handles the deserialization of the function and event @@ -64,8 +66,15 @@ public class ActionJsonDeserializer extends StdDeserializer<Action> { node.get("listen_event_types") .forEach(eventTypeNode -> listenEventTypes.add(eventTypeNode.asText())); + // Deserialize params + JsonNode configNode = node.get("config"); + Map<String, Object> config = new HashMap<>(); + if (configNode != null && configNode.isObject()) { + config = (Map<String, Object>) parseJsonNode(configNode); + } + try { - return new Action(name, func, listenEventTypes); + return new Action(name, func, listenEventTypes, config); } catch (Exception e) { throw new RuntimeException( String.format("Failed to create Action with name \"%s\"", name), e); @@ -100,4 +109,22 @@ public class ActionJsonDeserializer extends StdDeserializer<Action> { e); } } + + private Object parseJsonNode(JsonNode node) { + if (node.isObject()) { + Map<String, Object> map = new HashMap<>(); + node.fields() + .forEachRemaining( + entry -> map.put(entry.getKey(), parseJsonNode(entry.getValue()))); + return map; + } else if (node.isArray()) { + List<Object> list = new ArrayList<>(); + node.forEach(element -> list.add(parseJsonNode(element))); + return list; + } else if (node.isValueNode()) { + return node.asText(); + } else { + throw new UnsupportedOperationException("Unsupported node type: " + node.getNodeType()); + } + } } diff --git a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ActionJsonSerializer.java b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ActionJsonSerializer.java index 0ae784d..3ef09af 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ActionJsonSerializer.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ActionJsonSerializer.java @@ -26,6 +26,7 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.Serialize import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ser.std.StdSerializer; import java.io.IOException; +import java.util.Map; /** * Custom serializer for {@link Action} that handles the serialization of the function and event @@ -68,6 +69,25 @@ public class ActionJsonSerializer extends StdSerializer<Action> { } jsonGenerator.writeEndArray(); + // Write config field + Map<String, Object> config = action.getConfig(); + if (config == null) { + jsonGenerator.writeObjectField("config", null); + } else { + jsonGenerator.writeFieldName("config"); + jsonGenerator.writeStartObject(); + action.getConfig() + .forEach( + (name, value) -> { + try { + jsonGenerator.writeObjectField(name, value); + } catch (IOException e) { + throw new RuntimeException("Error writing action: " + name, e); + } + }); + jsonGenerator.writeEndObject(); + } + jsonGenerator.writeEndObject(); } diff --git a/python/flink_agents/api/agent.py b/python/flink_agents/api/agent.py index 5a06542..dd54999 100644 --- a/python/flink_agents/api/agent.py +++ b/python/flink_agents/api/agent.py @@ -71,7 +71,7 @@ class Agent(ABC): connection="my_connection") """ - _actions: Dict[str, Tuple[List[Type[Event]], Callable]] + _actions: Dict[str, Tuple[List[Type[Event]], Callable, Dict[str, Any]]] _resources: Dict[ResourceType, Dict[str, Any]] def __init__(self) -> None: @@ -82,7 +82,7 @@ class Agent(ABC): self._resources[type] = {} @property - def actions(self) -> Dict[str, Tuple[List[Type[Event]], Callable]]: + def actions(self) -> Dict[str, Tuple[List[Type[Event]], Callable, Dict[str, Any]]]: """Get added actions.""" return self._actions @@ -92,7 +92,7 @@ class Agent(ABC): return self._resources def add_action( - self, name: str, events: List[Type[Event]], func: Callable + self, name: str, events: List[Type[Event]], func: Callable, **config: Any ) -> "Agent": """Add action to agent. @@ -104,6 +104,8 @@ class Agent(ABC): The type of events listened by this action. func: Callable The function to be executed when receive listened events. + **config: Any + Key named arguments can be used by this action in runtime. Returns: ------- @@ -113,7 +115,7 @@ class Agent(ABC): if name in self._actions: msg = f"Action {name} already defined" raise ValueError(msg) - self._actions[name] = (events, func) + self._actions[name] = (events, func, config if config else None) return self def add_prompt(self, name: str, prompt: Prompt) -> "Agent": diff --git a/python/flink_agents/api/runner_context.py b/python/flink_agents/api/runner_context.py index c822a49..9555ce1 100644 --- a/python/flink_agents/api/runner_context.py +++ b/python/flink_agents/api/runner_context.py @@ -26,6 +26,7 @@ from flink_agents.api.resource import Resource, ResourceType if TYPE_CHECKING: from flink_agents.api.memory_object import MemoryObject + class RunnerContext(ABC): """Abstract base class providing context for agent execution. @@ -54,6 +55,31 @@ class RunnerContext(ABC): The type of the resource. """ + @abstractmethod + def get_action_config(self) -> Dict[str, Any]: + """Get config of the action. + + Returns: + ------- + Dict[str, Any] + The configuration of the action executed. + """ + + @abstractmethod + def get_action_config_value(self, key: str) -> Any: + """Get config option value of the action. + + Parameters + ---------- + key: str + The key of the config option. + + Returns: + ------- + Any + The config option value. + """ + @abstractmethod def get_short_term_memory(self) -> "MemoryObject": """Get the short-term memory. diff --git a/python/flink_agents/plan/actions/action.py b/python/flink_agents/plan/actions/action.py index b58645a..aa2bec8 100644 --- a/python/flink_agents/plan/actions/action.py +++ b/python/flink_agents/plan/actions/action.py @@ -15,9 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import List, Union +import importlib +import inspect +from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, field_serializer, model_validator from flink_agents.api.events.event import Event from flink_agents.api.runner_context import RunnerContext @@ -44,14 +46,49 @@ class Action(BaseModel): # TODO: Raise a warning when the action has a return value, as it will be ignored. exec: Union[PythonFunction, JavaFunction] listen_event_types: List[str] + config: Optional[Dict[str, Any]] = None + + @field_serializer("config") + def __serialize_config(self, config: Dict[str, Any]) -> Union[Dict[str, Any], None]: + if config is None: + return config + data = {} + data["config_type"] = "python" + for name, value in config.items(): + if isinstance(value, BaseModel): + data[name] = ( + inspect.getmodule(value).__name__, + value.__class__.__name__, + value, + ) + else: + data[name] = value + return data + + @model_validator(mode="before") + def __custom_deserialize(self) -> "Action": + config = self["config"] + if config is not None and "config_type" in config: + self["config"].pop("config_type") + for name, value in config.items(): + try: + module = importlib.import_module(value[0]) + clazz = getattr(module, value[1]) + self["config"][name] = clazz.model_validate(value[2]) + except Exception: # noqa : PERF203 + self["config"][name] = value + return self def __init__( self, name: str, exec: Function, listen_event_types: List[str], + config: Optional[Dict[str, Any]] = None, ) -> None: """Action will check function signature when init.""" - super().__init__(name=name, exec=exec, listen_event_types=listen_event_types) + super().__init__( + name=name, exec=exec, listen_event_types=listen_event_types, config=config + ) # TODO: Update expected signature after import State and Context. self.exec.check_signature(Event, RunnerContext) diff --git a/python/flink_agents/plan/agent_plan.py b/python/flink_agents/plan/agent_plan.py index de411e4..0cb1a81 100644 --- a/python/flink_agents/plan/agent_plan.py +++ b/python/flink_agents/plan/agent_plan.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, field_serializer, model_validator @@ -162,6 +162,38 @@ class AgentPlan(BaseModel): """ return [self.actions[name] for name in self.actions_by_event[event_type]] + def get_action_config(self, action_name: str) -> Dict[str, Any]: + """Get config of the action. + + Parameters + ---------- + action_name : str + The name of the action. + + Returns: + ------- + Dict[str, Any] + The config of action. + """ + return self.actions[action_name].config + + def get_action_config_value(self, action_name: str, key: str) -> Any: + """Get config of the action. + + Parameters + ---------- + action_name : str + The name of the action. + key : str + The name of the option. + + Returns: + ------- + Dict[str, Any] + The option value of the action config. + """ + return self.actions[action_name].config.get(key, None) + def get_resource(self, name: str, type: ResourceType) -> Resource: """Get resource from agent plan. @@ -218,7 +250,7 @@ def _get_actions(agent: Agent) -> List[Action]: ], ) ) - for name, action in agent._actions.items(): + for name, action in agent.actions.items(): actions.append( Action( name=name, @@ -227,6 +259,7 @@ def _get_actions(agent: Agent) -> List[Action]: f"{event_type.__module__}.{event_type.__name__}" for event_type in action[0] ], + config=action[2], ) ) return actions diff --git a/python/flink_agents/plan/tests/resources/action.json b/python/flink_agents/plan/tests/resources/action.json index da2cf91..dde2daf 100644 --- a/python/flink_agents/plan/tests/resources/action.json +++ b/python/flink_agents/plan/tests/resources/action.json @@ -1,11 +1,28 @@ { "name": "legal", "exec": { + "func_type": "PythonFunction", "module": "flink_agents.plan.tests.test_action", - "qualname": "legal_signature", - "func_type": "PythonFunction" + "qualname": "legal_signature" }, "listen_event_types": [ "flink_agents.api.events.event.InputEvent" - ] + ], + "config": { + "config_type": "python", + "output_schema": [ + "flink_agents.api.agents.react_agent", + "OutputSchema", + { + "output_schema": { + "names": [ + "result" + ], + "types": [ + "Integer" + ] + } + } + ] + } } \ No newline at end of file diff --git a/python/flink_agents/plan/tests/resources/agent_plan.json b/python/flink_agents/plan/tests/resources/agent_plan.json index dd2a382..14ff3d5 100644 --- a/python/flink_agents/plan/tests/resources/agent_plan.json +++ b/python/flink_agents/plan/tests/resources/agent_plan.json @@ -9,7 +9,8 @@ }, "listen_event_types": [ "flink_agents.api.events.event.InputEvent" - ] + ], + "config": null }, "second_action": { "name": "second_action", @@ -21,7 +22,8 @@ "listen_event_types": [ "flink_agents.api.events.event.InputEvent", "flink_agents.plan.tests.test_agent_plan.MyEvent" - ] + ], + "config": null }, "chat_model_action": { "name": "chat_model_action", @@ -33,7 +35,8 @@ "listen_event_types": [ "flink_agents.api.events.chat_event.ChatRequestEvent", "flink_agents.api.events.tool_event.ToolResponseEvent" - ] + ], + "config": null }, "tool_call_action": { "name": "tool_call_action", @@ -44,7 +47,8 @@ }, "listen_event_types": [ "flink_agents.api.events.tool_event.ToolRequestEvent" - ] + ], + "config": null } }, "actions_by_event": { diff --git a/python/flink_agents/plan/tests/test_action.py b/python/flink_agents/plan/tests/test_action.py index 8853d4d..4410dd4 100644 --- a/python/flink_agents/plan/tests/test_action.py +++ b/python/flink_agents/plan/tests/test_action.py @@ -19,7 +19,9 @@ import json from pathlib import Path import pytest +from pyflink.common.typeinfo import BasicTypeInfo, RowTypeInfo +from flink_agents.api.agents.react_agent import OutputSchema from flink_agents.api.events.event import InputEvent from flink_agents.api.runner_context import RunnerContext from flink_agents.plan.actions.action import Action @@ -58,6 +60,14 @@ def action() -> Action: # noqa: D103 name="legal", exec=func, listen_event_types=[f"{InputEvent.__module__}.{InputEvent.__qualname__}"], + config={ + "output_schema": OutputSchema( + output_schema=RowTypeInfo( + [BasicTypeInfo.INT_TYPE_INFO()], + ["result"], + ) + ) + }, ) @@ -65,7 +75,7 @@ current_dir = Path(__file__).parent def test_action_serialize(action: Action) -> None: # noqa: D103 - json_value = action.model_dump_json(serialize_as_any=True) + json_value = action.model_dump_json(serialize_as_any=True, indent=4) with Path.open(Path(f"{current_dir}/resources/action.json")) as f: expected_json = f.read() actual = json.loads(json_value) diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index f15d9d9..e520949 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -73,6 +73,20 @@ class FlinkRunnerContext(RunnerContext): def get_resource(self, name: str, type: ResourceType) -> Resource: return self.__agent_plan.get_resource(name, type) + @override + def get_action_config(self) -> Dict[str, Any]: + """Get config of the action.""" + return self.__agent_plan.get_action_config( + self._j_runner_context.getActionName() + ) + + @override + def get_action_config_value(self, key: str) -> Any: + """Get config of the action.""" + return self.__agent_plan.get_action_config_value( + action_name=self._j_runner_context.getActionName(), key=key + ) + @override def get_short_term_memory(self) -> FlinkMemoryObject: """Get the short-term memory object associated with this context. diff --git a/python/flink_agents/runtime/local_runner.py b/python/flink_agents/runtime/local_runner.py index a9ada75..6f36221 100644 --- a/python/flink_agents/runtime/local_runner.py +++ b/python/flink_agents/runtime/local_runner.py @@ -50,13 +50,14 @@ class LocalRunnerContext(RunnerContext): Unique identifier for the context, correspond to the key in flink KeyedStream. events : deque[Event] Queue of events to be processed in this context. - outputs : deque[Any] - Queue of outputs generated by agent execution. + action_name: str + Name of the action being executed. """ __agent_plan: AgentPlan __key: Any events: deque[Event] + action_name: str _store: dict[str, Any] _short_term_memory: MemoryObject _config: AgentConfiguration @@ -108,6 +109,18 @@ class LocalRunnerContext(RunnerContext): def get_resource(self, name: str, type: ResourceType) -> Resource: return self.__agent_plan.get_resource(name, type) + @override + def get_action_config(self) -> Dict[str, Any]: + """Get config of the action.""" + return self.__agent_plan.get_action_config(action_name=self.action_name) + + @override + def get_action_config_value(self, key: str) -> Any: + """Get config option value of the key.""" + return self.__agent_plan.get_action_config_value( + action_name=self.action_name, key=key + ) + @override def get_short_term_memory(self) -> MemoryObject: """Get the short-term memory object associated with this context. @@ -229,6 +242,7 @@ class LocalRunner(AgentRunner): event_type = f"{event.__class__.__module__}.{event.__class__.__name__}" for action in self.__agent_plan.get_actions(event_type): logger.info("key: %s, performing action: %s", key, action.name) + context.action_name = action.name func_result = action.exec(event, context) if isinstance(func_result, Generator): try: diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index a6c4e7a..368d4ab 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -115,4 +115,8 @@ public class RunnerContextImpl implements RunnerContext { public ReadableConfiguration getConfig() { return agentPlan.getConfig(); } + + public String getActionName() { + return actionName; + } }
