This is an automated email from the ASF dual-hosted git repository.
xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new 00c7c3d [Feature][Java] Add OpenAI chat model integration (#320)
00c7c3d is described below
commit 00c7c3d88d9bea2c9c3dde278d8705c58021a9b4
Author: Xiang Li <[email protected]>
AuthorDate: Mon Dec 1 22:38:16 2025 -0800
[Feature][Java] Add OpenAI chat model integration (#320)
---
.../pom.xml | 7 +-
.../test/ChatModelIntegrationAgent.java | 15 +
.../integration/test/ChatModelIntegrationTest.java | 2 +-
integrations/{ => chat-models/openai}/pom.xml | 36 +-
.../openai/OpenAIChatModelConnection.java | 438 +++++++++++++++++++++
.../chatmodels/openai/OpenAIChatModelSetup.java | 219 +++++++++++
integrations/chat-models/pom.xml | 3 +-
integrations/pom.xml | 3 +-
.../flink/agents/plan/tools/SchemaUtils.java | 4 +-
9 files changed, 710 insertions(+), 17 deletions(-)
diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml
b/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml
index b3edb3e..65c7d12 100644
--- a/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml
+++ b/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml
@@ -64,6 +64,11 @@ under the License.
<artifactId>flink-agents-integrations-chat-models-azureai</artifactId>
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+
<artifactId>flink-agents-integrations-chat-models-openai</artifactId>
+ <version>${project.version}</version>
+ </dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-agents-integrations-chat-models-ollama</artifactId>
@@ -76,4 +81,4 @@ under the License.
</dependency>
</dependencies>
-</project>
\ No newline at end of file
+</project>
diff --git
a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java
b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java
index 254041c..9056f99 100644
---
a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java
+++
b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java
@@ -37,6 +37,8 @@ import
org.apache.flink.agents.integrations.chatmodels.azureai.AzureAIChatModelC
import
org.apache.flink.agents.integrations.chatmodels.azureai.AzureAIChatModelSetup;
import
org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelConnection;
import
org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelSetup;
+import
org.apache.flink.agents.integrations.chatmodels.openai.OpenAIChatModelConnection;
+import
org.apache.flink.agents.integrations.chatmodels.openai.OpenAIChatModelSetup;
import java.util.Collections;
import java.util.List;
@@ -80,6 +82,11 @@ public class ChatModelIntegrationAgent extends Agent {
.addInitialArgument("endpoint", endpoint)
.addInitialArgument("apiKey", apiKey)
.build();
+ } else if (provider.equals("OPENAI")) {
+ String apiKey = System.getenv().get("OPENAI_API_KEY");
+ return
ResourceDescriptor.Builder.newBuilder(OpenAIChatModelConnection.class.getName())
+ .addInitialArgument("api_key", apiKey)
+ .build();
} else {
throw new RuntimeException(String.format("Unknown model provider
%s", provider));
}
@@ -105,6 +112,14 @@ public class ChatModelIntegrationAgent extends Agent {
"tools",
List.of("calculateBMI", "convertTemperature",
"createRandomNumber"))
.build();
+ } else if (provider.equals("OPENAI")) {
+ return
ResourceDescriptor.Builder.newBuilder(OpenAIChatModelSetup.class.getName())
+ .addInitialArgument("connection", "chatModelConnection")
+ .addInitialArgument("model", "gpt-4o-mini")
+ .addInitialArgument(
+ "tools",
+ List.of("calculateBMI", "convertTemperature",
"createRandomNumber"))
+ .build();
} else {
throw new RuntimeException(String.format("Unknown model provider
%s", provider));
}
diff --git
a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java
b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java
index 261e646..85035a5 100644
---
a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java
+++
b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java
@@ -52,7 +52,7 @@ public class ChatModelIntegrationTest extends
OllamaPreparationUtils {
}
@ParameterizedTest()
- @ValueSource(strings = {"OLLAMA", "AZURE"})
+ @ValueSource(strings = {"OLLAMA", "AZURE", "OPENAI"})
public void testChatModeIntegration(String provider) throws Exception {
Assumptions.assumeTrue(
(OLLAMA.equals(provider) && ollamaReady)
diff --git a/integrations/pom.xml b/integrations/chat-models/openai/pom.xml
similarity index 56%
copy from integrations/pom.xml
copy to integrations/chat-models/openai/pom.xml
index a145fb7..df35254 100644
--- a/integrations/pom.xml
+++ b/integrations/chat-models/openai/pom.xml
@@ -22,21 +22,33 @@ under the License.
<parent>
<groupId>org.apache.flink</groupId>
- <artifactId>flink-agents</artifactId>
+ <artifactId>flink-agents-integrations-chat-models</artifactId>
<version>0.2-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
</parent>
- <artifactId>flink-agents-integrations</artifactId>
- <name>Flink Agents : Integrations:</name>
- <packaging>pom</packaging>
+ <artifactId>flink-agents-integrations-chat-models-openai</artifactId>
+ <name>Flink Agents : Integrations: Chat Models: OpenAI</name>
+ <packaging>jar</packaging>
- <properties>
- <ollama4j.version>1.1.5</ollama4j.version>
- </properties>
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-agents-api</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-agents-plan</artifactId>
+ <version>${project.version}</version>
+ </dependency>
- <modules>
- <module>chat-models</module>
- <module>embedding-models</module>
- </modules>
+ <dependency>
+ <groupId>com.openai</groupId>
+ <artifactId>openai-java</artifactId>
+ <version>${openai.version}</version>
+ </dependency>
+ </dependencies>
+
+</project>
-</project>
\ No newline at end of file
diff --git
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
new file mode 100644
index 0000000..ff15cf6
--- /dev/null
+++
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java
@@ -0,0 +1,438 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.integrations.chatmodels.openai;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.openai.client.OpenAIClient;
+import com.openai.client.okhttp.OpenAIOkHttpClient;
+import com.openai.core.JsonValue;
+import com.openai.models.ChatModel;
+import com.openai.models.FunctionDefinition;
+import com.openai.models.FunctionParameters;
+import com.openai.models.ReasoningEffort;
+import com.openai.models.chat.completions.ChatCompletion;
+import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
+import com.openai.models.chat.completions.ChatCompletionCreateParams;
+import com.openai.models.chat.completions.ChatCompletionFunctionTool;
+import com.openai.models.chat.completions.ChatCompletionMessage;
+import
com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall;
+import com.openai.models.chat.completions.ChatCompletionMessageParam;
+import com.openai.models.chat.completions.ChatCompletionMessageToolCall;
+import com.openai.models.chat.completions.ChatCompletionSystemMessageParam;
+import com.openai.models.chat.completions.ChatCompletionTool;
+import com.openai.models.chat.completions.ChatCompletionToolMessageParam;
+import com.openai.models.chat.completions.ChatCompletionUserMessageParam;
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.chat.model.BaseChatModelConnection;
+import org.apache.flink.agents.api.resource.Resource;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.apache.flink.agents.api.resource.ResourceType;
+import org.apache.flink.agents.api.tools.Tool;
+import org.apache.flink.agents.api.tools.ToolMetadata;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+
+/**
+ * A chat model integration for the OpenAI Chat Completions service using the
official Java SDK.
+ *
+ * <p>Supported connection parameters:
+ *
+ * <ul>
+ * <li><b>api_key</b> (required): OpenAI API key
+ * <li><b>api_base_url</b> (optional): Base URL for OpenAI API (defaults to
+ * https://api.openai.com/v1)
+ * <li><b>timeout</b> (optional): Timeout in seconds for API requests
+ * <li><b>max_retries</b> (optional): Maximum number of retry attempts
(default: 2)
+ * <li><b>default_headers</b> (optional): Map of default headers to include
in all requests
+ * <li><b>model</b> (optional): Default model to use if not specified in
setup
+ * </ul>
+ *
+ * <p>Example usage:
+ *
+ * <pre>{@code
+ * public class MyAgent extends Agent {
+ * @ChatModelConnection
+ * public static ResourceDesc openAI() {
+ * return
ResourceDescriptor.Builder.newBuilder(OpenAIChatModelConnection.class.getName())
+ * .addInitialArgument("api_key", System.getenv("OPENAI_API_KEY"))
+ * .addInitialArgument("api_base_url", "https://api.openai.com/v1")
+ * .addInitialArgument("timeout", 120)
+ * .addInitialArgument("max_retries", 3)
+ * .addInitialArgument("default_headers",
Map.of("X-Custom-Header", "value"))
+ * .build();
+ * }
+ * }
+ * }</pre>
+ */
+public class OpenAIChatModelConnection extends BaseChatModelConnection {
+
+ private static final TypeReference<Map<String, Object>> MAP_TYPE = new
TypeReference<>() {};
+
+ private final ObjectMapper mapper = new ObjectMapper();
+ private final OpenAIClient client;
+ private final String defaultModel;
+
+ public OpenAIChatModelConnection(
+ ResourceDescriptor descriptor, BiFunction<String, ResourceType,
Resource> getResource) {
+ super(descriptor, getResource);
+
+ String apiKey = descriptor.getArgument("api_key");
+ if (apiKey == null || apiKey.isBlank()) {
+ throw new IllegalArgumentException("api_key should not be null or
empty.");
+ }
+
+ OpenAIOkHttpClient.Builder builder = new
OpenAIOkHttpClient.Builder().apiKey(apiKey);
+
+ String apiBaseUrl = descriptor.getArgument("api_base_url");
+ if (apiBaseUrl != null && !apiBaseUrl.isBlank()) {
+ builder.baseUrl(apiBaseUrl);
+ }
+
+ Integer timeoutSeconds = descriptor.getArgument("timeout");
+ if (timeoutSeconds != null && timeoutSeconds > 0) {
+ builder.timeout(Duration.ofSeconds(timeoutSeconds));
+ }
+
+ Integer maxRetries = descriptor.getArgument("max_retries");
+ if (maxRetries != null && maxRetries >= 0) {
+ builder.maxRetries(maxRetries);
+ }
+
+ Map<String, String> defaultHeaders =
descriptor.getArgument("default_headers");
+ if (defaultHeaders != null && !defaultHeaders.isEmpty()) {
+ for (Map.Entry<String, String> header : defaultHeaders.entrySet())
{
+ builder.putHeader(header.getKey(), header.getValue());
+ }
+ }
+
+ this.defaultModel = descriptor.getArgument("model");
+ this.client = builder.build();
+ }
+
+ @Override
+ public ChatMessage chat(
+ List<ChatMessage> messages, List<Tool> tools, Map<String, Object>
arguments) {
+ try {
+ ChatCompletionCreateParams params = buildRequest(messages, tools,
arguments);
+ ChatCompletion completion =
client.chat().completions().create(params);
+ return convertResponse(completion);
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to call OpenAI chat completions
API.", e);
+ }
+ }
+
+ private ChatCompletionCreateParams buildRequest(
+ List<ChatMessage> messages, List<Tool> tools, Map<String, Object>
rawArguments) {
+ Map<String, Object> arguments =
+ rawArguments != null ? new HashMap<>(rawArguments) : new
HashMap<>();
+
+ boolean strictMode = Boolean.TRUE.equals(arguments.remove("strict"));
+ String modelName = (String) arguments.remove("model");
+ if (modelName == null || modelName.isBlank()) {
+ modelName = this.defaultModel;
+ }
+
+ ChatCompletionCreateParams.Builder builder =
+ ChatCompletionCreateParams.builder()
+ .model(ChatModel.of(modelName))
+ .messages(
+ messages.stream()
+ .map(this::convertToOpenAIMessage)
+ .collect(Collectors.toList()));
+
+ if (tools != null && !tools.isEmpty()) {
+ builder.tools(convertTools(tools, strictMode));
+ }
+
+ Object temperature = arguments.remove("temperature");
+ if (temperature instanceof Number) {
+ builder.temperature(((Number) temperature).doubleValue());
+ }
+
+ Object maxTokens = arguments.remove("max_tokens");
+ if (maxTokens instanceof Number) {
+ builder.maxCompletionTokens(((Number) maxTokens).longValue());
+ }
+
+ Object logprobs = arguments.remove("logprobs");
+ boolean logprobsEnabled = Boolean.TRUE.equals(logprobs);
+ if (logprobsEnabled) {
+ builder.logprobs(true);
+ Object topLogprobs = arguments.remove("top_logprobs");
+ if (topLogprobs instanceof Number) {
+ builder.topLogprobs(((Number) topLogprobs).longValue());
+ }
+ } else {
+ arguments.remove("top_logprobs");
+ }
+
+ Object reasoningEffort = arguments.remove("reasoning_effort");
+ if (reasoningEffort instanceof String) {
+ builder.reasoningEffort(ReasoningEffort.of((String)
reasoningEffort));
+ }
+
+ @SuppressWarnings("unchecked")
+ Map<String, Object> additionalKwargs =
+ (Map<String, Object>) arguments.remove("additional_kwargs");
+ if (additionalKwargs != null) {
+ additionalKwargs.forEach(
+ (key, value) -> builder.putAdditionalBodyProperty(key,
toJsonValue(value)));
+ }
+
+ return builder.build();
+ }
+
+ private List<ChatCompletionTool> convertTools(List<Tool> tools, boolean
strictMode) {
+ List<ChatCompletionTool> openaiTools = new ArrayList<>(tools.size());
+ for (Tool tool : tools) {
+ ToolMetadata metadata = tool.getMetadata();
+ FunctionDefinition.Builder functionBuilder =
+ FunctionDefinition.builder()
+ .name(metadata.getName())
+ .description(metadata.getDescription());
+
+ String schema = metadata.getInputSchema();
+ if (schema != null && !schema.isBlank()) {
+ functionBuilder.parameters(parseFunctionParameters(schema));
+ }
+
+ if (strictMode) {
+ functionBuilder.strict(true);
+ }
+
+ ChatCompletionFunctionTool functionTool =
+ ChatCompletionFunctionTool.builder()
+ .function(functionBuilder.build())
+ .type(JsonValue.from("function"))
+ .build();
+
+ openaiTools.add(ChatCompletionTool.ofFunction(functionTool));
+ }
+ return openaiTools;
+ }
+
+ private FunctionParameters parseFunctionParameters(String schemaJson) {
+ try {
+ JsonNode root = mapper.readTree(schemaJson);
+ if (root == null || !root.isObject()) {
+ return FunctionParameters.builder().build();
+ }
+
+ FunctionParameters.Builder builder = FunctionParameters.builder();
+ root.fields()
+ .forEachRemaining(
+ entry ->
+ builder.putAdditionalProperty(
+ entry.getKey(),
+
JsonValue.fromJsonNode(entry.getValue())));
+ return builder.build();
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Failed to parse tool schema JSON.", e);
+ }
+ }
+
+ private ChatCompletionMessageParam convertToOpenAIMessage(ChatMessage
message) {
+ MessageRole role = message.getRole();
+ String content = Optional.ofNullable(message.getContent()).orElse("");
+
+ switch (role) {
+ case SYSTEM:
+ return ChatCompletionMessageParam.ofSystem(
+
ChatCompletionSystemMessageParam.builder().content(content).build());
+ case USER:
+ return ChatCompletionMessageParam.ofUser(
+
ChatCompletionUserMessageParam.builder().content(content).build());
+ case ASSISTANT:
+ ChatCompletionAssistantMessageParam.Builder assistantBuilder =
+ ChatCompletionAssistantMessageParam.builder();
+ if (!content.isEmpty()) {
+ assistantBuilder.content(content);
+ }
+ List<Map<String, Object>> toolCalls = message.getToolCalls();
+ if (toolCalls != null && !toolCalls.isEmpty()) {
+
assistantBuilder.toolCalls(convertAssistantToolCalls(toolCalls));
+ }
+ Object refusal = message.getExtraArgs().get("refusal");
+ if (refusal instanceof String) {
+ assistantBuilder.refusal((String) refusal);
+ }
+ return
ChatCompletionMessageParam.ofAssistant(assistantBuilder.build());
+ case TOOL:
+ ChatCompletionToolMessageParam.Builder toolBuilder =
+
ChatCompletionToolMessageParam.builder().content(content);
+ Object toolCallId = message.getExtraArgs().get("externalId");
+ if (toolCallId == null) {
+ throw new IllegalArgumentException(
+ "Tool message must have an externalId in
extraArgs.");
+ }
+ toolBuilder.toolCallId(toolCallId.toString());
+ return ChatCompletionMessageParam.ofTool(toolBuilder.build());
+ default:
+ throw new IllegalArgumentException("Unsupported role: " +
role);
+ }
+ }
+
+ private List<ChatCompletionMessageToolCall> convertAssistantToolCalls(
+ List<Map<String, Object>> toolCalls) {
+ List<ChatCompletionMessageToolCall> result = new
ArrayList<>(toolCalls.size());
+ for (Map<String, Object> call : toolCalls) {
+ Object type = call.getOrDefault("type", "function");
+ if (!"function".equals(String.valueOf(type))) {
+ continue;
+ }
+
+ Map<String, Object> functionPayload = toMap(call.get("function"));
+ ChatCompletionMessageFunctionToolCall.Function.Builder
functionBuilder =
+ ChatCompletionMessageFunctionToolCall.Function.builder();
+
+ Object functionName = functionPayload.get("name");
+ if (functionName != null) {
+ functionBuilder.name(functionName.toString());
+ }
+
+ Object arguments = functionPayload.get("arguments");
+ functionBuilder.arguments(serializeArguments(arguments));
+
+ Object idObj = call.get("id");
+ if (idObj == null) {
+ throw new IllegalArgumentException("Tool call must have an
id.");
+ }
+ String toolCallId = idObj.toString();
+
+ ChatCompletionMessageFunctionToolCall.Builder toolCallBuilder =
+ ChatCompletionMessageFunctionToolCall.builder()
+ .id(toolCallId)
+ .function(functionBuilder.build())
+ .type(JsonValue.from(String.valueOf(type)));
+
+
result.add(ChatCompletionMessageToolCall.ofFunction(toolCallBuilder.build()));
+ }
+ return result;
+ }
+
+ private ChatMessage convertResponse(ChatCompletion completion) {
+ List<ChatCompletion.Choice> choices = completion.choices();
+ if (choices.isEmpty()) {
+ throw new IllegalStateException("OpenAI response did not contain
any choices.");
+ }
+
+ ChatCompletionMessage message = choices.get(0).message();
+ String content = message.content().orElse("");
+ ChatMessage response = ChatMessage.assistant(content);
+
+ message.refusal().ifPresent(refusal ->
response.getExtraArgs().put("refusal", refusal));
+
+ List<ChatCompletionMessageToolCall> toolCalls =
message.toolCalls().orElse(List.of());
+ if (!toolCalls.isEmpty()) {
+ response.setToolCalls(convertResponseToolCalls(toolCalls));
+ }
+
+ return response;
+ }
+
+ private List<Map<String, Object>> convertResponseToolCalls(
+ List<ChatCompletionMessageToolCall> toolCalls) {
+ List<Map<String, Object>> result = new ArrayList<>(toolCalls.size());
+ for (ChatCompletionMessageToolCall toolCall : toolCalls) {
+ if (!toolCall.isFunction()) {
+ continue;
+ }
+
+ ChatCompletionMessageFunctionToolCall functionToolCall =
toolCall.asFunction();
+ Map<String, Object> callMap = new LinkedHashMap<>();
+ String toolCallId = functionToolCall.id();
+ if (toolCallId == null || toolCallId.isBlank()) {
+ throw new IllegalStateException("OpenAI tool call ID is null
or empty.");
+ }
+
+ callMap.put("id", toolCallId);
+ callMap.put("type", "function");
+
+ ChatCompletionMessageFunctionToolCall.Function function =
functionToolCall.function();
+ Map<String, Object> functionMap = new LinkedHashMap<>();
+ functionMap.put("name", function.name());
+ functionMap.put("arguments", parseArguments(function.arguments()));
+ callMap.put("function", functionMap);
+ callMap.put("original_id", toolCallId);
+ result.add(callMap);
+ }
+ return result;
+ }
+
+ private Map<String, Object> parseArguments(String arguments) {
+ if (arguments == null || arguments.isBlank()) {
+ return Map.of();
+ }
+ try {
+ return mapper.readValue(arguments, MAP_TYPE);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Failed to parse tool arguments: " +
arguments, e);
+ }
+ }
+
+ private JsonValue toJsonValue(Object value) {
+ if (value instanceof JsonValue) {
+ return (JsonValue) value;
+ }
+ if (value instanceof String
+ || value instanceof Number
+ || value instanceof Boolean
+ || value == null) {
+ return JsonValue.from(value);
+ }
+ return JsonValue.fromJsonNode(mapper.valueToTree(value));
+ }
+
+ private String serializeArguments(Object arguments) {
+ if (arguments == null) {
+ return "{}";
+ }
+ if (arguments instanceof String) {
+ return (String) arguments;
+ }
+ try {
+ return mapper.writeValueAsString(arguments);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Failed to serialize tool call
arguments.", e);
+ }
+ }
+
+ private Map<String, Object> toMap(Object value) {
+ if (value instanceof Map) {
+ @SuppressWarnings("unchecked")
+ Map<String, Object> casted = (Map<String, Object>) value;
+ return new LinkedHashMap<>(casted);
+ }
+ if (value == null) {
+ return new LinkedHashMap<>();
+ }
+ return mapper.convertValue(value, MAP_TYPE);
+ }
+}
diff --git
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelSetup.java
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelSetup.java
new file mode 100644
index 0000000..270dcb9
--- /dev/null
+++
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelSetup.java
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.integrations.chatmodels.openai;
+
+import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
+import org.apache.flink.agents.api.resource.Resource;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.apache.flink.agents.api.resource.ResourceType;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.function.BiFunction;
+
+/**
+ * Chat model setup for the OpenAI Chat Completions API.
+ *
+ * <p>Responsible for providing per-chat configuration such as model,
temperature, tool bindings,
+ * and additional OpenAI parameters. The setup delegates execution to {@link
+ * OpenAIChatModelConnection}.
+ *
+ * <p>Example usage:
+ *
+ * <pre>{@code
+ * public class MyAgent extends Agent {
+ * @ChatModelSetup
+ * public static ResourceDesc openAI() {
+ * return
ResourceDescriptor.Builder.newBuilder(OpenAIChatModelSetup.class.getName())
+ * .addInitialArgument("connection", "myOpenAIConnection")
+ * .addInitialArgument("model", "gpt-4o-mini")
+ * .addInitialArgument("temperature", 0.3d)
+ * .addInitialArgument("max_tokens", 500)
+ * .addInitialArgument("strict", true)
+ * .addInitialArgument("reasoning_effort", "medium")
+ * .addInitialArgument("tools", List.of("convertTemperature",
"calculateBMI"))
+ * .addInitialArgument(
+ * "additional_kwargs",
+ * Map.of("seed", 42, "user", "user-123"))
+ * .build();
+ * }
+ * }
+ * }</pre>
+ */
+public class OpenAIChatModelSetup extends BaseChatModelSetup {
+
+ private static final String DEFAULT_MODEL = "gpt-3.5-turbo";
+ private static final double DEFAULT_TEMPERATURE = 0.1d;
+ private static final int DEFAULT_TOP_LOGPROBS = 0;
+ private static final boolean DEFAULT_STRICT = false;
+ private static final Set<String> VALID_REASONING_EFFORTS = Set.of("low",
"medium", "high");
+
+ private final Double temperature;
+ private final Integer maxTokens;
+ private final Boolean logprobs;
+ private final Integer topLogprobs;
+ private final Boolean strict;
+ private final String reasoningEffort;
+ private final Map<String, Object> additionalArguments;
+
+ public OpenAIChatModelSetup(
+ ResourceDescriptor descriptor, BiFunction<String, ResourceType,
Resource> getResource) {
+ super(descriptor, getResource);
+ this.temperature =
+
Optional.ofNullable(descriptor.<Number>getArgument("temperature"))
+ .map(Number::doubleValue)
+ .orElse(DEFAULT_TEMPERATURE);
+ if (this.temperature < 0.0 || this.temperature > 2.0) {
+ throw new IllegalArgumentException("temperature must be between
0.0 and 2.0");
+ }
+
+ this.maxTokens =
+
Optional.ofNullable(descriptor.<Number>getArgument("max_tokens"))
+ .map(Number::intValue)
+ .orElse(null);
+ if (this.maxTokens != null && this.maxTokens <= 0) {
+ throw new IllegalArgumentException("max_tokens must be greater
than 0");
+ }
+
+ this.logprobs = descriptor.getArgument("logprobs");
+
+ this.topLogprobs =
+
Optional.ofNullable(descriptor.<Number>getArgument("top_logprobs"))
+ .map(Number::intValue)
+ .orElse(DEFAULT_TOP_LOGPROBS);
+ if (this.topLogprobs < 0 || this.topLogprobs > 20) {
+ throw new IllegalArgumentException("top_logprobs must be between 0
and 20");
+ }
+
+ this.strict =
+ Optional.ofNullable(descriptor.<Boolean>getArgument("strict"))
+ .orElse(DEFAULT_STRICT);
+
+ this.reasoningEffort = descriptor.getArgument("reasoning_effort");
+ if (this.reasoningEffort != null
+ && !VALID_REASONING_EFFORTS.contains(this.reasoningEffort)) {
+ throw new IllegalArgumentException(
+ "reasoning_effort must be one of: low, medium, high");
+ }
+
+ Map<String, Object> additional =
+ Optional.ofNullable(
+ descriptor.<Map<String,
Object>>getArgument("additional_kwargs"))
+ .map(HashMap::new)
+ .orElseGet(HashMap::new);
+ this.additionalArguments = additional;
+
+ if (this.model == null || this.model.isBlank()) {
+ this.model = DEFAULT_MODEL;
+ }
+ }
+
+ public OpenAIChatModelSetup(
+ String model,
+ double temperature,
+ Integer maxTokens,
+ Boolean logprobs,
+ Integer topLogprobs,
+ Boolean strict,
+ String reasoningEffort,
+ Map<String, Object> additionalArguments,
+ List<String> tools,
+ BiFunction<String, ResourceType, Resource> getResource) {
+ this(
+ createDescriptor(
+ model,
+ temperature,
+ maxTokens,
+ logprobs,
+ topLogprobs,
+ strict,
+ reasoningEffort,
+ additionalArguments,
+ tools),
+ getResource);
+ }
+
+ @Override
+ public Map<String, Object> getParameters() {
+ Map<String, Object> parameters = new HashMap<>();
+ if (model != null) {
+ parameters.put("model", model);
+ }
+ parameters.put("temperature", temperature);
+ if (maxTokens != null) {
+ parameters.put("max_tokens", maxTokens);
+ }
+ if (Boolean.TRUE.equals(logprobs)) {
+ parameters.put("logprobs", logprobs);
+ parameters.put("top_logprobs", topLogprobs);
+ }
+ if (strict) {
+ parameters.put("strict", strict);
+ }
+ if (reasoningEffort != null) {
+ parameters.put("reasoning_effort", reasoningEffort);
+ }
+ if (additionalArguments != null && !additionalArguments.isEmpty()) {
+ parameters.put("additional_kwargs", additionalArguments);
+ }
+ return parameters;
+ }
+
+ private static ResourceDescriptor createDescriptor(
+ String model,
+ double temperature,
+ Integer maxTokens,
+ Boolean logprobs,
+ Integer topLogprobs,
+ Boolean strict,
+ String reasoningEffort,
+ Map<String, Object> additionalArguments,
+ List<String> tools) {
+ ResourceDescriptor.Builder builder =
+
ResourceDescriptor.Builder.newBuilder(OpenAIChatModelSetup.class.getName())
+ .addInitialArgument("model", model)
+ .addInitialArgument("temperature", temperature);
+
+ if (maxTokens != null) {
+ builder.addInitialArgument("max_tokens", maxTokens);
+ }
+ if (logprobs != null) {
+ builder.addInitialArgument("logprobs", logprobs);
+ }
+ if (topLogprobs != null) {
+ builder.addInitialArgument("top_logprobs", topLogprobs);
+ }
+ if (strict != null) {
+ builder.addInitialArgument("strict", strict);
+ }
+ if (reasoningEffort != null) {
+ builder.addInitialArgument("reasoning_effort", reasoningEffort);
+ }
+ if (additionalArguments != null && !additionalArguments.isEmpty()) {
+ builder.addInitialArgument("additional_kwargs",
additionalArguments);
+ }
+ if (tools != null && !tools.isEmpty()) {
+ builder.addInitialArgument("tools", tools);
+ }
+
+ return builder.build();
+ }
+}
diff --git a/integrations/chat-models/pom.xml b/integrations/chat-models/pom.xml
index dafeae9..4a09cc6 100644
--- a/integrations/chat-models/pom.xml
+++ b/integrations/chat-models/pom.xml
@@ -33,6 +33,7 @@ under the License.
<modules>
<module>azureai</module>
<module>ollama</module>
+ <module>openai</module>
</modules>
-</project>
\ No newline at end of file
+</project>
diff --git a/integrations/pom.xml b/integrations/pom.xml
index a145fb7..dd21927 100644
--- a/integrations/pom.xml
+++ b/integrations/pom.xml
@@ -32,6 +32,7 @@ under the License.
<properties>
<ollama4j.version>1.1.5</ollama4j.version>
+ <openai.version>4.8.0</openai.version>
</properties>
<modules>
@@ -39,4 +40,4 @@ under the License.
<module>embedding-models</module>
</modules>
-</project>
\ No newline at end of file
+</project>
diff --git
a/plan/src/main/java/org/apache/flink/agents/plan/tools/SchemaUtils.java
b/plan/src/main/java/org/apache/flink/agents/plan/tools/SchemaUtils.java
index 44e0882..6829dfa 100644
--- a/plan/src/main/java/org/apache/flink/agents/plan/tools/SchemaUtils.java
+++ b/plan/src/main/java/org/apache/flink/agents/plan/tools/SchemaUtils.java
@@ -61,7 +61,9 @@ public class SchemaUtils {
}
Map<String, Object> paramSchema = getParamSchema(param);
- paramSchema.put("description", paramDescription);
+ if (paramDescription != null) {
+ paramSchema.put("description", paramDescription);
+ }
properties.put(paramName, paramSchema);
}