This is an automated email from the ASF dual-hosted git repository. sxnan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit c3bf06141bcbdeaccc14c06757f8e8fe453de111 Author: Yunfeng Zhou <[email protected]> AuthorDate: Tue Oct 28 18:49:51 2025 +0800 [FLINK-38581][model] Support surfacing error message This closes #27163 --- docs/content.zh/docs/connectors/models/openai.md | 10 ++ docs/content/docs/connectors/models/openai.md | 14 ++ .../model/openai/AbstractOpenAIModelFunction.java | 158 +++++++++++++++++++-- .../model/openai/OpenAIChatModelFunction.java | 29 +++- .../model/openai/OpenAIEmbeddingModelFunction.java | 33 ++++- .../ModelFunctionErrorHandlingStrategyTest.java | 37 +++++ 6 files changed, 257 insertions(+), 24 deletions(-) diff --git a/docs/content.zh/docs/connectors/models/openai.md b/docs/content.zh/docs/connectors/models/openai.md index ec6da302015..f1ab2ff0337 100644 --- a/docs/content.zh/docs/connectors/models/openai.md +++ b/docs/content.zh/docs/connectors/models/openai.md @@ -115,3 +115,13 @@ FROM ML_PREDICT( </tr> </tbody> </table> + +### 可用元数据 + +当配置 `error-handling-strategy` 为 `ignore` 时,您可以选择额外指定以下元数据列,将故障信息展示到您的输出流中。 + +* error-string(STRING):与错误相关的消息 +* http-status-code(INT):HTTP状态码 +* http-headers-map(MAP<STRING, ARRAY<STRING>>):响应返回的头部信息 + +如果您在Output Schema中定义了这些元数据列,但调用未失败,则这些列将填充为null值。 diff --git a/docs/content/docs/connectors/models/openai.md b/docs/content/docs/connectors/models/openai.md index 58c2842c7b3..b127d9928bd 100644 --- a/docs/content/docs/connectors/models/openai.md +++ b/docs/content/docs/connectors/models/openai.md @@ -94,6 +94,8 @@ FROM ML_PREDICT( ## Schema Requirement +The following table lists the schema requirement for each task. + <table class="table table-bordered"> <thead> <tr> @@ -115,3 +117,15 @@ FROM ML_PREDICT( </tr> </tbody> </table> + +### Available Metadata + +When configuring `error-handling-strategy` as `ignore`, you can choose to additionally specify the +following metadata columns to surface information about failures into your stream. + +* error-string(STRING): A message associated with the error +* http-status-code(INT): The HTTP status code +* http-headers-map(MAP<STRING, ARRAY<STRING>>): The headers returned with the response + +If you defined these metadata columns in the output schema but the call did not fail, the columns +will be filled with null values. diff --git a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java index 22ca44aa042..193da8a72e7 100644 --- a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java +++ b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java @@ -21,25 +21,41 @@ package org.apache.flink.model.openai; import org.apache.flink.configuration.DescribedEnum; import org.apache.flink.configuration.ReadableConfig; import org.apache.flink.configuration.description.InlineElement; +import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.catalog.Column; import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.data.GenericMapData; +import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.binary.BinaryStringData; import org.apache.flink.table.factories.ModelProviderFactory; import org.apache.flink.table.functions.AsyncPredictFunction; import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.VarCharType; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.Preconditions; import com.openai.client.OpenAIClientAsync; +import com.openai.core.http.Headers; +import com.openai.errors.OpenAIServiceException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nullable; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.function.Function; import java.util.stream.Collectors; import static org.apache.flink.configuration.description.TextElement.text; @@ -58,6 +74,7 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction { private final String model; @Nullable private final Integer maxContextSize; private final ContextOverflowAction contextOverflowAction; + protected final List<String> outputColumnNames; public AbstractOpenAIModelFunction( ModelProviderFactory.Context factoryContext, ReadableConfig config) { @@ -79,6 +96,9 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction { factoryContext.getCatalogModel().getResolvedInputSchema(), new VarCharType(VarCharType.MAX_LENGTH), "input"); + + this.outputColumnNames = + factoryContext.getCatalogModel().getResolvedOutputSchema().getColumnNames(); } @Override @@ -123,23 +143,19 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction { protected void validateSingleColumnSchema( ResolvedSchema schema, LogicalType expectedType, String inputOrOutput) { List<Column> columns = schema.getColumns(); - if (columns.size() != 1) { - throw new IllegalArgumentException( - String.format( - "Model should have exactly one %s column, but actually has %s columns: %s", - inputOrOutput, - columns.size(), - columns.stream().map(Column::getName).collect(Collectors.toList()))); - } - - Column column = columns.get(0); - if (!column.isPhysical()) { + List<String> physicalColumnNames = + columns.stream() + .filter(Column::isPhysical) + .map(Column::getName) + .collect(Collectors.toList()); + if (physicalColumnNames.size() != 1) { throw new IllegalArgumentException( String.format( - "%s column %s should be a physical column, but is a %s.", - inputOrOutput, column.getName(), column.getClass())); + "Model should have exactly one %s physical column, but actually has %s physical columns: %s", + inputOrOutput, physicalColumnNames.size(), physicalColumnNames)); } + Column column = schema.getColumn(physicalColumnNames.get(0)).get(); if (!expectedType.equals(column.getDataType().getLogicalType())) { throw new IllegalArgumentException( String.format( @@ -149,6 +165,33 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction { expectedType, column.getDataType().getLogicalType())); } + + List<Column> metadataColumns = + columns.stream() + .filter(x -> x instanceof Column.MetadataColumn) + .collect(Collectors.toList()); + if (!metadataColumns.isEmpty()) { + Preconditions.checkArgument( + "output".equals(inputOrOutput), "Only output schema supports metadata column"); + + for (Column metadataColumn : metadataColumns) { + ErrorMessageMetadata errorMessageMetadata = + ErrorMessageMetadata.get(metadataColumn.getName()); + Preconditions.checkNotNull( + errorMessageMetadata, + String.format( + "Unexpected metadata column %s. Supported metadata columns:\n%s", + metadataColumn.getName(), + ErrorMessageMetadata.getAllKeysAndDescriptions())); + Preconditions.checkArgument( + errorMessageMetadata.dataType.equals(metadataColumn.getDataType()), + String.format( + "Expected metadata column %s to be of type %s, but is of type %s", + metadataColumn.getName(), + errorMessageMetadata.dataType, + metadataColumn.getDataType())); + } + } } protected Collection<RowData> handleErrorsAndRespond(Throwable t) { @@ -160,7 +203,20 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction { if (finalErrorHandlingStrategy == ErrorHandlingStrategy.FAILOVER) { throw new RuntimeException(t); } else if (finalErrorHandlingStrategy == ErrorHandlingStrategy.IGNORE) { - return Collections.emptyList(); + LOG.warn( + "The input row data failed to acquire a valid response. Ignoring the input.", + t); + GenericRowData rowData = new GenericRowData(this.outputColumnNames.size()); + boolean isMetadataSet = false; + for (int i = 0; i < this.outputColumnNames.size(); i++) { + String columnName = this.outputColumnNames.get(i); + ErrorMessageMetadata errorMessageMetadata = ErrorMessageMetadata.get(columnName); + if (errorMessageMetadata != null) { + rowData.setField(i, errorMessageMetadata.converter.apply(t)); + isMetadataSet = true; + } + } + return isMetadataSet ? Collections.singletonList(rowData) : Collections.emptyList(); } else { throw new UnsupportedOperationException( "Unsupported error handling strategy: " + finalErrorHandlingStrategy); @@ -204,4 +260,78 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction { return text(strategy.description); } } + + /** + * Metadata that can be read from the output row about error messages. Referenced from Flink + * HTTP Connector's ReadableMetadata. + */ + protected enum ErrorMessageMetadata { + ERROR_STRING( + "error-string", + DataTypes.STRING(), + x -> BinaryStringData.fromString(x.getMessage()), + "A message associated with the error"), + HTTP_STATUS_CODE( + "http-status-code", + DataTypes.INT(), + e -> + ExceptionUtils.findThrowable(e, OpenAIServiceException.class) + .map(OpenAIServiceException::statusCode) + .orElse(null), + "The HTTP status code"), + HTTP_HEADERS_MAP( + "http-headers-map", + DataTypes.MAP(DataTypes.STRING(), DataTypes.ARRAY(DataTypes.STRING())), + e -> + ExceptionUtils.findThrowable(e, OpenAIServiceException.class) + .map( + e1 -> { + Map<StringData, ArrayData> map = new HashMap<>(); + Headers headers = e1.headers(); + for (String name : headers.names()) { + map.put( + BinaryStringData.fromString(name), + new GenericArrayData( + headers.values(name).stream() + .map( + BinaryStringData + ::fromString) + .toArray())); + } + return new GenericMapData(map); + }) + .orElse(null), + "The headers returned with the response"); + + final String key; + final DataType dataType; + final Function<Throwable, Object> converter; + final String description; + + ErrorMessageMetadata( + String key, + DataType dataType, + Function<Throwable, Object> converter, + String description) { + this.key = key; + this.dataType = dataType; + this.converter = converter; + this.description = description; + } + + static @Nullable ErrorMessageMetadata get(String key) { + for (ErrorMessageMetadata value : values()) { + if (value.key.equals(key)) { + return value; + } + } + return null; + } + + static String getAllKeysAndDescriptions() { + return Arrays.stream(values()) + .map(value -> value.key + ":\t" + value.description) + .collect(Collectors.joining("\n")); + } + } } diff --git a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java index 2a2d3cb0ae4..ce426f08dd6 100644 --- a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java +++ b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java @@ -48,6 +48,7 @@ public class OpenAIChatModelFunction extends AbstractOpenAIModelFunction { private final String model; private final String systemPrompt; private final Configuration config; + private final int outputColumnIndex; public OpenAIChatModelFunction( ModelProviderFactory.Context factoryContext, ReadableConfig config) { @@ -59,6 +60,21 @@ public class OpenAIChatModelFunction extends AbstractOpenAIModelFunction { factoryContext.getCatalogModel().getResolvedOutputSchema(), new VarCharType(VarCharType.MAX_LENGTH), "output"); + this.outputColumnIndex = getOutputColumnIndex(); + } + + private int getOutputColumnIndex() { + for (int i = 0; i < this.outputColumnNames.size(); i++) { + String columnName = this.outputColumnNames.get(i); + if (ErrorMessageMetadata.get(columnName) == null) { + // Prior checks have guaranteed that there is one and only one physical output + // column. + return i; + } + } + throw new IllegalArgumentException( + "There should be one and only one physical output column. Actual columns: " + + this.outputColumnNames); } @Override @@ -97,10 +113,15 @@ public class OpenAIChatModelFunction extends AbstractOpenAIModelFunction { return chatCompletion.choices().stream() .map( - choice -> - GenericRowData.of( - BinaryStringData.fromString( - choice.message().content().orElse("")))) + choice -> { + GenericRowData rowData = + new GenericRowData(this.outputColumnNames.size()); + rowData.setField( + this.outputColumnIndex, + BinaryStringData.fromString( + choice.message().content().orElse(""))); + return rowData; + }) .collect(Collectors.toList()); } diff --git a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java index fd0e78b9ace..dbe47fa69af 100644 --- a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java +++ b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIEmbeddingModelFunction.java @@ -44,6 +44,7 @@ public class OpenAIEmbeddingModelFunction extends AbstractOpenAIModelFunction { private final String model; @Nullable private final Long dimensions; + private final int outputColumnIndex; public OpenAIEmbeddingModelFunction( ModelProviderFactory.Context factoryContext, ReadableConfig config) { @@ -55,6 +56,21 @@ public class OpenAIEmbeddingModelFunction extends AbstractOpenAIModelFunction { factoryContext.getCatalogModel().getResolvedOutputSchema(), new ArrayType(new FloatType()), "output"); + this.outputColumnIndex = getOutputColumnIndex(); + } + + private int getOutputColumnIndex() { + for (int i = 0; i < this.outputColumnNames.size(); i++) { + String columnName = this.outputColumnNames.get(i); + if (ErrorMessageMetadata.get(columnName) == null) { + // Prior checks have guaranteed that there is one and only one physical output + // column. + return i; + } + } + throw new IllegalArgumentException( + "There should be one and only one physical output column. Actual columns: " + + this.outputColumnNames); } @Override @@ -83,12 +99,17 @@ public class OpenAIEmbeddingModelFunction extends AbstractOpenAIModelFunction { return response.data().stream() .map( - embedding -> - GenericRowData.of( - new GenericArrayData( - embedding.embedding().stream() - .map(Double::floatValue) - .toArray(Float[]::new)))) + embedding -> { + GenericRowData rowData = + new GenericRowData(this.outputColumnNames.size()); + rowData.setField( + outputColumnIndex, + new GenericArrayData( + embedding.embedding().stream() + .map(Double::floatValue) + .toArray(Float[]::new))); + return rowData; + }) .collect(Collectors.toList()); } } diff --git a/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/ModelFunctionErrorHandlingStrategyTest.java b/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/ModelFunctionErrorHandlingStrategyTest.java index f7d293a201a..066048fa2e3 100644 --- a/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/ModelFunctionErrorHandlingStrategyTest.java +++ b/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/ModelFunctionErrorHandlingStrategyTest.java @@ -171,6 +171,43 @@ public class ModelFunctionErrorHandlingStrategyTest { assertThat(result).isEmpty(); } + @Test + public void testIgnoreAndSurfaceError() { + modelOptions.put("error-handling-strategy", "ignore"); + Schema outputSchemaWithErrorMessage = + Schema.newBuilder() + .columnByMetadata("error-string", DataTypes.STRING()) + .column("content", DataTypes.STRING()) + .columnByMetadata("http-status-code", DataTypes.INT()) + .columnByMetadata( + "http-headers-map", + DataTypes.MAP( + DataTypes.STRING(), DataTypes.ARRAY(DataTypes.STRING()))) + .build(); + + createModel(INPUT_SCHEMA, outputSchemaWithErrorMessage); + TableResult tableResult = + tEnv.executeSql( + String.format( + "WITH v(input) AS (SELECT * FROM (VALUES ('%s'))) " + + "SELECT * FROM ML_PREDICT( " + + " TABLE v, " + + " MODEL `%s`, " + + " DESCRIPTOR(`input`) " + + ")", + RETRYABLE_INPUT_DATA, modelName)); + List<Row> result = IteratorUtils.toList(tableResult.collect()); + assertThat(result).hasSize(1); + assertThat(result.get(0).getArity()).isEqualTo(5); + assertThat((String) result.get(0).getFieldAs(0)).isEqualTo(RETRYABLE_INPUT_DATA); + assertThat((String) result.get(0).getFieldAs(1)) + .isEqualTo("com.openai.errors.RateLimitException: 429: null"); + assertThat(result.get(0).getField(2)).isNull(); + assertThat((Integer) result.get(0).getFieldAs(3)).isEqualTo(429); + assertThat((Map<String, String[]>) result.get(0).getFieldAs(4)) + .containsEntry("Content-Length", new String[] {"0"}); + } + @Test public void testFailoverStrategy() { modelOptions.put("error-handling-strategy", "failover");
