This is an automated email from the ASF dual-hosted git repository. sxnan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 7ab3e03ec9be7ec3363f7120821678cb8acae9cb Author: Yunfeng Zhou <[email protected]> AuthorDate: Tue Oct 28 18:49:51 2025 +0800 [FLINK-38581][model] Model Function supports error handling strategy --- .../generated/model_openai_common_section.html | 18 ++ .../shortcodes/generated/openai_configuration.html | 18 ++ flink-models/flink-model-openai/pom.xml | 6 + .../model/openai/AbstractOpenAIModelFunction.java | 76 +++++- .../model/openai/OpenAIChatModelFunction.java | 13 +- .../model/openai/OpenAIEmbeddingModelFunction.java | 10 +- .../model/openai/OpenAIModelProviderFactory.java | 3 + .../apache/flink/model/openai/OpenAIOptions.java | 28 ++ .../ModelFunctionErrorHandlingStrategyTest.java | 294 +++++++++++++++++++++ .../src/test/resources/log4j2-test.properties | 28 ++ 10 files changed, 475 insertions(+), 19 deletions(-) diff --git a/docs/layouts/shortcodes/generated/model_openai_common_section.html b/docs/layouts/shortcodes/generated/model_openai_common_section.html index e003338b58e..435be03a136 100644 --- a/docs/layouts/shortcodes/generated/model_openai_common_section.html +++ b/docs/layouts/shortcodes/generated/model_openai_common_section.html @@ -26,6 +26,12 @@ <td>String</td> <td>Full URL of the OpenAI API endpoint, e.g., <code class="highlighter-rouge">https://api.openai.com/v1/chat/completions</code> or <code class="highlighter-rouge">https://api.openai.com/v1/embeddings</code></td> </tr> + <tr> + <td><h5>error-handling-strategy</h5></td> + <td style="word-wrap: break-word;">RETRY</td> + <td><p>Enum</p></td> + <td>Strategy for handling errors during model requests.<br /><br />Possible values:<ul><li>"RETRY": Retry sending the request.</li><li>"FAILOVER": Throw exceptions and fail the Flink job.</li><li>"IGNORE": Ignore the input that caused the error and continue. The error itself would be recorded in log.</li></ul></td> + </tr> <tr> <td><h5>max-context-size</h5></td> <td style="word-wrap: break-word;">(none)</td> @@ -38,5 +44,17 @@ <td>String</td> <td>Model name, e.g., <code class="highlighter-rouge">gpt-3.5-turbo</code>, <code class="highlighter-rouge">text-embedding-ada-002</code>.</td> </tr> + <tr> + <td><h5>retry-fallback-strategy</h5></td> + <td style="word-wrap: break-word;">FAILOVER</td> + <td><p>Enum</p></td> + <td>Fallback strategy to employ if the retry attempts are exhausted. This strategy is applied when error-handling-strategy is set to retry.<br /><br />Possible values:<ul><li>"FAILOVER": Throw exceptions and fail the Flink job.</li><li>"IGNORE": Ignore the input that caused the error and continue. The error itself would be recorded in log.</li></ul></td> + </tr> + <tr> + <td><h5>retry-num</h5></td> + <td style="word-wrap: break-word;">100</td> + <td>Integer</td> + <td>Number of retry for OpenAI client requests.</td> + </tr> </tbody> </table> diff --git a/docs/layouts/shortcodes/generated/openai_configuration.html b/docs/layouts/shortcodes/generated/openai_configuration.html index 6cea54aa9fe..7a6e87e6a06 100644 --- a/docs/layouts/shortcodes/generated/openai_configuration.html +++ b/docs/layouts/shortcodes/generated/openai_configuration.html @@ -32,6 +32,12 @@ <td>String</td> <td>Full URL of the OpenAI API endpoint, e.g., <code class="highlighter-rouge">https://api.openai.com/v1/chat/completions</code> or <code class="highlighter-rouge">https://api.openai.com/v1/embeddings</code></td> </tr> + <tr> + <td><h5>error-handling-strategy</h5></td> + <td style="word-wrap: break-word;">RETRY</td> + <td><p>Enum</p></td> + <td>Strategy for handling errors during model requests.<br /><br />Possible values:<ul><li>"RETRY": Retry sending the request.</li><li>"FAILOVER": Throw exceptions and fail the Flink job.</li><li>"IGNORE": Ignore the input that caused the error and continue. The error itself would be recorded in log.</li></ul></td> + </tr> <tr> <td><h5>max-context-size</h5></td> <td style="word-wrap: break-word;">(none)</td> @@ -68,6 +74,18 @@ <td><p>Enum</p></td> <td>The format of the response, e.g., 'text' or 'json_object'.<br /><br />Possible values:<ul><li>"text"</li><li>"json_object"</li></ul></td> </tr> + <tr> + <td><h5>retry-fallback-strategy</h5></td> + <td style="word-wrap: break-word;">FAILOVER</td> + <td><p>Enum</p></td> + <td>Fallback strategy to employ if the retry attempts are exhausted. This strategy is applied when error-handling-strategy is set to retry.<br /><br />Possible values:<ul><li>"FAILOVER": Throw exceptions and fail the Flink job.</li><li>"IGNORE": Ignore the input that caused the error and continue. The error itself would be recorded in log.</li></ul></td> + </tr> + <tr> + <td><h5>retry-num</h5></td> + <td style="word-wrap: break-word;">100</td> + <td>Integer</td> + <td>Number of retry for OpenAI client requests.</td> + </tr> <tr> <td><h5>seed</h5></td> <td style="word-wrap: break-word;">(none)</td> diff --git a/flink-models/flink-model-openai/pom.xml b/flink-models/flink-model-openai/pom.xml index 984caea39d8..f251b2e0e26 100644 --- a/flink-models/flink-model-openai/pom.xml +++ b/flink-models/flink-model-openai/pom.xml @@ -127,6 +127,12 @@ under the License. <version>${project.version}</version> <scope>test</scope> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-test-utils-junit</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> </dependencies> <repositories> diff --git a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java index 1d02580e7b1..22ca44aa042 100644 --- a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java +++ b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java @@ -18,8 +18,9 @@ package org.apache.flink.model.openai; +import org.apache.flink.configuration.DescribedEnum; import org.apache.flink.configuration.ReadableConfig; -import org.apache.flink.table.api.config.ExecutionConfigOptions; +import org.apache.flink.configuration.description.InlineElement; import org.apache.flink.table.catalog.Column; import org.apache.flink.table.catalog.ResolvedSchema; import org.apache.flink.table.data.RowData; @@ -41,13 +42,17 @@ import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; +import static org.apache.flink.configuration.description.TextElement.text; + /** Abstract parent class for {@link AsyncPredictFunction}s for OpenAI API. */ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction { private static final Logger LOG = LoggerFactory.getLogger(AbstractOpenAIModelFunction.class); protected transient OpenAIClientAsync client; + private final ErrorHandlingStrategy errorHandlingStrategy; private final int numRetry; + private final RetryFallbackStrategy retryFallbackStrategy; private final String baseUrl; private final String apiKey; private final String model; @@ -59,19 +64,16 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction { String endpoint = config.get(OpenAIOptions.ENDPOINT); this.baseUrl = endpoint.replaceAll(String.format("/%s/*$", getEndpointSuffix()), ""); this.apiKey = config.get(OpenAIOptions.API_KEY); - // The model service enforces rate-limiting constraints, necessitating retry mechanisms in - // most operational scenarios. Within the asynchronous operator framework, the system is - // designed to process up to - // config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_LOOKUP_BUFFER_CAPACITY) concurrent - // requests in parallel. To mitigate potential performance degradation from simultaneous - // requests, a dynamic retry strategy is implemented where the maximum retry count is - // directly proportional to the configured parallelism level, ensuring robust error - // resilience while maintaining throughput efficiency. + + this.errorHandlingStrategy = config.get(OpenAIOptions.ERROR_HANDLING_STRATEGY); this.numRetry = - config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_LOOKUP_BUFFER_CAPACITY) * 10; + this.errorHandlingStrategy == ErrorHandlingStrategy.RETRY + ? config.get(OpenAIOptions.RETRY_NUM) + : 0; this.model = config.get(OpenAIOptions.MODEL); this.maxContextSize = config.get(OpenAIOptions.MAX_CONTEXT_SIZE); this.contextOverflowAction = config.get(OpenAIOptions.CONTEXT_OVERFLOW_ACTION); + this.retryFallbackStrategy = config.get(OpenAIOptions.RETRY_FALLBACK_STRATEGY); validateSingleColumnSchema( factoryContext.getCatalogModel().getResolvedInputSchema(), @@ -148,4 +150,58 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction { column.getDataType().getLogicalType())); } } + + protected Collection<RowData> handleErrorsAndRespond(Throwable t) { + ErrorHandlingStrategy finalErrorHandlingStrategy = + this.errorHandlingStrategy == ErrorHandlingStrategy.RETRY + ? this.retryFallbackStrategy.strategy + : this.errorHandlingStrategy; + + if (finalErrorHandlingStrategy == ErrorHandlingStrategy.FAILOVER) { + throw new RuntimeException(t); + } else if (finalErrorHandlingStrategy == ErrorHandlingStrategy.IGNORE) { + return Collections.emptyList(); + } else { + throw new UnsupportedOperationException( + "Unsupported error handling strategy: " + finalErrorHandlingStrategy); + } + } + + /** Strategy for handling errors during model requests. */ + public enum ErrorHandlingStrategy implements DescribedEnum { + RETRY("Retry sending the request."), + FAILOVER("Throw exceptions and fail the Flink job."), + IGNORE( + "Ignore the input that caused the error and continue. The error itself would be recorded in log."); + + private final String description; + + ErrorHandlingStrategy(String description) { + this.description = description; + } + + @Override + public InlineElement getDescription() { + return text(description); + } + } + + /** + * The fallback strategy for when retry attempts are exhausted. It should be identical to {@link + * ErrorHandlingStrategy} except that it does not support {@link ErrorHandlingStrategy#RETRY}. + */ + public enum RetryFallbackStrategy implements DescribedEnum { + FAILOVER(ErrorHandlingStrategy.FAILOVER), + IGNORE(ErrorHandlingStrategy.IGNORE); + private final ErrorHandlingStrategy strategy; + + RetryFallbackStrategy(ErrorHandlingStrategy strategy) { + this.strategy = strategy; + } + + @Override + public InlineElement getDescription() { + return text(strategy.description); + } + } } diff --git a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java index 3e96a4a2d9a..2a2d3cb0ae4 100644 --- a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java +++ b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java @@ -34,7 +34,6 @@ import com.openai.models.chat.completions.ChatCompletionCreateParams.ResponseFor import java.util.Arrays; import java.util.Collection; -import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; @@ -87,13 +86,15 @@ public class OpenAIChatModelFunction extends AbstractOpenAIModelFunction { .getOptional(OpenAIOptions.RESPONSE_FORMAT) .ifPresent(x -> builder.responseFormat(x.getResponseFormat())); - return client.chat() - .completions() - .create(builder.build()) - .thenApply(this::convertToRowData); + return client.chat().completions().create(builder.build()).handle(this::convertToRowData); } - private List<RowData> convertToRowData(ChatCompletion chatCompletion) { + private Collection<RowData> convertToRowData( + ChatCompletion chatCompletion, Throwable throwable) { + if (throwable != null) { + return handleErrorsAndRespond(throwable); + } + return chatCompletion.choices().stream() .map( choice -> diff --git a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java index 854ca2cf5d0..fd0e78b9ace 100644 --- a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java +++ b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java @@ -33,7 +33,6 @@ import com.openai.models.embeddings.EmbeddingCreateParams.EncodingFormat; import javax.annotation.Nullable; import java.util.Collection; -import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; @@ -73,10 +72,15 @@ public class OpenAIEmbeddingModelFunction extends AbstractOpenAIModelFunction { builder.dimensions(dimensions); } - return client.embeddings().create(builder.build()).thenApply(this::convertToRowData); + return client.embeddings().create(builder.build()).handle(this::convertToRowData); } - private List<RowData> convertToRowData(CreateEmbeddingResponse response) { + private Collection<RowData> convertToRowData( + CreateEmbeddingResponse response, Throwable throwable) { + if (throwable != null) { + return handleErrorsAndRespond(throwable); + } + return response.data().stream() .map( embedding -> diff --git a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIModelProviderFactory.java b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIModelProviderFactory.java index 4c5b340528a..12cbb662255 100644 --- a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIModelProviderFactory.java +++ b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIModelProviderFactory.java @@ -69,6 +69,9 @@ public class OpenAIModelProviderFactory implements ModelProviderFactory { Set<ConfigOption<?>> set = new HashSet<>(); set.add(OpenAIOptions.MAX_CONTEXT_SIZE); set.add(OpenAIOptions.CONTEXT_OVERFLOW_ACTION); + set.add(OpenAIOptions.ERROR_HANDLING_STRATEGY); + set.add(OpenAIOptions.RETRY_NUM); + set.add(OpenAIOptions.RETRY_FALLBACK_STRATEGY); set.add(OpenAIOptions.SYSTEM_PROMPT); set.add(OpenAIOptions.TEMPERATURE); set.add(OpenAIOptions.TOP_P); diff --git a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIOptions.java b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIOptions.java index 8e605be51e0..84cf72f5444 100644 --- a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIOptions.java +++ b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIOptions.java @@ -83,6 +83,34 @@ public class OpenAIOptions { .text("Action to handle context overflows.") .build()); + @Documentation.Section({Documentation.Sections.MODEL_OPENAI_COMMON}) + public static final ConfigOption<AbstractOpenAIModelFunction.ErrorHandlingStrategy> + ERROR_HANDLING_STRATEGY = + ConfigOptions.key("error-handling-strategy") + .enumType(AbstractOpenAIModelFunction.ErrorHandlingStrategy.class) + .defaultValue(AbstractOpenAIModelFunction.ErrorHandlingStrategy.RETRY) + .withDescription("Strategy for handling errors during model requests."); + + // The model service enforces rate-limiting constraints, necessitating retry mechanisms in + // most operational scenarios. + @Documentation.Section({Documentation.Sections.MODEL_OPENAI_COMMON}) + public static final ConfigOption<Integer> RETRY_NUM = + ConfigOptions.key("retry-num") + .intType() + .defaultValue(100) + .withDescription("Number of retry for OpenAI client requests."); + + @Documentation.Section({Documentation.Sections.MODEL_OPENAI_COMMON}) + public static final ConfigOption<AbstractOpenAIModelFunction.RetryFallbackStrategy> + RETRY_FALLBACK_STRATEGY = + ConfigOptions.key("retry-fallback-strategy") + .enumType(AbstractOpenAIModelFunction.RetryFallbackStrategy.class) + .defaultValue( + AbstractOpenAIModelFunction.RetryFallbackStrategy.FAILOVER) + .withDescription( + "Fallback strategy to employ if the retry attempts are exhausted." + + " This strategy is applied when error-handling-strategy is set to retry."); + // ------------------------------------------------------------------------ // Options for Chat Completion Model Functions // ------------------------------------------------------------------------ diff --git a/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/ModelFunctionErrorHandlingStrategyTest.java b/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/ModelFunctionErrorHandlingStrategyTest.java new file mode 100644 index 00000000000..f7d293a201a --- /dev/null +++ b/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/ModelFunctionErrorHandlingStrategyTest.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.model.openai; + +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.core.testutils.FlinkAssertions; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.TableResult; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableEnvironmentImpl; +import org.apache.flink.table.catalog.CatalogManager; +import org.apache.flink.table.catalog.CatalogModel; +import org.apache.flink.table.catalog.ObjectIdentifier; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; + +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.apache.commons.collections.IteratorUtils; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for Model Function metrics and error handling strategy. */ +@SuppressWarnings("unchecked") +public class ModelFunctionErrorHandlingStrategyTest { + private static final String RETRYABLE_INPUT_DATA = "Return a retryable error code."; + + private static final Row[] INPUT_DATA = + new Row[] { + Row.of("Why is sky blue?", 1.0), + Row.of("Why are stars shining?", 1.0), + Row.of("What is the meaning of life?", 1.0), + Row.of("What is a heart attack?", 0.0), + Row.of("How fast can human run?", 1.0), + Row.of("Why is the ocean salty?", 1.0), + Row.of("How do airplanes fly?", 1.0), + Row.of("What causes earthquakes?", 1.0), + Row.of("Why do leaves change color in autumn?", 0.0), + Row.of("How does photosynthesis work?", 1.0) + }; + + private static final Schema INPUT_SCHEMA = + Schema.newBuilder().column("input", DataTypes.STRING()).build(); + private static final Schema OUTPUT_SCHEMA = + Schema.newBuilder().column("content", DataTypes.STRING()).build(); + + private static MockWebServer server; + + private String modelName; + + private Map<String, String> modelOptions; + + private StreamTableEnvironment tEnv; + + @BeforeAll + public static void beforeAll() throws IOException { + server = new MockWebServer(); + server.setDispatcher(new TestDispatcher()); + server.start(); + } + + @AfterAll + public static void afterAll() throws IOException { + if (server != null) { + server.close(); + } + } + + @BeforeEach + public void setup() { + modelName = "Model" + System.currentTimeMillis(); + + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + DataStream<Row> source = + env.fromData(INPUT_DATA) + .returns( + Types.ROW_NAMED( + new String[] {"input", "invalid_input"}, + Types.STRING, + Types.DOUBLE)); + tEnv = StreamTableEnvironment.create(env); + Table sourceTable = + tEnv.fromDataStream( + source, + Schema.newBuilder() + .column("input", DataTypes.STRING()) + .column("invalid_input", DataTypes.DOUBLE()) + .build()); + tEnv.createTemporaryView("MyTable", sourceTable); + + modelOptions = new HashMap<>(); + modelOptions.put("provider", "openai"); + modelOptions.put("endpoint", server.url("/chat/completions").toString()); + modelOptions.put("api-key", "foobar"); + modelOptions.put("model", "qwen-turbo"); + + TestDispatcher.INVOKE_COUNT.set(0); + } + + @AfterEach + public void afterEach() { + assertThat(OpenAIUtils.getCache()).isEmpty(); + } + + @Test + public void testSuccessInvoke() { + createModel(); + TableResult tableResult = + tEnv.executeSql( + String.format( + "SELECT input, content FROM ML_PREDICT(TABLE MyTable, MODEL %s, DESCRIPTOR(`input`))", + modelName)); + List<Row> result = IteratorUtils.toList(tableResult.collect()); + assertThat(result).hasSize(INPUT_DATA.length); + for (Row row : result) { + assertThat(row.getField(0)).isInstanceOf(String.class); + assertThat(row.getField(1)).isInstanceOf(String.class); + assertThat((String) row.getFieldAs(1)).isNotEmpty(); + } + } + + @Test + public void testIgnoreError() { + // The "t" in localhost is missing on purpose to test against invalid endpoint + modelOptions.put("endpoint", "http://localhos:9999/chat/completions"); + modelOptions.put("error-handling-strategy", "ignore"); + createModel(); + TableResult tableResult = + tEnv.executeSql( + String.format( + "SELECT input, content FROM ML_PREDICT(TABLE MyTable, MODEL %s, DESCRIPTOR(`input`))", + modelName)); + List<Row> result = IteratorUtils.toList(tableResult.collect()); + assertThat(result).isEmpty(); + } + + @Test + public void testFailoverStrategy() { + modelOptions.put("error-handling-strategy", "failover"); + createModel(); + TableResult tableResult = + tEnv.executeSql( + String.format( + "WITH v(input) AS (SELECT * FROM (VALUES ('%s'))) " + + "SELECT * FROM ML_PREDICT( " + + " TABLE v, " + + " MODEL `%s`, " + + " DESCRIPTOR(`input`) " + + ")", + RETRYABLE_INPUT_DATA, modelName)); + assertThatThrownBy(() -> IteratorUtils.toList(tableResult.collect())) + .rootCause() + .satisfies( + FlinkAssertions.anyCauseMatches( + "com.openai.errors.RateLimitException: 429")); + } + + @Test + public void testRetryWithFailoverStrategy() { + modelOptions.put("retry-num", "3"); + modelOptions.put("error-handling-strategy", "retry"); + modelOptions.put("retry-fallback-strategy", "failover"); + createModel(); + TableResult tableResult = + tEnv.executeSql( + String.format( + "WITH v(input) AS (SELECT * FROM (VALUES ('%s'))) " + + "SELECT * FROM ML_PREDICT( " + + " TABLE v, " + + " MODEL `%s`, " + + " DESCRIPTOR(`input`) " + + ")", + RETRYABLE_INPUT_DATA, modelName)); + assertThatThrownBy(() -> IteratorUtils.toList(tableResult.collect())) + .rootCause() + .satisfies( + FlinkAssertions.anyCauseMatches( + "com.openai.errors.RateLimitException: 429")); + + assertThat(TestDispatcher.INVOKE_COUNT.get()).isEqualTo(4); // retryNum + 1 + } + + private void createModel() { + createModel(INPUT_SCHEMA, OUTPUT_SCHEMA); + } + + private void createModel(Schema inputSchema, Schema outputSchema) { + CatalogManager catalogManager = ((TableEnvironmentImpl) tEnv).getCatalogManager(); + ObjectIdentifier modelIdentifier = + ObjectIdentifier.of( + Objects.requireNonNull(catalogManager.getCurrentCatalog()), + Objects.requireNonNull(catalogManager.getCurrentDatabase()), + modelName); + catalogManager.createModel( + CatalogModel.of(inputSchema, outputSchema, modelOptions, "This is a new model."), + modelIdentifier, + false); + } + + private static class TestDispatcher extends Dispatcher { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final AtomicInteger INVOKE_COUNT = new AtomicInteger(0); + + @Override + public MockResponse dispatch(RecordedRequest request) { + INVOKE_COUNT.incrementAndGet(); + + String path = request.getRequestUrl().encodedPath(); + String body = request.getBody().readUtf8(); + + if (!path.endsWith("/chat/completions")) { + return new MockResponse().setResponseCode(404); + } + + try { + JsonNode root = OBJECT_MAPPER.readTree(body); + + // The messages contain one system prompt and one user message. + Preconditions.checkArgument(root.get("messages").size() == 2); + if (root.get("messages") + .get(1) + .get("content") + .toString() + .contains(RETRYABLE_INPUT_DATA)) { + return new MockResponse().setResponseCode(429); + } + + String responseBody = + "{" + + " \"id\": \"chatcmpl-1234567890ABCD\"," + + " \"object\": \"chat.completion\"," + + " \"created\": 1717029203," + + " \"model\": \"gpt-3.5-turbo-0125\"," + + " \"choices\": [{" + + " \"index\": 0," + + " \"message\": {" + + " \"role\": \"assistant\"," + + " \"content\": \"This is a mocked response\"" + + " }," + + " \"finish_reason\": \"stop\"" + + " }]," + + " \"usage\": {" + + " \"prompt_tokens\": 9," + + " \"completion_tokens\": 16," + + " \"total_tokens\": 25" + + " }" + + "}"; + + return new MockResponse() + .setHeader("Content-Type", "application/json") + .setBody(responseBody); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/flink-models/flink-model-openai/src/test/resources/log4j2-test.properties b/flink-models/flink-model-openai/src/test/resources/log4j2-test.properties new file mode 100644 index 00000000000..835c2ec9a3d --- /dev/null +++ b/flink-models/flink-model-openai/src/test/resources/log4j2-test.properties @@ -0,0 +1,28 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +rootLogger.level = OFF +rootLogger.appenderRef.test.ref = TestLogger + +appender.testlogger.name = TestLogger +appender.testlogger.type = CONSOLE +appender.testlogger.target = SYSTEM_ERR +appender.testlogger.layout.type = PatternLayout +appender.testlogger.layout.pattern = %-4r [%t] %-5p %c %x - %m%n
