This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 5f668c7bfab9b1d2f0fde8c7848e6d91751b977f
Author: WenjinXie <[email protected]>
AuthorDate: Mon Sep 22 13:39:56 2025 +0800

    [api] Support built-in ReAct Agent in java.
    
    add license.
---
 .../apache/flink/agents/api/agents/ReActAgent.java | 296 +++++++++++++++++++++
 .../flink/agents/api/tools/ToolResponse.java       |  12 +-
 .../flink/agents/api/agents/ReActAgentTest.java    |  45 ++++
 .../flink/agents/examples/ReActAgentExample.java   | 141 ++++++++++
 .../ollama/OllamaChatModelConnection.java          |  23 +-
 5 files changed, 506 insertions(+), 11 deletions(-)

diff --git 
a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java 
b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java
new file mode 100644
index 0000000..035f2fa
--- /dev/null
+++ b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java
@@ -0,0 +1,296 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api.agents;
+
+import org.apache.commons.lang3.ClassUtils;
+import org.apache.flink.agents.api.Agent;
+import org.apache.flink.agents.api.InputEvent;
+import org.apache.flink.agents.api.OutputEvent;
+import org.apache.flink.agents.api.annotation.Action;
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.context.RunnerContext;
+import org.apache.flink.agents.api.event.ChatRequestEvent;
+import org.apache.flink.agents.api.event.ChatResponseEvent;
+import org.apache.flink.agents.api.prompt.Prompt;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JacksonException;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonGenerator;
+import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonParser;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.DeserializationContext;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonMappingException;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.SerializerProvider;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.annotation.JsonDeserialize;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.annotation.JsonSerialize;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.deser.std.StdDeserializer;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ser.std.StdSerializer;
+import org.apache.flink.types.Row;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/** Built-in ReAct Agent implementation based on the function call ability of 
llm. . */
+public class ReActAgent extends Agent {
+    private static final String DEFAULT_CHAT_MODEL = "_default_chat_model";
+    private static final String DEFAULT_SCHEMA_PROMPT = 
"_default_schema_prompt";
+    private static final String DEFAULT_USER_PROMPT = "_default_user_prompt";
+    private static final ObjectMapper mapper = new ObjectMapper();
+
+    public ReActAgent(
+            ResourceDescriptor descriptor, @Nullable Prompt prompt, @Nullable 
Object outputSchema) {
+        this.addChatModelSetup(DEFAULT_CHAT_MODEL, descriptor);
+
+        if (outputSchema != null) {
+            String jsonSchema;
+            if (outputSchema instanceof RowTypeInfo) {
+                jsonSchema = outputSchema.toString();
+                outputSchema = new OutputSchema((RowTypeInfo) outputSchema);
+            } else if (outputSchema instanceof Class) {
+                try {
+                    jsonSchema = mapper.generateJsonSchema((Class<?>) 
outputSchema).toString();
+                } catch (JsonMappingException e) {
+                    throw new RuntimeException(e);
+                }
+            } else {
+                throw new IllegalArgumentException(
+                        "Output schema must be RowTypeInfo or Pojo class.");
+            }
+            Prompt schemaPrompt =
+                    new Prompt(
+                            String.format(
+                                    "The final response should be json format, 
and match the schema %s",
+                                    jsonSchema));
+            this.addPrompt(DEFAULT_SCHEMA_PROMPT, schemaPrompt);
+        }
+
+        if (prompt != null) {
+            this.addPrompt(DEFAULT_USER_PROMPT, prompt);
+        }
+
+        Map<String, Object> actionConfig = new HashMap<>();
+        actionConfig.put("output_schema", outputSchema);
+
+        try {
+            Method method =
+                    this.getClass()
+                            .getMethod("stopAction", ChatResponseEvent.class, 
RunnerContext.class);
+            this.addAction(new Class[] {ChatResponseEvent.class}, method, 
actionConfig);
+        } catch (NoSuchMethodException e) {
+            throw new IllegalStateException(
+                    "Can't find the method stopAction, this must be a bug.");
+        }
+    }
+
+    @Action(listenEvents = {InputEvent.class})
+    public static void startAction(InputEvent event, RunnerContext ctx) {
+        Object input = event.getInput();
+
+        Prompt userPrompt;
+        try {
+            userPrompt = (Prompt) ctx.getResource(DEFAULT_USER_PROMPT, 
ResourceType.PROMPT);
+        } catch (Exception e) {
+            userPrompt = null;
+        }
+
+        List<ChatMessage> inputMessages = new ArrayList<>();
+        if (ClassUtils.isPrimitiveOrWrapper(input.getClass())) {
+            if (userPrompt != null) {
+                inputMessages =
+                        userPrompt.formatMessages(
+                                MessageRole.USER, Map.of("input", 
String.valueOf(input)));
+            } else {
+                inputMessages.add(new ChatMessage(MessageRole.USER, 
String.valueOf(input)));
+            }
+        } else {
+            if (userPrompt == null) {
+                throw new RuntimeException(
+                        String.format(
+                                "The input type is %s, which is not primitive 
types,"
+                                        + " user should provide prompt to help 
convert it to ChatMessage",
+                                input.getClass()));
+            }
+
+            Map<String, String> fields = new HashMap<>();
+            if (input instanceof Row) {
+                Row userInput = (Row) input;
+                for (String name : 
Objects.requireNonNull(userInput.getFieldNames(true))) {
+                    fields.put(name, String.valueOf(userInput.getField(name)));
+                }
+            } else { // regard as pojo
+                ObjectMapper objectMapper = new ObjectMapper();
+                try {
+                    fields = 
mapper.readValue(objectMapper.writeValueAsString(input), Map.class);
+                } catch (JsonProcessingException e) {
+                    throw new RuntimeException(
+                            String.format(
+                                    "Input must be primitive type, Row or 
Pojo, but is %s",
+                                    input.getClass()));
+                }
+            }
+
+            inputMessages = userPrompt.formatMessages(MessageRole.USER, 
fields);
+        }
+
+        Prompt schmaPrompt;
+        try {
+            schmaPrompt = (Prompt) ctx.getResource(DEFAULT_SCHEMA_PROMPT, 
ResourceType.PROMPT);
+        } catch (Exception e) {
+            schmaPrompt = null;
+        }
+
+        if (schmaPrompt != null) {
+            List<ChatMessage> instruct = 
schmaPrompt.formatMessages(MessageRole.SYSTEM, Map.of());
+            inputMessages.addAll(0, instruct);
+        }
+
+        ctx.sendEvent(new ChatRequestEvent(DEFAULT_CHAT_MODEL, inputMessages));
+    }
+
+    public static void stopAction(ChatResponseEvent event, RunnerContext ctx)
+            throws JsonProcessingException {
+        Object output = String.valueOf(event.getResponse().getContent());
+
+        Object outputSchema = ctx.getActionConfigValue("output_schema");
+
+        // TODO: handle parse error according to configured strategy.
+        if (outputSchema != null) {
+            if (outputSchema instanceof Class) {
+                output = mapper.readValue(String.valueOf(output), (Class<?>) 
outputSchema);
+            } else if (outputSchema instanceof OutputSchema) {
+                RowTypeInfo info = ((OutputSchema) outputSchema).getSchema();
+                Map<String, Object> fields = 
mapper.readValue(String.valueOf(output), Map.class);
+                output = Row.withNames();
+                for (String name : info.getFieldNames()) {
+                    ((Row) output).setField(name, fields.get(name));
+                }
+            }
+        }
+
+        ctx.sendEvent(new OutputEvent(output));
+    }
+
+    /**
+     * Helper class for {@link RowTypeInfo} serialization.
+     *
+     * <p>Currently, only support row contains basic type.
+     */
+    @VisibleForTesting
+    @JsonSerialize(using = OutputSchemaJsonSerializer.class)
+    @JsonDeserialize(using = OutputSchemaJsonDeserializer.class)
+    public static class OutputSchema {
+        private final RowTypeInfo schema;
+
+        public OutputSchema(RowTypeInfo schema) {
+            this.schema = schema;
+            for (TypeInformation<?> info : schema.getFieldTypes()) {
+                if (!info.isBasicType()) {
+                    throw new IllegalArgumentException(
+                            "Currently, output schema only support row 
contains basic type.");
+                }
+            }
+        }
+
+        public RowTypeInfo getSchema() {
+            return schema;
+        }
+    }
+
+    public static class OutputSchemaJsonSerializer extends 
StdSerializer<OutputSchema> {
+
+        protected OutputSchemaJsonSerializer() {
+            super(OutputSchema.class);
+        }
+
+        @Override
+        public void serialize(
+                OutputSchema schema,
+                JsonGenerator jsonGenerator,
+                SerializerProvider serializerProvider)
+                throws IOException {
+            RowTypeInfo typeInfo = schema.getSchema();
+            jsonGenerator.writeStartObject();
+
+            jsonGenerator.writeFieldName("fieldNames");
+            jsonGenerator.writeStartArray();
+            for (String name : typeInfo.getFieldNames()) {
+                jsonGenerator.writeString(name);
+            }
+            jsonGenerator.writeEndArray();
+
+            // TODO: support type information which is not basic.
+            jsonGenerator.writeFieldName("types");
+            jsonGenerator.writeStartArray();
+            for (TypeInformation<?> info : typeInfo.getFieldTypes()) {
+                jsonGenerator.writeObject(info.getTypeClass());
+            }
+            jsonGenerator.writeEndArray();
+
+            jsonGenerator.writeEndObject();
+        }
+    }
+
+    public static class OutputSchemaJsonDeserializer extends 
StdDeserializer<OutputSchema> {
+        private static final ObjectMapper mapper = new ObjectMapper();
+
+        protected OutputSchemaJsonDeserializer() {
+            super(OutputSchema.class);
+        }
+
+        @Override
+        public OutputSchema deserialize(
+                JsonParser jsonParser, DeserializationContext 
deserializationContext)
+                throws IOException, JacksonException {
+            JsonNode node = jsonParser.getCodec().readTree(jsonParser);
+            List<String> fieldNames = new ArrayList<>();
+            node.get("fieldNames").forEach(fieldNameNode -> 
fieldNames.add(fieldNameNode.asText()));
+            List<TypeInformation<?>> types = new ArrayList<>();
+            node.get("types")
+                    .forEach(
+                            typeNode -> {
+                                try {
+                                    types.add(
+                                            BasicTypeInfo.getInfoFor(
+                                                    
mapper.treeToValue(typeNode, Class.class)));
+                                } catch (JsonProcessingException e) {
+                                    throw new RuntimeException(e);
+                                }
+                            });
+
+            return new OutputSchema(
+                    new RowTypeInfo(
+                            types.toArray(new TypeInformation[0]),
+                            fieldNames.toArray(new String[0])));
+        }
+    }
+}
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/tools/ToolResponse.java 
b/api/src/main/java/org/apache/flink/agents/api/tools/ToolResponse.java
index 3f38669..93ad9f3 100644
--- a/api/src/main/java/org/apache/flink/agents/api/tools/ToolResponse.java
+++ b/api/src/main/java/org/apache/flink/agents/api/tools/ToolResponse.java
@@ -175,17 +175,9 @@ public class ToolResponse {
     @Override
     public String toString() {
         if (success) {
-            return String.format(
-                    "ToolResponse{success=true, result=%s, 
executionTime=%dms%s}",
-                    result,
-                    executionTimeMs,
-                    toolName != null ? ", toolName='" + toolName + "'" : "");
+            return result.toString();
         } else {
-            return String.format(
-                    "ToolResponse{success=false, error='%s', 
executionTime=%dms%s}",
-                    error,
-                    executionTimeMs,
-                    toolName != null ? ", toolName='" + toolName + "'" : "");
+            return error;
         }
     }
 }
diff --git 
a/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java 
b/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java
new file mode 100644
index 0000000..d559a9c
--- /dev/null
+++ b/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api.agents;
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class ReActAgentTest {
+    @Test
+    public void testOutputSchemaSerialization() throws JsonProcessingException 
{
+        ObjectMapper mapper = new ObjectMapper();
+        RowTypeInfo typeInfo =
+                new RowTypeInfo(
+                        new TypeInformation[] {
+                            BasicTypeInfo.INT_TYPE_INFO, 
BasicTypeInfo.STRING_TYPE_INFO
+                        },
+                        new String[] {"a", "b"});
+        ReActAgent.OutputSchema schema = new ReActAgent.OutputSchema(typeInfo);
+        String json = mapper.writeValueAsString(schema);
+        ReActAgent.OutputSchema deserialized =
+                mapper.readValue(json, ReActAgent.OutputSchema.class);
+        Assertions.assertEquals(typeInfo, deserialized.getSchema());
+    }
+}
diff --git 
a/examples/src/main/java/org/apache/flink/agents/examples/ReActAgentExample.java
 
b/examples/src/main/java/org/apache/flink/agents/examples/ReActAgentExample.java
new file mode 100644
index 0000000..dd8ffe9
--- /dev/null
+++ 
b/examples/src/main/java/org/apache/flink/agents/examples/ReActAgentExample.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.examples;
+
+import org.apache.flink.agents.api.Agent;
+import org.apache.flink.agents.api.AgentsExecutionEnvironment;
+import org.apache.flink.agents.api.agents.ReActAgent;
+import org.apache.flink.agents.api.annotation.Tool;
+import org.apache.flink.agents.api.annotation.ToolParam;
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.prompt.Prompt;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import 
org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelConnection;
+import 
org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelSetup;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.jetbrains.annotations.NotNull;
+
+import java.util.List;
+
+public class ReActAgentExample {
+    @Tool(description = "Useful function to add two numbers.")
+    public static double add(@ToolParam(name = "a") Double a, @ToolParam(name 
= "b") Double b) {
+        return a + b;
+    }
+
+    @Tool(description = "Useful function to multiply two numbers.")
+    public static double multiply(
+            @ToolParam(name = "a") Double a, @ToolParam(name = "b") Double b) {
+        return a * b;
+    }
+
+    /** Runs the example pipeline. */
+    public static void main(String[] args) throws Exception {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(1);
+
+        // Create the table environment
+        StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
+        tableEnv.getConfig().set("table.exec.result.display.max-column-width", 
"100");
+
+        // Create agents execution environment
+        AgentsExecutionEnvironment agentsEnv =
+                AgentsExecutionEnvironment.getExecutionEnvironment(env);
+
+        // register resource to agents execution environment.
+        agentsEnv
+                .addChatModelConnection(
+                        "ollama",
+                        ResourceDescriptor.Builder.newBuilder(
+                                        
OllamaChatModelConnection.class.getName())
+                                .addInitialArgument("endpoint", 
"http://localhost:11434";)
+                                .build())
+                .addTool(ReActAgentExample.class.getMethod("add", 
Double.class, Double.class))
+                .addTool(ReActAgentExample.class.getMethod("multiply", 
Double.class, Double.class));
+
+        // Declare the ReAct agent.
+        Agent agent = getAgent();
+
+        // Create input table from sample data
+        Table inputTable =
+                tableEnv.fromValues(
+                        DataTypes.ROW(
+                                DataTypes.FIELD("a", DataTypes.DOUBLE()),
+                                DataTypes.FIELD("b", DataTypes.DOUBLE()),
+                                DataTypes.FIELD("c", DataTypes.DOUBLE())),
+                        Row.of(1, 2, 3));
+
+        // Define output schema
+        Schema outputSchema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.ROW(DataTypes.FIELD("result", 
DataTypes.DOUBLE())))
+                        .build();
+
+        // Apply agent to the Table
+        Table outputTable =
+                agentsEnv
+                        .fromTable(
+                                inputTable,
+                                tableEnv,
+                                (KeySelector<Object, Double>)
+                                        value -> (Double) ((Row) 
value).getField("a"))
+                        .apply(agent)
+                        .toTable(outputSchema);
+
+        // Print the results to fully display the data
+        tableEnv.toDataStream(outputTable)
+                .map((MapFunction<Row, Row>) x -> (Row) x.getField("f0"))
+                .print();
+        env.execute();
+    }
+
+    // create ReAct agent.
+    private static @NotNull Agent getAgent() {
+        ResourceDescriptor chatModelDescriptor =
+                
ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName())
+                        .addInitialArgument("connection", "ollama")
+                        .addInitialArgument("model", "qwen3:8b")
+                        .addInitialArgument("tools", List.of("add", 
"multiply"))
+                        .build();
+
+        Prompt prompt =
+                new Prompt(
+                        List.of(
+                                new ChatMessage(
+                                        MessageRole.SYSTEM,
+                                        "An example of output is {\"result\": 
30.32}"),
+                                new ChatMessage(MessageRole.USER, "What is 
({a} + {b}) * {c}")));
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO},
+                        new String[] {"result"});
+        return new ReActAgent(chatModelDescriptor, prompt, outputTypeInfo);
+    }
+}
diff --git 
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
 
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
index 1452e41..0b8d21b 100644
--- 
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
+++ 
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
@@ -39,6 +39,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.function.BiFunction;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
 /**
@@ -65,6 +67,7 @@ import java.util.stream.Collectors;
  */
 public class OllamaChatModelConnection extends BaseChatModelConnection {
     private final OllamaAPI client;
+    private final Pattern pattern;
     /**
      * Creates a new ollama chat model connection.
      *
@@ -80,6 +83,10 @@ public class OllamaChatModelConnection extends 
BaseChatModelConnection {
             throw new IllegalArgumentException("endpoint should not be null or 
empty.");
         }
         this.client = new OllamaAPI(endpoint);
+        Integer maxChatToolCallRetries = 
descriptor.getArgument("maxChatToolCallRetries");
+        this.client.setMaxChatToolCallRetries(
+                maxChatToolCallRetries != null ? maxChatToolCallRetries : 10);
+        this.pattern = Pattern.compile("<think>(.*?)</think>", Pattern.DOTALL);
     }
 
     /**
@@ -200,9 +207,23 @@ public class OllamaChatModelConnection extends 
BaseChatModelConnection {
             final OllamaChatResult ollamaChatResult =
                     this.client.chat((String) arguments.get("model"), 
ollamaChatMessages);
 
-            return ChatMessage.assistant(ollamaChatResult.getResponse());
+            return extraReasoning(ollamaChatResult.getResponse());
         } catch (Exception e) {
             throw new RuntimeException(e);
         }
     }
+
+    private ChatMessage extraReasoning(String response) {
+        Matcher matcher = pattern.matcher(response);
+        StringBuilder reasoning = new StringBuilder();
+        while (matcher.find()) {
+            reasoning.append(matcher.group(1));
+        }
+        response = matcher.replaceAll("").strip();
+        ChatMessage responseMessage = ChatMessage.assistant(response);
+        Map<String, Object> extraArgs = new HashMap<>();
+        extraArgs.put("reasoning", reasoning.toString().strip());
+        responseMessage.setExtraArgs(extraArgs);
+        return responseMessage;
+    }
 }

Reply via email to