weiqingy commented on code in PR #854:
URL: https://github.com/apache/flink-agents/pull/854#discussion_r3464165131
##########
python/flink_agents/plan/actions/tool_call_action.py:
##########
@@ -39,27 +46,80 @@ async def process_tool_request(event: Event, ctx:
RunnerContext) -> None:
responses = {}
external_ids = {}
for tool_call in event.tool_calls:
- id = tool_call["id"]
+ call_id = tool_call["id"]
name = tool_call["function"]["name"]
kwargs = tool_call["function"]["arguments"]
- tool = ctx.get_resource(name, ResourceType.TOOL)
external_id = tool_call.get("original_id")
+
+ tool = ctx.get_resource(name, ResourceType.TOOL)
if not tool:
- response = f"Tool `{name}` does not exist."
+ responses[call_id] = f"Tool `{name}` does not exist."
+ external_ids[call_id] = external_id
+ continue
else:
- if tool_call_async:
- response = await ctx.durable_execute_async(tool.call, **kwargs)
- else:
- response = ctx.durable_execute(tool.call, **kwargs)
- responses[id] = response
- external_ids[id] = external_id
+ try:
+ call_kwargs = dict(kwargs or {})
+ # Framework-owned injected args must win over model-provided
values so
+ # hidden context such as tenant ids cannot be spoofed by tool
calls.
+ call_kwargs.update(_resolve_injected_arguments(tool, ctx))
+ if tool_call_async:
+ response = await ctx.durable_execute_async(
+ tool.call, **call_kwargs
+ )
+ else:
+ response = ctx.durable_execute(tool.call, **call_kwargs)
+ responses[call_id] = response
+ except Exception as e:
+ responses[call_id] = str(e)
Review Comment:
This new per-call `except Exception as e: responses[call_id] = str(e)`
routes an injection failure (missing config, missing memory path) into the
response channel. But Python `ToolResponseEvent` carries only `responses` —
there's no success/error channel like Java's — and the consumer at
`chat_model_action.py:439` does `content=str(response)` for every entry. So a
failed injection gets fed back to the LLM as an ordinary string,
indistinguishable from a real tool result. The Java side emits
`ToolResponse.error(...)` and populates the error map, so there the failure
stays structurally distinguishable.
To be clear, the *absence* of a Python error channel is pre-existing —
what's new here is routing injection failures through it. Is feeding the error
text back to the model the intended contract, or should an injection
misconfiguration fail loudly / be distinguishable from a successful result?
CLAUDE.md asks for Java/Python parity, and this is where the two sides diverge
most.
##########
python/flink_agents/api/tools/tool_parameter_injection.py:
##########
@@ -0,0 +1,89 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+from enum import Enum
+
+from pydantic import BaseModel
+
+
+class ToolParameterSource(str, Enum):
+ """Source for a framework-injected tool parameter."""
+
+ CONFIG = "config"
+ SENSORY_MEMORY = "sensory_memory"
+ SHORT_TERM_MEMORY = "short_term_memory"
+
+ @classmethod
+ def _missing_(cls, value: object) -> "ToolParameterSource | None":
+ if not isinstance(value, str):
+ return None
+ normalized = value.lower()
+ if normalized == "action_config":
+ return cls.CONFIG
+ for source in cls:
+ if source.value == normalized:
+ return source
+ return None
+
+
+class InjectedArg(BaseModel):
+ """Declarative source binding for a framework-injected tool parameter."""
+
+ source: ToolParameterSource = ToolParameterSource.SENSORY_MEMORY
+ key: str | None = None
+
+ @staticmethod
+ def from_config(key: str) -> "InjectedArg":
+ """Create an injected argument read from global agent config."""
+ return InjectedArg(source=ToolParameterSource.CONFIG, key=key)
+
+ @staticmethod
+ def from_sensory_memory(path: str) -> "InjectedArg":
+ """Create an injected argument read from sensory memory."""
+ return InjectedArg(source=ToolParameterSource.SENSORY_MEMORY, key=path)
+
+ @staticmethod
+ def from_short_term_memory(path: str) -> "InjectedArg":
+ """Create an injected argument read from short-term memory."""
+ return InjectedArg(source=ToolParameterSource.SHORT_TERM_MEMORY,
key=path)
+
+ def with_default_key(self, key: str) -> "InjectedArg":
+ """Use the tool parameter name when no explicit key is configured."""
+ if self.key:
+ return self
+ return InjectedArg(source=self.source, key=key)
+
+
+def normalize_injected_args(
Review Comment:
The reflection path (`@ToolParam(injected=true)`) is safe by construction,
but the declarative paths aren't: the YAML `injected_args` keys and the
`@tool(injected_args={...})` decorator keys accept arbitrary names with no
check that the key matches an actual function parameter. A typo like
`tenent_id` gets merged into the call kwargs and fails late — a `TypeError:
unexpected keyword argument` on Python, which then gets swallowed per the
response-channel point above. Was up-front validation at tool construction
considered? A cheap key-against-signature check would turn an opaque runtime
tool-failure into a clear config error.
##########
python/flink_agents/plan/tests/resources/agent_plan.json:
##########
@@ -102,30 +102,30 @@
"__resource_provider_type__": "PythonResourceProvider"
}
},
- "embedding_model": {
- "mock_embedding": {
- "name": "mock_embedding",
- "type": "embedding_model",
+ "embedding_model_connection": {
Review Comment:
This hunk swaps the positions of the `embedding_model` and
`embedding_model_connection` blocks; their contents are byte-identical before
and after, and no `injected_args` is added here — it looks like unrelated
regeneration churn. Worth reverting this file to keep the diff focused on
injection?
##########
api/src/main/java/org/apache/flink/agents/api/tools/ToolParameterSource.java:
##########
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.api.tools;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonValue;
+
+/** Source for a framework-injected tool parameter. */
+public enum ToolParameterSource {
+ CONFIG("config"),
+ SENSORY_MEMORY("sensory_memory"),
+ SHORT_TERM_MEMORY("short_term_memory");
+
+ private final String value;
+
+ ToolParameterSource(String value) {
+ this.value = value;
+ }
+
+ @JsonValue
+ public String getValue() {
+ return value;
+ }
+
+ @JsonCreator
+ public static ToolParameterSource fromValue(String value) {
+ if ("action_config".equals(value)) {
Review Comment:
This branch maps the string `"action_config"` → `CONFIG` (mirrored in Python
at `tool_parameter_injection.py:35`), but `action_config` is a brand-new name
with no test, no doc, and no YAML/JSON reference. It also collides with the
existing, unrelated per-action `action_config` concept
(`RunnerContext.action_config`, the per-action config dict — not the global
config this maps to). It reads as vestigial, likely a leftover from the removed
hook iteration. Is `action_config` meant to be a public source name? If not,
dropping both branches keeps the contract to the documented `config` source.
##########
plan/src/test/java/org/apache/flink/agents/plan/actions/ToolCallActionTest.java:
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.plan.actions;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.api.annotation.ToolParam;
+import org.apache.flink.agents.api.configuration.ReadableConfiguration;
+import org.apache.flink.agents.api.context.DurableCallable;
+import org.apache.flink.agents.api.context.MemoryObject;
+import org.apache.flink.agents.api.context.MemoryRef;
+import org.apache.flink.agents.api.context.RunnerContext;
+import org.apache.flink.agents.api.event.ToolRequestEvent;
+import org.apache.flink.agents.api.event.ToolResponseEvent;
+import org.apache.flink.agents.api.memory.BaseLongTermMemory;
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
+import org.apache.flink.agents.api.resource.Resource;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.agents.api.tools.Tool;
+import org.apache.flink.agents.api.tools.ToolMetadata;
+import org.apache.flink.agents.api.tools.ToolParameterInjection;
+import org.apache.flink.agents.api.tools.ToolParameters;
+import org.apache.flink.agents.api.tools.ToolResponse;
+import org.apache.flink.agents.api.tools.ToolType;
+import org.apache.flink.agents.plan.AgentConfiguration;
+import org.apache.flink.agents.plan.JavaFunction;
+import org.apache.flink.agents.plan.tools.FunctionTool;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class ToolCallActionTest {
+
+ public static String queryOrder(
+ @ToolParam(name = "orderId") String orderId,
+ @ToolParam(name = "tenant_id") String tenantId) {
+ return tenantId + ":" + orderId;
+ }
+
+ static class FailingTool extends Tool {
+ FailingTool() {
+ super(new ToolMetadata("queryOrder", "Query order.", "{}"));
+ }
+
+ @Override
+ public ToolType getToolType() {
+ return ToolType.FUNCTION;
+ }
+
+ @Override
+ public ToolResponse call(ToolParameters parameters) {
+ throw new RuntimeException("database timeout");
+ }
+ }
+
+ static class FakeRunnerContext implements RunnerContext {
+ private final List<Event> sentEvents = new ArrayList<>();
+ private final AgentConfiguration config =
+ new AgentConfiguration(Map.of("tenant_id", "tenant-1"));
+ private ToolParameterInjection injection =
ToolParameterInjection.fromConfig("tenant_id");
+ private MemoryObject sensoryMemory;
+ private MemoryObject shortTermMemory;
+
+ FakeRunnerContext withInjection(ToolParameterInjection injection) {
+ this.injection = injection;
+ return this;
+ }
+
+ FakeRunnerContext withSensoryMemory(Map<String, Object> values) {
+ this.sensoryMemory = new FakeMemoryObject(values);
+ return this;
+ }
+
+ FakeRunnerContext withShortTermMemory(Map<String, Object> values) {
+ this.shortTermMemory = new FakeMemoryObject(values);
+ return this;
+ }
+
+ @Override
+ public void sendEvent(Event event) {
+ sentEvents.add(event);
+ }
+
+ @Override
+ public MemoryObject getSensoryMemory() {
+ return sensoryMemory;
+ }
+
+ @Override
+ public MemoryObject getShortTermMemory() {
+ return shortTermMemory;
+ }
+
+ @Override
+ public BaseLongTermMemory getLongTermMemory() {
+ return null;
+ }
+
+ @Override
+ public FlinkAgentsMetricGroup getAgentMetricGroup() {
+ return null;
+ }
+
+ @Override
+ public FlinkAgentsMetricGroup getActionMetricGroup() {
+ return null;
+ }
+
+ @Override
+ public Resource getResource(String name, ResourceType type) throws
Exception {
+ assertThat(name).isEqualTo("queryOrder");
+ assertThat(type).isEqualTo(ResourceType.TOOL);
+ return new FunctionTool(
+ new ToolMetadata("queryOrder", "Query order.", "{}"),
+ new JavaFunction(
+ ToolCallActionTest.class,
+ "queryOrder",
+ new Class[] {String.class, String.class}),
+ Map.of("tenant_id", injection));
+ }
+
+ @Override
+ public ReadableConfiguration getConfig() {
+ return config;
+ }
+
+ @Override
+ public Map<String, Object> getActionConfig() {
+ return Map.of();
+ }
+
+ @Override
+ public Object getActionConfigValue(String key) {
+ return null;
+ }
+
+ @Override
+ public <T> T durableExecute(DurableCallable<T> callable) throws
Exception {
+ return callable.call();
+ }
+
+ @Override
+ public <T> T durableExecuteAsync(DurableCallable<T> callable) throws
Exception {
+ return callable.call();
+ }
+
+ @Override
+ public void close() {}
+ }
+
+ static class FakeMemoryObject implements MemoryObject {
+ private final Map<String, Object> values;
+ private final Object value;
+
+ FakeMemoryObject(Map<String, Object> values) {
+ this(values, null);
+ }
+
+ FakeMemoryObject(Map<String, Object> values, Object value) {
+ this.values = values;
+ this.value = value;
+ }
+
+ @Override
+ public MemoryObject get(String path) throws Exception {
+ if (!values.containsKey(path)) {
+ throw new IllegalArgumentException("Missing path: " + path);
+ }
+ return new FakeMemoryObject(values, values.get(path));
+ }
+
+ @Override
+ public MemoryObject get(MemoryRef ref) throws Exception {
+ return get(ref.getPath());
+ }
+
+ @Override
+ public MemoryRef set(String path, Object value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public MemoryObject newObject(String path, boolean overwrite) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean isExist(String path) {
+ return values.containsKey(path);
+ }
+
+ @Override
+ public List<String> getFieldNames() {
+ return new ArrayList<>(values.keySet());
+ }
+
+ @Override
+ public Map<String, Object> getFields() {
+ return Collections.unmodifiableMap(values);
+ }
+
+ @Override
+ public Object getValue() {
+ return value;
+ }
+
+ @Override
+ public boolean isNestedObject() {
+ return value == null;
+ }
+ }
+
+ @Test
+ void processToolRequestInjectsArgsFromConfigBeforeDurableToolCall() throws
Exception {
+ FakeRunnerContext ctx = new FakeRunnerContext();
+
+ Map<String, Object> arguments = new HashMap<>(Map.of("orderId",
"order-1"));
+ ToolRequestEvent event =
+ new ToolRequestEvent(
+ "model",
+ List.of(
+ Map.of(
+ "id",
+ "call-1",
+ "type",
+ "function",
+ "function",
+ Map.of("name", "queryOrder",
"arguments", arguments))));
+
+ ToolCallAction.processToolRequest(event, ctx);
+
+ ToolResponseEvent response =
ToolResponseEvent.fromEvent(ctx.sentEvents.get(0));
+
assertThat(response.getResponses().get("call-1").getResult()).isEqualTo("tenant-1:order-1");
+ assertThat(arguments).containsOnly(Map.entry("orderId", "order-1"));
Review Comment:
The comment in `ToolCallAction` says injected args must win over
model-provided values "so hidden context such as tenant ids cannot be spoofed
by a tool call payload" — but this config-injection test passes only
non-conflicting args (`orderId`), so it would still pass even if Java used
`putIfAbsent` instead of `putAll`. The security-relevant override is verified
only on the Python side
(`test_tool_call_action_injected_arg_overrides_model_argument`). Something like
this, if useful — add a case where the model also supplies `tenant_id` and
assert the injected value wins, so the Java override is actually pinned.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]