GreatEugenius commented on code in PR #854: URL: https://github.com/apache/flink-agents/pull/854#discussion_r3465785900
########## plan/src/test/java/org/apache/flink/agents/plan/actions/ToolCallActionTest.java: ########## @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.plan.actions; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.annotation.ToolParam; +import org.apache.flink.agents.api.configuration.ReadableConfiguration; +import org.apache.flink.agents.api.context.DurableCallable; +import org.apache.flink.agents.api.context.MemoryObject; +import org.apache.flink.agents.api.context.MemoryRef; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.event.ToolRequestEvent; +import org.apache.flink.agents.api.event.ToolResponseEvent; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.api.tools.ToolParameterInjection; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.apache.flink.agents.api.tools.ToolType; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.agents.plan.JavaFunction; +import org.apache.flink.agents.plan.tools.FunctionTool; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +public class ToolCallActionTest { + + public static String queryOrder( + @ToolParam(name = "orderId") String orderId, + @ToolParam(name = "tenant_id") String tenantId) { + return tenantId + ":" + orderId; + } + + static class FailingTool extends Tool { + FailingTool() { + super(new ToolMetadata("queryOrder", "Query order.", "{}")); + } + + @Override + public ToolType getToolType() { + return ToolType.FUNCTION; + } + + @Override + public ToolResponse call(ToolParameters parameters) { + throw new RuntimeException("database timeout"); + } + } + + static class FakeRunnerContext implements RunnerContext { + private final List<Event> sentEvents = new ArrayList<>(); + private final AgentConfiguration config = + new AgentConfiguration(Map.of("tenant_id", "tenant-1")); + private ToolParameterInjection injection = ToolParameterInjection.fromConfig("tenant_id"); + private MemoryObject sensoryMemory; + private MemoryObject shortTermMemory; + + FakeRunnerContext withInjection(ToolParameterInjection injection) { + this.injection = injection; + return this; + } + + FakeRunnerContext withSensoryMemory(Map<String, Object> values) { + this.sensoryMemory = new FakeMemoryObject(values); + return this; + } + + FakeRunnerContext withShortTermMemory(Map<String, Object> values) { + this.shortTermMemory = new FakeMemoryObject(values); + return this; + } + + @Override + public void sendEvent(Event event) { + sentEvents.add(event); + } + + @Override + public MemoryObject getSensoryMemory() { + return sensoryMemory; + } + + @Override + public MemoryObject getShortTermMemory() { + return shortTermMemory; + } + + @Override + public BaseLongTermMemory getLongTermMemory() { + return null; + } + + @Override + public FlinkAgentsMetricGroup getAgentMetricGroup() { + return null; + } + + @Override + public FlinkAgentsMetricGroup getActionMetricGroup() { + return null; + } + + @Override + public Resource getResource(String name, ResourceType type) throws Exception { + assertThat(name).isEqualTo("queryOrder"); + assertThat(type).isEqualTo(ResourceType.TOOL); + return new FunctionTool( + new ToolMetadata("queryOrder", "Query order.", "{}"), + new JavaFunction( + ToolCallActionTest.class, + "queryOrder", + new Class[] {String.class, String.class}), + Map.of("tenant_id", injection)); + } + + @Override + public ReadableConfiguration getConfig() { + return config; + } + + @Override + public Map<String, Object> getActionConfig() { + return Map.of(); + } + + @Override + public Object getActionConfigValue(String key) { + return null; + } + + @Override + public <T> T durableExecute(DurableCallable<T> callable) throws Exception { + return callable.call(); + } + + @Override + public <T> T durableExecuteAsync(DurableCallable<T> callable) throws Exception { + return callable.call(); + } + + @Override + public void close() {} + } + + static class FakeMemoryObject implements MemoryObject { + private final Map<String, Object> values; + private final Object value; + + FakeMemoryObject(Map<String, Object> values) { + this(values, null); + } + + FakeMemoryObject(Map<String, Object> values, Object value) { + this.values = values; + this.value = value; + } + + @Override + public MemoryObject get(String path) throws Exception { + if (!values.containsKey(path)) { + throw new IllegalArgumentException("Missing path: " + path); + } + return new FakeMemoryObject(values, values.get(path)); + } + + @Override + public MemoryObject get(MemoryRef ref) throws Exception { + return get(ref.getPath()); + } + + @Override + public MemoryRef set(String path, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public MemoryObject newObject(String path, boolean overwrite) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isExist(String path) { + return values.containsKey(path); + } + + @Override + public List<String> getFieldNames() { + return new ArrayList<>(values.keySet()); + } + + @Override + public Map<String, Object> getFields() { + return Collections.unmodifiableMap(values); + } + + @Override + public Object getValue() { + return value; + } + + @Override + public boolean isNestedObject() { + return value == null; + } + } + + @Test + void processToolRequestInjectsArgsFromConfigBeforeDurableToolCall() throws Exception { + FakeRunnerContext ctx = new FakeRunnerContext(); + + Map<String, Object> arguments = new HashMap<>(Map.of("orderId", "order-1")); + ToolRequestEvent event = + new ToolRequestEvent( + "model", + List.of( + Map.of( + "id", + "call-1", + "type", + "function", + "function", + Map.of("name", "queryOrder", "arguments", arguments)))); + + ToolCallAction.processToolRequest(event, ctx); + + ToolResponseEvent response = ToolResponseEvent.fromEvent(ctx.sentEvents.get(0)); + assertThat(response.getResponses().get("call-1").getResult()).isEqualTo("tenant-1:order-1"); + assertThat(arguments).containsOnly(Map.entry("orderId", "order-1")); Review Comment: Good point. I added a Java test that covers the security-relevant case directly: the model supplies `tenant_id`, the injected source supplies another value, and the injected value wins. The test also checks that the original model argument map is not mutated. ########## python/flink_agents/plan/tests/resources/agent_plan.json: ########## @@ -102,30 +102,30 @@ "__resource_provider_type__": "PythonResourceProvider" } }, - "embedding_model": { - "mock_embedding": { - "name": "mock_embedding", - "type": "embedding_model", + "embedding_model_connection": { Review Comment: Agreed, this was unrelated regeneration churn. I reverted the resource-order-only change so the snapshot diff stays focused on the injected-args behavior. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
