This is an automated email from the ASF dual-hosted git repository.

sxnan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 7ab3e03ec9be7ec3363f7120821678cb8acae9cb
Author: Yunfeng Zhou <[email protected]>
AuthorDate: Tue Oct 28 18:49:51 2025 +0800

    [FLINK-38581][model] Model Function supports error handling strategy
---
 .../generated/model_openai_common_section.html     |  18 ++
 .../shortcodes/generated/openai_configuration.html |  18 ++
 flink-models/flink-model-openai/pom.xml            |   6 +
 .../model/openai/AbstractOpenAIModelFunction.java  |  76 +++++-
 .../model/openai/OpenAIChatModelFunction.java      |  13 +-
 .../model/openai/OpenAIEmbeddingModelFunction.java |  10 +-
 .../model/openai/OpenAIModelProviderFactory.java   |   3 +
 .../apache/flink/model/openai/OpenAIOptions.java   |  28 ++
 .../ModelFunctionErrorHandlingStrategyTest.java    | 294 +++++++++++++++++++++
 .../src/test/resources/log4j2-test.properties      |  28 ++
 10 files changed, 475 insertions(+), 19 deletions(-)

diff --git a/docs/layouts/shortcodes/generated/model_openai_common_section.html 
b/docs/layouts/shortcodes/generated/model_openai_common_section.html
index e003338b58e..435be03a136 100644
--- a/docs/layouts/shortcodes/generated/model_openai_common_section.html
+++ b/docs/layouts/shortcodes/generated/model_openai_common_section.html
@@ -26,6 +26,12 @@
             <td>String</td>
             <td>Full URL of the OpenAI API endpoint, e.g., <code 
class="highlighter-rouge">https://api.openai.com/v1/chat/completions</code> or 
<code class="highlighter-rouge">https://api.openai.com/v1/embeddings</code></td>
         </tr>
+        <tr>
+            <td><h5>error-handling-strategy</h5></td>
+            <td style="word-wrap: break-word;">RETRY</td>
+            <td><p>Enum</p></td>
+            <td>Strategy for handling errors during model requests.<br /><br 
/>Possible values:<ul><li>"RETRY": Retry sending the 
request.</li><li>"FAILOVER": Throw exceptions and fail the Flink 
job.</li><li>"IGNORE": Ignore the input that caused the error and continue. The 
error itself would be recorded in log.</li></ul></td>
+        </tr>
         <tr>
             <td><h5>max-context-size</h5></td>
             <td style="word-wrap: break-word;">(none)</td>
@@ -38,5 +44,17 @@
             <td>String</td>
             <td>Model name, e.g., <code 
class="highlighter-rouge">gpt-3.5-turbo</code>, <code 
class="highlighter-rouge">text-embedding-ada-002</code>.</td>
         </tr>
+        <tr>
+            <td><h5>retry-fallback-strategy</h5></td>
+            <td style="word-wrap: break-word;">FAILOVER</td>
+            <td><p>Enum</p></td>
+            <td>Fallback strategy to employ if the retry attempts are 
exhausted. This strategy is applied when error-handling-strategy is set to 
retry.<br /><br />Possible values:<ul><li>"FAILOVER": Throw exceptions and fail 
the Flink job.</li><li>"IGNORE": Ignore the input that caused the error and 
continue. The error itself would be recorded in log.</li></ul></td>
+        </tr>
+        <tr>
+            <td><h5>retry-num</h5></td>
+            <td style="word-wrap: break-word;">100</td>
+            <td>Integer</td>
+            <td>Number of retry for OpenAI client requests.</td>
+        </tr>
     </tbody>
 </table>
diff --git a/docs/layouts/shortcodes/generated/openai_configuration.html 
b/docs/layouts/shortcodes/generated/openai_configuration.html
index 6cea54aa9fe..7a6e87e6a06 100644
--- a/docs/layouts/shortcodes/generated/openai_configuration.html
+++ b/docs/layouts/shortcodes/generated/openai_configuration.html
@@ -32,6 +32,12 @@
             <td>String</td>
             <td>Full URL of the OpenAI API endpoint, e.g., <code 
class="highlighter-rouge">https://api.openai.com/v1/chat/completions</code> or 
<code class="highlighter-rouge">https://api.openai.com/v1/embeddings</code></td>
         </tr>
+        <tr>
+            <td><h5>error-handling-strategy</h5></td>
+            <td style="word-wrap: break-word;">RETRY</td>
+            <td><p>Enum</p></td>
+            <td>Strategy for handling errors during model requests.<br /><br 
/>Possible values:<ul><li>"RETRY": Retry sending the 
request.</li><li>"FAILOVER": Throw exceptions and fail the Flink 
job.</li><li>"IGNORE": Ignore the input that caused the error and continue. The 
error itself would be recorded in log.</li></ul></td>
+        </tr>
         <tr>
             <td><h5>max-context-size</h5></td>
             <td style="word-wrap: break-word;">(none)</td>
@@ -68,6 +74,18 @@
             <td><p>Enum</p></td>
             <td>The format of the response, e.g., 'text' or 'json_object'.<br 
/><br />Possible values:<ul><li>"text"</li><li>"json_object"</li></ul></td>
         </tr>
+        <tr>
+            <td><h5>retry-fallback-strategy</h5></td>
+            <td style="word-wrap: break-word;">FAILOVER</td>
+            <td><p>Enum</p></td>
+            <td>Fallback strategy to employ if the retry attempts are 
exhausted. This strategy is applied when error-handling-strategy is set to 
retry.<br /><br />Possible values:<ul><li>"FAILOVER": Throw exceptions and fail 
the Flink job.</li><li>"IGNORE": Ignore the input that caused the error and 
continue. The error itself would be recorded in log.</li></ul></td>
+        </tr>
+        <tr>
+            <td><h5>retry-num</h5></td>
+            <td style="word-wrap: break-word;">100</td>
+            <td>Integer</td>
+            <td>Number of retry for OpenAI client requests.</td>
+        </tr>
         <tr>
             <td><h5>seed</h5></td>
             <td style="word-wrap: break-word;">(none)</td>
diff --git a/flink-models/flink-model-openai/pom.xml 
b/flink-models/flink-model-openai/pom.xml
index 984caea39d8..f251b2e0e26 100644
--- a/flink-models/flink-model-openai/pom.xml
+++ b/flink-models/flink-model-openai/pom.xml
@@ -127,6 +127,12 @@ under the License.
                        <version>${project.version}</version>
                        <scope>test</scope>
                </dependency>
+               <dependency>
+                       <groupId>org.apache.flink</groupId>
+                       <artifactId>flink-test-utils-junit</artifactId>
+                       <version>${project.version}</version>
+                       <scope>test</scope>
+               </dependency>
        </dependencies>
 
        <repositories>
diff --git 
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java
 
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java
index 1d02580e7b1..22ca44aa042 100644
--- 
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java
+++ 
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java
@@ -18,8 +18,9 @@
 
 package org.apache.flink.model.openai;
 
+import org.apache.flink.configuration.DescribedEnum;
 import org.apache.flink.configuration.ReadableConfig;
-import org.apache.flink.table.api.config.ExecutionConfigOptions;
+import org.apache.flink.configuration.description.InlineElement;
 import org.apache.flink.table.catalog.Column;
 import org.apache.flink.table.catalog.ResolvedSchema;
 import org.apache.flink.table.data.RowData;
@@ -41,13 +42,17 @@ import java.util.List;
 import java.util.concurrent.CompletableFuture;
 import java.util.stream.Collectors;
 
+import static org.apache.flink.configuration.description.TextElement.text;
+
 /** Abstract parent class for {@link AsyncPredictFunction}s for OpenAI API. */
 public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction 
{
     private static final Logger LOG = 
LoggerFactory.getLogger(AbstractOpenAIModelFunction.class);
 
     protected transient OpenAIClientAsync client;
 
+    private final ErrorHandlingStrategy errorHandlingStrategy;
     private final int numRetry;
+    private final RetryFallbackStrategy retryFallbackStrategy;
     private final String baseUrl;
     private final String apiKey;
     private final String model;
@@ -59,19 +64,16 @@ public abstract class AbstractOpenAIModelFunction extends 
AsyncPredictFunction {
         String endpoint = config.get(OpenAIOptions.ENDPOINT);
         this.baseUrl = endpoint.replaceAll(String.format("/%s/*$", 
getEndpointSuffix()), "");
         this.apiKey = config.get(OpenAIOptions.API_KEY);
-        // The model service enforces rate-limiting constraints, necessitating 
retry mechanisms in
-        // most operational scenarios. Within the asynchronous operator 
framework, the system is
-        // designed to process up to
-        // 
config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_LOOKUP_BUFFER_CAPACITY) 
concurrent
-        // requests in parallel. To mitigate potential performance degradation 
from simultaneous
-        // requests, a dynamic retry strategy is implemented where the maximum 
retry count is
-        // directly proportional to the configured parallelism level, ensuring 
robust error
-        // resilience while maintaining throughput efficiency.
+
+        this.errorHandlingStrategy = 
config.get(OpenAIOptions.ERROR_HANDLING_STRATEGY);
         this.numRetry =
-                
config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_LOOKUP_BUFFER_CAPACITY) * 10;
+                this.errorHandlingStrategy == ErrorHandlingStrategy.RETRY
+                        ? config.get(OpenAIOptions.RETRY_NUM)
+                        : 0;
         this.model = config.get(OpenAIOptions.MODEL);
         this.maxContextSize = config.get(OpenAIOptions.MAX_CONTEXT_SIZE);
         this.contextOverflowAction = 
config.get(OpenAIOptions.CONTEXT_OVERFLOW_ACTION);
+        this.retryFallbackStrategy = 
config.get(OpenAIOptions.RETRY_FALLBACK_STRATEGY);
 
         validateSingleColumnSchema(
                 factoryContext.getCatalogModel().getResolvedInputSchema(),
@@ -148,4 +150,58 @@ public abstract class AbstractOpenAIModelFunction extends 
AsyncPredictFunction {
                             column.getDataType().getLogicalType()));
         }
     }
+
+    protected Collection<RowData> handleErrorsAndRespond(Throwable t) {
+        ErrorHandlingStrategy finalErrorHandlingStrategy =
+                this.errorHandlingStrategy == ErrorHandlingStrategy.RETRY
+                        ? this.retryFallbackStrategy.strategy
+                        : this.errorHandlingStrategy;
+
+        if (finalErrorHandlingStrategy == ErrorHandlingStrategy.FAILOVER) {
+            throw new RuntimeException(t);
+        } else if (finalErrorHandlingStrategy == ErrorHandlingStrategy.IGNORE) 
{
+            return Collections.emptyList();
+        } else {
+            throw new UnsupportedOperationException(
+                    "Unsupported error handling strategy: " + 
finalErrorHandlingStrategy);
+        }
+    }
+
+    /** Strategy for handling errors during model requests. */
+    public enum ErrorHandlingStrategy implements DescribedEnum {
+        RETRY("Retry sending the request."),
+        FAILOVER("Throw exceptions and fail the Flink job."),
+        IGNORE(
+                "Ignore the input that caused the error and continue. The 
error itself would be recorded in log.");
+
+        private final String description;
+
+        ErrorHandlingStrategy(String description) {
+            this.description = description;
+        }
+
+        @Override
+        public InlineElement getDescription() {
+            return text(description);
+        }
+    }
+
+    /**
+     * The fallback strategy for when retry attempts are exhausted. It should 
be identical to {@link
+     * ErrorHandlingStrategy} except that it does not support {@link 
ErrorHandlingStrategy#RETRY}.
+     */
+    public enum RetryFallbackStrategy implements DescribedEnum {
+        FAILOVER(ErrorHandlingStrategy.FAILOVER),
+        IGNORE(ErrorHandlingStrategy.IGNORE);
+        private final ErrorHandlingStrategy strategy;
+
+        RetryFallbackStrategy(ErrorHandlingStrategy strategy) {
+            this.strategy = strategy;
+        }
+
+        @Override
+        public InlineElement getDescription() {
+            return text(strategy.description);
+        }
+    }
 }
diff --git 
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java
 
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java
index 3e96a4a2d9a..2a2d3cb0ae4 100644
--- 
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java
+++ 
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java
@@ -34,7 +34,6 @@ import 
com.openai.models.chat.completions.ChatCompletionCreateParams.ResponseFor
 
 import java.util.Arrays;
 import java.util.Collection;
-import java.util.List;
 import java.util.concurrent.CompletableFuture;
 import java.util.stream.Collectors;
 
@@ -87,13 +86,15 @@ public class OpenAIChatModelFunction extends 
AbstractOpenAIModelFunction {
                 .getOptional(OpenAIOptions.RESPONSE_FORMAT)
                 .ifPresent(x -> builder.responseFormat(x.getResponseFormat()));
 
-        return client.chat()
-                .completions()
-                .create(builder.build())
-                .thenApply(this::convertToRowData);
+        return 
client.chat().completions().create(builder.build()).handle(this::convertToRowData);
     }
 
-    private List<RowData> convertToRowData(ChatCompletion chatCompletion) {
+    private Collection<RowData> convertToRowData(
+            ChatCompletion chatCompletion, Throwable throwable) {
+        if (throwable != null) {
+            return handleErrorsAndRespond(throwable);
+        }
+
         return chatCompletion.choices().stream()
                 .map(
                         choice ->
diff --git 
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java
 
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java
index 854ca2cf5d0..fd0e78b9ace 100644
--- 
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java
+++ 
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java
@@ -33,7 +33,6 @@ import 
com.openai.models.embeddings.EmbeddingCreateParams.EncodingFormat;
 import javax.annotation.Nullable;
 
 import java.util.Collection;
-import java.util.List;
 import java.util.concurrent.CompletableFuture;
 import java.util.stream.Collectors;
 
@@ -73,10 +72,15 @@ public class OpenAIEmbeddingModelFunction extends 
AbstractOpenAIModelFunction {
             builder.dimensions(dimensions);
         }
 
-        return 
client.embeddings().create(builder.build()).thenApply(this::convertToRowData);
+        return 
client.embeddings().create(builder.build()).handle(this::convertToRowData);
     }
 
-    private List<RowData> convertToRowData(CreateEmbeddingResponse response) {
+    private Collection<RowData> convertToRowData(
+            CreateEmbeddingResponse response, Throwable throwable) {
+        if (throwable != null) {
+            return handleErrorsAndRespond(throwable);
+        }
+
         return response.data().stream()
                 .map(
                         embedding ->
diff --git 
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIModelProviderFactory.java
 
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIModelProviderFactory.java
index 4c5b340528a..12cbb662255 100644
--- 
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIModelProviderFactory.java
+++ 
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIModelProviderFactory.java
@@ -69,6 +69,9 @@ public class OpenAIModelProviderFactory implements 
ModelProviderFactory {
         Set<ConfigOption<?>> set = new HashSet<>();
         set.add(OpenAIOptions.MAX_CONTEXT_SIZE);
         set.add(OpenAIOptions.CONTEXT_OVERFLOW_ACTION);
+        set.add(OpenAIOptions.ERROR_HANDLING_STRATEGY);
+        set.add(OpenAIOptions.RETRY_NUM);
+        set.add(OpenAIOptions.RETRY_FALLBACK_STRATEGY);
         set.add(OpenAIOptions.SYSTEM_PROMPT);
         set.add(OpenAIOptions.TEMPERATURE);
         set.add(OpenAIOptions.TOP_P);
diff --git 
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIOptions.java
 
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIOptions.java
index 8e605be51e0..84cf72f5444 100644
--- 
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIOptions.java
+++ 
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIOptions.java
@@ -83,6 +83,34 @@ public class OpenAIOptions {
                                     .text("Action to handle context 
overflows.")
                                     .build());
 
+    @Documentation.Section({Documentation.Sections.MODEL_OPENAI_COMMON})
+    public static final 
ConfigOption<AbstractOpenAIModelFunction.ErrorHandlingStrategy>
+            ERROR_HANDLING_STRATEGY =
+                    ConfigOptions.key("error-handling-strategy")
+                            
.enumType(AbstractOpenAIModelFunction.ErrorHandlingStrategy.class)
+                            
.defaultValue(AbstractOpenAIModelFunction.ErrorHandlingStrategy.RETRY)
+                            .withDescription("Strategy for handling errors 
during model requests.");
+
+    // The model service enforces rate-limiting constraints, necessitating 
retry mechanisms in
+    // most operational scenarios.
+    @Documentation.Section({Documentation.Sections.MODEL_OPENAI_COMMON})
+    public static final ConfigOption<Integer> RETRY_NUM =
+            ConfigOptions.key("retry-num")
+                    .intType()
+                    .defaultValue(100)
+                    .withDescription("Number of retry for OpenAI client 
requests.");
+
+    @Documentation.Section({Documentation.Sections.MODEL_OPENAI_COMMON})
+    public static final 
ConfigOption<AbstractOpenAIModelFunction.RetryFallbackStrategy>
+            RETRY_FALLBACK_STRATEGY =
+                    ConfigOptions.key("retry-fallback-strategy")
+                            
.enumType(AbstractOpenAIModelFunction.RetryFallbackStrategy.class)
+                            .defaultValue(
+                                    
AbstractOpenAIModelFunction.RetryFallbackStrategy.FAILOVER)
+                            .withDescription(
+                                    "Fallback strategy to employ if the retry 
attempts are exhausted."
+                                            + " This strategy is applied when 
error-handling-strategy is set to retry.");
+
     // ------------------------------------------------------------------------
     // Options for Chat Completion Model Functions
     // ------------------------------------------------------------------------
diff --git 
a/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/ModelFunctionErrorHandlingStrategyTest.java
 
b/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/ModelFunctionErrorHandlingStrategyTest.java
new file mode 100644
index 00000000000..f7d293a201a
--- /dev/null
+++ 
b/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/ModelFunctionErrorHandlingStrategyTest.java
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.openai;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.core.testutils.FlinkAssertions;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.TableResult;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableEnvironmentImpl;
+import org.apache.flink.table.catalog.CatalogManager;
+import org.apache.flink.table.catalog.CatalogModel;
+import org.apache.flink.table.catalog.ObjectIdentifier;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
+
+import okhttp3.mockwebserver.Dispatcher;
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import okhttp3.mockwebserver.RecordedRequest;
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for Model Function metrics and error handling strategy. */
+@SuppressWarnings("unchecked")
+public class ModelFunctionErrorHandlingStrategyTest {
+    private static final String RETRYABLE_INPUT_DATA = "Return a retryable 
error code.";
+
+    private static final Row[] INPUT_DATA =
+            new Row[] {
+                Row.of("Why is sky blue?", 1.0),
+                Row.of("Why are stars shining?", 1.0),
+                Row.of("What is the meaning of life?", 1.0),
+                Row.of("What is a heart attack?", 0.0),
+                Row.of("How fast can human run?", 1.0),
+                Row.of("Why is the ocean salty?", 1.0),
+                Row.of("How do airplanes fly?", 1.0),
+                Row.of("What causes earthquakes?", 1.0),
+                Row.of("Why do leaves change color in autumn?", 0.0),
+                Row.of("How does photosynthesis work?", 1.0)
+            };
+
+    private static final Schema INPUT_SCHEMA =
+            Schema.newBuilder().column("input", DataTypes.STRING()).build();
+    private static final Schema OUTPUT_SCHEMA =
+            Schema.newBuilder().column("content", DataTypes.STRING()).build();
+
+    private static MockWebServer server;
+
+    private String modelName;
+
+    private Map<String, String> modelOptions;
+
+    private StreamTableEnvironment tEnv;
+
+    @BeforeAll
+    public static void beforeAll() throws IOException {
+        server = new MockWebServer();
+        server.setDispatcher(new TestDispatcher());
+        server.start();
+    }
+
+    @AfterAll
+    public static void afterAll() throws IOException {
+        if (server != null) {
+            server.close();
+        }
+    }
+
+    @BeforeEach
+    public void setup() {
+        modelName = "Model" + System.currentTimeMillis();
+
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        DataStream<Row> source =
+                env.fromData(INPUT_DATA)
+                        .returns(
+                                Types.ROW_NAMED(
+                                        new String[] {"input", 
"invalid_input"},
+                                        Types.STRING,
+                                        Types.DOUBLE));
+        tEnv = StreamTableEnvironment.create(env);
+        Table sourceTable =
+                tEnv.fromDataStream(
+                        source,
+                        Schema.newBuilder()
+                                .column("input", DataTypes.STRING())
+                                .column("invalid_input", DataTypes.DOUBLE())
+                                .build());
+        tEnv.createTemporaryView("MyTable", sourceTable);
+
+        modelOptions = new HashMap<>();
+        modelOptions.put("provider", "openai");
+        modelOptions.put("endpoint", 
server.url("/chat/completions").toString());
+        modelOptions.put("api-key", "foobar");
+        modelOptions.put("model", "qwen-turbo");
+
+        TestDispatcher.INVOKE_COUNT.set(0);
+    }
+
+    @AfterEach
+    public void afterEach() {
+        assertThat(OpenAIUtils.getCache()).isEmpty();
+    }
+
+    @Test
+    public void testSuccessInvoke() {
+        createModel();
+        TableResult tableResult =
+                tEnv.executeSql(
+                        String.format(
+                                "SELECT input, content FROM ML_PREDICT(TABLE 
MyTable, MODEL %s, DESCRIPTOR(`input`))",
+                                modelName));
+        List<Row> result = IteratorUtils.toList(tableResult.collect());
+        assertThat(result).hasSize(INPUT_DATA.length);
+        for (Row row : result) {
+            assertThat(row.getField(0)).isInstanceOf(String.class);
+            assertThat(row.getField(1)).isInstanceOf(String.class);
+            assertThat((String) row.getFieldAs(1)).isNotEmpty();
+        }
+    }
+
+    @Test
+    public void testIgnoreError() {
+        // The "t" in localhost is missing on purpose to test against invalid 
endpoint
+        modelOptions.put("endpoint", "http://localhos:9999/chat/completions";);
+        modelOptions.put("error-handling-strategy", "ignore");
+        createModel();
+        TableResult tableResult =
+                tEnv.executeSql(
+                        String.format(
+                                "SELECT input, content FROM ML_PREDICT(TABLE 
MyTable, MODEL %s, DESCRIPTOR(`input`))",
+                                modelName));
+        List<Row> result = IteratorUtils.toList(tableResult.collect());
+        assertThat(result).isEmpty();
+    }
+
+    @Test
+    public void testFailoverStrategy() {
+        modelOptions.put("error-handling-strategy", "failover");
+        createModel();
+        TableResult tableResult =
+                tEnv.executeSql(
+                        String.format(
+                                "WITH v(input) AS (SELECT * FROM (VALUES 
('%s'))) "
+                                        + "SELECT * FROM ML_PREDICT( "
+                                        + "  TABLE v, "
+                                        + "  MODEL `%s`, "
+                                        + "  DESCRIPTOR(`input`) "
+                                        + ")",
+                                RETRYABLE_INPUT_DATA, modelName));
+        assertThatThrownBy(() -> IteratorUtils.toList(tableResult.collect()))
+                .rootCause()
+                .satisfies(
+                        FlinkAssertions.anyCauseMatches(
+                                "com.openai.errors.RateLimitException: 429"));
+    }
+
+    @Test
+    public void testRetryWithFailoverStrategy() {
+        modelOptions.put("retry-num", "3");
+        modelOptions.put("error-handling-strategy", "retry");
+        modelOptions.put("retry-fallback-strategy", "failover");
+        createModel();
+        TableResult tableResult =
+                tEnv.executeSql(
+                        String.format(
+                                "WITH v(input) AS (SELECT * FROM (VALUES 
('%s'))) "
+                                        + "SELECT * FROM ML_PREDICT( "
+                                        + "  TABLE v, "
+                                        + "  MODEL `%s`, "
+                                        + "  DESCRIPTOR(`input`) "
+                                        + ")",
+                                RETRYABLE_INPUT_DATA, modelName));
+        assertThatThrownBy(() -> IteratorUtils.toList(tableResult.collect()))
+                .rootCause()
+                .satisfies(
+                        FlinkAssertions.anyCauseMatches(
+                                "com.openai.errors.RateLimitException: 429"));
+
+        assertThat(TestDispatcher.INVOKE_COUNT.get()).isEqualTo(4); // 
retryNum + 1
+    }
+
+    private void createModel() {
+        createModel(INPUT_SCHEMA, OUTPUT_SCHEMA);
+    }
+
+    private void createModel(Schema inputSchema, Schema outputSchema) {
+        CatalogManager catalogManager = ((TableEnvironmentImpl) 
tEnv).getCatalogManager();
+        ObjectIdentifier modelIdentifier =
+                ObjectIdentifier.of(
+                        
Objects.requireNonNull(catalogManager.getCurrentCatalog()),
+                        
Objects.requireNonNull(catalogManager.getCurrentDatabase()),
+                        modelName);
+        catalogManager.createModel(
+                CatalogModel.of(inputSchema, outputSchema, modelOptions, "This 
is a new model."),
+                modelIdentifier,
+                false);
+    }
+
+    private static class TestDispatcher extends Dispatcher {
+        private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+        private static final AtomicInteger INVOKE_COUNT = new AtomicInteger(0);
+
+        @Override
+        public MockResponse dispatch(RecordedRequest request) {
+            INVOKE_COUNT.incrementAndGet();
+
+            String path = request.getRequestUrl().encodedPath();
+            String body = request.getBody().readUtf8();
+
+            if (!path.endsWith("/chat/completions")) {
+                return new MockResponse().setResponseCode(404);
+            }
+
+            try {
+                JsonNode root = OBJECT_MAPPER.readTree(body);
+
+                // The messages contain one system prompt and one user message.
+                Preconditions.checkArgument(root.get("messages").size() == 2);
+                if (root.get("messages")
+                        .get(1)
+                        .get("content")
+                        .toString()
+                        .contains(RETRYABLE_INPUT_DATA)) {
+                    return new MockResponse().setResponseCode(429);
+                }
+
+                String responseBody =
+                        "{"
+                                + "  \"id\": \"chatcmpl-1234567890ABCD\","
+                                + "  \"object\": \"chat.completion\","
+                                + "  \"created\": 1717029203,"
+                                + "  \"model\": \"gpt-3.5-turbo-0125\","
+                                + "  \"choices\": [{"
+                                + "    \"index\": 0,"
+                                + "    \"message\": {"
+                                + "      \"role\": \"assistant\","
+                                + "      \"content\": \"This is a mocked 
response\""
+                                + "    },"
+                                + "    \"finish_reason\": \"stop\""
+                                + "  }],"
+                                + "  \"usage\": {"
+                                + "    \"prompt_tokens\": 9,"
+                                + "    \"completion_tokens\": 16,"
+                                + "    \"total_tokens\": 25"
+                                + "  }"
+                                + "}";
+
+                return new MockResponse()
+                        .setHeader("Content-Type", "application/json")
+                        .setBody(responseBody);
+            } catch (Exception e) {
+                throw new RuntimeException(e);
+            }
+        }
+    }
+}
diff --git 
a/flink-models/flink-model-openai/src/test/resources/log4j2-test.properties 
b/flink-models/flink-model-openai/src/test/resources/log4j2-test.properties
new file mode 100644
index 00000000000..835c2ec9a3d
--- /dev/null
+++ b/flink-models/flink-model-openai/src/test/resources/log4j2-test.properties
@@ -0,0 +1,28 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Set root logger level to OFF to not flood build logs
+# set manually to INFO for debugging purposes
+rootLogger.level = OFF
+rootLogger.appenderRef.test.ref = TestLogger
+
+appender.testlogger.name = TestLogger
+appender.testlogger.type = CONSOLE
+appender.testlogger.target = SYSTEM_ERR
+appender.testlogger.layout.type = PatternLayout
+appender.testlogger.layout.pattern = %-4r [%t] %-5p %c %x - %m%n

Reply via email to