This is an automated email from the ASF dual-hosted git repository. wenjin272 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 4e4689fc5c515f71d1c3f7127f1896337e2b5dba Author: WenjinXie <[email protected]> AuthorDate: Tue May 19 14:59:23 2026 +0800 [api][plan][runtime] Cross-language Function descriptors and FunctionTool Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> --- .../org/apache/flink/agents/api/agents/Agent.java | 34 +++-- .../FunctionTool.java => function/Function.java} | 34 ++--- .../flink/agents/api/function/JavaFunction.java | 113 +++++++++++++++ .../flink/agents/api/function/PythonFunction.java | 76 ++++++++++ .../api/resource/python/PythonResourceAdapter.java | 32 +++++ .../flink/agents/api/tools/FunctionTool.java | 29 ++-- .../org/apache/flink/agents/api/tools/Tool.java | 13 +- .../agents/api/agents/AgentAddActionTest.java | 79 +++++++++++ .../flink/agents/api/function/FunctionTest.java} | 30 ++-- .../agents/api/function/JavaFunctionTest.java | 78 ++++++++++ .../agents/api/function/PythonFunctionTest.java | 51 +++++++ .../flink/agents/api/tools/FunctionToolTest.java | 61 ++++++++ .../org/apache/flink/agents/plan/AgentPlan.java | 157 ++++++++++++++++++--- .../flink/agents/plan/tools/FunctionTool.java | 101 +++++++++---- .../tools/FunctionToolSetPythonAdapterTest.java | 84 +++++++++++ python/flink_agents/api/tools/utils.py | 9 +- python/flink_agents/runtime/python_java_utils.py | 43 ++++++ .../apache/flink/agents/runtime/ResourceCache.java | 6 + .../python/utils/PythonResourceAdapterImpl.java | 23 +++ .../flink/agents/runtime/ResourceCacheTest.java | 10 ++ 20 files changed, 959 insertions(+), 104 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/Agent.java b/api/src/main/java/org/apache/flink/agents/api/agents/Agent.java index 230e5a7b..f7337248 100644 --- a/api/src/main/java/org/apache/flink/agents/api/agents/Agent.java +++ b/api/src/main/java/org/apache/flink/agents/api/agents/Agent.java @@ -18,6 +18,8 @@ package org.apache.flink.agents.api.agents; +import org.apache.flink.agents.api.function.Function; +import org.apache.flink.agents.api.function.JavaFunction; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.SerializableResource; @@ -31,7 +33,7 @@ import java.util.Map; /** Base class for defining agent logic. */ public class Agent { - private final Map<String, Tuple3<String[], Method, Map<String, Object>>> actions; + private final Map<String, Tuple3<String[], Function, Map<String, Object>>> actions; private final Map<ResourceType, Map<String, Object>> resources; @@ -43,7 +45,7 @@ public class Agent { this.actions = new HashMap<>(); } - public Map<String, Tuple3<String[], Method, Map<String, Object>>> getActions() { + public Map<String, Tuple3<String[], Function, Map<String, Object>>> getActions() { return actions; } @@ -60,12 +62,7 @@ public class Agent { */ public Agent addAction( String[] eventTypes, Method method, @Nullable Map<String, Object> config) { - String name = method.getName(); - if (actions.containsKey(name)) { - throw new IllegalArgumentException(String.format("Action %s already defined.", name)); - } - actions.put(name, new Tuple3<>(eventTypes, method, config)); - return this; + return addAction(method.getName(), eventTypes, JavaFunction.fromMethod(method), config); } /** @@ -78,6 +75,27 @@ public class Agent { return addAction(eventTypes, method, null); } + /** + * Add action to agent. + * + * @param name The action name. Must be unique within this agent. + * @param eventTypes The event type strings this action listens to. + * @param function The api-layer function descriptor; will be promoted to a plan-layer + * executable at {@code AgentPlan} construction. + * @param config Optional config for this action. + */ + public Agent addAction( + String name, + String[] eventTypes, + Function function, + @Nullable Map<String, Object> config) { + if (actions.containsKey(name)) { + throw new IllegalArgumentException(String.format("Action %s already defined.", name)); + } + actions.put(name, new Tuple3<>(eventTypes, function, config)); + return this; + } + public void addResourcesIfAbsent(Map<ResourceType, Map<String, Object>> resources) { for (ResourceType type : resources.keySet()) { Map<String, Object> typedResources = resources.get(type); diff --git a/api/src/main/java/org/apache/flink/agents/api/tools/FunctionTool.java b/api/src/main/java/org/apache/flink/agents/api/function/Function.java similarity index 57% copy from api/src/main/java/org/apache/flink/agents/api/tools/FunctionTool.java copy to api/src/main/java/org/apache/flink/agents/api/function/Function.java index 3ccba674..097a78ee 100644 --- a/api/src/main/java/org/apache/flink/agents/api/tools/FunctionTool.java +++ b/api/src/main/java/org/apache/flink/agents/api/function/Function.java @@ -15,28 +15,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.flink.agents.api.function; -package org.apache.flink.agents.api.tools; - -import org.apache.flink.agents.api.resource.ResourceType; -import org.apache.flink.agents.api.resource.SerializableResource; - -import java.lang.reflect.Method; - -/** Tool keeps a method, will be converted to tool after compile. */ -public class FunctionTool extends SerializableResource { - private final Method method; - - public FunctionTool(Method method) { - this.method = method; - } - - @Override - public ResourceType getResourceType() { - return ResourceType.TOOL; - } - - public Method getMethod() { - return method; - } -} +/** + * Pure-data marker for user-defined function descriptors carried on the api layer. + * + * <p>Implementations describe <em>which</em> function ({@link PythonFunction}, {@link + * JavaFunction}) but do not execute it. The plan-layer twins ({@code + * org.apache.flink.agents.plan.Function} and friends) own execution; the conversion from api → plan + * happens during {@code AgentPlan} construction. + */ +public interface Function {} diff --git a/api/src/main/java/org/apache/flink/agents/api/function/JavaFunction.java b/api/src/main/java/org/apache/flink/agents/api/function/JavaFunction.java new file mode 100644 index 00000000..2e438d9c --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/function/JavaFunction.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.function; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.Serializable; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * Pure-data descriptor for a Java method, identified by declaring class FQN, method name, and + * parameter types as strings. + * + * <p>Parameter types are strings — JVM primitive names ({@code int}, {@code long}, {@code boolean}, + * …) or fully-qualified reference type names ({@code java.lang.String}, {@code java.util.List}). No + * generic parameters. The wire form keeps the descriptor pure data; class resolution is deferred to + * the plan-layer twin. + */ +public final class JavaFunction implements Function, Serializable { + + private static final String FIELD_QUAL_NAME = "qualName"; + private static final String FIELD_METHOD_NAME = "methodName"; + private static final String FIELD_PARAMETER_TYPES = "parameterTypes"; + + @JsonProperty(FIELD_QUAL_NAME) + private final String qualName; + + @JsonProperty(FIELD_METHOD_NAME) + private final String methodName; + + @JsonProperty(FIELD_PARAMETER_TYPES) + private final List<String> parameterTypes; + + @JsonCreator + public JavaFunction( + @JsonProperty(FIELD_QUAL_NAME) String qualName, + @JsonProperty(FIELD_METHOD_NAME) String methodName, + @JsonProperty(FIELD_PARAMETER_TYPES) List<String> parameterTypes) { + this.qualName = Objects.requireNonNull(qualName, "qualName"); + this.methodName = Objects.requireNonNull(methodName, "methodName"); + this.parameterTypes = + parameterTypes == null + ? Collections.emptyList() + : Collections.unmodifiableList(new ArrayList<>(parameterTypes)); + } + + /** + * Build a descriptor from a reflected {@link Method}. Each parameter type is captured via + * {@link Class#getName()} — the same form {@link Class#forName(String)} accepts when the api + * descriptor is later promoted to its plan-layer twin. For primitives this is the keyword + * ({@code int}, {@code long}); for reference types the fully-qualified name; for array types + * the JVM-internal descriptor ({@code [I}, {@code [Ljava.lang.String;}). + */ + public static JavaFunction fromMethod(Method method) { + List<String> params = new ArrayList<>(method.getParameterCount()); + for (Class<?> p : method.getParameterTypes()) { + params.add(p.getName()); + } + return new JavaFunction(method.getDeclaringClass().getName(), method.getName(), params); + } + + public String getQualName() { + return qualName; + } + + public String getMethodName() { + return methodName; + } + + public List<String> getParameterTypes() { + return parameterTypes; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof JavaFunction)) return false; + JavaFunction that = (JavaFunction) o; + return qualName.equals(that.qualName) + && methodName.equals(that.methodName) + && parameterTypes.equals(that.parameterTypes); + } + + @Override + public int hashCode() { + return Objects.hash(qualName, methodName, parameterTypes); + } + + @Override + public String toString() { + return "JavaFunction{" + qualName + "#" + methodName + parameterTypes + "}"; + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/function/PythonFunction.java b/api/src/main/java/org/apache/flink/agents/api/function/PythonFunction.java new file mode 100644 index 00000000..e249a693 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/function/PythonFunction.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.function; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.Serializable; +import java.util.Objects; + +/** + * Pure-data descriptor for a Python callable, identified by its module and qualified name. + * + * <p>Carries no execution behavior — the plan-layer {@code + * org.apache.flink.agents.plan.PythonFunction} owns invocation via the Pemja interpreter. + */ +public final class PythonFunction implements Function, Serializable { + + private static final String FIELD_MODULE = "module"; + private static final String FIELD_QUAL_NAME = "qualName"; + + @JsonProperty(FIELD_MODULE) + private final String module; + + @JsonProperty(FIELD_QUAL_NAME) + private final String qualName; + + @JsonCreator + public PythonFunction( + @JsonProperty(FIELD_MODULE) String module, + @JsonProperty(FIELD_QUAL_NAME) String qualName) { + this.module = Objects.requireNonNull(module, "module"); + this.qualName = Objects.requireNonNull(qualName, "qualName"); + } + + public String getModule() { + return module; + } + + public String getQualName() { + return qualName; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof PythonFunction)) return false; + PythonFunction that = (PythonFunction) o; + return module.equals(that.module) && qualName.equals(that.qualName); + } + + @Override + public int hashCode() { + return Objects.hash(module, qualName); + } + + @Override + public String toString() { + return "PythonFunction{" + module + ":" + qualName + "}"; + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java b/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java index a28006f8..03eb8248 100644 --- a/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java +++ b/api/src/main/java/org/apache/flink/agents/api/resource/python/PythonResourceAdapter.java @@ -128,4 +128,36 @@ public interface PythonResourceAdapter { * @return the result of the method invocation */ Object invoke(String name, Object... args); + + /** + * Look up tool metadata for a Python function across the JVM→Python bridge. + * + * <p>The Java side asks the Python side to introspect a callable identified by {@code module} + + * {@code qualName}, and returns a flat {@code Map<String, String>} with keys {@code "name"}, + * {@code "description"}, and {@code "inputSchema"} (a JSON schema string compatible with {@code + * ToolMetadata.inputSchema}). + * + * <p>The return shape is intentionally flat — pemja can SIGSEGV when returning arbitrary Python + * objects to Java on non-main-interpreter threads. + * + * @param module the Python module containing the callable + * @param qualName the qualified name of the callable inside the module (e.g. {@code "fn"} or + * {@code "MyClass.method"}) + * @return flat map with keys "name", "description", "inputSchema" + */ + Map<String, String> getPythonToolMetadata(String module, String qualName); + + /** + * Invoke a Python callable as a tool, passing keyword arguments. Used when a Java chat model's + * tool list contains a {@code plan.FunctionTool} whose function descriptor is a {@code + * PythonFunction}: instead of routing the invocation through Java reflection, dispatch it + * across the bridge so the underlying Python function runs in the Pemja interpreter. + * + * @param module the Python module containing the callable + * @param qualName the qualified name of the callable inside the module + * @param kwargs keyword arguments to pass to the callable; LLM tool calls always arrive as + * keyword arguments + * @return the raw return value from the Python callable + */ + Object invokePythonTool(String module, String qualName, Map<String, Object> kwargs); } diff --git a/api/src/main/java/org/apache/flink/agents/api/tools/FunctionTool.java b/api/src/main/java/org/apache/flink/agents/api/tools/FunctionTool.java index 3ccba674..cf9dbd8a 100644 --- a/api/src/main/java/org/apache/flink/agents/api/tools/FunctionTool.java +++ b/api/src/main/java/org/apache/flink/agents/api/tools/FunctionTool.java @@ -18,25 +18,38 @@ package org.apache.flink.agents.api.tools; +import org.apache.flink.agents.api.function.Function; +import org.apache.flink.agents.api.function.JavaFunction; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.SerializableResource; import java.lang.reflect.Method; +import java.util.Objects; -/** Tool keeps a method, will be converted to tool after compile. */ +/** + * Pure-data tool descriptor: carries an {@link Function} reference. Used at agent-construction + * time; compiled to the plan-layer executable {@code plan.tools.FunctionTool} when the agent + * becomes an {@code AgentPlan}. + */ public class FunctionTool extends SerializableResource { - private final Method method; - public FunctionTool(Method method) { - this.method = method; + private final Function func; + + public FunctionTool(Function func) { + this.func = Objects.requireNonNull(func, "func"); + } + + /** Convenience factory: derive a {@link JavaFunction} from a reflected method. */ + public static FunctionTool fromMethod(Method method) { + return new FunctionTool(JavaFunction.fromMethod(method)); + } + + public Function getFunc() { + return func; } @Override public ResourceType getResourceType() { return ResourceType.TOOL; } - - public Method getMethod() { - return method; - } } diff --git a/api/src/main/java/org/apache/flink/agents/api/tools/Tool.java b/api/src/main/java/org/apache/flink/agents/api/tools/Tool.java index a0238400..11f0356d 100644 --- a/api/src/main/java/org/apache/flink/agents/api/tools/Tool.java +++ b/api/src/main/java/org/apache/flink/agents/api/tools/Tool.java @@ -30,12 +30,21 @@ import java.lang.reflect.Method; */ public abstract class Tool extends SerializableResource { - protected final ToolMetadata metadata; + protected ToolMetadata metadata; protected Tool(ToolMetadata metadata) { this.metadata = java.util.Objects.requireNonNull(metadata, "metadata cannot be null"); } + /** + * Replace this tool's metadata. Intended for subclasses that derive metadata lazily once a + * runtime bridge becomes available (e.g. {@code FunctionTool} backed by a {@code + * PythonFunction} refreshing placeholder metadata via the JVM→Python adapter). + */ + protected void setMetadata(ToolMetadata metadata) { + this.metadata = java.util.Objects.requireNonNull(metadata, "metadata cannot be null"); + } + /** Get the metadata of this tool. */ public final ToolMetadata getMetadata() { return metadata; @@ -68,6 +77,6 @@ public abstract class Tool extends SerializableResource { /** Get tool keeps a method. */ public static FunctionTool fromMethod(Method method) { - return new FunctionTool(method); + return FunctionTool.fromMethod(method); } } diff --git a/api/src/test/java/org/apache/flink/agents/api/agents/AgentAddActionTest.java b/api/src/test/java/org/apache/flink/agents/api/agents/AgentAddActionTest.java new file mode 100644 index 00000000..fd906bd5 --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/agents/AgentAddActionTest.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.agents; + +import org.apache.flink.agents.api.function.Function; +import org.apache.flink.agents.api.function.JavaFunction; +import org.apache.flink.agents.api.function.PythonFunction; +import org.apache.flink.api.java.tuple.Tuple3; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class AgentAddActionTest { + + public static void onInput(Object event, Object ctx) {} + + @Test + void newFunctionOverloadStoresApiFunction() { + Agent agent = new Agent(); + PythonFunction pf = new PythonFunction("pkg", "fn"); + agent.addAction("act", new String[] {"_input_event"}, pf, Map.of("k", "v")); + + Map<String, Tuple3<String[], Function, Map<String, Object>>> actions = agent.getActions(); + Tuple3<String[], Function, Map<String, Object>> entry = actions.get("act"); + assertThat(entry).isNotNull(); + assertThat(entry.f0).containsExactly("_input_event"); + assertThat(entry.f1).isSameAs(pf); + assertThat(entry.f2).containsEntry("k", "v"); + } + + @Test + void methodOverloadDelegatesToFunctionAsJavaFunction() throws Exception { + Method m = + AgentAddActionTest.class.getDeclaredMethod("onInput", Object.class, Object.class); + Agent agent = new Agent(); + agent.addAction(new String[] {"_input_event"}, m); + + Tuple3<String[], Function, Map<String, Object>> entry = agent.getActions().get("onInput"); + assertThat(entry.f1).isInstanceOf(JavaFunction.class); + JavaFunction jf = (JavaFunction) entry.f1; + assertThat(jf.getQualName()).isEqualTo(AgentAddActionTest.class.getName()); + assertThat(jf.getMethodName()).isEqualTo("onInput"); + } + + @Test + void duplicateNameRejected() { + Agent agent = new Agent(); + agent.addAction("act", new String[] {"_input_event"}, new PythonFunction("p", "q"), null); + assertThatThrownBy( + () -> + agent.addAction( + "act", + new String[] {"_input_event"}, + new PythonFunction("p", "q"), + null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("act"); + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/tools/FunctionTool.java b/api/src/test/java/org/apache/flink/agents/api/function/FunctionTest.java similarity index 58% copy from api/src/main/java/org/apache/flink/agents/api/tools/FunctionTool.java copy to api/src/test/java/org/apache/flink/agents/api/function/FunctionTest.java index 3ccba674..b10314ca 100644 --- a/api/src/main/java/org/apache/flink/agents/api/tools/FunctionTool.java +++ b/api/src/test/java/org/apache/flink/agents/api/function/FunctionTest.java @@ -15,28 +15,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.flink.agents.api.function; -package org.apache.flink.agents.api.tools; +import org.junit.jupiter.api.Test; -import org.apache.flink.agents.api.resource.ResourceType; -import org.apache.flink.agents.api.resource.SerializableResource; +import static org.assertj.core.api.Assertions.assertThat; -import java.lang.reflect.Method; +class FunctionTest { -/** Tool keeps a method, will be converted to tool after compile. */ -public class FunctionTool extends SerializableResource { - private final Method method; - - public FunctionTool(Method method) { - this.method = method; - } - - @Override - public ResourceType getResourceType() { - return ResourceType.TOOL; - } - - public Method getMethod() { - return method; + /** Function is a pure marker; the concrete data subtypes implement it. */ + @Test + void pythonAndJavaFunctionImplementMarker() { + Function py = new PythonFunction("pkg.mod", "fn"); + Function ja = new JavaFunction("com.example.X", "m", java.util.List.of()); + assertThat(py).isInstanceOf(Function.class); + assertThat(ja).isInstanceOf(Function.class); } } diff --git a/api/src/test/java/org/apache/flink/agents/api/function/JavaFunctionTest.java b/api/src/test/java/org/apache/flink/agents/api/function/JavaFunctionTest.java new file mode 100644 index 00000000..b79eaa79 --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/function/JavaFunctionTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +class JavaFunctionTest { + + public static int add(int a, int b) { + return a + b; + } + + @Test + void exposesAllFields() { + JavaFunction fn = + new JavaFunction("com.example.X", "add", List.of("int", "java.lang.String")); + assertThat(fn.getQualName()).isEqualTo("com.example.X"); + assertThat(fn.getMethodName()).isEqualTo("add"); + assertThat(fn.getParameterTypes()).containsExactly("int", "java.lang.String"); + } + + @Test + void fromMethodCapturesDeclaringClassAndPrimitiveAndReferenceParams() throws Exception { + Method m = JavaFunctionTest.class.getDeclaredMethod("add", int.class, int.class); + JavaFunction fn = JavaFunction.fromMethod(m); + assertThat(fn.getQualName()) + .isEqualTo("org.apache.flink.agents.api.function.JavaFunctionTest"); + assertThat(fn.getMethodName()).isEqualTo("add"); + assertThat(fn.getParameterTypes()).containsExactly("int", "int"); + } + + @Test + void parameterTypesListIsDefensiveCopy() { + var src = new java.util.ArrayList<>(List.of("int")); + JavaFunction fn = new JavaFunction("X", "m", src); + src.add("mutated"); + assertThat(fn.getParameterTypes()).containsExactly("int"); + } + + @Test + void equalsBasedOnAllFields() { + JavaFunction a = new JavaFunction("X", "m", List.of("int")); + JavaFunction b = new JavaFunction("X", "m", List.of("int")); + JavaFunction c = new JavaFunction("X", "m", List.of("long")); + assertThat(a).isEqualTo(b).isNotEqualTo(c); + assertThat(a).hasSameHashCodeAs(b); + } + + @Test + void jacksonRoundTrip() throws Exception { + ObjectMapper m = new ObjectMapper(); + JavaFunction fn = new JavaFunction("com.example.X", "m", List.of("int")); + String json = m.writeValueAsString(fn); + JavaFunction back = m.readValue(json, JavaFunction.class); + assertThat(back).isEqualTo(fn); + } +} diff --git a/api/src/test/java/org/apache/flink/agents/api/function/PythonFunctionTest.java b/api/src/test/java/org/apache/flink/agents/api/function/PythonFunctionTest.java new file mode 100644 index 00000000..d2ae089c --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/function/PythonFunctionTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class PythonFunctionTest { + + @Test + void exposesModuleAndQualName() { + PythonFunction fn = new PythonFunction("pkg.mod", "MyClass.method"); + assertThat(fn.getModule()).isEqualTo("pkg.mod"); + assertThat(fn.getQualName()).isEqualTo("MyClass.method"); + } + + @Test + void equalsBasedOnModuleAndQualName() { + PythonFunction a = new PythonFunction("m", "q"); + PythonFunction b = new PythonFunction("m", "q"); + PythonFunction c = new PythonFunction("m", "other"); + assertThat(a).isEqualTo(b).isNotEqualTo(c); + assertThat(a).hasSameHashCodeAs(b); + } + + @Test + void jacksonRoundTrip() throws Exception { + ObjectMapper m = new ObjectMapper(); + PythonFunction fn = new PythonFunction("pkg.mod", "fn"); + String json = m.writeValueAsString(fn); + PythonFunction back = m.readValue(json, PythonFunction.class); + assertThat(back).isEqualTo(fn); + } +} diff --git a/api/src/test/java/org/apache/flink/agents/api/tools/FunctionToolTest.java b/api/src/test/java/org/apache/flink/agents/api/tools/FunctionToolTest.java new file mode 100644 index 00000000..77e7eb99 --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/tools/FunctionToolTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.tools; + +import org.apache.flink.agents.api.function.JavaFunction; +import org.apache.flink.agents.api.function.PythonFunction; +import org.apache.flink.agents.api.resource.ResourceType; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; + +import static org.assertj.core.api.Assertions.assertThat; + +class FunctionToolTest { + + public static int demo(int a) { + return a; + } + + @Test + void holdsPythonFunction() { + PythonFunction pf = new PythonFunction("pkg.mod", "fn"); + FunctionTool tool = new FunctionTool(pf); + assertThat(tool.getFunc()).isSameAs(pf); + assertThat(tool.getResourceType()).isEqualTo(ResourceType.TOOL); + } + + @Test + void holdsJavaFunction() { + JavaFunction jf = new JavaFunction("X", "m", java.util.List.of("int")); + FunctionTool tool = new FunctionTool(jf); + assertThat(tool.getFunc()).isSameAs(jf); + } + + @Test + void fromMethodBuildsJavaFunction() throws Exception { + Method m = FunctionToolTest.class.getDeclaredMethod("demo", int.class); + FunctionTool tool = FunctionTool.fromMethod(m); + assertThat(tool.getFunc()).isInstanceOf(JavaFunction.class); + JavaFunction jf = (JavaFunction) tool.getFunc(); + assertThat(jf.getQualName()).isEqualTo(FunctionToolTest.class.getName()); + assertThat(jf.getMethodName()).isEqualTo("demo"); + assertThat(jf.getParameterTypes()).containsExactly("int"); + } +} diff --git a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java index 3a77f866..62eb0876 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java @@ -181,24 +181,22 @@ public class AgentPlan implements Serializable { } private void extractActions( - String[] listenEventTypeStrings, Method method, Map<String, Object> config) + String actionName, + String[] listenEventTypeStrings, + org.apache.flink.agents.plan.Function function, + Map<String, Object> config) throws Exception { List<String> eventTypeNames = new ArrayList<>(Arrays.asList(listenEventTypeStrings)); if (eventTypeNames.isEmpty()) { throw new IllegalArgumentException( - "Action method " - + method.getName() + "Action " + + actionName + " must specify at least one event type via listenEventTypes."); } - // Create a JavaFunction for this method - JavaFunction javaFunction = - new JavaFunction( - method.getDeclaringClass(), method.getName(), method.getParameterTypes()); - // Create an Action - Action action = new Action(method.getName(), javaFunction, eventTypeNames, config); + Action action = new Action(actionName, function, eventTypeNames, config); // Add to actions map actions.put(action.getName(), action); @@ -235,14 +233,26 @@ public class AgentPlan implements Serializable { String[] listenEventTypeStrings = Objects.requireNonNull(actionAnnotation).listenEventTypes(); - extractActions(listenEventTypeStrings, method, null); + org.apache.flink.agents.plan.JavaFunction javaFunction = + new org.apache.flink.agents.plan.JavaFunction( + method.getDeclaringClass(), + method.getName(), + method.getParameterTypes()); + extractActions(method.getName(), listenEventTypeStrings, javaFunction, null); } } - for (Map.Entry<String, Tuple3<String[], Method, Map<String, Object>>> action : - agent.getActions().entrySet()) { - Tuple3<String[], Method, Map<String, Object>> tuple = action.getValue(); - extractActions(tuple.f0, tuple.f1, tuple.f2); + for (Map.Entry< + String, + Tuple3< + String[], + org.apache.flink.agents.api.function.Function, + Map<String, Object>>> + action : agent.getActions().entrySet()) { + String actionName = action.getKey(); + Tuple3<String[], org.apache.flink.agents.api.function.Function, Map<String, Object>> + tuple = action.getValue(); + extractActions(actionName, tuple.f0, toPlanFunction(tuple.f1), tuple.f2); } } @@ -484,9 +494,21 @@ public class AgentPlan implements Serializable { } } else if (type == TOOL) { for (Map.Entry<String, Object> kv : entry.getValue().entrySet()) { - extractTool( - ((org.apache.flink.agents.api.tools.FunctionTool) kv.getValue()) - .getMethod()); + String resourceName = kv.getKey(); + Object value = kv.getValue(); + if (value instanceof org.apache.flink.agents.api.tools.FunctionTool) { + registerApiFunctionTool( + resourceName, + (org.apache.flink.agents.api.tools.FunctionTool) value); + } else if (value instanceof SerializableResource) { + // Plan-layer tools added directly (MCP-generated, etc.) — pass through. + addResourceProvider( + JavaSerializableResourceProvider.createResourceProvider( + resourceName, TOOL, (SerializableResource) value)); + } else { + throw new IllegalStateException( + "Unsupported tool resource '" + resourceName + "': " + value); + } } } else if (type == ResourceType.SKILLS) { for (Map.Entry<String, Object> kv : entry.getValue().entrySet()) { @@ -584,4 +606,105 @@ public class AgentPlan implements Serializable { .computeIfAbsent(provider.getType(), k -> new HashMap<>()) .put(provider.getName(), provider); } + + /** + * Promote an api-layer {@link org.apache.flink.agents.api.function.Function} descriptor to its + * plan-layer twin. Java parameter type strings are resolved to {@link Class} here; Python + * descriptors pass through unchanged. + */ + private static org.apache.flink.agents.plan.Function toPlanFunction( + org.apache.flink.agents.api.function.Function f) throws Exception { + if (f instanceof org.apache.flink.agents.api.function.JavaFunction) { + org.apache.flink.agents.api.function.JavaFunction jf = + (org.apache.flink.agents.api.function.JavaFunction) f; + Class<?>[] params = resolveParameterTypes(jf.getParameterTypes()); + Class<?> clazz = + Class.forName( + jf.getQualName(), true, Thread.currentThread().getContextClassLoader()); + return new org.apache.flink.agents.plan.JavaFunction(clazz, jf.getMethodName(), params); + } + if (f instanceof org.apache.flink.agents.api.function.PythonFunction) { + org.apache.flink.agents.api.function.PythonFunction pf = + (org.apache.flink.agents.api.function.PythonFunction) f; + return new org.apache.flink.agents.plan.PythonFunction( + pf.getModule(), pf.getQualName()); + } + throw new IllegalStateException("Unknown api.function.Function: " + f); + } + + private static Class<?>[] resolveParameterTypes(List<String> names) + throws ClassNotFoundException { + Class<?>[] out = new Class<?>[names.size()]; + for (int i = 0; i < names.size(); i++) { + out[i] = resolveParameterType(names.get(i)); + } + return out; + } + + private static Class<?> resolveParameterType(String name) throws ClassNotFoundException { + switch (name) { + case "boolean": + return boolean.class; + case "byte": + return byte.class; + case "short": + return short.class; + case "int": + return int.class; + case "long": + return long.class; + case "float": + return float.class; + case "double": + return double.class; + case "char": + return char.class; + case "void": + return void.class; + default: + return Class.forName(name, true, Thread.currentThread().getContextClassLoader()); + } + } + + /** + * Promote an api-layer {@link org.apache.flink.agents.api.tools.FunctionTool} to a plan-layer + * executable {@link FunctionTool} and register it under the YAML-declared resource name. + */ + private void registerApiFunctionTool( + String resourceName, org.apache.flink.agents.api.tools.FunctionTool apiTool) + throws Exception { + org.apache.flink.agents.api.function.Function func = apiTool.getFunc(); + if (func instanceof org.apache.flink.agents.api.function.JavaFunction) { + org.apache.flink.agents.api.function.JavaFunction jf = + (org.apache.flink.agents.api.function.JavaFunction) func; + Class<?>[] params = resolveParameterTypes(jf.getParameterTypes()); + Class<?> clazz = + Class.forName( + jf.getQualName(), true, Thread.currentThread().getContextClassLoader()); + Method method = clazz.getMethod(jf.getMethodName(), params); + ToolMetadata metadata = ToolMetadataFactory.fromStaticMethod(method); + org.apache.flink.agents.plan.JavaFunction planFunc = + new org.apache.flink.agents.plan.JavaFunction(clazz, method.getName(), params); + FunctionTool tool = new FunctionTool(metadata, planFunc); + addResourceProvider( + JavaSerializableResourceProvider.createResourceProvider( + resourceName, TOOL, tool)); + } else if (func instanceof org.apache.flink.agents.api.function.PythonFunction) { + org.apache.flink.agents.api.function.PythonFunction pf = + (org.apache.flink.agents.api.function.PythonFunction) func; + org.apache.flink.agents.plan.PythonFunction planFunc = + new org.apache.flink.agents.plan.PythonFunction( + pf.getModule(), pf.getQualName()); + // Placeholder metadata: ResourceCache will replace it with introspected values from + // the Python bridge via FunctionTool.setPythonResourceAdapter at first resolve. + ToolMetadata metadata = new ToolMetadata(resourceName, "", "{}"); + FunctionTool tool = new FunctionTool(metadata, planFunc); + addResourceProvider( + JavaSerializableResourceProvider.createResourceProvider( + resourceName, TOOL, tool)); + } else { + throw new IllegalStateException( + "Unknown api.function.Function for tool '" + resourceName + "': " + func); + } + } } diff --git a/plan/src/main/java/org/apache/flink/agents/plan/tools/FunctionTool.java b/plan/src/main/java/org/apache/flink/agents/plan/tools/FunctionTool.java index 66ef21a9..69247004 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/tools/FunctionTool.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/tools/FunctionTool.java @@ -20,9 +20,11 @@ package org.apache.flink.agents.plan.tools; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import org.apache.flink.agents.api.annotation.ToolParam; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.api.tools.Tool; import org.apache.flink.agents.api.tools.ToolMetadata; import org.apache.flink.agents.api.tools.ToolParameters; @@ -30,12 +32,15 @@ import org.apache.flink.agents.api.tools.ToolResponse; import org.apache.flink.agents.api.tools.ToolType; import org.apache.flink.agents.plan.Function; import org.apache.flink.agents.plan.JavaFunction; +import org.apache.flink.agents.plan.PythonFunction; import org.apache.flink.agents.plan.tools.serializer.FunctionToolJsonDeserializer; import org.apache.flink.agents.plan.tools.serializer.FunctionToolJsonSerializer; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.Parameter; +import java.util.HashMap; +import java.util.Map; /** * Plan-level implementation of a tool that wraps a static Java method. This belongs in the plan @@ -48,6 +53,8 @@ public class FunctionTool extends Tool { private final Function function; + @JsonIgnore private transient PythonResourceAdapter pythonResourceAdapter; + /** Create a FunctionTool from ToolMetadata and Function */ public FunctionTool(ToolMetadata metadata, Function function) { super(metadata); @@ -104,38 +111,84 @@ public class FunctionTool extends Tool { @Override public ToolResponse call(ToolParameters parameters) { try { - // Map ToolParameters to method arguments by name and type - Method method = ((JavaFunction) function).getMethod(); - Parameter[] methodParams = method.getParameters(); - Object[] args = new Object[methodParams.length]; - for (int i = 0; i < methodParams.length; i++) { - Parameter p = methodParams[i]; - String paramName = p.getName(); - if (p.isAnnotationPresent(ToolParam.class)) { - ToolParam ann = p.getAnnotation(ToolParam.class); - if (!ann.name().isEmpty()) { - paramName = ann.name(); - } + if (function instanceof PythonFunction) { + return callPython((PythonFunction) function, parameters); + } + return callJava(parameters); + } catch (Exception e) { + return ToolResponse.error(e); + } + } + + private ToolResponse callJava(ToolParameters parameters) throws Exception { + // Map ToolParameters to method arguments by name and type + Method method = ((JavaFunction) function).getMethod(); + Parameter[] methodParams = method.getParameters(); + Object[] args = new Object[methodParams.length]; + for (int i = 0; i < methodParams.length; i++) { + Parameter p = methodParams[i]; + String paramName = p.getName(); + if (p.isAnnotationPresent(ToolParam.class)) { + ToolParam ann = p.getAnnotation(ToolParam.class); + if (!ann.name().isEmpty()) { + paramName = ann.name(); } - Object value = parameters.getParameter(paramName, p.getType()); - if (value == null && p.isAnnotationPresent(ToolParam.class)) { - ToolParam ann = p.getAnnotation(ToolParam.class); - if (ann.required() && ann.defaultValue().isEmpty()) { - throw new IllegalArgumentException( - "Missing required parameter: " + paramName); - } + } + Object value = parameters.getParameter(paramName, p.getType()); + if (value == null && p.isAnnotationPresent(ToolParam.class)) { + ToolParam ann = p.getAnnotation(ToolParam.class); + if (ann.required() && ann.defaultValue().isEmpty()) { + throw new IllegalArgumentException("Missing required parameter: " + paramName); } - args[i] = value; } + args[i] = value; + } + Object result = function.call(args); + return ToolResponse.success(result); + } - Object result = function.call(args); - return ToolResponse.success(result); - } catch (Exception e) { - return ToolResponse.error(e); + private ToolResponse callPython(PythonFunction pf, ToolParameters parameters) { + if (pythonResourceAdapter == null) { + return ToolResponse.error( + new IllegalStateException( + "Python tool '" + + pf.getQualName() + + "' has no PythonResourceAdapter; runtime should inject one" + + " before invocation.")); + } + Map<String, Object> kwargs = new HashMap<>(); + for (String name : parameters.getParameterNames()) { + kwargs.put(name, parameters.getParameter(name)); } + Object result = + pythonResourceAdapter.invokePythonTool(pf.getModule(), pf.getQualName(), kwargs); + return ToolResponse.success(result); } public Function getFunction() { return function; } + + /** + * Refresh this tool's metadata via the Python bridge when the underlying function is a {@link + * PythonFunction}. No-op for Java-backed tools. + * + * <p>Called by the runtime resource cache the first time the tool is resolved, so the + * placeholder metadata that {@code AgentPlan.registerApiFunctionTool} writes for Python tools + * gets replaced with real introspected values (name, description, inputSchema) sourced from the + * Python callable's signature and docstring. + */ + public void setPythonResourceAdapter(PythonResourceAdapter adapter) { + if (!(function instanceof PythonFunction)) { + return; + } + this.pythonResourceAdapter = adapter; + PythonFunction pf = (PythonFunction) function; + Map<String, String> flat = adapter.getPythonToolMetadata(pf.getModule(), pf.getQualName()); + setMetadata( + new ToolMetadata( + flat.get("name"), + flat.getOrDefault("description", ""), + flat.getOrDefault("inputSchema", "{}"))); + } } diff --git a/plan/src/test/java/org/apache/flink/agents/plan/tools/FunctionToolSetPythonAdapterTest.java b/plan/src/test/java/org/apache/flink/agents/plan/tools/FunctionToolSetPythonAdapterTest.java new file mode 100644 index 00000000..8e7d1876 --- /dev/null +++ b/plan/src/test/java/org/apache/flink/agents/plan/tools/FunctionToolSetPythonAdapterTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.plan.tools; + +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.plan.JavaFunction; +import org.apache.flink.agents.plan.PythonFunction; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class FunctionToolSetPythonAdapterTest { + + @Test + void replacesPlaceholderMetadataForPythonFunction() { + ToolMetadata placeholder = new ToolMetadata("notify", "", "{}"); + PythonFunction pf = new PythonFunction("pkg.mod", "notify"); + FunctionTool tool = new FunctionTool(placeholder, pf); + + PythonResourceAdapter adapter = Mockito.mock(PythonResourceAdapter.class); + when(adapter.getPythonToolMetadata("pkg.mod", "notify")) + .thenReturn( + Map.of( + "name", "notify", + "description", "Send a notification.", + "inputSchema", + "{\"properties\":{\"id\":{\"type\":\"string\"," + + "\"description\":\"recipient id\"}}}")); + + tool.setPythonResourceAdapter(adapter); + + assertThat(tool.getMetadata().getName()).isEqualTo("notify"); + assertThat(tool.getMetadata().getDescription()).isEqualTo("Send a notification."); + assertThat(tool.getMetadata().getInputSchema()).contains("recipient id"); + verify(adapter, times(1)).getPythonToolMetadata(eq("pkg.mod"), eq("notify")); + } + + @Test + void noOpForJavaFunction() throws Exception { + ToolMetadata original = new ToolMetadata("add", "Adds.", "{\"properties\":{}}"); + JavaFunction jf = + new JavaFunction( + FunctionToolSetPythonAdapterTest.class, + "stubMethod", + new Class<?>[] {int.class}); + FunctionTool tool = new FunctionTool(original, jf); + + PythonResourceAdapter adapter = Mockito.mock(PythonResourceAdapter.class); + tool.setPythonResourceAdapter(adapter); + + // Metadata untouched + assertThat(tool.getMetadata()).isSameAs(original); + verify(adapter, never()).getPythonToolMetadata(Mockito.anyString(), Mockito.anyString()); + } + + /** Helper static method to back JavaFunction in the no-op test. */ + public static int stubMethod(int x) { + return x; + } +} diff --git a/python/flink_agents/api/tools/utils.py b/python/flink_agents/api/tools/utils.py index f94fc818..f1145b5c 100644 --- a/python/flink_agents/api/tools/utils.py +++ b/python/flink_agents/api/tools/utils.py @@ -18,7 +18,7 @@ import json import typing from inspect import signature -from typing import Any, Callable, Optional, Type, Union +from typing import Any, Callable, Dict, Optional, Type, Union from docstring_parser import parse from pydantic import BaseModel, create_model @@ -210,6 +210,7 @@ def create_java_tool_schema_str_from_model(model: type[BaseModel]) -> str: REVERSE_TYPE_MAPPING = {v: k for k, v in TYPE_MAPPING.items()} properties = {} + required = [] for field_name, field_info in model.model_fields.items(): field_type = field_info.annotation @@ -228,7 +229,11 @@ def create_java_tool_schema_str_from_model(model: type[BaseModel]) -> str: description = f"Parameter: {field_name}" properties[field_name] = {"type": json_type, "description": description} + if field_info.is_required(): + required.append(field_name) - json_schema = {"properties": properties} + json_schema: Dict[str, Any] = {"properties": properties} + if required: + json_schema["required"] = required return json.dumps(json_schema, ensure_ascii=False, indent=2) diff --git a/python/flink_agents/runtime/python_java_utils.py b/python/flink_agents/runtime/python_java_utils.py index 58389c82..23ed4f5c 100644 --- a/python/flink_agents/runtime/python_java_utils.py +++ b/python/flink_agents/runtime/python_java_utils.py @@ -126,6 +126,49 @@ def from_java_tool(j_tool: Any) -> JavaTool: return JavaTool(metadata=metadata) +def get_python_tool_metadata(module: str, qual_name: str) -> Dict[str, str]: + """Introspect a Python callable into the flat tool-metadata shape expected by + the Java-side ``PythonResourceAdapter.getPythonToolMetadata``. + + Mirrors the Python side's eager-metadata derivation for + ``PythonFunction``-backed ``FunctionTool``s. Returns the same three-key shape + ``JavaResourceAdapter.getJavaToolMetadata`` returns in the reverse direction + so the Java side can rebuild ``ToolMetadata`` from String fields only — + avoiding pemja's SIGSEGV when wrapping arbitrary Python objects on non-main + interpreter threads. + """ + from docstring_parser import parse + + from flink_agents.api.function import PythonFunction + from flink_agents.api.tools.utils import ( + create_java_tool_schema_str_from_model, + create_schema_from_function, + ) + + descriptor = PythonFunction(module=module, qualname=qual_name) + callable_ = descriptor.as_callable() + name = callable_.__name__ + description = (parse(callable_.__doc__).description or "") if callable_.__doc__ else "" + args_schema_model = create_schema_from_function(name, callable_) + input_schema = create_java_tool_schema_str_from_model(args_schema_model) + return {"name": name, "description": description, "inputSchema": input_schema} + + +def invoke_python_tool( + module: str, qual_name: str, kwargs: Dict[str, Any] +) -> Any: + """Invoke a Python callable as a tool, passing the provided keyword arguments. + + Used by the Java-side ``PythonResourceAdapter.invokePythonTool`` so a Java host can + dispatch a Python function tool from a Java chat model without the Python side + needing to know about Pemja's threading model. + """ + from flink_agents.api.function import PythonFunction + + descriptor = PythonFunction(module=module, qualname=qual_name) + return descriptor.as_callable()(**kwargs) + + def from_java_prompt(j_prompt: Any) -> JavaPrompt: """Convert a Java prompt object to a Python JavaPrompt instance. diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/ResourceCache.java b/runtime/src/main/java/org/apache/flink/agents/runtime/ResourceCache.java index db8e5dc3..8e56bb98 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/ResourceCache.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/ResourceCache.java @@ -23,6 +23,7 @@ import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; import org.apache.flink.agents.plan.resourceprovider.ResourceProvider; +import org.apache.flink.agents.plan.tools.FunctionTool; import org.apache.flink.agents.runtime.resource.ResourceContextImpl; import java.util.HashMap; @@ -96,6 +97,11 @@ public class ResourceCache implements AutoCloseable { throw new RuntimeException(e); } })); + + if (pythonResourceAdapter != null && resource instanceof FunctionTool) { + ((FunctionTool) resource).setPythonResourceAdapter(pythonResourceAdapter); + } + resource.open(); cache.computeIfAbsent(type, k -> new ConcurrentHashMap<>()).put(name, resource); return resource; diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java index 238e0e8c..f4284e48 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonResourceAdapterImpl.java @@ -72,6 +72,11 @@ public class PythonResourceAdapterImpl implements PythonResourceAdapter { static final String FROM_JAVA_VECTOR_STORE_QUERY = PYTHON_MODULE_PREFIX + "from_java_vector_store_query"; + static final String GET_PYTHON_TOOL_METADATA = + PYTHON_MODULE_PREFIX + "get_python_tool_metadata"; + + static final String INVOKE_PYTHON_TOOL = PYTHON_MODULE_PREFIX + "invoke_python_tool"; + private final ResourceContext resourceContext; private final PythonInterpreter interpreter; private final JavaResourceAdapter javaResourceAdapter; @@ -199,4 +204,22 @@ public class PythonResourceAdapterImpl implements PythonResourceAdapter { public Object invoke(String name, Object... args) { return interpreter.invoke(name, args); } + + @Override + public Map<String, String> getPythonToolMetadata(String module, String qualName) { + @SuppressWarnings("unchecked") + Map<String, String> result = + (Map<String, String>) + interpreter.invoke(GET_PYTHON_TOOL_METADATA, module, qualName); + if (result == null) { + throw new IllegalStateException( + "Python get_python_tool_metadata returned null for " + module + ":" + qualName); + } + return result; + } + + @Override + public Object invokePythonTool(String module, String qualName, Map<String, Object> kwargs) { + return interpreter.invoke(INVOKE_PYTHON_TOOL, module, qualName, kwargs); + } } diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/ResourceCacheTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/ResourceCacheTest.java index ecf8ef18..a48f618e 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/ResourceCacheTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/ResourceCacheTest.java @@ -188,6 +188,16 @@ public class ResourceCacheTest { public Object invoke(String name, Object... args) { return null; } + + @Override + public Map<String, String> getPythonToolMetadata(String module, String qualName) { + return Map.of("name", qualName, "description", "", "inputSchema", "{}"); + } + + @Override + public Object invokePythonTool(String module, String qualName, Map<String, Object> kwargs) { + return null; + } } @Test
