This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 770315f82b328f644666e58bbf6bef84fb261258 Author: WenjinXie <[email protected]> AuthorDate: Mon Nov 3 12:02:43 2025 +0800 [api] ReAct agent supports error handling strategy. --- .../apache/flink/agents/api/agents/ReActAgent.java | 50 +++++++++++++--- .../agents/api/agents/ReActAgentConfigOptions.java | 31 ++++++++++ .../agents/integration/test/ReActAgentExample.java | 10 +++- .../src/main/resources/log4j2.properties | 25 ++++++++ .../ollama/OllamaChatModelConnection.java | 2 +- .../flink/agents/plan/AgentConfiguration.java | 3 + python/flink_agents/api/agents/react_agent.py | 69 +++++++++++++++++----- 7 files changed, 164 insertions(+), 26 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java index 89a0847..37a5d53 100644 --- a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java +++ b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java @@ -49,6 +49,8 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.annotatio import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.deser.std.StdDeserializer; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ser.std.StdSerializer; import org.apache.flink.types.Row; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import javax.annotation.Nullable; @@ -62,6 +64,8 @@ import java.util.Objects; /** Built-in ReAct Agent implementation based on the function call ability of llm. . */ public class ReActAgent extends Agent { + private static final Logger LOG = LoggerFactory.getLogger(ReActAgent.class); + private static final String DEFAULT_CHAT_MODEL = "_default_chat_model"; private static final String DEFAULT_SCHEMA_PROMPT = "_default_schema_prompt"; private static final String DEFAULT_USER_PROMPT = "_default_user_prompt"; @@ -183,16 +187,29 @@ public class ReActAgent extends Agent { Object outputSchema = ctx.getActionConfigValue("output_schema"); - // TODO: handle parse error according to configured strategy. if (outputSchema != null) { - if (outputSchema instanceof Class) { - output = mapper.readValue(String.valueOf(output), (Class<?>) outputSchema); - } else if (outputSchema instanceof OutputSchema) { - RowTypeInfo info = ((OutputSchema) outputSchema).getSchema(); - Map<String, Object> fields = mapper.readValue(String.valueOf(output), Map.class); - output = Row.withNames(); - for (String name : info.getFieldNames()) { - ((Row) output).setField(name, fields.get(name)); + ErrorHandlingStrategy strategy = + ctx.getConfig().get(ReActAgentConfigOptions.ERROR_HANDLING_STRATEGY); + try { + if (outputSchema instanceof Class) { + output = mapper.readValue(String.valueOf(output), (Class<?>) outputSchema); + } else if (outputSchema instanceof OutputSchema) { + RowTypeInfo info = ((OutputSchema) outputSchema).getSchema(); + Map<String, Object> fields = + mapper.readValue(String.valueOf(output), Map.class); + output = Row.withNames(); + for (String name : info.getFieldNames()) { + ((Row) output).setField(name, fields.get(name)); + } + } + } catch (Exception e) { + if (strategy == ErrorHandlingStrategy.FAIL) { + throw e; + } else if (strategy == ErrorHandlingStrategy.IGNORE) { + LOG.warn( + "The response of llm {} doesn't match schema constraint, ignoring.", + output); + return; } } } @@ -293,4 +310,19 @@ public class ReActAgent extends Agent { fieldNames.toArray(new String[0]))); } } + + public enum ErrorHandlingStrategy { + FAIL("fail"), + IGNORE("ignore"); + + private final String value; + + ErrorHandlingStrategy(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } } diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgentConfigOptions.java b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgentConfigOptions.java new file mode 100644 index 0000000..d3edee7 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgentConfigOptions.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.agents; + +import org.apache.flink.agents.api.configuration.ConfigOption; + +/** Config Options for {@link ReActAgent}. */ +public class ReActAgentConfigOptions { + /** The option specifies the error handling strategy for react agent. */ + public static final ConfigOption<ReActAgent.ErrorHandlingStrategy> ERROR_HANDLING_STRATEGY = + new ConfigOption<>( + "error-handling-strategy", + ReActAgent.ErrorHandlingStrategy.class, + ReActAgent.ErrorHandlingStrategy.FAIL); +} diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/ReActAgentExample.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/ReActAgentExample.java index ad7f170..09bd4dd 100644 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/ReActAgentExample.java +++ b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/ReActAgentExample.java @@ -21,6 +21,7 @@ package org.apache.flink.agents.integration.test; import org.apache.flink.agents.api.Agent; import org.apache.flink.agents.api.AgentsExecutionEnvironment; import org.apache.flink.agents.api.agents.ReActAgent; +import org.apache.flink.agents.api.agents.ReActAgentConfigOptions; import org.apache.flink.agents.api.annotation.ToolParam; import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; @@ -75,7 +76,7 @@ public class ReActAgentExample { agentsEnv .addResource( "ollama", - ResourceType.CHAT_MODEL, + ResourceType.CHAT_MODEL_CONNECTION, ResourceDescriptor.Builder.newBuilder( OllamaChatModelConnection.class.getName()) .addInitialArgument("endpoint", "http://localhost:11434") @@ -93,6 +94,12 @@ public class ReActAgentExample { ReActAgentExample.class.getMethod( "multiply", Double.class, Double.class))); + agentsEnv + .getConfig() + .set( + ReActAgentConfigOptions.ERROR_HANDLING_STRATEGY, + ReActAgent.ErrorHandlingStrategy.IGNORE); + // Declare the ReAct agent. Agent agent = getAgent(); @@ -135,6 +142,7 @@ public class ReActAgentExample { .addInitialArgument("connection", "ollama") .addInitialArgument("model", "qwen3:8b") .addInitialArgument("tools", List.of("add", "multiply")) + .addInitialArgument("extract_reasoning", "true") .build(); Prompt prompt = diff --git a/e2e-test/integration-test/src/main/resources/log4j2.properties b/e2e-test/integration-test/src/main/resources/log4j2.properties new file mode 100644 index 0000000..9206863 --- /dev/null +++ b/e2e-test/integration-test/src/main/resources/log4j2.properties @@ -0,0 +1,25 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +rootLogger.level = INFO +rootLogger.appenderRef.console.ref = ConsoleAppender + +appender.console.name = ConsoleAppender +appender.console.type = CONSOLE +appender.console.layout.type = PatternLayout +appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss,SSS} %-5p %-60c %x - %m%n diff --git a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java index 633514c..67b0a9e 100644 --- a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java +++ b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java @@ -79,7 +79,7 @@ public class OllamaChatModelConnection extends BaseChatModelConnection { Integer requestTimeout = descriptor.getArgument("requestTimeout"); this.caller = new OllamaChatEndpointCaller( - endpoint, null, requestTimeout != null ? requestTimeout : 10); + endpoint, null, requestTimeout != null ? requestTimeout : 60); } /** diff --git a/plan/src/main/java/org/apache/flink/agents/plan/AgentConfiguration.java b/plan/src/main/java/org/apache/flink/agents/plan/AgentConfiguration.java index 34d6efa..6457efb 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/AgentConfiguration.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/AgentConfiguration.java @@ -19,6 +19,7 @@ package org.apache.flink.agents.plan; import org.apache.flink.agents.api.configuration.ConfigOption; import org.apache.flink.agents.api.configuration.Configuration; +import org.apache.flink.configuration.ConfigurationUtils; import java.util.HashMap; import java.util.Map; @@ -146,6 +147,8 @@ public class AgentConfiguration implements Configuration { return targetType.cast(Double.parseDouble(rawValue.toString())); } else if (Boolean.class.equals(targetType)) { return targetType.cast(Boolean.parseBoolean(rawValue.toString())); + } else if (targetType.isEnum()) { + return ConfigurationUtils.convertValue(rawValue, targetType); } else { throw new ClassCastException( "Unsupported type conversion from " diff --git a/python/flink_agents/api/agents/react_agent.py b/python/flink_agents/api/agents/react_agent.py index 67cb0df..e629556 100644 --- a/python/flink_agents/api/agents/react_agent.py +++ b/python/flink_agents/api/agents/react_agent.py @@ -17,14 +17,22 @@ ################################################################################# import importlib import json +import logging +from enum import Enum from typing import Any, cast -from pydantic import BaseModel, ConfigDict, model_serializer, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + model_serializer, + model_validator, +) from pyflink.common import Row from pyflink.common.typeinfo import BasicType, BasicTypeInfo, RowTypeInfo from flink_agents.api.agent import Agent from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.configuration import ConfigOption from flink_agents.api.decorators import action from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent from flink_agents.api.events.event import InputEvent, OutputEvent @@ -38,6 +46,23 @@ _DEFAULT_USER_PROMPT = "_default_user_prompt" _OUTPUT_SCHEMA = "_output_schema" +class ErrorHandlingStrategy(Enum): + """Error handling strategy for ReActAgent.""" + + FAIL = "fail" + IGNORE = "ignore" + + +class ReActAgentOptions: + """Config options for ReActAgent.""" + + ERROR_HANDLING_STRATEGY = ConfigOption( + key="error-handling-strategy", + config_type=ErrorHandlingStrategy, + default=ErrorHandlingStrategy.FAIL, + ) + + class OutputSchema(BaseModel): """Util class to help serialize and deserialize output schema json.""" @@ -134,7 +159,7 @@ class ReActAgent(Agent): tools=["notify_shipping_manager"], ), prompt=prompt, - output_schema=OutputData + output_schema=OutputData, ) """ @@ -241,19 +266,33 @@ class ReActAgent(Agent): def stop_action(event: ChatResponseEvent, ctx: RunnerContext) -> None: """Stop action to output result.""" output = event.response.content - # parse llm response to target schema. - # TODO: config error handle strategy by configuration. output_schema = ctx.get_action_config_value(key="output_schema") - if output_schema: - output_schema = output_schema.output_schema - output = json.loads(output.strip()) - if isinstance(output_schema, type) and issubclass(output_schema, BaseModel): - output = output_schema.model_validate(output) - elif isinstance(output_schema, RowTypeInfo): - field_names = output_schema.get_field_names() - values = {} - for field_name in field_names: - values[field_name] = output[field_name] - output = Row(**values) + + error_handling_strategy = ctx.config.get( + ReActAgentOptions.ERROR_HANDLING_STRATEGY + ) + try: + if output_schema: + output_schema = output_schema.output_schema + output = json.loads(output.strip()) + if isinstance(output_schema, type) and issubclass( + output_schema, BaseModel + ): + output = output_schema.model_validate(output) + elif isinstance(output_schema, RowTypeInfo): + field_names = output_schema.get_field_names() + values = {} + for field_name in field_names: + values[field_name] = output[field_name] + output = Row(**values) + except Exception: + if error_handling_strategy == ErrorHandlingStrategy.IGNORE: + logging.warning( + f"The response of llm {output} doesn't match schema constraint, ignoring." + ) + return + elif error_handling_strategy == ErrorHandlingStrategy.FAIL: + raise + ctx.send_event(OutputEvent(output=output))
