This is an automated email from the ASF dual-hosted git repository.
xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new 1ac8583 [Feature][runtime] Support the use of Java ChatModel in
Python (#373)
1ac8583 is described below
commit 1ac8583c55831619c220ce3c2901d1345fb12ee9
Author: Eugene <[email protected]>
AuthorDate: Wed Dec 17 23:27:31 2025 +0800
[Feature][runtime] Support the use of Java ChatModel in Python (#373)
---
.../agents/api/resource/ResourceDescriptor.java | 64 +++++-
.../resource/test/ChatModelCrossLanguageAgent.java | 2 +-
.../agents/plan/resource/python/PythonPrompt.java | 88 +++++++
.../agents/plan/resource/python/PythonTool.java | 94 ++++++++
.../resourceprovider/JavaResourceProvider.java | 8 +-
.../resourceprovider/PythonResourceProvider.java | 61 ++---
.../PythonSerializableResourceProvider.java | 18 +-
.../ResourceProviderJsonDeserializer.java | 18 +-
.../serializer/ResourceProviderJsonSerializer.java | 23 +-
.../apache/flink/agents/plan/AgentPlanTest.java | 17 --
.../compatibility/CreateJavaAgentPlanFromJson.java | 10 +-
.../plan/resource/python/PythonPromptTest.java | 112 +++++++++
.../plan/resource/python/PythonToolTest.java | 252 +++++++++++++++++++++
.../ResourceProviderDeserializerTest.java | 11 +-
.../serializer/ResourceProviderSerializerTest.java | 10 +-
.../python_resource_provider.json | 12 +-
.../api/chat_models/java_chat_model.py | 46 ++++
python/flink_agents/api/decorators.py | 5 +
python/flink_agents/api/resource.py | 118 ++++++++--
.../e2e_tests_resource_cross_language/__init__.py | 17 ++
.../chat_model_cross_language_agent.py | 151 ++++++++++++
.../chat_model_cross_language_test.py | 100 ++++++++
.../java_chat_module_input/input_data.txt | 2 +
python/flink_agents/plan/agent_plan.py | 30 ++-
python/flink_agents/plan/resource_provider.py | 67 ++++--
.../plan/tests/resources/agent_plan.json | 52 +++--
.../plan/tests/resources/resource_provider.json | 12 +-
.../plan/tests/test_resource_provider.py | 6 +-
.../flink_agents/runtime/flink_runner_context.py | 7 +-
.../flink_agents/runtime/java/java_chat_model.py | 124 ++++++++++
python/flink_agents/runtime/python_java_utils.py | 11 +
.../runtime/operator/ActionExecutionOperator.java | 7 +-
.../runtime/python/utils/JavaResourceAdapter.java | 124 ++++++++++
.../runtime/python/utils/PythonActionExecutor.java | 10 +-
34 files changed, 1480 insertions(+), 209 deletions(-)
diff --git
a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceDescriptor.java
b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceDescriptor.java
index f2803ba..1de8955 100644
---
a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceDescriptor.java
+++
b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceDescriptor.java
@@ -20,36 +20,69 @@ package org.apache.flink.agents.api.resource;
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
-import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeInfo;
import java.util.HashMap;
import java.util.Map;
+import java.util.Objects;
/** Helper class to describe a {@link Resource} */
public class ResourceDescriptor {
- private static final String FIELD_CLAZZ = "clazz";
- private static final String FIELD_INITIAL_ARGUMENTS = "initialArguments";
+ private static final String FIELD_CLAZZ = "target_clazz";
+ private static final String FIELD_MODULE = "target_module";
+ private static final String FIELD_INITIAL_ARGUMENTS = "arguments";
@JsonProperty(FIELD_CLAZZ)
private final String clazz;
- // TODO: support nested map/list with non primitive value.
+ @JsonProperty(FIELD_MODULE)
+ private final String module;
+
@JsonProperty(FIELD_INITIAL_ARGUMENTS)
- @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
private final Map<String, Object> initialArguments;
+ /**
+ * Initialize ResourceDescriptor.
+ *
+ * <p>Creates a new ResourceDescriptor with the specified class
information and initial
+ * arguments. This constructor supports cross-platform compatibility
between Java and Python
+ * resources.
+ *
+ * @param clazz The class identifier for the resource. Its meaning depends
on the resource type:
+ * <ul>
+ * <li><b>For Java resources:</b> The fully qualified Java class
name (e.g.,
+ * "com.example.YourJavaClass"). The {@code module} parameter
should be empty or null.
+ * <li><b>For Python resources (when declaring from Java):</b> The
Python class name
+ * (simple name, not module path, e.g., "YourPythonClass"). The
Python module path
+ * must be specified in the {@code module} parameter (e.g.,
"your_module.submodule").
+ * </ul>
+ *
+ * @param module The Python module path for cross-platform compatibility.
Defaults to empty
+ * string for Java resources. Example: "your_module.submodule"
+ * @param initialArguments Additional arguments for resource
initialization. Can be null or
+ * empty map if no initial arguments are needed.
+ */
@JsonCreator
public ResourceDescriptor(
+ @JsonProperty(FIELD_MODULE) String module,
@JsonProperty(FIELD_CLAZZ) String clazz,
@JsonProperty(FIELD_INITIAL_ARGUMENTS) Map<String, Object>
initialArguments) {
this.clazz = clazz;
+ this.module = module;
this.initialArguments = initialArguments;
}
+ public ResourceDescriptor(String clazz, Map<String, Object>
initialArguments) {
+ this("", clazz, initialArguments);
+ }
+
public String getClazz() {
return clazz;
}
+ public String getModule() {
+ return module;
+ }
+
public Map<String, Object> getInitialArguments() {
return initialArguments;
}
@@ -64,6 +97,27 @@ public class ResourceDescriptor {
return value != null ? value : defaultValue;
}
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ ResourceDescriptor that = (ResourceDescriptor) o;
+ return Objects.equals(this.clazz, that.clazz)
+ && Objects.equals(this.module, that.module)
+ && Objects.equals(this.initialArguments,
that.initialArguments);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(clazz, module, initialArguments);
+ }
+
public static class Builder {
private final String clazz;
private final Map<String, Object> initialArguments;
diff --git
a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageAgent.java
b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageAgent.java
index 4511044..cf3ec33 100644
---
a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageAgent.java
+++
b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageAgent.java
@@ -75,10 +75,10 @@ public class ChatModelCrossLanguageAgent extends Agent {
@ChatModelSetup
public static ResourceDescriptor chatModel() {
return
ResourceDescriptor.Builder.newBuilder(PythonChatModelSetup.class.getName())
- .addInitialArgument("connection", "chatModelConnection")
.addInitialArgument(
"module",
"flink_agents.integrations.chat_models.ollama_chat_model")
.addInitialArgument("clazz", "OllamaChatModelSetup")
+ .addInitialArgument("connection", "chatModelConnection")
.addInitialArgument("model", OLLAMA_MODEL)
.addInitialArgument(
"tools",
diff --git
a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonPrompt.java
b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonPrompt.java
new file mode 100644
index 0000000..fee43f4
--- /dev/null
+++
b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonPrompt.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.plan.resource.python;
+
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.prompt.Prompt;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * PythonPrompt is a subclass of Prompt that provides a method to parse a
Python prompt from a
+ * serialized map.
+ */
+public class PythonPrompt extends Prompt {
+ public PythonPrompt(String template) {
+ super(template);
+ }
+
+ public PythonPrompt(List<ChatMessage> template) {
+ super(template);
+ }
+
+ public static PythonPrompt fromSerializedMap(Map<String, Object>
serialized) {
+ if (serialized == null || !serialized.containsKey("template")) {
+ throw new IllegalArgumentException("Map must contain 'template'
key");
+ }
+
+ Object templateObj = serialized.get("template");
+ if (templateObj instanceof String) {
+ return new PythonPrompt((String) templateObj);
+ } else if (templateObj instanceof List) {
+ List<?> templateList = (List<?>) templateObj;
+ if (templateList.isEmpty()) {
+ throw new IllegalArgumentException("Template list cannot be
empty");
+ }
+
+ List<ChatMessage> messages = new ArrayList<>();
+ for (Object item : templateList) {
+ if (!(item instanceof Map)) {
+ throw new IllegalArgumentException("Each template item
must be a Map");
+ }
+
+ Map<String, Object> messageMap = (Map<String, Object>) item;
+ ChatMessage chatMessage = parseChatMessage(messageMap);
+ messages.add(chatMessage);
+ }
+
+ return new PythonPrompt(messages);
+ }
+ throw new IllegalArgumentException(
+ "Python prompt parsing failed. Template is not a string or
list.");
+ }
+
+ /** Parse a single ChatMessage from a Map representation. */
+ @SuppressWarnings("unchecked")
+ private static ChatMessage parseChatMessage(Map<String, Object>
messageMap) {
+ String roleValue = messageMap.get("role").toString();
+ MessageRole role = MessageRole.fromValue(roleValue);
+
+ Object contentObj = messageMap.get("content");
+ String content = contentObj != null ? contentObj.toString() : "";
+
+ List<Map<String, Object>> toolCalls =
+ (List<Map<String, Object>>) messageMap.get("tool_calls");
+
+ Map<String, Object> extraArgs = (Map<String, Object>)
messageMap.get("extra_args");
+
+ return new ChatMessage(role, content, toolCalls, extraArgs);
+ }
+}
diff --git
a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonTool.java
b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonTool.java
new file mode 100644
index 0000000..64eea7e
--- /dev/null
+++
b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonTool.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.plan.resource.python;
+
+import org.apache.flink.agents.api.tools.Tool;
+import org.apache.flink.agents.api.tools.ToolMetadata;
+import org.apache.flink.agents.api.tools.ToolParameters;
+import org.apache.flink.agents.api.tools.ToolResponse;
+import org.apache.flink.agents.api.tools.ToolType;
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
+
+import java.util.Map;
+
+/**
+ * PythonTool is a subclass of Tool that that provides a method to parse a
Python tool metadata from
+ * a serialized map.
+ */
+public class PythonTool extends Tool {
+ protected PythonTool(ToolMetadata metadata) {
+ super(metadata);
+ }
+
+ @SuppressWarnings("unchecked")
+ public static PythonTool fromSerializedMap(Map<String, Object> serialized)
+ throws JsonProcessingException {
+ if (serialized == null) {
+ throw new IllegalArgumentException("Serialized map cannot be
null");
+ }
+
+ if (!serialized.containsKey("metadata")) {
+ throw new IllegalArgumentException("Map must contain 'metadata'
key");
+ }
+
+ Object metadataObj = serialized.get("metadata");
+ if (!(metadataObj instanceof Map)) {
+ throw new IllegalArgumentException("'metadata' must be a Map");
+ }
+
+ Map<String, Object> metadata = (Map<String, Object>) metadataObj;
+
+ if (!metadata.containsKey("name")) {
+ throw new IllegalArgumentException("Metadata must contain 'name'
key");
+ }
+
+ if (!metadata.containsKey("description")) {
+ throw new IllegalArgumentException("Metadata must contain
'description' key");
+ }
+
+ if (!metadata.containsKey("args_schema")) {
+ throw new IllegalArgumentException("Metadata must contain
'args_schema' key");
+ }
+
+ String name = (String) metadata.get("name");
+ String description = (String) metadata.get("description");
+
+ if (name == null) {
+ throw new IllegalArgumentException("'name' cannot be null");
+ }
+
+ if (description == null) {
+ throw new IllegalArgumentException("'description' cannot be null");
+ }
+
+ ObjectMapper mapper = new ObjectMapper();
+ String inputSchema =
mapper.writeValueAsString(metadata.get("args_schema"));
+ return new PythonTool(new ToolMetadata(name, description,
inputSchema));
+ }
+
+ @Override
+ public ToolType getToolType() {
+ return ToolType.REMOTE_FUNCTION;
+ }
+
+ @Override
+ public ToolResponse call(ToolParameters parameters) {
+ throw new UnsupportedOperationException("PythonTool does not support
call method.");
+ }
+}
diff --git
a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/JavaResourceProvider.java
b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/JavaResourceProvider.java
index cd4317b..373c3c4 100644
---
a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/JavaResourceProvider.java
+++
b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/JavaResourceProvider.java
@@ -37,7 +37,13 @@ public class JavaResourceProvider extends ResourceProvider {
@Override
public Resource provide(BiFunction<String, ResourceType, Resource>
getResource)
throws Exception {
- Class<?> clazz = Class.forName(descriptor.getClazz());
+ String clazzName;
+ if (descriptor.getModule() == null ||
descriptor.getModule().isEmpty()) {
+ clazzName = descriptor.getClazz();
+ } else {
+ clazzName =
descriptor.getInitialArguments().remove("java_clazz").toString();
+ }
+ Class<?> clazz = Class.forName(clazzName);
Constructor<?> constructor =
clazz.getConstructor(ResourceDescriptor.class,
BiFunction.class);
return (Resource) constructor.newInstance(descriptor, getResource);
diff --git
a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java
b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java
index ce621a4..f70833f 100644
---
a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java
+++
b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java
@@ -26,7 +26,6 @@ import pemja.core.object.PyObject;
import java.lang.reflect.Constructor;
import java.util.HashMap;
-import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
@@ -39,37 +38,12 @@ import static
org.apache.flink.util.Preconditions.checkState;
* class, and initialization arguments.
*/
public class PythonResourceProvider extends ResourceProvider {
- private final String module;
- private final String clazz;
- private final Map<String, Object> kwargs;
private final ResourceDescriptor descriptor;
protected PythonResourceAdapter pythonResourceAdapter;
- public PythonResourceProvider(
- String name,
- ResourceType type,
- String module,
- String clazz,
- Map<String, Object> kwargs) {
- super(name, type);
- this.module = module;
- this.clazz = clazz;
- this.kwargs = kwargs;
- this.descriptor = null;
- }
-
public PythonResourceProvider(String name, ResourceType type,
ResourceDescriptor descriptor) {
super(name, type);
- this.kwargs = new HashMap<>(descriptor.getInitialArguments());
- module = (String) kwargs.remove("module");
- if (module == null || module.isEmpty()) {
- throw new IllegalArgumentException("module should not be null or
empty.");
- }
- clazz = (String) kwargs.remove("clazz");
- if (clazz == null || clazz.isEmpty()) {
- throw new IllegalArgumentException("clazz should not be null or
empty.");
- }
this.descriptor = descriptor;
}
@@ -77,16 +51,8 @@ public class PythonResourceProvider extends ResourceProvider
{
this.pythonResourceAdapter = pythonResourceAdapter;
}
- public String getModule() {
- return module;
- }
-
- public String getClazz() {
- return clazz;
- }
-
- public Map<String, Object> getKwargs() {
- return kwargs;
+ public ResourceDescriptor getDescriptor() {
+ return descriptor;
}
@Override
@@ -94,8 +60,17 @@ public class PythonResourceProvider extends ResourceProvider
{
throws Exception {
checkState(pythonResourceAdapter != null, "PythonResourceAdapter is
not set");
Class<?> clazz = Class.forName(descriptor.getClazz());
- PyObject pyResource =
- pythonResourceAdapter.initPythonResource(this.module,
this.clazz, kwargs);
+
+ HashMap<String, Object> kwargs = new
HashMap<>(descriptor.getInitialArguments());
+ String pyModule = (String) kwargs.remove("module");
+ if (pyModule == null || pyModule.isEmpty()) {
+ throw new IllegalArgumentException("module should not be null or
empty.");
+ }
+ String pyClazz = (String) kwargs.remove("clazz");
+ if (pyClazz == null || pyClazz.isEmpty()) {
+ throw new IllegalArgumentException("clazz should not be null or
empty.");
+ }
+ PyObject pyResource =
pythonResourceAdapter.initPythonResource(pyModule, pyClazz, kwargs);
Constructor<?> constructor =
clazz.getConstructor(
PythonResourceAdapter.class,
@@ -119,17 +94,11 @@ public class PythonResourceProvider extends
ResourceProvider {
PythonResourceProvider that = (PythonResourceProvider) o;
return Objects.equals(this.getName(), that.getName())
&& Objects.equals(this.getType(), that.getType())
- && Objects.equals(this.module, that.module)
- && Objects.equals(this.clazz, that.clazz)
- && Objects.equals(this.kwargs, that.kwargs);
+ && Objects.equals(this.getDescriptor(), that.getDescriptor());
}
@Override
public int hashCode() {
- return Objects.hash(this.getName(), this.getType(), module, clazz,
kwargs);
- }
-
- public ResourceDescriptor getDescriptor() {
- return descriptor;
+ return Objects.hash(this.getName(), this.getType(),
this.getDescriptor());
}
}
diff --git
a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonSerializableResourceProvider.java
b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonSerializableResourceProvider.java
index d39cb37..283c59c 100644
---
a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonSerializableResourceProvider.java
+++
b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonSerializableResourceProvider.java
@@ -21,6 +21,8 @@ package org.apache.flink.agents.plan.resourceprovider;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.api.resource.SerializableResource;
+import org.apache.flink.agents.plan.resource.python.PythonPrompt;
+import org.apache.flink.agents.plan.resource.python.PythonTool;
import java.util.Map;
import java.util.Objects;
@@ -68,11 +70,17 @@ public class PythonSerializableResourceProvider extends
SerializableResourceProv
@Override
public Resource provide(BiFunction<String, ResourceType, Resource>
getResource)
throws Exception {
- // TODO: Implement Python resource deserialization logic
- // This would typically involve calling into Python runtime to
deserialize the
- // resource
- throw new UnsupportedOperationException(
- "Python resource deserialization not yet implemented in Java
runtime");
+ if (resource == null) {
+ if (this.getType() == ResourceType.PROMPT) {
+ resource = PythonPrompt.fromSerializedMap(serialized);
+ } else if (this.getType() == ResourceType.TOOL) {
+ resource = PythonTool.fromSerializedMap(serialized);
+ } else {
+ throw new UnsupportedOperationException(
+ "Unsupported resource type: " + this.getType());
+ }
+ }
+ return resource;
}
@Override
diff --git
a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonDeserializer.java
b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonDeserializer.java
index df44e27..2caf154 100644
---
a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonDeserializer.java
+++
b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonDeserializer.java
@@ -75,24 +75,12 @@ public class ResourceProviderJsonDeserializer extends
StdDeserializer<ResourcePr
String name = node.get("name").asText();
String type = node.get("type").asText();
try {
- if (node.has("descriptor")) {
- ResourceDescriptor descriptor =
- mapper.treeToValue(node.get("descriptor"),
ResourceDescriptor.class);
- return new PythonResourceProvider(name,
ResourceType.fromValue(type), descriptor);
- }
+ ResourceDescriptor descriptor =
+ mapper.treeToValue(node.get("descriptor"),
ResourceDescriptor.class);
+ return new PythonResourceProvider(name,
ResourceType.fromValue(type), descriptor);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
- String module = node.get("module").asText();
- String clazz = node.get("clazz").asText();
-
- JsonNode kwargsNode = node.get("kwargs");
- Map<String, Object> kwargs = new HashMap<>();
- if (kwargsNode != null && kwargsNode.isObject()) {
- kwargs = mapper.convertValue(kwargsNode, Map.class);
- }
- return new PythonResourceProvider(
- name, ResourceType.fromValue(type), module, clazz, kwargs);
}
private PythonSerializableResourceProvider
deserializePythonSerializableResourceProvider(
diff --git
a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonSerializer.java
b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonSerializer.java
index 4b61c83..e876f37 100644
---
a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonSerializer.java
+++
b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ResourceProviderJsonSerializer.java
@@ -68,28 +68,7 @@ public class ResourceProviderJsonSerializer extends
StdSerializer<ResourceProvid
throws IOException {
gen.writeStringField("name", provider.getName());
gen.writeStringField("type", provider.getType().getValue());
- gen.writeStringField("module", provider.getModule());
- gen.writeStringField("clazz", provider.getClazz());
-
- if (provider.getDescriptor() != null) {
- gen.writeObjectField("descriptor", provider.getDescriptor());
- }
-
- gen.writeFieldName("kwargs");
- gen.writeStartObject();
- provider.getKwargs()
- .forEach(
- (name, value) -> {
- try {
- gen.writeObjectField(name, value);
- } catch (IOException e) {
- throw new RuntimeException(
- "Error writing kwargs of
PythonResourceProvider: " + name,
- e);
- }
- });
- gen.writeEndObject();
-
+ gen.writeObjectField("descriptor", provider.getDescriptor());
gen.writeStringField("__resource_provider_type__",
"PythonResourceProvider");
}
diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java
b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java
index 88af18b..b19ef40 100644
--- a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java
+++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java
@@ -412,12 +412,6 @@ public class AgentPlanTest {
assertThat(pythonChatModelProvider).isInstanceOf(PythonResourceProvider.class);
assertThat(pythonChatModelProvider.getName()).isEqualTo("pythonChatModel");
assertThat(pythonChatModelProvider.getType()).isEqualTo(ResourceType.CHAT_MODEL);
-
- // Test PythonResourceProvider specific methods
- PythonResourceProvider pythonResourceProvider =
- (PythonResourceProvider) pythonChatModelProvider;
- assertThat(pythonResourceProvider.getClazz()).isEqualTo("TestClazz");
-
assertThat(pythonResourceProvider.getModule()).isEqualTo("test.module");
}
@Test
@@ -453,15 +447,4 @@ public class AgentPlanTest {
Resource myToolAgain = agentPlan.getResource("myTool",
ResourceType.TOOL);
assertThat(myTool).isSameAs(myToolAgain);
}
-
- @Test
- public void testExtractIllegalResourceProviderFromAgent() throws Exception
{
- // Create an agent with resource annotations
- TestAgentWithIllegalPythonResource agent = new
TestAgentWithIllegalPythonResource();
-
- // Expect IllegalArgumentException when creating AgentPlan with
illegal resource
- assertThatThrownBy(() -> new AgentPlan(agent))
- .isInstanceOf(IllegalArgumentException.class)
- .hasMessageContaining("module should not be null or empty");
- }
}
diff --git
a/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java
b/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java
index 235de33..e0e576e 100644
---
a/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java
+++
b/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java
@@ -18,6 +18,7 @@
package org.apache.flink.agents.plan.compatibility;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.plan.AgentPlan;
import org.apache.flink.agents.plan.PythonFunction;
@@ -145,13 +146,14 @@ public class CreateJavaAgentPlanFromJson {
kwargs.put("name", "chat_model");
kwargs.put("prompt", "prompt");
kwargs.put("tools", List.of("add"));
- PythonResourceProvider resourceProvider =
- new PythonResourceProvider(
- "chat_model",
- ResourceType.CHAT_MODEL,
+ ResourceDescriptor chatModelDescriptor =
+ new ResourceDescriptor(
"flink_agents.plan.tests.compatibility.python_agent_plan_compatibility_test_agent",
"MockChatModel",
kwargs);
+ PythonResourceProvider resourceProvider =
+ new PythonResourceProvider(
+ "chat_model", ResourceType.CHAT_MODEL,
chatModelDescriptor);
Map<String, Object> serialized = new HashMap<>();
diff --git
a/plan/src/test/java/org/apache/flink/agents/plan/resource/python/PythonPromptTest.java
b/plan/src/test/java/org/apache/flink/agents/plan/resource/python/PythonPromptTest.java
new file mode 100644
index 0000000..aa65bb9
--- /dev/null
+++
b/plan/src/test/java/org/apache/flink/agents/plan/resource/python/PythonPromptTest.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.plan.resource.python;
+
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Test class for {@link PythonPrompt}. */
+public class PythonPromptTest {
+
+ @Test
+ public void testFromSerializedMapWithStringTemplate() {
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("template", "Hello, {name}!");
+
+ PythonPrompt prompt = PythonPrompt.fromSerializedMap(serialized);
+
+ assertThat(prompt).isNotNull();
+ // Test that the prompt works correctly
+ Map<String, String> kwargs = new HashMap<>();
+ kwargs.put("name", "Bob");
+ String formatted = prompt.formatString(kwargs);
+ assertThat(formatted).isEqualTo("Hello, Bob!");
+ }
+
+ @Test
+ public void testFromSerializedMapWithMessageListTemplate() {
+ // Create message map
+ Map<String, Object> systemMessage = new HashMap<>();
+ systemMessage.put("role", "system");
+ systemMessage.put("content", "You are a helpful assistant.");
+
+ Map<String, Object> userMessage = new HashMap<>();
+ userMessage.put("role", "user");
+ userMessage.put("content", "Hello!");
+
+ List<Map<String, Object>> messageList = new ArrayList<>();
+ messageList.add(systemMessage);
+ messageList.add(userMessage);
+
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("template", messageList);
+
+ PythonPrompt prompt = PythonPrompt.fromSerializedMap(serialized);
+
+ assertThat(prompt).isNotNull();
+
+ // Test that the prompt formats messages correctly
+ List<ChatMessage> formattedMessages =
+ prompt.formatMessages(MessageRole.SYSTEM, new HashMap<>());
+ assertThat(formattedMessages).hasSize(2);
+
assertThat(formattedMessages.get(0).getRole()).isEqualTo(MessageRole.SYSTEM);
+ assertThat(formattedMessages.get(0).getContent()).isEqualTo("You are a
helpful assistant.");
+
assertThat(formattedMessages.get(1).getRole()).isEqualTo(MessageRole.USER);
+ assertThat(formattedMessages.get(1).getContent()).isEqualTo("Hello!");
+ }
+
+ @Test
+ public void testFromSerializedMapWithMissingTemplateKey() {
+ Map<String, Object> serialized = new HashMap<>();
+ // Missing template key
+
+ assertThatThrownBy(() -> PythonPrompt.fromSerializedMap(serialized))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Map must contain 'template' key");
+ }
+
+ @Test
+ public void testFromSerializedMapWithEmptyList() {
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("template", new ArrayList<>());
+
+ assertThatThrownBy(() -> PythonPrompt.fromSerializedMap(serialized))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Template list cannot be empty");
+ }
+
+ @Test
+ public void testFromSerializedMapWithInvalidTemplateType() {
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("template", 123); // Invalid type
+
+ assertThatThrownBy(() -> PythonPrompt.fromSerializedMap(serialized))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Python prompt parsing failed. Template is not a
string or list.");
+ }
+}
diff --git
a/plan/src/test/java/org/apache/flink/agents/plan/resource/python/PythonToolTest.java
b/plan/src/test/java/org/apache/flink/agents/plan/resource/python/PythonToolTest.java
new file mode 100644
index 0000000..fe8cd91
--- /dev/null
+++
b/plan/src/test/java/org/apache/flink/agents/plan/resource/python/PythonToolTest.java
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.plan.resource.python;
+
+import org.apache.flink.agents.api.tools.ToolParameters;
+import org.apache.flink.agents.api.tools.ToolType;
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+import org.junit.jupiter.api.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Test class for {@link PythonTool}. */
+public class PythonToolTest {
+
+ @Test
+ public void testFromSerializedMapSuccess() throws JsonProcessingException {
+ // Create test data
+ Map<String, Object> argsSchema = new HashMap<>();
+ argsSchema.put("type", "function");
+ argsSchema.put("properties", new HashMap<>());
+
+ Map<String, Object> metadata = new HashMap<>();
+ metadata.put("name", "test_tool");
+ metadata.put("description", "A test tool for validation");
+ metadata.put("args_schema", argsSchema);
+
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("metadata", metadata);
+
+ // Test the method
+ PythonTool tool = PythonTool.fromSerializedMap(serialized);
+
+ // Verify the result
+ assertThat(tool).isNotNull();
+ assertThat(tool.getMetadata().getName()).isEqualTo("test_tool");
+ assertThat(tool.getMetadata().getDescription()).isEqualTo("A test tool
for validation");
+
assertThat(tool.getMetadata().getInputSchema()).contains("\"type\":\"function\"");
+ assertThat(tool.getToolType()).isEqualTo(ToolType.REMOTE_FUNCTION);
+ }
+
+ @Test
+ public void testFromSerializedMapWithComplexArgsSchema() throws
JsonProcessingException {
+ // Create complex args schema
+ Map<String, Object> properties = new HashMap<>();
+ Map<String, Object> nameProperty = new HashMap<>();
+ nameProperty.put("type", "string");
+ nameProperty.put("description", "The name parameter");
+ properties.put("name", nameProperty);
+
+ Map<String, Object> ageProperty = new HashMap<>();
+ ageProperty.put("type", "integer");
+ ageProperty.put("description", "The age parameter");
+ properties.put("age", ageProperty);
+
+ Map<String, Object> argsSchema = new HashMap<>();
+ argsSchema.put("type", "object");
+ argsSchema.put("properties", properties);
+ argsSchema.put("required", new String[] {"name"});
+
+ Map<String, Object> metadata = new HashMap<>();
+ metadata.put("name", "complex_tool");
+ metadata.put("description", "A tool with complex schema");
+ metadata.put("args_schema", argsSchema);
+
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("metadata", metadata);
+
+ // Test the method
+ PythonTool tool = PythonTool.fromSerializedMap(serialized);
+
+ // Verify the result
+ assertThat(tool).isNotNull();
+ assertThat(tool.getMetadata().getName()).isEqualTo("complex_tool");
+ assertThat(tool.getMetadata().getDescription()).isEqualTo("A tool with
complex schema");
+ String inputSchema = tool.getMetadata().getInputSchema();
+ assertThat(inputSchema).contains("\"name\"");
+ assertThat(inputSchema).contains("\"age\"");
+ assertThat(inputSchema).contains("\"required\"");
+ }
+
+ @Test
+ public void testFromSerializedMapWithNullMap() {
+ assertThatThrownBy(() -> PythonTool.fromSerializedMap(null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Serialized map cannot be null");
+ }
+
+ @Test
+ public void testFromSerializedMapMissingMetadata() {
+ Map<String, Object> serialized = new HashMap<>();
+ // Missing metadata key
+
+ assertThatThrownBy(() -> PythonTool.fromSerializedMap(serialized))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Map must contain 'metadata' key");
+ }
+
+ @Test
+ public void testFromSerializedMapWithInvalidMetadataType() {
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("metadata", "invalid_type"); // Should be Map
+
+ assertThatThrownBy(() -> PythonTool.fromSerializedMap(serialized))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("'metadata' must be a Map");
+ }
+
+ @Test
+ public void testFromSerializedMapMissingName() {
+ Map<String, Object> metadata = new HashMap<>();
+ metadata.put("description", "A test tool");
+ metadata.put("args_schema", new HashMap<>());
+ // Missing name
+
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("metadata", metadata);
+
+ assertThatThrownBy(() -> PythonTool.fromSerializedMap(serialized))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Metadata must contain 'name' key");
+ }
+
+ @Test
+ public void testFromSerializedMapMissingDescription() {
+ Map<String, Object> metadata = new HashMap<>();
+ metadata.put("name", "test_tool");
+ metadata.put("args_schema", new HashMap<>());
+ // Missing description
+
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("metadata", metadata);
+
+ assertThatThrownBy(() -> PythonTool.fromSerializedMap(serialized))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Metadata must contain 'description' key");
+ }
+
+ @Test
+ public void testFromSerializedMapMissingArgsSchema() {
+ Map<String, Object> metadata = new HashMap<>();
+ metadata.put("name", "test_tool");
+ metadata.put("description", "A test tool");
+ // Missing args_schema
+
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("metadata", metadata);
+
+ assertThatThrownBy(() -> PythonTool.fromSerializedMap(serialized))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("Metadata must contain 'args_schema' key");
+ }
+
+ @Test
+ public void testFromSerializedMapWithNullName() {
+ Map<String, Object> metadata = new HashMap<>();
+ metadata.put("name", null);
+ metadata.put("description", "A test tool");
+ metadata.put("args_schema", new HashMap<>());
+
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("metadata", metadata);
+
+ assertThatThrownBy(() -> PythonTool.fromSerializedMap(serialized))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("'name' cannot be null");
+ }
+
+ @Test
+ public void testFromSerializedMapWithNullDescription() {
+ Map<String, Object> metadata = new HashMap<>();
+ metadata.put("name", "test_tool");
+ metadata.put("description", null);
+ metadata.put("args_schema", new HashMap<>());
+
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("metadata", metadata);
+
+ assertThatThrownBy(() -> PythonTool.fromSerializedMap(serialized))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessage("'description' cannot be null");
+ }
+
+ @Test
+ public void testGetToolType() throws JsonProcessingException {
+ Map<String, Object> metadata = new HashMap<>();
+ metadata.put("name", "test_tool");
+ metadata.put("description", "A test tool");
+ metadata.put("args_schema", new HashMap<>());
+
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("metadata", metadata);
+
+ PythonTool tool = PythonTool.fromSerializedMap(serialized);
+
+ assertThat(tool.getToolType()).isEqualTo(ToolType.REMOTE_FUNCTION);
+ }
+
+ @Test
+ public void testCallMethodThrowsUnsupportedOperationException() throws
JsonProcessingException {
+ Map<String, Object> metadata = new HashMap<>();
+ metadata.put("name", "test_tool");
+ metadata.put("description", "A test tool");
+ metadata.put("args_schema", new HashMap<>());
+
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("metadata", metadata);
+
+ PythonTool tool = PythonTool.fromSerializedMap(serialized);
+
+ assertThatThrownBy(() -> tool.call(new ToolParameters()))
+ .isInstanceOf(UnsupportedOperationException.class)
+ .hasMessage("PythonTool does not support call method.");
+ }
+
+ @Test
+ public void testFromSerializedMapWithEmptyArgsSchema() throws
JsonProcessingException {
+ Map<String, Object> metadata = new HashMap<>();
+ metadata.put("name", "empty_schema_tool");
+ metadata.put("description", "Tool with empty schema");
+ metadata.put("args_schema", new HashMap<>());
+
+ Map<String, Object> serialized = new HashMap<>();
+ serialized.put("metadata", metadata);
+
+ PythonTool tool = PythonTool.fromSerializedMap(serialized);
+
+ assertThat(tool).isNotNull();
+
assertThat(tool.getMetadata().getName()).isEqualTo("empty_schema_tool");
+ assertThat(tool.getMetadata().getDescription()).isEqualTo("Tool with
empty schema");
+ assertThat(tool.getMetadata().getInputSchema()).isEqualTo("{}");
+ }
+}
diff --git
a/plan/src/test/java/org/apache/flink/agents/plan/serializer/ResourceProviderDeserializerTest.java
b/plan/src/test/java/org/apache/flink/agents/plan/serializer/ResourceProviderDeserializerTest.java
index 9dd40fe..30bfbaa 100644
---
a/plan/src/test/java/org/apache/flink/agents/plan/serializer/ResourceProviderDeserializerTest.java
+++
b/plan/src/test/java/org/apache/flink/agents/plan/serializer/ResourceProviderDeserializerTest.java
@@ -18,6 +18,7 @@
package org.apache.flink.agents.plan.serializer;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider;
import
org.apache.flink.agents.plan.resourceprovider.PythonSerializableResourceProvider;
@@ -48,16 +49,16 @@ public class ResourceProviderDeserializerTest {
PythonResourceProvider pythonResourceProvider =
(PythonResourceProvider) provider;
assertEquals("my_chat_model", pythonResourceProvider.getName());
assertEquals(ResourceType.CHAT_MODEL,
pythonResourceProvider.getType());
- assertEquals(
- "flink_agents.plan.tests.test_resource_provider",
- pythonResourceProvider.getModule());
- assertEquals("MockChatModelImpl", pythonResourceProvider.getClazz());
+
+ ResourceDescriptor descriptor = pythonResourceProvider.getDescriptor();
+ assertEquals("flink_agents.plan.tests.test_resource_provider",
descriptor.getModule());
+ assertEquals("MockChatModelImpl", descriptor.getClazz());
Map<String, Object> kwargs = new HashMap<>();
kwargs.put("host", "8.8.8.8");
kwargs.put("desc", "mock chat model");
- assertEquals(kwargs, pythonResourceProvider.getKwargs());
+ assertEquals(kwargs, descriptor.getInitialArguments());
}
@Test
diff --git
a/plan/src/test/java/org/apache/flink/agents/plan/serializer/ResourceProviderSerializerTest.java
b/plan/src/test/java/org/apache/flink/agents/plan/serializer/ResourceProviderSerializerTest.java
index d4287ba..c7424c9 100644
---
a/plan/src/test/java/org/apache/flink/agents/plan/serializer/ResourceProviderSerializerTest.java
+++
b/plan/src/test/java/org/apache/flink/agents/plan/serializer/ResourceProviderSerializerTest.java
@@ -18,6 +18,7 @@
package org.apache.flink.agents.plan.serializer;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider;
import
org.apache.flink.agents.plan.resourceprovider.PythonSerializableResourceProvider;
@@ -60,13 +61,14 @@ public class ResourceProviderSerializerTest {
kwargs.put("host", "8.8.8.8");
kwargs.put("desc", "mock chat model");
// Create a resource provider.
- PythonResourceProvider provider =
- new PythonResourceProvider(
- "my_chat_model",
- ResourceType.CHAT_MODEL,
+ ResourceDescriptor mockChatModelImpl =
+ new ResourceDescriptor(
"flink_agents.plan.tests.test_resource_provider",
"MockChatModelImpl",
kwargs);
+ PythonResourceProvider provider =
+ new PythonResourceProvider(
+ "my_chat_model", ResourceType.CHAT_MODEL,
mockChatModelImpl);
// Serialize the resource provider to JSON
String json =
diff --git
a/plan/src/test/resources/resource_providers/python_resource_provider.json
b/plan/src/test/resources/resource_providers/python_resource_provider.json
index 91e7b71..51c6863 100644
--- a/plan/src/test/resources/resource_providers/python_resource_provider.json
+++ b/plan/src/test/resources/resource_providers/python_resource_provider.json
@@ -1,11 +1,13 @@
{
"name" : "my_chat_model",
"type" : "chat_model",
- "module" : "flink_agents.plan.tests.test_resource_provider",
- "clazz" : "MockChatModelImpl",
- "kwargs" : {
- "host" : "8.8.8.8",
- "desc" : "mock chat model"
+ "descriptor" : {
+ "target_clazz" : "MockChatModelImpl",
+ "target_module" : "flink_agents.plan.tests.test_resource_provider",
+ "arguments" : {
+ "host" : "8.8.8.8",
+ "desc" : "mock chat model"
+ }
},
"__resource_provider_type__" : "PythonResourceProvider"
}
\ No newline at end of file
diff --git a/python/flink_agents/api/chat_models/java_chat_model.py
b/python/flink_agents/api/chat_models/java_chat_model.py
new file mode 100644
index 0000000..4c7415b
--- /dev/null
+++ b/python/flink_agents/api/chat_models/java_chat_model.py
@@ -0,0 +1,46 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+from flink_agents.api.chat_models.chat_model import (
+ BaseChatModelConnection,
+ BaseChatModelSetup,
+)
+from flink_agents.api.decorators import java_resource
+
+
+@java_resource
+class JavaChatModelConnection(BaseChatModelConnection):
+ """Java-based implementation of ChatModelConnection that wraps a Java chat
model
+ object.
+
+ This class serves as a bridge between Python and Java chat model
environments, but
+ unlike JavaChatModelSetup, it does not provide direct chat functionality
in Python.
+ """
+
+ java_class_name: str=""
+
+@java_resource
+class JavaChatModelSetup(BaseChatModelSetup):
+ """Java-based implementation of ChatModelSetup that bridges Python and
Java chat
+ model functionality.
+
+ This class wraps a Java chat model setup object and provides Python
interface
+ compatibility while delegating actual chat operations to the underlying
Java
+ implementation.
+ """
+
+ java_class_name: str=""
diff --git a/python/flink_agents/api/decorators.py
b/python/flink_agents/api/decorators.py
index b45e55a..3289d8a 100644
--- a/python/flink_agents/api/decorators.py
+++ b/python/flink_agents/api/decorators.py
@@ -188,3 +188,8 @@ def vector_store(func: Callable) -> Callable:
"""
func._is_vector_store = True
return func
+
+def java_resource(cls: Type) -> Type:
+ """Decorator to mark a class as Java resource."""
+ cls._is_java_resource = True
+ return cls
diff --git a/python/flink_agents/api/resource.py
b/python/flink_agents/api/resource.py
index f50f76c..056a850 100644
--- a/python/flink_agents/api/resource.py
+++ b/python/flink_agents/api/resource.py
@@ -15,11 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
+import importlib
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Dict, Type
-from pydantic import BaseModel, Field, model_validator
+from pydantic import BaseModel, Field, model_serializer, model_validator
class ResourceType(Enum):
@@ -73,23 +74,100 @@ class SerializableResource(Resource, ABC):
return self
-class ResourceDescriptor:
- """Descriptor of resource, includes the class and the initialize
arguments."""
+class ResourceDescriptor(BaseModel):
+ """Descriptor for Resource instances, storing metadata for serialization
and
+ instantiation.
- _clazz: Type[Resource]
- _arguments: Dict[str, Any]
-
- def __init__(self, *, clazz: Type[Resource], **arguments: Any) -> None:
- """Init method."""
- self._clazz = clazz
- self._arguments = arguments
-
- @property
- def clazz(self) -> Type[Resource]:
- """Get the class of the resource."""
- return self._clazz
-
- @property
- def arguments(self) -> Dict[str, Any]:
- """Get the initialize arguments of the resource."""
- return self._arguments
+ Attributes:
+ clazz: The Python Resource class name.
+ arguments: Dictionary containing resource initialization parameters.
+ """
+ clazz: Type[Resource] | None = None
+ arguments: Dict[str, Any]
+
+ def __init__(self, /,
+ *,
+ clazz: Type[Resource] | None = None,
+ **arguments: Any) -> None:
+ """Initialize ResourceDescriptor.
+
+ Args:
+ clazz: The Resource class type to create a descriptor for.
+ **arguments: Additional arguments for resource initialization.
+
+ Usage:
+ descriptor = ResourceDescriptor(clazz=YourResourceClass,
+ param1="value1",
+ param2="value2")
+ """
+ super().__init__(clazz=clazz, arguments=arguments)
+
+ @model_serializer
+ def __custom_serializer(self) -> dict[str, Any]:
+ """Serialize ResourceDescriptor to dictionary.
+
+ Returns:
+ Dictionary containing python_clazz, python_module, java_clazz, and
+ arguments.
+ """
+ return {
+ "target_clazz": self.clazz.__name__,
+ "target_module": self.clazz.__module__,
+ "arguments": self.arguments,
+ }
+
+ @model_validator(mode="before")
+ @classmethod
+ def __custom_deserialize(cls, data: dict[str, Any]) -> dict[str, Any]:
+ """Deserialize data to ResourceDescriptor fields.
+
+ Handles both new format (with python_module) and legacy format
+ (full path in python_clazz).
+
+ Args:
+ data: Dictionary or other data to deserialize.
+
+ Returns:
+ Dictionary with normalized field structure.
+ """
+ if "clazz" in data and data["clazz"] is not None:
+ return data
+
+ args = data["arguments"]
+ python_clazz = args.pop("target_clazz")
+ python_module = args.pop("target_module")
+ data["clazz"] = get_resource_class(python_module, python_clazz)
+ data["arguments"] = args["arguments"]
+ return data
+
+ def __eq__(self, other: object) -> bool:
+ """Compare ResourceDescriptor objects, ignoring private _clazz field.
+
+ This ensures that deserialized objects (with _clazz=None) can be
compared
+ equal to runtime objects (with _clazz set) as long as their
serializable
+ fields match.
+ """
+ if not isinstance(other, ResourceDescriptor):
+ return False
+ return (
+ self.clazz == other.clazz
+ and self.arguments == other.arguments
+ )
+
+ def __hash__(self) -> int:
+ """Generate hash for ResourceDescriptor."""
+ return hash((self.clazz, tuple(sorted(self.arguments.items()))))
+
+
+def get_resource_class(module_path: str, class_name: str) -> Type[Resource]:
+ """Get Resource class from separate module path and class name.
+
+ Args:
+ module_path: Python module path (e.g., 'your.module.path').
+ class_name: Class name (e.g., 'YourResourceClass').
+
+ Returns:
+ The Resource class type.
+ """
+ module = importlib.import_module(module_path)
+ return getattr(module, class_name)
diff --git
a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/__init__.py
b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/__init__.py
new file mode 100644
index 0000000..e154fad
--- /dev/null
+++
b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/__init__.py
@@ -0,0 +1,17 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
diff --git
a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_agent.py
b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_agent.py
new file mode 100644
index 0000000..2159068
--- /dev/null
+++
b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_agent.py
@@ -0,0 +1,151 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+import os
+
+from flink_agents.api.agent import Agent
+from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.chat_models.java_chat_model import (
+ JavaChatModelConnection,
+ JavaChatModelSetup,
+)
+from flink_agents.api.decorators import (
+ action,
+ chat_model_connection,
+ chat_model_setup,
+ prompt,
+ tool,
+)
+from flink_agents.api.events.chat_event import ChatRequestEvent,
ChatResponseEvent
+from flink_agents.api.events.event import InputEvent, OutputEvent
+from flink_agents.api.prompts.prompt import Prompt
+from flink_agents.api.resource import ResourceDescriptor
+from flink_agents.api.runner_context import RunnerContext
+
+
+class ChatModelCrossLanguageAgent(Agent):
+ """Example agent demonstrating cross-language integration testing.
+
+ This test includes:
+ - Python scheduling Java ChatModel
+ - Java ChatModel scheduling Python Prompt
+ - Java ChatModel scheduling Python Tool
+ """
+
+ @prompt
+ @staticmethod
+ def from_text_prompt() -> Prompt:
+ """Prompt for instruction."""
+ return Prompt.from_text("Please answer the user's question.")
+
+ @prompt
+ @staticmethod
+ def from_messages_prompt() -> Prompt:
+ """Prompt for instruction."""
+ return Prompt.from_messages(
+ messages=[
+ ChatMessage(
+ role=MessageRole.SYSTEM,
+ content="Please answer the user's question.",
+ ),
+ ],
+ )
+
+ @chat_model_connection
+ @staticmethod
+ def ollama_connection() -> ResourceDescriptor:
+ """ChatModelConnection responsible for ollama model service
connection."""
+ return ResourceDescriptor(
+ clazz=JavaChatModelConnection,
+
java_clazz="org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelConnection",
+ endpoint="http://localhost:11434",
+ requestTimeout=120,
+ )
+
+ @chat_model_setup
+ @staticmethod
+ def math_chat_model() -> ResourceDescriptor:
+ """ChatModel which focus on math, and reuse ChatModelConnection."""
+ return ResourceDescriptor(
+ clazz=JavaChatModelSetup,
+ connection="ollama_connection",
+
java_clazz="org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelSetup",
+ model=os.environ.get("OLLAMA_CHAT_MODEL", "qwen3:1.7b"),
+ prompt="from_messages_prompt",
+ tools=["add"],
+ extract_reasoning=True,
+ )
+
+ @chat_model_setup
+ @staticmethod
+ def creative_chat_model() -> ResourceDescriptor:
+ """ChatModel which focus on text generate, and reuse
ChatModelConnection."""
+ return ResourceDescriptor(
+ clazz=JavaChatModelSetup,
+ connection="ollama_connection",
+
java_clazz="org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelSetup",
+ model=os.environ.get("OLLAMA_CHAT_MODEL", "qwen3:1.7b"),
+ prompt="from_text_prompt",
+ extract_reasoning=True,
+ )
+
+ @tool
+ @staticmethod
+ def add(a: int, b: int) -> int:
+ """Calculate the sum of a and b.
+
+ Parameters
+ ----------
+ a : int
+ The first operand
+ b : int
+ The second operand
+
+ Returns:
+ -------
+ int:
+ The sum of a and b
+ """
+ return a + b
+
+ @action(InputEvent)
+ @staticmethod
+ def process_input(event: InputEvent, ctx: RunnerContext) -> None:
+ """User defined action for processing input.
+
+ In this action, we will send ChatRequestEvent to trigger built-in
actions.
+ """
+ input_text = event.input.lower()
+ model_name = (
+ "math_chat_model"
+ if ("calculate" in input_text or "sum" in input_text)
+ else "creative_chat_model"
+ )
+ ctx.send_event(
+ ChatRequestEvent(
+ model=model_name,
+ messages=[ChatMessage(role=MessageRole.USER,
content=event.input)],
+ )
+ )
+
+ @action(ChatResponseEvent)
+ @staticmethod
+ def process_chat_response(event: ChatResponseEvent, ctx: RunnerContext) ->
None:
+ """User defined action for processing chat model response."""
+ input = event.response
+ if event.response and input.content:
+ ctx.send_event(OutputEvent(output=input.content))
diff --git
a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_test.py
b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_test.py
new file mode 100644
index 0000000..3ff79cd
--- /dev/null
+++
b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_test.py
@@ -0,0 +1,100 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+import os
+from pathlib import Path
+
+import pytest
+from pyflink.common import Encoder, WatermarkStrategy
+from pyflink.common.typeinfo import Types
+from pyflink.datastream import (
+ RuntimeExecutionMode,
+ StreamExecutionEnvironment,
+)
+from pyflink.datastream.connectors.file_system import (
+ FileSource,
+ StreamFormat,
+ StreamingFileSink,
+)
+
+from flink_agents.api.execution_environment import AgentsExecutionEnvironment
+from
flink_agents.e2e_tests.e2e_tests_resource_cross_language.chat_model_cross_language_agent
import (
+ ChatModelCrossLanguageAgent,
+)
+from flink_agents.e2e_tests.test_utils import pull_model
+
+current_dir = Path(__file__).parent
+
+OLLAMA_MODEL = os.environ.get("OLLAMA_CHAT_MODEL", "qwen3:1.7b")
+os.environ["OLLAMA_CHAT_MODEL"] = OLLAMA_MODEL
+
+client = pull_model(OLLAMA_MODEL)
+
[email protected](client is None, reason="Ollama client is not available or
test model is missing.")
+def test_java_chat_model_integration(tmp_path: Path) -> None: # noqa: D103
+ env = StreamExecutionEnvironment.get_execution_environment()
+ env.set_runtime_mode(RuntimeExecutionMode.STREAMING)
+ env.set_parallelism(1)
+
+ # currently, bounded source is not supported due to runtime
implementation, so
+ # we use continuous file source here.
+ input_datastream = env.from_source(
+ source=FileSource.for_record_stream_format(
+ StreamFormat.text_line_format(),
f"file:///{current_dir}/../resources/java_chat_module_input"
+ ).build(),
+ watermark_strategy=WatermarkStrategy.no_watermarks(),
+ source_name="streaming_agent_example",
+ )
+
+ deserialize_datastream = input_datastream.map(
+ lambda x: str(x)
+ )
+
+ agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env)
+ output_datastream = (
+ agents_env.from_datastream(
+ input=deserialize_datastream, key_selector= lambda x: "orderKey"
+ )
+ .apply(ChatModelCrossLanguageAgent())
+ .to_datastream()
+ )
+
+ result_dir = tmp_path / "results"
+ result_dir.mkdir(parents=True, exist_ok=True)
+
+ (output_datastream.map(lambda x: str(x).replace('\n', '')
+ .replace('\r', ''), Types.STRING()).add_sink(
+ StreamingFileSink.for_row_format(
+ base_path=str(result_dir.absolute()),
+ encoder=Encoder.simple_string_encoder(),
+ ).build()
+ ))
+
+ agents_env.execute()
+
+ actual_result = []
+ for file in result_dir.iterdir():
+ if file.is_dir():
+ for child in file.iterdir():
+ with child.open() as f:
+ actual_result.extend(f.readlines())
+ if file.is_file():
+ with file.open() as f:
+ actual_result.extend(f.readlines())
+
+ assert "3" in actual_result[0]
+ assert "cat" in actual_result[1]
diff --git
a/python/flink_agents/e2e_tests/resources/java_chat_module_input/input_data.txt
b/python/flink_agents/e2e_tests/resources/java_chat_module_input/input_data.txt
new file mode 100644
index 0000000..3614125
--- /dev/null
+++
b/python/flink_agents/e2e_tests/resources/java_chat_module_input/input_data.txt
@@ -0,0 +1,2 @@
+calculate the sum of 1 and 2.
+Tell me a joke about cats.
\ No newline at end of file
diff --git a/python/flink_agents/plan/agent_plan.py
b/python/flink_agents/plan/agent_plan.py
index 8f6c48d..7ff2f6d 100644
--- a/python/flink_agents/plan/agent_plan.py
+++ b/python/flink_agents/plan/agent_plan.py
@@ -58,6 +58,7 @@ class AgentPlan(BaseModel):
resource_providers: Dict[ResourceType, Dict[str, ResourceProvider]] | None
= None
config: AgentConfiguration | None = None
__resources: Dict[ResourceType, Dict[str, Resource]] = {}
+ __j_resource_adapter: Any = None
@field_serializer("resource_providers")
def __serialize_resource_providers(
@@ -210,12 +211,18 @@ class AgentPlan(BaseModel):
self.__resources[type] = {}
if name not in self.__resources[type]:
resource_provider = self.resource_providers[type][name]
+ if isinstance(resource_provider, JavaResourceProvider):
+
resource_provider.set_java_resource_adapter(self.__j_resource_adapter)
resource = resource_provider.provide(
get_resource=self.get_resource, config=self.config
)
self.__resources[type][name] = resource
return self.__resources[type][name]
+ def set_java_resource_adapter(self, j_resource_adapter: Any) -> None:
+ """Set java resource adapter for java resource provider."""
+ self.__j_resource_adapter = j_resource_adapter
+
def _get_actions(agent: Agent) -> List[Action]:
"""Extract all registered agent actions from an agent.
@@ -284,9 +291,15 @@ def _get_resource_providers(agent: Agent) ->
List[ResourceProvider]:
value = value.__func__
if callable(value):
- resource_providers.append(
- PythonResourceProvider.get(name=name, descriptor=value())
- )
+ descriptor = value()
+ if hasattr(descriptor.clazz, "_is_java_resource"):
+ resource_providers.append(
+ JavaResourceProvider.get(name=name, descriptor=value())
+ )
+ else:
+ resource_providers.append(
+ PythonResourceProvider.get(name=name,
descriptor=value())
+ )
elif hasattr(value, "_is_tool"):
if isinstance(value, staticmethod):
@@ -341,9 +354,14 @@ def _get_resource_providers(agent: Agent) ->
List[ResourceProvider]:
ResourceType.VECTOR_STORE,
]:
for name, descriptor in agent.resources[resource_type].items():
- resource_providers.append(
- PythonResourceProvider.get(name=name, descriptor=descriptor)
- )
+ if hasattr(descriptor.clazz, "_is_java_resource"):
+ resource_providers.append(
+ JavaResourceProvider.get(name=name, descriptor=descriptor)
+ )
+ else:
+ resource_providers.append(
+ PythonResourceProvider.get(name=name,
descriptor=descriptor)
+ )
return resource_providers
diff --git a/python/flink_agents/plan/resource_provider.py
b/python/flink_agents/plan/resource_provider.py
index 2d14d7e..1263cc2 100644
--- a/python/flink_agents/plan/resource_provider.py
+++ b/python/flink_agents/plan/resource_provider.py
@@ -27,6 +27,7 @@ from flink_agents.api.resource import (
ResourceDescriptor,
ResourceType,
SerializableResource,
+ get_resource_class,
)
from flink_agents.plan.configuration import AgentConfiguration
@@ -89,9 +90,7 @@ class PythonResourceProvider(ResourceProvider):
The initialization arguments of the resource.
"""
- module: str
- clazz: str
- kwargs: Dict[str, Any]
+ descriptor: ResourceDescriptor
@staticmethod
def get(name: str, descriptor: ResourceDescriptor) ->
"PythonResourceProvider":
@@ -100,23 +99,20 @@ class PythonResourceProvider(ResourceProvider):
return PythonResourceProvider(
name=name,
type=clazz.resource_type(),
- module=clazz.__module__,
- clazz=clazz.__name__,
- kwargs=descriptor.arguments,
+ descriptor= descriptor,
)
def provide(self, get_resource: Callable, config: AgentConfiguration) ->
Resource:
"""Create resource in runtime."""
- module = importlib.import_module(self.module)
- cls = getattr(module, self.clazz)
+ cls = self.descriptor.clazz
final_kwargs = {}
- resource_class_config = config.get_config_data_by_prefix(self.clazz)
+ resource_class_config = config.get_config_data_by_prefix(cls.__name__)
final_kwargs.update(resource_class_config)
- final_kwargs.update(self.kwargs)
+ final_kwargs.update(self.descriptor.arguments)
resource = cls(**final_kwargs, get_resource=get_resource)
return resource
@@ -158,21 +154,60 @@ class
PythonSerializableResourceProvider(SerializableResourceProvider):
self.resource = clazz.model_validate(self.serialized)
return self.resource
+JAVA_RESOURCE_MAPPING: dict[ResourceType, str] = {
+ ResourceType.CHAT_MODEL:
"flink_agents.runtime.java.java_chat_model.JavaChatModelSetupImpl",
+ ResourceType.CHAT_MODEL_CONNECTION:
"flink_agents.runtime.java.java_chat_model.JavaChatModelConnectionImpl",
+}
-# TODO: implementation
class JavaResourceProvider(ResourceProvider):
"""Represent Resource Provider declared by Java.
Currently, this class only used for deserializing Java agent plan json
"""
+ descriptor: ResourceDescriptor
+ _j_resource_adapter: Any = None
+
+ @staticmethod
+ def get(name: str, descriptor: ResourceDescriptor) ->
"JavaResourceProvider":
+ """Create JavaResourceProvider instance."""
+ wrapper_clazz = descriptor.clazz
+ kwargs = {}
+ kwargs.update(descriptor.arguments)
+
+ clazz = descriptor.arguments.get("java_clazz", "")
+ if len(clazz) <1:
+ err_msg = f"java_clazz are not set for {wrapper_clazz.__name__}"
+ raise KeyError(err_msg)
+
+ return JavaResourceProvider(
+ name=name,
+ type=wrapper_clazz.resource_type(),
+ descriptor=descriptor,
+ )
+
def provide(self, get_resource: Callable, config: AgentConfiguration) ->
Resource:
"""Create resource in runtime."""
- err_msg = (
- "Currently, flink-agents doesn't support create resource "
- "by JavaResourceProvider in python."
- )
- raise NotImplementedError(err_msg)
+ if not self._j_resource_adapter:
+ err_msg = "java resource adapter is not set"
+ raise RuntimeError(err_msg)
+
+ j_resource = self._j_resource_adapter.getResource(self.name,
self.type.value)
+
+ class_path = JAVA_RESOURCE_MAPPING.get(self.type)
+ if not class_path:
+ err_msg = f"No Java resource mapping found for {self.type.value}"
+ raise ValueError(err_msg)
+ module_path, class_name = class_path.rsplit(".", 1)
+ cls = get_resource_class(module_path, class_name)
+ kwargs = self.descriptor.arguments
+
+ return cls(**kwargs, get_resource=get_resource, j_resource=j_resource,
j_resource_adapter= self._j_resource_adapter)
+
+
+ def set_java_resource_adapter(self, j_resource_adapter: Any) -> None:
+ """Set java resource adapter for java resource initialization."""
+ self._j_resource_adapter = j_resource_adapter
# TODO: implementation
diff --git a/python/flink_agents/plan/tests/resources/agent_plan.json
b/python/flink_agents/plan/tests/resources/agent_plan.json
index a885e9d..61daa60 100644
--- a/python/flink_agents/plan/tests/resources/agent_plan.json
+++ b/python/flink_agents/plan/tests/resources/agent_plan.json
@@ -89,12 +89,14 @@
"mock": {
"name": "mock",
"type": "chat_model",
- "module": "flink_agents.plan.tests.test_agent_plan",
- "clazz": "MockChatModelImpl",
- "kwargs": {
- "host": "8.8.8.8",
- "desc": "mock resource just for testing.",
- "connection": "mock"
+ "descriptor": {
+ "target_module": "flink_agents.plan.tests.test_agent_plan",
+ "target_clazz": "MockChatModelImpl",
+ "arguments": {
+ "host": "8.8.8.8",
+ "desc": "mock resource just for testing.",
+ "connection": "mock"
+ }
},
"__resource_provider_type__": "PythonResourceProvider"
}
@@ -103,11 +105,13 @@
"mock_embedding": {
"name": "mock_embedding",
"type": "embedding_model",
- "module": "flink_agents.plan.tests.test_agent_plan",
- "clazz": "MockEmbeddingModelSetup",
- "kwargs": {
- "model": "test-model",
- "connection": "mock_embedding_conn"
+ "descriptor": {
+ "target_module": "flink_agents.plan.tests.test_agent_plan",
+ "target_clazz": "MockEmbeddingModelSetup",
+ "arguments": {
+ "model": "test-model",
+ "connection": "mock_embedding_conn"
+ }
},
"__resource_provider_type__": "PythonResourceProvider"
}
@@ -116,10 +120,12 @@
"mock_embedding_conn": {
"name": "mock_embedding_conn",
"type": "embedding_model_connection",
- "module": "flink_agents.plan.tests.test_agent_plan",
- "clazz": "MockEmbeddingModelConnection",
- "kwargs": {
- "api_key": "mock-api-key"
+ "descriptor": {
+ "target_module": "flink_agents.plan.tests.test_agent_plan",
+ "target_clazz": "MockEmbeddingModelConnection",
+ "arguments": {
+ "api_key": "mock-api-key"
+ }
},
"__resource_provider_type__": "PythonResourceProvider"
}
@@ -128,13 +134,15 @@
"mock_vector_store": {
"name": "mock_vector_store",
"type": "vector_store",
- "module": "flink_agents.plan.tests.test_agent_plan",
- "clazz": "MockVectorStore",
- "kwargs": {
- "embedding_model": "mock_embedding",
- "host": "localhost",
- "port": 8000,
- "collection_name": "test_collection"
+ "descriptor": {
+ "target_module": "flink_agents.plan.tests.test_agent_plan",
+ "target_clazz": "MockVectorStore",
+ "arguments": {
+ "embedding_model": "mock_embedding",
+ "host": "localhost",
+ "port": 8000,
+ "collection_name": "test_collection"
+ }
},
"__resource_provider_type__": "PythonResourceProvider"
}
diff --git a/python/flink_agents/plan/tests/resources/resource_provider.json
b/python/flink_agents/plan/tests/resources/resource_provider.json
index c5ce41e..6326cbb 100644
--- a/python/flink_agents/plan/tests/resources/resource_provider.json
+++ b/python/flink_agents/plan/tests/resources/resource_provider.json
@@ -1,10 +1,12 @@
{
"name": "mock",
"type": "chat_model",
- "module": "flink_agents.plan.tests.test_resource_provider",
- "clazz": "MockChatModelImpl",
- "kwargs": {
- "host": "8.8.8.8",
- "desc": "mock chat model"
+ "descriptor": {
+ "target_module": "flink_agents.plan.tests.test_resource_provider",
+ "target_clazz": "MockChatModelImpl",
+ "arguments": {
+ "host": "8.8.8.8",
+ "desc": "mock chat model"
+ }
}
}
\ No newline at end of file
diff --git a/python/flink_agents/plan/tests/test_resource_provider.py
b/python/flink_agents/plan/tests/test_resource_provider.py
index fd3f434..5c64f87 100644
--- a/python/flink_agents/plan/tests/test_resource_provider.py
+++ b/python/flink_agents/plan/tests/test_resource_provider.py
@@ -20,7 +20,7 @@ from pathlib import Path
import pytest
-from flink_agents.api.resource import Resource, ResourceType
+from flink_agents.api.resource import Resource, ResourceDescriptor,
ResourceType
from flink_agents.plan.resource_provider import PythonResourceProvider,
ResourceProvider
current_dir = Path(__file__).parent
@@ -40,9 +40,7 @@ def resource_provider() -> ResourceProvider: # noqa: D103
return PythonResourceProvider(
name="mock",
type=MockChatModelImpl.resource_type(),
- module=MockChatModelImpl.__module__,
- clazz=MockChatModelImpl.__name__,
- kwargs={"host": "8.8.8.8", "desc": "mock chat model"},
+ descriptor=ResourceDescriptor(clazz=MockChatModelImpl, host="8.8.8.8",
desc="mock chat model"),
)
diff --git a/python/flink_agents/runtime/flink_runner_context.py
b/python/flink_agents/runtime/flink_runner_context.py
index 56f2ebb..69334cc 100644
--- a/python/flink_agents/runtime/flink_runner_context.py
+++ b/python/flink_agents/runtime/flink_runner_context.py
@@ -41,7 +41,7 @@ class FlinkRunnerContext(RunnerContext):
__agent_plan: AgentPlan
def __init__(
- self, j_runner_context: Any, agent_plan_json: str, executor:
ThreadPoolExecutor
+ self, j_runner_context: Any, agent_plan_json: str, executor:
ThreadPoolExecutor, j_resource_adapter: Any
) -> None:
"""Initialize a flink runner context with the given java runner
context.
@@ -52,6 +52,7 @@ class FlinkRunnerContext(RunnerContext):
"""
self._j_runner_context = j_runner_context
self.__agent_plan = AgentPlan.model_validate_json(agent_plan_json)
+ self.__agent_plan.set_java_resource_adapter(j_resource_adapter)
self.executor = executor
@override
@@ -182,10 +183,10 @@ class FlinkRunnerContext(RunnerContext):
def create_flink_runner_context(
- j_runner_context: Any, agent_plan_json: str, executor: ThreadPoolExecutor
+ j_runner_context: Any, agent_plan_json: str, executor: ThreadPoolExecutor,
j_resource_adapter: Any
) -> FlinkRunnerContext:
"""Used to create a FlinkRunnerContext Python object in Pemja
environment."""
- return FlinkRunnerContext(j_runner_context, agent_plan_json, executor)
+ return FlinkRunnerContext(j_runner_context, agent_plan_json, executor,
j_resource_adapter)
def create_async_thread_pool() -> ThreadPoolExecutor:
diff --git a/python/flink_agents/runtime/java/java_chat_model.py
b/python/flink_agents/runtime/java/java_chat_model.py
new file mode 100644
index 0000000..a30e54c
--- /dev/null
+++ b/python/flink_agents/runtime/java/java_chat_model.py
@@ -0,0 +1,124 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+from typing import Any, Dict, List, Sequence
+
+from typing_extensions import override
+
+from flink_agents.api.chat_message import ChatMessage
+from flink_agents.api.chat_models.java_chat_model import (
+ JavaChatModelConnection,
+ JavaChatModelSetup,
+)
+from flink_agents.api.tools.tool import Tool
+
+
+class JavaChatModelConnectionImpl(JavaChatModelConnection):
+ """Java-based implementation of ChatModelConnection that wraps a Java chat
model
+ object.
+
+ This class serves as a bridge between Python and Java chat model
environments, but
+ unlike JavaChatModelSetup, it does not provide direct chat functionality
in Python.
+ """
+
+
+ j_resource: Any
+
+ @override
+ def chat(
+ self,
+ messages: Sequence[ChatMessage],
+ tools: List[Tool] | None = None,
+ **kwargs: Any,
+ ) -> ChatMessage:
+ """Chat method that throws UnsupportedOperationException.
+
+ This connection serves as a Java resource wrapper only.
+ Chat operations should be performed on the Java side using the
underlying Java
+ chat model object.
+ """
+ err_msg = (
+ "Chat method of JavaChatModelConnection cannot be called directly
from Python runtime. "
+ "This connection serves as a Java resource wrapper only. "
+ "Chat operations should be performed on the Java side using the
underlying Java chat model object."
+ )
+ raise NotImplementedError(err_msg)
+
+
+class JavaChatModelSetupImpl(JavaChatModelSetup):
+ """Java-based implementation of ChatModelSetup that bridges Python and
Java chat
+ model functionality.
+
+ This class wraps a Java chat model setup object and provides Python
interface
+ compatibility while delegating actual chat operations to the underlying
Java
+ implementation.
+ """
+
+ _j_resource: Any
+ _j_resource_adapter: Any
+
+ def __init__(self, j_resource: Any, j_resource_adapter: Any, **kwargs:
Any) -> None:
+ """Creates a new JavaChatModelSetup.
+
+ Args:
+ j_resource: The Java resource object
+ j_resource_adapter: The Java resource adapter for method invocation
+ **kwargs: Additional keyword arguments
+ """
+ super().__init__(**kwargs)
+ self._j_resource=j_resource
+ self._j_resource_adapter=j_resource_adapter
+
+ @property
+ @override
+ def model_kwargs(self) -> Dict[str, Any]:
+ """Return chat model settings.
+
+ Returns:
+ Empty dictionary as parameters are managed by Java side
+ """
+ return {}
+
+ @override
+ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) ->
ChatMessage:
+ """Execute chat conversation by delegating to Java implementation.
+
+ 1. Convert Python messages to Java format
+ 2. Call Java chat method
+ 3. Convert Java response back to Python format
+
+ Parameters
+ ----------
+ messages : Sequence[ChatMessage]
+ Input message sequence
+ **kwargs : Any
+ Additional parameters passed to the model service
+
+ Returns:
+ -------
+ ChatMessage
+ Model response message
+ """
+ # Convert Python messages to Java format
+ java_messages =
[self._j_resource_adapter.fromPythonChatMessage(message) for message in
messages]
+ j_response_message = self._j_resource.chat(java_messages, kwargs)
+
+ # Convert Java response back to Python format
+ from flink_agents.runtime.python_java_utils import (
+ from_java_chat_message,
+ )
+ return from_java_chat_message(j_response_message)
diff --git a/python/flink_agents/runtime/python_java_utils.py
b/python/flink_agents/runtime/python_java_utils.py
index 3c38985..9debb3c 100644
--- a/python/flink_agents/runtime/python_java_utils.py
+++ b/python/flink_agents/runtime/python_java_utils.py
@@ -150,6 +150,17 @@ def to_java_chat_message(chat_message: ChatMessage) -> Any:
return j_chat_message
+# TODO: Replace this with `to_java_chat_message()` when the `find_class` bug
is fixed.
+def update_java_chat_message(chat_message: ChatMessage, j_chat_message: Any)
-> str:
+ """Update a Java chat message using Python chat message."""
+ j_chat_message.setContent(chat_message.content)
+ j_chat_message.setExtraArgs(chat_message.extra_args)
+ if chat_message.tool_calls:
+ tool_calls = [normalize_tool_call_id(tool_call) for tool_call in
chat_message.tool_calls]
+ j_chat_message.setToolCalls(tool_calls)
+
+ return chat_message.role.value
+
def call_method(obj: Any, method_name: str, kwargs: Dict[str, Any]) -> Any:
"""Calls a method on `obj` by name and passes in positional and keyword
arguments.
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
index e755916..5f66440 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
@@ -47,6 +47,7 @@ import
org.apache.flink.agents.runtime.operator.queue.SegmentedQueue;
import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl;
import org.apache.flink.agents.runtime.python.event.PythonEvent;
import org.apache.flink.agents.runtime.python.operator.PythonActionTask;
+import org.apache.flink.agents.runtime.python.utils.JavaResourceAdapter;
import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
import org.apache.flink.agents.runtime.python.utils.PythonResourceAdapterImpl;
import org.apache.flink.agents.runtime.utils.EventUtil;
@@ -552,9 +553,13 @@ public class ActionExecutionOperator<IN, OUT> extends
AbstractStreamOperator<OUT
}
private void initPythonActionExecutor() throws Exception {
+ JavaResourceAdapter javaResourceAdapter =
+ new JavaResourceAdapter(agentPlan, pythonInterpreter);
pythonActionExecutor =
new PythonActionExecutor(
- pythonInterpreter, new
ObjectMapper().writeValueAsString(agentPlan));
+ pythonInterpreter,
+ new ObjectMapper().writeValueAsString(agentPlan),
+ javaResourceAdapter);
pythonActionExecutor.open();
}
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java
new file mode 100644
index 0000000..fcdc6dc
--- /dev/null
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.python.utils;
+
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.resource.Resource;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.agents.plan.AgentPlan;
+import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider;
+import org.apache.flink.agents.plan.resourceprovider.ResourceProvider;
+import pemja.core.PythonInterpreter;
+
+import javax.naming.OperationNotSupportedException;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+/** Adapter for managing Java resources and facilitating Python-Java
interoperability. */
+public class JavaResourceAdapter {
+ private final Map<ResourceType, Map<String, ResourceProvider>>
resourceProviders;
+
+ private final transient PythonInterpreter interpreter;
+
+ /** Cache for instantiated resources. */
+ private final transient Map<ResourceType, Map<String, Resource>>
resourceCache;
+
+ public JavaResourceAdapter(AgentPlan agentPlan, PythonInterpreter
interpreter) {
+ this.resourceProviders = agentPlan.getResourceProviders();
+ this.interpreter = interpreter;
+ this.resourceCache = new ConcurrentHashMap<>();
+ }
+
+ /**
+ * Retrieves a Java resource by name and type value. This method is
intended for use by the
+ * Python interpreter.
+ *
+ * @param name the name of the resource to retrieve
+ * @param typeValue the type value of the resource
+ * @return the resource
+ * @throws Exception if the resource cannot be retrieved
+ */
+ public Resource getResource(String name, String typeValue) throws
Exception {
+ return getResource(name, ResourceType.fromValue(typeValue));
+ }
+
+ /**
+ * Retrieves a Java resource by name and type.
+ *
+ * @param name the name of the resource to retrieve
+ * @param type the type of the resource
+ * @return the resource
+ * @throws Exception if the resource cannot be retrieved
+ */
+ public Resource getResource(String name, ResourceType type) throws
Exception {
+ if (resourceCache.containsKey(type) &&
resourceCache.get(type).containsKey(name)) {
+ return resourceCache.get(type).get(name);
+ }
+
+ if (!resourceProviders.containsKey(type)
+ || !resourceProviders.get(type).containsKey(name)) {
+ throw new IllegalArgumentException("Resource not found: " + name +
" of type " + type);
+ }
+
+ ResourceProvider provider = resourceProviders.get(type).get(name);
+ if (provider instanceof PythonResourceProvider) {
+ // TODO: Support getting resources from PythonResourceProvider in
JavaResourceAdapter.
+ throw new OperationNotSupportedException("PythonResourceProvider
is not supported.");
+ }
+
+ Resource resource =
+ provider.provide(
+ (String anotherName, ResourceType anotherType) -> {
+ try {
+ return this.getResource(anotherName,
anotherType);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ });
+
+ // Cache the resource
+ resourceCache.computeIfAbsent(type, k -> new
ConcurrentHashMap<>()).put(name, resource);
+
+ return resource;
+ }
+
+ /**
+ * Convert a Python chat message to a Java chat message. This method is
intended for use by the
+ * Python interpreter.
+ *
+ * @param pythonChatMessage the Python chat message
+ * @return the Java chat message
+ */
+ public ChatMessage fromPythonChatMessage(Object pythonChatMessage) {
+ // TODO: Delete this method after the pemja findClass method is fixed.
+ ChatMessage chatMessage = new ChatMessage();
+ if (interpreter == null) {
+ throw new IllegalStateException("Python interpreter is not set.");
+ }
+ String roleValue =
+ (String)
+ interpreter.invoke(
+ "python_java_utils.update_java_chat_message",
+ pythonChatMessage,
+ chatMessage);
+ chatMessage.setRole(MessageRole.fromValue(roleValue));
+ return chatMessage;
+ }
+}
diff --git
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
index 9b8c7c4..2cef288 100644
---
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
+++
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
@@ -59,11 +59,16 @@ public class PythonActionExecutor {
private final PythonInterpreter interpreter;
private final String agentPlanJson;
+ private final JavaResourceAdapter javaResourceAdapter;
private Object pythonAsyncThreadPool;
- public PythonActionExecutor(PythonInterpreter interpreter, String
agentPlanJson) {
+ public PythonActionExecutor(
+ PythonInterpreter interpreter,
+ String agentPlanJson,
+ JavaResourceAdapter javaResourceAdapter) {
this.interpreter = interpreter;
this.agentPlanJson = agentPlanJson;
+ this.javaResourceAdapter = javaResourceAdapter;
}
public void open() throws Exception {
@@ -93,7 +98,8 @@ public class PythonActionExecutor {
CREATE_FLINK_RUNNER_CONTEXT,
runnerContext,
agentPlanJson,
- pythonAsyncThreadPool);
+ pythonAsyncThreadPool,
+ javaResourceAdapter);
Object pythonEventObject =
interpreter.invoke(CONVERT_TO_PYTHON_OBJECT, event.getEvent());