This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 5f668c7bfab9b1d2f0fde8c7848e6d91751b977f Author: WenjinXie <[email protected]> AuthorDate: Mon Sep 22 13:39:56 2025 +0800 [api] Support built-in ReAct Agent in java. add license. --- .../apache/flink/agents/api/agents/ReActAgent.java | 296 +++++++++++++++++++++ .../flink/agents/api/tools/ToolResponse.java | 12 +- .../flink/agents/api/agents/ReActAgentTest.java | 45 ++++ .../flink/agents/examples/ReActAgentExample.java | 141 ++++++++++ .../ollama/OllamaChatModelConnection.java | 23 +- 5 files changed, 506 insertions(+), 11 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java new file mode 100644 index 0000000..035f2fa --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.agents; + +import org.apache.commons.lang3.ClassUtils; +import org.apache.flink.agents.api.Agent; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.annotation.Action; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.event.ChatRequestEvent; +import org.apache.flink.agents.api.event.ChatResponseEvent; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JacksonException; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonGenerator; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonParser; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.DeserializationContext; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonMappingException; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.SerializerProvider; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.annotation.JsonSerialize; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ser.std.StdSerializer; +import org.apache.flink.types.Row; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** Built-in ReAct Agent implementation based on the function call ability of llm. . */ +public class ReActAgent extends Agent { + private static final String DEFAULT_CHAT_MODEL = "_default_chat_model"; + private static final String DEFAULT_SCHEMA_PROMPT = "_default_schema_prompt"; + private static final String DEFAULT_USER_PROMPT = "_default_user_prompt"; + private static final ObjectMapper mapper = new ObjectMapper(); + + public ReActAgent( + ResourceDescriptor descriptor, @Nullable Prompt prompt, @Nullable Object outputSchema) { + this.addChatModelSetup(DEFAULT_CHAT_MODEL, descriptor); + + if (outputSchema != null) { + String jsonSchema; + if (outputSchema instanceof RowTypeInfo) { + jsonSchema = outputSchema.toString(); + outputSchema = new OutputSchema((RowTypeInfo) outputSchema); + } else if (outputSchema instanceof Class) { + try { + jsonSchema = mapper.generateJsonSchema((Class<?>) outputSchema).toString(); + } catch (JsonMappingException e) { + throw new RuntimeException(e); + } + } else { + throw new IllegalArgumentException( + "Output schema must be RowTypeInfo or Pojo class."); + } + Prompt schemaPrompt = + new Prompt( + String.format( + "The final response should be json format, and match the schema %s", + jsonSchema)); + this.addPrompt(DEFAULT_SCHEMA_PROMPT, schemaPrompt); + } + + if (prompt != null) { + this.addPrompt(DEFAULT_USER_PROMPT, prompt); + } + + Map<String, Object> actionConfig = new HashMap<>(); + actionConfig.put("output_schema", outputSchema); + + try { + Method method = + this.getClass() + .getMethod("stopAction", ChatResponseEvent.class, RunnerContext.class); + this.addAction(new Class[] {ChatResponseEvent.class}, method, actionConfig); + } catch (NoSuchMethodException e) { + throw new IllegalStateException( + "Can't find the method stopAction, this must be a bug."); + } + } + + @Action(listenEvents = {InputEvent.class}) + public static void startAction(InputEvent event, RunnerContext ctx) { + Object input = event.getInput(); + + Prompt userPrompt; + try { + userPrompt = (Prompt) ctx.getResource(DEFAULT_USER_PROMPT, ResourceType.PROMPT); + } catch (Exception e) { + userPrompt = null; + } + + List<ChatMessage> inputMessages = new ArrayList<>(); + if (ClassUtils.isPrimitiveOrWrapper(input.getClass())) { + if (userPrompt != null) { + inputMessages = + userPrompt.formatMessages( + MessageRole.USER, Map.of("input", String.valueOf(input))); + } else { + inputMessages.add(new ChatMessage(MessageRole.USER, String.valueOf(input))); + } + } else { + if (userPrompt == null) { + throw new RuntimeException( + String.format( + "The input type is %s, which is not primitive types," + + " user should provide prompt to help convert it to ChatMessage", + input.getClass())); + } + + Map<String, String> fields = new HashMap<>(); + if (input instanceof Row) { + Row userInput = (Row) input; + for (String name : Objects.requireNonNull(userInput.getFieldNames(true))) { + fields.put(name, String.valueOf(userInput.getField(name))); + } + } else { // regard as pojo + ObjectMapper objectMapper = new ObjectMapper(); + try { + fields = mapper.readValue(objectMapper.writeValueAsString(input), Map.class); + } catch (JsonProcessingException e) { + throw new RuntimeException( + String.format( + "Input must be primitive type, Row or Pojo, but is %s", + input.getClass())); + } + } + + inputMessages = userPrompt.formatMessages(MessageRole.USER, fields); + } + + Prompt schmaPrompt; + try { + schmaPrompt = (Prompt) ctx.getResource(DEFAULT_SCHEMA_PROMPT, ResourceType.PROMPT); + } catch (Exception e) { + schmaPrompt = null; + } + + if (schmaPrompt != null) { + List<ChatMessage> instruct = schmaPrompt.formatMessages(MessageRole.SYSTEM, Map.of()); + inputMessages.addAll(0, instruct); + } + + ctx.sendEvent(new ChatRequestEvent(DEFAULT_CHAT_MODEL, inputMessages)); + } + + public static void stopAction(ChatResponseEvent event, RunnerContext ctx) + throws JsonProcessingException { + Object output = String.valueOf(event.getResponse().getContent()); + + Object outputSchema = ctx.getActionConfigValue("output_schema"); + + // TODO: handle parse error according to configured strategy. + if (outputSchema != null) { + if (outputSchema instanceof Class) { + output = mapper.readValue(String.valueOf(output), (Class<?>) outputSchema); + } else if (outputSchema instanceof OutputSchema) { + RowTypeInfo info = ((OutputSchema) outputSchema).getSchema(); + Map<String, Object> fields = mapper.readValue(String.valueOf(output), Map.class); + output = Row.withNames(); + for (String name : info.getFieldNames()) { + ((Row) output).setField(name, fields.get(name)); + } + } + } + + ctx.sendEvent(new OutputEvent(output)); + } + + /** + * Helper class for {@link RowTypeInfo} serialization. + * + * <p>Currently, only support row contains basic type. + */ + @VisibleForTesting + @JsonSerialize(using = OutputSchemaJsonSerializer.class) + @JsonDeserialize(using = OutputSchemaJsonDeserializer.class) + public static class OutputSchema { + private final RowTypeInfo schema; + + public OutputSchema(RowTypeInfo schema) { + this.schema = schema; + for (TypeInformation<?> info : schema.getFieldTypes()) { + if (!info.isBasicType()) { + throw new IllegalArgumentException( + "Currently, output schema only support row contains basic type."); + } + } + } + + public RowTypeInfo getSchema() { + return schema; + } + } + + public static class OutputSchemaJsonSerializer extends StdSerializer<OutputSchema> { + + protected OutputSchemaJsonSerializer() { + super(OutputSchema.class); + } + + @Override + public void serialize( + OutputSchema schema, + JsonGenerator jsonGenerator, + SerializerProvider serializerProvider) + throws IOException { + RowTypeInfo typeInfo = schema.getSchema(); + jsonGenerator.writeStartObject(); + + jsonGenerator.writeFieldName("fieldNames"); + jsonGenerator.writeStartArray(); + for (String name : typeInfo.getFieldNames()) { + jsonGenerator.writeString(name); + } + jsonGenerator.writeEndArray(); + + // TODO: support type information which is not basic. + jsonGenerator.writeFieldName("types"); + jsonGenerator.writeStartArray(); + for (TypeInformation<?> info : typeInfo.getFieldTypes()) { + jsonGenerator.writeObject(info.getTypeClass()); + } + jsonGenerator.writeEndArray(); + + jsonGenerator.writeEndObject(); + } + } + + public static class OutputSchemaJsonDeserializer extends StdDeserializer<OutputSchema> { + private static final ObjectMapper mapper = new ObjectMapper(); + + protected OutputSchemaJsonDeserializer() { + super(OutputSchema.class); + } + + @Override + public OutputSchema deserialize( + JsonParser jsonParser, DeserializationContext deserializationContext) + throws IOException, JacksonException { + JsonNode node = jsonParser.getCodec().readTree(jsonParser); + List<String> fieldNames = new ArrayList<>(); + node.get("fieldNames").forEach(fieldNameNode -> fieldNames.add(fieldNameNode.asText())); + List<TypeInformation<?>> types = new ArrayList<>(); + node.get("types") + .forEach( + typeNode -> { + try { + types.add( + BasicTypeInfo.getInfoFor( + mapper.treeToValue(typeNode, Class.class))); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + }); + + return new OutputSchema( + new RowTypeInfo( + types.toArray(new TypeInformation[0]), + fieldNames.toArray(new String[0]))); + } + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/tools/ToolResponse.java b/api/src/main/java/org/apache/flink/agents/api/tools/ToolResponse.java index 3f38669..93ad9f3 100644 --- a/api/src/main/java/org/apache/flink/agents/api/tools/ToolResponse.java +++ b/api/src/main/java/org/apache/flink/agents/api/tools/ToolResponse.java @@ -175,17 +175,9 @@ public class ToolResponse { @Override public String toString() { if (success) { - return String.format( - "ToolResponse{success=true, result=%s, executionTime=%dms%s}", - result, - executionTimeMs, - toolName != null ? ", toolName='" + toolName + "'" : ""); + return result.toString(); } else { - return String.format( - "ToolResponse{success=false, error='%s', executionTime=%dms%s}", - error, - executionTimeMs, - toolName != null ? ", toolName='" + toolName + "'" : ""); + return error; } } } diff --git a/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java b/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java new file mode 100644 index 0000000..d559a9c --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.agents; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ReActAgentTest { + @Test + public void testOutputSchemaSerialization() throws JsonProcessingException { + ObjectMapper mapper = new ObjectMapper(); + RowTypeInfo typeInfo = + new RowTypeInfo( + new TypeInformation[] { + BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO + }, + new String[] {"a", "b"}); + ReActAgent.OutputSchema schema = new ReActAgent.OutputSchema(typeInfo); + String json = mapper.writeValueAsString(schema); + ReActAgent.OutputSchema deserialized = + mapper.readValue(json, ReActAgent.OutputSchema.class); + Assertions.assertEquals(typeInfo, deserialized.getSchema()); + } +} diff --git a/examples/src/main/java/org/apache/flink/agents/examples/ReActAgentExample.java b/examples/src/main/java/org/apache/flink/agents/examples/ReActAgentExample.java new file mode 100644 index 0000000..dd8ffe9 --- /dev/null +++ b/examples/src/main/java/org/apache/flink/agents/examples/ReActAgentExample.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.examples; + +import org.apache.flink.agents.api.Agent; +import org.apache.flink.agents.api.AgentsExecutionEnvironment; +import org.apache.flink.agents.api.agents.ReActAgent; +import org.apache.flink.agents.api.annotation.Tool; +import org.apache.flink.agents.api.annotation.ToolParam; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelConnection; +import org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelSetup; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.jetbrains.annotations.NotNull; + +import java.util.List; + +public class ReActAgentExample { + @Tool(description = "Useful function to add two numbers.") + public static double add(@ToolParam(name = "a") Double a, @ToolParam(name = "b") Double b) { + return a + b; + } + + @Tool(description = "Useful function to multiply two numbers.") + public static double multiply( + @ToolParam(name = "a") Double a, @ToolParam(name = "b") Double b) { + return a * b; + } + + /** Runs the example pipeline. */ + public static void main(String[] args) throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + // Create the table environment + StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env); + tableEnv.getConfig().set("table.exec.result.display.max-column-width", "100"); + + // Create agents execution environment + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + // register resource to agents execution environment. + agentsEnv + .addChatModelConnection( + "ollama", + ResourceDescriptor.Builder.newBuilder( + OllamaChatModelConnection.class.getName()) + .addInitialArgument("endpoint", "http://localhost:11434") + .build()) + .addTool(ReActAgentExample.class.getMethod("add", Double.class, Double.class)) + .addTool(ReActAgentExample.class.getMethod("multiply", Double.class, Double.class)); + + // Declare the ReAct agent. + Agent agent = getAgent(); + + // Create input table from sample data + Table inputTable = + tableEnv.fromValues( + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.DOUBLE()), + DataTypes.FIELD("b", DataTypes.DOUBLE()), + DataTypes.FIELD("c", DataTypes.DOUBLE())), + Row.of(1, 2, 3)); + + // Define output schema + Schema outputSchema = + Schema.newBuilder() + .column("f0", DataTypes.ROW(DataTypes.FIELD("result", DataTypes.DOUBLE()))) + .build(); + + // Apply agent to the Table + Table outputTable = + agentsEnv + .fromTable( + inputTable, + tableEnv, + (KeySelector<Object, Double>) + value -> (Double) ((Row) value).getField("a")) + .apply(agent) + .toTable(outputSchema); + + // Print the results to fully display the data + tableEnv.toDataStream(outputTable) + .map((MapFunction<Row, Row>) x -> (Row) x.getField("f0")) + .print(); + env.execute(); + } + + // create ReAct agent. + private static @NotNull Agent getAgent() { + ResourceDescriptor chatModelDescriptor = + ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName()) + .addInitialArgument("connection", "ollama") + .addInitialArgument("model", "qwen3:8b") + .addInitialArgument("tools", List.of("add", "multiply")) + .build(); + + Prompt prompt = + new Prompt( + List.of( + new ChatMessage( + MessageRole.SYSTEM, + "An example of output is {\"result\": 30.32}"), + new ChatMessage(MessageRole.USER, "What is ({a} + {b}) * {c}"))); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + new TypeInformation[] {BasicTypeInfo.DOUBLE_TYPE_INFO}, + new String[] {"result"}); + return new ReActAgent(chatModelDescriptor, prompt, outputTypeInfo); + } +} diff --git a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java index 1452e41..0b8d21b 100644 --- a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java +++ b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java @@ -39,6 +39,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.BiFunction; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; /** @@ -65,6 +67,7 @@ import java.util.stream.Collectors; */ public class OllamaChatModelConnection extends BaseChatModelConnection { private final OllamaAPI client; + private final Pattern pattern; /** * Creates a new ollama chat model connection. * @@ -80,6 +83,10 @@ public class OllamaChatModelConnection extends BaseChatModelConnection { throw new IllegalArgumentException("endpoint should not be null or empty."); } this.client = new OllamaAPI(endpoint); + Integer maxChatToolCallRetries = descriptor.getArgument("maxChatToolCallRetries"); + this.client.setMaxChatToolCallRetries( + maxChatToolCallRetries != null ? maxChatToolCallRetries : 10); + this.pattern = Pattern.compile("<think>(.*?)</think>", Pattern.DOTALL); } /** @@ -200,9 +207,23 @@ public class OllamaChatModelConnection extends BaseChatModelConnection { final OllamaChatResult ollamaChatResult = this.client.chat((String) arguments.get("model"), ollamaChatMessages); - return ChatMessage.assistant(ollamaChatResult.getResponse()); + return extraReasoning(ollamaChatResult.getResponse()); } catch (Exception e) { throw new RuntimeException(e); } } + + private ChatMessage extraReasoning(String response) { + Matcher matcher = pattern.matcher(response); + StringBuilder reasoning = new StringBuilder(); + while (matcher.find()) { + reasoning.append(matcher.group(1)); + } + response = matcher.replaceAll("").strip(); + ChatMessage responseMessage = ChatMessage.assistant(response); + Map<String, Object> extraArgs = new HashMap<>(); + extraArgs.put("reasoning", reasoning.toString().strip()); + responseMessage.setExtraArgs(extraArgs); + return responseMessage; + } }
