This is an automated email from the ASF dual-hosted git repository.
xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new d1a419b [Feature][Integrations][Java] Support explicit tool call
handling in OllamaChatModel (#279)
d1a419b is described below
commit d1a419b25601ceba7b3b6eac5faaaa3784cd5aaf
Author: twosom <[email protected]>
AuthorDate: Wed Oct 22 11:32:47 2025 +0900
[Feature][Integrations][Java] Support explicit tool call handling in
OllamaChatModel (#279)
---
integrations/chat-models/ollama/pom.xml | 2 +-
.../ollama/OllamaChatModelConnection.java | 152 ++++++++++++---------
.../chatmodels/ollama/OllamaChatModelSetup.java | 3 +
.../flink/agents/plan/actions/ChatModelAction.java | 12 +-
4 files changed, 98 insertions(+), 71 deletions(-)
diff --git a/integrations/chat-models/ollama/pom.xml
b/integrations/chat-models/ollama/pom.xml
index 96850c4..0477962 100644
--- a/integrations/chat-models/ollama/pom.xml
+++ b/integrations/chat-models/ollama/pom.xml
@@ -46,7 +46,7 @@ under the License.
<dependency>
<groupId>io.github.ollama4j</groupId>
<artifactId>ollama4j</artifactId>
- <version>1.1.0</version>
+ <version>1.1.2</version>
</dependency>
</dependencies>
diff --git
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
index 72b291c..633514c 100644
---
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
+++
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
@@ -20,11 +20,9 @@ package
org.apache.flink.agents.integrations.chatmodels.ollama;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
-import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.RoleNotFoundException;
-import io.github.ollama4j.models.chat.OllamaChatMessage;
-import io.github.ollama4j.models.chat.OllamaChatMessageRole;
-import io.github.ollama4j.models.chat.OllamaChatResult;
+import io.github.ollama4j.models.chat.*;
+import io.github.ollama4j.models.request.OllamaChatEndpointCaller;
import io.github.ollama4j.tools.Tools;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
@@ -33,14 +31,9 @@ import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.api.tools.Tool;
-import org.apache.flink.agents.api.tools.ToolParameters;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
import java.util.function.BiFunction;
-import java.util.regex.Matcher;
-import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
@@ -66,8 +59,8 @@ import java.util.stream.Collectors;
* }</pre>
*/
public class OllamaChatModelConnection extends BaseChatModelConnection {
- private final OllamaAPI client;
- private final Pattern pattern;
+
+ private final OllamaChatEndpointCaller caller;
/**
* Creates a new ollama chat model connection.
@@ -83,13 +76,10 @@ public class OllamaChatModelConnection extends
BaseChatModelConnection {
if (endpoint == null || endpoint.isEmpty()) {
throw new IllegalArgumentException("endpoint should not be null or
empty.");
}
- this.client = new OllamaAPI(endpoint);
- Integer maxChatToolCallRetries =
descriptor.getArgument("maxChatToolCallRetries");
- this.client.setMaxChatToolCallRetries(
- maxChatToolCallRetries != null ? maxChatToolCallRetries : 10);
Integer requestTimeout = descriptor.getArgument("requestTimeout");
- this.client.setRequestTimeoutSeconds(requestTimeout != null ?
requestTimeout : 10);
- this.pattern = Pattern.compile("<think>(.*?)</think>", Pattern.DOTALL);
+ this.caller =
+ new OllamaChatEndpointCaller(
+ endpoint, null, requestTimeout != null ?
requestTimeout : 10);
}
/**
@@ -108,18 +98,20 @@ public class OllamaChatModelConnection extends
BaseChatModelConnection {
}
/**
- * Registers tools with the Ollama client based on tool resource names.
+ * Converts Flink Agent tools to Ollama compatible tool specifications.
*
* <p>Each tool's input schema is expected to be a JSON schema containing
"properties" and
* "required" keys. The schema is converted into the function/tool
specification that Ollama
- * understands, and a callable is wired to invoke the underlying BaseTool
with ToolParameters.
+ * understands, and each tool is properly formatted for Ollama API
integration.
*
- * @param tools tools to be registered to the client
- * @throws RuntimeException if schema parsing or registration fails
+ * @param tools List of Flink Agent tools to be converted to Ollama tools
+ * @return List of Ollama compatible tool specifications
+ * @throws RuntimeException if schema parsing or conversion fails
*/
@SuppressWarnings("unchecked")
- private void registerTools(List<Tool> tools) {
+ private List<Tools.Tool> convertToOllamaTools(List<Tool> tools) {
final ObjectMapper mapper = new ObjectMapper();
+ final List<Tools.Tool> ollamaTools = new ArrayList<>();
try {
for (Tool tool : tools) {
final Map<String, Object> schema =
@@ -130,7 +122,7 @@ public class OllamaChatModelConnection extends
BaseChatModelConnection {
(Map<String, Map<String, String>>)
schema.get("properties");
final List<String> required = (List<String>)
schema.get("required");
- Map<String, Tools.PromptFuncDefinition.Property> propertiesMap
= new HashMap<>();
+ Map<String, Tools.Property> propertiesMap = new HashMap<>();
for (Map.Entry<String, Map<String, String>> entry :
properties.entrySet()) {
final String paramName = entry.getKey();
@@ -140,40 +132,26 @@ public class OllamaChatModelConnection extends
BaseChatModelConnection {
propertiesMap.put(
paramName,
- Tools.PromptFuncDefinition.Property.builder()
+ Tools.Property.builder()
.type(type)
.description(description)
.required(required.contains(paramName))
.build());
}
- final Tools.ToolSpecification toolSpec =
- Tools.ToolSpecification.builder()
- .functionName(tool.getName())
- .functionDescription(tool.getDescription())
- .toolPrompt(
- Tools.PromptFuncDefinition.builder()
- .type("prompt")
- .function(
-
Tools.PromptFuncDefinition.PromptFuncSpec
- .builder()
-
.name(tool.getName())
-
.description(tool.getDescription())
- .parameters(
-
Tools.PromptFuncDefinition
-
.Parameters
-
.builder()
-
.type("object")
-
.properties(
-
propertiesMap)
-
.build())
- .build())
+ final Tools.Tool toolSpec =
+ Tools.Tool.builder()
+ .toolSpec(
+ Tools.ToolSpec.builder()
+ .name(tool.getName())
+
.description(tool.getDescription())
+
.parameters(Tools.Parameters.of(propertiesMap))
.build())
- .toolFunction(arguments -> tool.call(new
ToolParameters(arguments)))
.build();
-
- this.client.registerTool(toolSpec);
+ ollamaTools.add(toolSpec);
}
+
+ return ollamaTools;
} catch (Exception e) {
throw new RuntimeException(e);
}
@@ -201,32 +179,78 @@ public class OllamaChatModelConnection extends
BaseChatModelConnection {
public ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object>
arguments) {
try {
- registerTools(tools);
+ final boolean extractReasoning =
+ (boolean) arguments.getOrDefault("extract_reasoning",
false);
+
+ final List<Tools.Tool> ollamaTools =
this.convertToOllamaTools(tools);
final List<OllamaChatMessage> ollamaChatMessages =
messages.stream()
.map(this::convertToOllamaChatMessages)
.collect(Collectors.toList());
- final OllamaChatResult ollamaChatResult =
- this.client.chat((String) arguments.get("model"),
ollamaChatMessages);
+ final OllamaChatRequest chatRequest =
+ OllamaChatRequest.builder()
+ .withMessages(ollamaChatMessages)
+ .withModel((String) arguments.get("model"))
+ .withThinking(extractReasoning)
+ .withUseTools(false)
+ .build();
+
+ chatRequest.setTools(ollamaTools);
+ final OllamaChatResult ollamaChatResult =
this.caller.callSync(chatRequest);
+ final OllamaChatResponseModel ollamaChatResponse =
ollamaChatResult.getResponseModel();
+ final OllamaChatMessage ollamaChatMessage =
ollamaChatResponse.getMessage();
+
+ Map<String, Object> extraArgs = new HashMap<>();
+ if (extractReasoning) {
+ extraArgs.put("reasoning", ollamaChatMessage.getThinking());
+ }
+
+ final List<OllamaChatToolCalls> ollamaToolCalls =
ollamaChatMessage.getToolCalls();
+ final ChatMessage chatMessage =
ChatMessage.assistant(ollamaChatMessage.getResponse());
+ chatMessage.setExtraArgs(extraArgs);
+
+ if (ollamaToolCalls != null) {
+ final List<Map<String, Object>> toolCalls =
convertToAgentsTools(ollamaToolCalls);
+ chatMessage.setToolCalls(toolCalls);
+ }
- return extraReasoning(ollamaChatResult.getResponse());
+ return chatMessage;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
- private ChatMessage extraReasoning(String response) {
- Matcher matcher = pattern.matcher(response);
- StringBuilder reasoning = new StringBuilder();
- while (matcher.find()) {
- reasoning.append(matcher.group(1));
+ /**
+ * Converts Ollama tool calls to the format expected by the Flink Agents
framework.
+ *
+ * <p>This method transforms Ollama-specific tool call representations
into a generic format
+ * that can be used by the Flink Agents framework. Each tool call is
assigned a unique ID and
+ * structured with the appropriate function name and arguments.
+ *
+ * @param ollamaToolCalls the list of tool calls returned from Ollama API
+ * @return a list of tool calls formatted for Flink Agents, where each
tool call is represented
+ * as a map containing id, type, and function details
+ */
+ private List<Map<String, Object>> convertToAgentsTools(
+ List<OllamaChatToolCalls> ollamaToolCalls) {
+ final List<Map<String, Object>> toolCalls = new
ArrayList<>(ollamaToolCalls.size());
+ for (OllamaChatToolCalls ollamaToolCall : ollamaToolCalls) {
+ final UUID id = UUID.randomUUID();
+ final Map<String, Object> toolCall =
+ Map.of(
+ "id",
+ id,
+ "type",
+ "function",
+ "function",
+ Map.of(
+ "name",
+ ollamaToolCall.getFunction().getName(),
+ "arguments",
+
ollamaToolCall.getFunction().getArguments()));
+ toolCalls.add(toolCall);
}
- response = matcher.replaceAll("").strip();
- ChatMessage responseMessage = ChatMessage.assistant(response);
- Map<String, Object> extraArgs = new HashMap<>();
- extraArgs.put("reasoning", reasoning.toString().strip());
- responseMessage.setExtraArgs(extraArgs);
- return responseMessage;
+ return toolCalls;
}
}
diff --git
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java
index 80bdeb3..8b78f8f 100644
---
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java
+++
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java
@@ -56,11 +56,13 @@ import java.util.function.BiFunction;
public class OllamaChatModelSetup extends BaseChatModelSetup {
private final String model;
+ private final boolean extractReasoning;
public OllamaChatModelSetup(
ResourceDescriptor descriptor, BiFunction<String, ResourceType,
Resource> getResource) {
super(descriptor, getResource);
this.model = descriptor.getArgument("model");
+ this.extractReasoning =
Boolean.parseBoolean(descriptor.getArgument("extract_reasoning"));
}
/**
@@ -88,6 +90,7 @@ public class OllamaChatModelSetup extends BaseChatModelSetup {
public Map<String, Object> getParameters() {
Map<String, Object> params = new HashMap<>();
params.put("model", model);
+ params.put("extract_reasoning", extractReasoning);
return params;
}
}
diff --git
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
index f20fdc1..6fdbd92 100644
---
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
+++
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
@@ -31,10 +31,7 @@ import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.api.tools.ToolResponse;
import org.apache.flink.agents.plan.JavaFunction;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.UUID;
+import java.util.*;
/** Built-in action for processing chat request and tool call result. */
public class ChatModelAction {
@@ -83,7 +80,8 @@ public class ChatModelAction {
toolCallContext.put(initialRequestId, messages);
}
List<ChatMessage> messageContext =
- (List<ChatMessage>) toolCallContext.get(initialRequestId);
+ new ArrayList<>((List<ChatMessage>)
toolCallContext.get(initialRequestId));
+
messageContext.add(response);
stm.set(TOOL_CALL_CONTEXT, toolCallContext);
@@ -159,7 +157,9 @@ public class ChatModelAction {
Map<UUID, Object> toolCallContext =
(Map<UUID, Object>) stm.get(TOOL_CALL_CONTEXT).getValue();
// update tool call context
- List<ChatMessage> messages = (List<ChatMessage>)
toolCallContext.get(initialRequestId);
+ List<ChatMessage> messages =
+ new ArrayList<>((List<ChatMessage>)
toolCallContext.get(initialRequestId));
+
for (Map.Entry<String, ToolResponse> entry : responses.entrySet())
{
Map<String, Object> extraArgs = new HashMap<>();
String toolCallId = entry.getKey();