This is an automated email from the ASF dual-hosted git repository.
sxnan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new e7639297bec [FLINK-38572][model] Support controlling randomness and
format of Chat Model Function (#27155)
e7639297bec is described below
commit e7639297becce21e5300ddd04a8bbd38d6fb6518
Author: yunfengzhou-hub <[email protected]>
AuthorDate: Tue Oct 28 10:09:43 2025 +0800
[FLINK-38572][model] Support controlling randomness and format of Chat
Model Function (#27155)
---
docs/content.zh/docs/connectors/models/openai.md | 36 ++++
docs/content/docs/connectors/models/openai.md | 36 ++++
.../model/openai/OpenAIChatModelFunction.java | 108 +++++++++---
.../model/openai/OpenAIModelProviderFactory.java | 4 +
.../flink/model/openai/OpenAIChatModelTest.java | 191 ++++++++++++---------
5 files changed, 265 insertions(+), 110 deletions(-)
diff --git a/docs/content.zh/docs/connectors/models/openai.md
b/docs/content.zh/docs/connectors/models/openai.md
index 6b72036b34a..1632518508e 100644
--- a/docs/content.zh/docs/connectors/models/openai.md
+++ b/docs/content.zh/docs/connectors/models/openai.md
@@ -218,6 +218,42 @@ FROM ML_PREDICT(
<td>Long</td>
<td>生成的最大token数。参考<a
href="https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens">max
tokens</a></td>
</tr>
+ <tr>
+ <td>
+ <h5>presence-penalty</h5>
+ </td>
+ <td>可选</td>
+ <td style="word-wrap: break-word;">(none)</td>
+ <td>Double</td>
+
<td>数值范围为-2.0到2.0之间。正值会根据新token是否出现在当前文本中对其进行惩罚,从而增加模型讨论新话题的可能性。</td>
+ </tr>
+ <tr>
+ <td>
+ <h5>n</h5>
+ </td>
+ <td>可选</td>
+ <td style="word-wrap: break-word;">(none)</td>
+ <td>Long</td>
+
<td>为每个输入消息生成的聊天完成选项数量。请注意,您将根据所有选项生成的token数量进行收费。为最小化成本,需将n保持为1。</td>
+ </tr>
+ <tr>
+ <td>
+ <h5>seed</h5>
+ </td>
+ <td>可选</td>
+ <td style="word-wrap: break-word;">(none)</td>
+ <td>Long</td>
+
<td>如果指定,模型平台将尽最大努力进行确定性采样,使得使用相同种子和参数的重复请求应返回相同的结果。但不保证结果一定是确定的。</td>
+ </tr>
+ <tr>
+ <td>
+ <h5>response-format</h5>
+ </td>
+ <td>可选</td>
+ <td style="word-wrap: break-word;">(none)</td>
+ <td>Enum</td>
+ <td>响应的格式,例如 'text' 或 'json_object'。</td>
+ </tr>
</tbody>
</table>
diff --git a/docs/content/docs/connectors/models/openai.md
b/docs/content/docs/connectors/models/openai.md
index 6398eb599cc..350d21387cb 100644
--- a/docs/content/docs/connectors/models/openai.md
+++ b/docs/content/docs/connectors/models/openai.md
@@ -218,6 +218,42 @@ FROM ML_PREDICT(
<td>Long</td>
<td>Maximum number of tokens to generate. See <a
href="https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens">max
tokens</a></td>
</tr>
+ <tr>
+ <td>
+ <h5>presence-penalty</h5>
+ </td>
+ <td>optional</td>
+ <td style="word-wrap: break-word;">(none)</td>
+ <td>Double</td>
+ <td>Number between -2.0 and 2.0. Positive values penalize new
tokens based on whether they appear in the text so far, increasing the model's
likelihood to talk about new topics.</td>
+ </tr>
+ <tr>
+ <td>
+ <h5>n</h5>
+ </td>
+ <td>optional</td>
+ <td style="word-wrap: break-word;">(none)</td>
+ <td>Long</td>
+ <td>How many chat completion choices to generate for each input
message. Note that you will be charged based on the number of generated tokens
across all of the choices. Keep n as 1 to minimize costs.</td>
+ </tr>
+ <tr>
+ <td>
+ <h5>seed</h5>
+ </td>
+ <td>optional</td>
+ <td style="word-wrap: break-word;">(none)</td>
+ <td>Long</td>
+ <td>If specified, the model platform will make a best effort to
sample deterministically, such that repeated requests with the same seed and
parameters should return the same result. Determinism is not guaranteed.</td>
+ </tr>
+ <tr>
+ <td>
+ <h5>response-format</h5>
+ </td>
+ <td>optional</td>
+ <td style="word-wrap: break-word;">(none)</td>
+ <td>Enum</td>
+ <td>The format of the response, e.g., 'text' or 'json_object'.</td>
+ </tr>
</tbody>
</table>
diff --git
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java
index e931d77d303..c4df81ccb88 100644
---
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java
+++
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java
@@ -19,6 +19,7 @@ package org.apache.flink.model.openai;
import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.configuration.ConfigOptions;
+import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
@@ -27,10 +28,11 @@ import
org.apache.flink.table.factories.ModelProviderFactory;
import org.apache.flink.table.functions.AsyncPredictFunction;
import org.apache.flink.table.types.logical.VarCharType;
+import com.openai.models.ResponseFormatJsonObject;
+import com.openai.models.ResponseFormatText;
import com.openai.models.chat.completions.ChatCompletion;
import com.openai.models.chat.completions.ChatCompletionCreateParams;
-
-import javax.annotation.Nullable;
+import
com.openai.models.chat.completions.ChatCompletionCreateParams.ResponseFormat;
import java.util.Arrays;
import java.util.Collection;
@@ -77,25 +79,49 @@ public class OpenAIChatModelFunction extends
AbstractOpenAIModelFunction {
.noDefaultValue()
.withDescription("Maximum number of tokens to generate.");
+ public static final ConfigOption<Double> PRESENCE_PENALTY =
+ ConfigOptions.key("presence-penalty")
+ .doubleType()
+ .noDefaultValue()
+ .withDescription(
+ "Number between -2.0 and 2.0."
+ + " Positive values penalize new tokens
based on whether they appear in the text so far,"
+ + " increasing the model's likelihood to
talk about new topics.");
+
+ public static final ConfigOption<Long> N =
+ ConfigOptions.key("n")
+ .longType()
+ .noDefaultValue()
+ .withDescription(
+ "How many chat completion choices to generate for
each input message."
+ + " Note that you will be charged based on
the number of generated tokens across all of the choices."
+ + " Keep n as 1 to minimize costs.");
+
+ public static final ConfigOption<Long> SEED =
+ ConfigOptions.key("seed")
+ .longType()
+ .noDefaultValue()
+ .withDescription(
+ "If specified, the model platform will make a best
effort to sample deterministically,"
+ + " such that repeated requests with the
same seed and parameters should return the same result."
+ + " Determinism is not guaranteed.");
+
+ public static final ConfigOption<ChatModelResponseFormat> RESPONSE_FORMAT =
+ ConfigOptions.key("response-format")
+ .enumType(ChatModelResponseFormat.class)
+ .noDefaultValue()
+ .withDescription("The format of the response, e.g., 'text'
or 'json_object'.");
+
private final String model;
private final String systemPrompt;
- @Nullable private final Double temperature;
- @Nullable private final Double topP;
- @Nullable private final List<String> stop;
- @Nullable private final Long maxTokens;
+ private final Configuration config;
public OpenAIChatModelFunction(
ModelProviderFactory.Context factoryContext, ReadableConfig
config) {
super(factoryContext, config);
model = config.get(MODEL);
systemPrompt = config.get(SYSTEM_PROMPT);
- temperature = config.get(TEMPERATURE);
- topP = config.get(TOP_P);
- stop =
- config.get(STOP) == null
- ? null
- :
Arrays.asList(config.get(STOP).split(STOP_SEPARATOR));
- maxTokens = config.get(MAX_TOKENS);
+ this.config = Configuration.fromMap(config.toMap());
validateSingleColumnSchema(
factoryContext.getCatalogModel().getResolvedOutputSchema(),
new VarCharType(VarCharType.MAX_LENGTH),
@@ -114,18 +140,18 @@ public class OpenAIChatModelFunction extends
AbstractOpenAIModelFunction {
.addSystemMessage(systemPrompt)
.addUserMessage(input)
.model(model);
- if (temperature != null) {
- builder.temperature(temperature);
- }
- if (topP != null) {
- builder.topP(topP);
- }
- if (stop != null) {
- builder.stopOfStrings(stop);
- }
- if (maxTokens != null) {
- builder.maxTokens(maxTokens);
- }
+ this.config.getOptional(TEMPERATURE).ifPresent(builder::temperature);
+ this.config.getOptional(TOP_P).ifPresent(builder::topP);
+ this.config
+ .getOptional(STOP)
+ .ifPresent(x ->
builder.stopOfStrings(Arrays.asList(x.split(STOP_SEPARATOR))));
+ this.config.getOptional(MAX_TOKENS).ifPresent(builder::maxTokens);
+
this.config.getOptional(PRESENCE_PENALTY).ifPresent(builder::presencePenalty);
+ this.config.getOptional(N).ifPresent(builder::n);
+ this.config.getOptional(SEED).ifPresent(builder::seed);
+ this.config
+ .getOptional(RESPONSE_FORMAT)
+ .ifPresent(x -> builder.responseFormat(x.getResponseFormat()));
return client.chat()
.completions()
@@ -142,4 +168,36 @@ public class OpenAIChatModelFunction extends
AbstractOpenAIModelFunction {
choice.message().content().orElse(""))))
.collect(Collectors.toList());
}
+
+ /**
+ * The response format for Chat model function. It's an Enum
representation for {@link
+ * ResponseFormat}.
+ */
+ public enum ChatModelResponseFormat {
+ TEXT("text") {
+ @Override
+ public ResponseFormat getResponseFormat() {
+ return
ResponseFormat.ofText(ResponseFormatText.builder().build());
+ }
+ },
+ JSON_OBJECT("json_object") {
+ @Override
+ public ResponseFormat getResponseFormat() {
+ return
ResponseFormat.ofJsonObject(ResponseFormatJsonObject.builder().build());
+ }
+ };
+
+ private final String value;
+
+ ChatModelResponseFormat(String value) {
+ this.value = value;
+ }
+
+ public abstract ResponseFormat getResponseFormat();
+
+ @Override
+ public String toString() {
+ return value;
+ }
+ }
}
diff --git
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIModelProviderFactory.java
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIModelProviderFactory.java
index e81ea6a17d6..6f573639ae2 100644
---
a/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIModelProviderFactory.java
+++
b/flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIModelProviderFactory.java
@@ -74,6 +74,10 @@ public class OpenAIModelProviderFactory implements
ModelProviderFactory {
set.add(OpenAIChatModelFunction.TOP_P);
set.add(OpenAIChatModelFunction.STOP);
set.add(OpenAIChatModelFunction.MAX_TOKENS);
+ set.add(OpenAIChatModelFunction.PRESENCE_PENALTY);
+ set.add(OpenAIChatModelFunction.N);
+ set.add(OpenAIChatModelFunction.SEED);
+ set.add(OpenAIChatModelFunction.RESPONSE_FORMAT);
set.add(OpenAIEmbeddingModelFunction.DIMENSION);
return set;
}
diff --git
a/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/OpenAIChatModelTest.java
b/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/OpenAIChatModelTest.java
index 7d3bda11e72..835ac2e7beb 100644
---
a/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/OpenAIChatModelTest.java
+++
b/flink-models/flink-model-openai/src/test/java/org/apache/flink/model/openai/OpenAIChatModelTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.table.catalog.CatalogManager;
import org.apache.flink.table.catalog.CatalogModel;
import org.apache.flink.table.catalog.ObjectIdentifier;
import org.apache.flink.types.Row;
+import org.apache.flink.types.variant.Variant;
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
@@ -47,6 +48,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -100,15 +102,7 @@ public class OpenAIChatModelTest {
@Test
public void testChat() {
- CatalogManager catalogManager = ((TableEnvironmentImpl)
tEnv).getCatalogManager();
- catalogManager.createModel(
- CatalogModel.of(INPUT_SCHEMA, OUTPUT_SCHEMA, modelOptions,
"This is a new model."),
- ObjectIdentifier.of(
- catalogManager.getCurrentCatalog(),
- catalogManager.getCurrentDatabase(),
- MODEL_NAME),
- false);
-
+ createModel();
TableResult tableResult =
tEnv.executeSql(
String.format(
@@ -128,16 +122,8 @@ public class OpenAIChatModelTest {
@Test
public void testMaxToken() {
int maxTokens = 20;
- CatalogManager catalogManager = ((TableEnvironmentImpl)
tEnv).getCatalogManager();
- Map<String, String> modelOptions = new HashMap<>(this.modelOptions);
modelOptions.put("max-tokens", Integer.toString(maxTokens));
- catalogManager.createModel(
- CatalogModel.of(INPUT_SCHEMA, OUTPUT_SCHEMA, modelOptions,
"This is a new model."),
- ObjectIdentifier.of(
- catalogManager.getCurrentCatalog(),
- catalogManager.getCurrentDatabase(),
- MODEL_NAME),
- false);
+ createModel();
TableResult tableResult =
tEnv.executeSql(
@@ -159,16 +145,8 @@ public class OpenAIChatModelTest {
@Test
public void testStop() {
String stop = "a,the";
- CatalogManager catalogManager = ((TableEnvironmentImpl)
tEnv).getCatalogManager();
- Map<String, String> modelOptions = new HashMap<>(this.modelOptions);
modelOptions.put("stop", stop);
- catalogManager.createModel(
- CatalogModel.of(INPUT_SCHEMA, OUTPUT_SCHEMA, modelOptions,
"This is a new model."),
- ObjectIdentifier.of(
- catalogManager.getCurrentCatalog(),
- catalogManager.getCurrentDatabase(),
- MODEL_NAME),
- false);
+ createModel();
TableResult tableResult =
tEnv.executeSql(
@@ -189,18 +167,10 @@ public class OpenAIChatModelTest {
@Test
public void testMaxContextSize() {
- CatalogManager catalogManager = ((TableEnvironmentImpl)
tEnv).getCatalogManager();
- Map<String, String> modelOptions = new HashMap<>(this.modelOptions);
modelOptions.put("model", "gpt-4");
modelOptions.put("max-context-size", "2");
modelOptions.put("context-overflow-action", "skipped");
- catalogManager.createModel(
- CatalogModel.of(INPUT_SCHEMA, OUTPUT_SCHEMA, modelOptions,
"This is a new model."),
- ObjectIdentifier.of(
- catalogManager.getCurrentCatalog(),
- catalogManager.getCurrentDatabase(),
- MODEL_NAME),
- false);
+ createModel();
TableResult tableResult =
tEnv.executeSql(
@@ -211,21 +181,62 @@ public class OpenAIChatModelTest {
assertThat(result).isEmpty();
}
+ @Test
+ public void testN() {
+ modelOptions.put("n", "2");
+ modelOptions.put("model", "qwen-plus");
+ createModel();
+ TableResult tableResult =
+ tEnv.executeSql(
+ String.format(
+ "SELECT input, content FROM ML_PREDICT(TABLE
MyTable, MODEL %s, DESCRIPTOR(`input`))",
+ MODEL_NAME));
+ List<Row> result = IteratorUtils.toList(tableResult.collect());
+ assertThat(result).hasSize(20);
+ for (Row row : result) {
+ assertThat(row.getField(0)).isInstanceOf(String.class);
+ assertThat(row.getField(1)).isInstanceOf(String.class);
+ assertThat((String) row.getFieldAs(1))
+ .isEqualTo(
+ "This is a mocked response continuation
continuation continuation continuation continuation continuation continuation
continuation continuation continuation");
+ }
+ }
+
+ @Test
+ public void testResponseFormat() {
+ modelOptions.put("response-format", "json_object");
+ modelOptions.put(
+ "system-prompt",
+ "You are a helpful assistant. Please output your response in
json format.");
+ createModel();
+ TableResult tableResult =
+ tEnv.executeSql(
+ String.format(
+ "SELECT input, content,
TRY_PARSE_JSON(content) as content_json FROM ML_PREDICT(TABLE MyTable, MODEL
%s, DESCRIPTOR(`input`))",
+ MODEL_NAME));
+ List<Row> result = IteratorUtils.toList(tableResult.collect());
+ assertThat(result).hasSize(10);
+ for (Row row : result) {
+ assertThat(row.getField(0)).isInstanceOf(String.class);
+ assertThat(row.getField(1)).isInstanceOf(String.class);
+ String content = row.getFieldAs(1);
+ assertThat(content).isNotEmpty();
+ assertThat(row.getField(2))
+ .withFailMessage("%s is not a valid json object", content)
+ .isInstanceOf(Variant.class);
+ assertThat((Variant) row.getFieldAs(2))
+ .withFailMessage("%s is not a valid json object", content)
+ .isNotNull();
+ }
+ }
+
@Test
public void testNullValue() {
tEnv.executeSql(
"CREATE TABLE MyTableWithNull(input STRING, invalid_input
DOUBLE) "
+ "WITH ( 'connector' = 'datagen', 'number-of-rows' =
'10', 'fields.input.null-rate' = '1')");
- CatalogManager catalogManager = ((TableEnvironmentImpl)
tEnv).getCatalogManager();
- catalogManager.createModel(
- CatalogModel.of(INPUT_SCHEMA, OUTPUT_SCHEMA, modelOptions,
"This is a new model."),
- ObjectIdentifier.of(
- catalogManager.getCurrentCatalog(),
- catalogManager.getCurrentDatabase(),
- MODEL_NAME),
- false);
-
+ createModel();
TableResult tableResult =
tEnv.executeSql(
String.format(
@@ -237,24 +248,10 @@ public class OpenAIChatModelTest {
@Test
public void testInvalidInputSchema() {
- CatalogManager catalogManager = ((TableEnvironmentImpl)
tEnv).getCatalogManager();
- ObjectIdentifier modelIdentifier =
- ObjectIdentifier.of(
- catalogManager.getCurrentCatalog(),
- catalogManager.getCurrentDatabase(),
- MODEL_NAME);
-
Schema inputSchemaWithInvalidColumnType =
Schema.newBuilder().column("input",
DataTypes.DOUBLE()).build();
- catalogManager.createModel(
- CatalogModel.of(
- inputSchemaWithInvalidColumnType,
- OUTPUT_SCHEMA,
- modelOptions,
- "This is a new model."),
- modelIdentifier,
- false);
+ createModel(inputSchemaWithInvalidColumnType, OUTPUT_SCHEMA);
assertThatThrownBy(
() ->
tEnv.executeSql(
@@ -268,24 +265,10 @@ public class OpenAIChatModelTest {
@Test
public void testInvalidOutputSchema() {
- CatalogManager catalogManager = ((TableEnvironmentImpl)
tEnv).getCatalogManager();
- ObjectIdentifier modelIdentifier =
- ObjectIdentifier.of(
- catalogManager.getCurrentCatalog(),
- catalogManager.getCurrentDatabase(),
- MODEL_NAME);
-
Schema outputSchemaWithInvalidColumnType =
Schema.newBuilder().column("output",
DataTypes.DOUBLE()).build();
- catalogManager.createModel(
- CatalogModel.of(
- INPUT_SCHEMA,
- outputSchemaWithInvalidColumnType,
- modelOptions,
- "This is a new model."),
- modelIdentifier,
- false);
+ createModel(INPUT_SCHEMA, outputSchemaWithInvalidColumnType);
assertThatThrownBy(
() ->
tEnv.executeSql(
@@ -297,6 +280,23 @@ public class OpenAIChatModelTest {
.hasMessageContainingAll("output", "DOUBLE", "STRING");
}
+ private void createModel() {
+ createModel(INPUT_SCHEMA, OUTPUT_SCHEMA);
+ }
+
+ private void createModel(Schema inputSchema, Schema outputSchema) {
+ CatalogManager catalogManager = ((TableEnvironmentImpl)
tEnv).getCatalogManager();
+ ObjectIdentifier modelIdentifier =
+ ObjectIdentifier.of(
+
Objects.requireNonNull(catalogManager.getCurrentCatalog()),
+
Objects.requireNonNull(catalogManager.getCurrentDatabase()),
+ MODEL_NAME);
+ catalogManager.createModel(
+ CatalogModel.of(inputSchema, outputSchema, modelOptions, "This
is a new model."),
+ modelIdentifier,
+ false);
+ }
+
private static class TestDispatcher extends Dispatcher {
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
@@ -318,6 +318,11 @@ public class OpenAIChatModelTest {
root.get("stop").forEach(node -> stop.add(node.asText()));
}
+ int n = 1;
+ if (root.has("n")) {
+ n = root.get("n").intValue();
+ }
+
StringBuilder contentBuilder = new StringBuilder("This is a
mocked response");
contentBuilder.append(" continuation".repeat(Math.max(0,
maxTokens - 6)));
for (String stopWord : stop) {
@@ -329,22 +334,38 @@ public class OpenAIChatModelTest {
}
}
+ String content = contentBuilder.toString();
+ if (root.has("response_format")
+ && "\"json_object\""
+ .equalsIgnoreCase(
+
root.get("response_format").get("type").toString())) {
+ content = "{\\\"content\\\": \\\"" + content + "\\\"}";
+ }
+
+ List<String> choices = new ArrayList<>();
+ for (int i = 0; i < n; i++) {
+ choices.add(
+ "{"
+ + " \"index\": 0,"
+ + " \"message\": {"
+ + " \"role\": \"assistant\","
+ + " \"content\": \""
+ + content
+ + "\""
+ + " },"
+ + " \"finish_reason\": \"stop\""
+ + " }");
+ }
+
String responseBody =
"{"
+ " \"id\": \"chatcmpl-1234567890ABCD\","
+ " \"object\": \"chat.completion\","
+ " \"created\": 1717029203,"
+ " \"model\": \"gpt-3.5-turbo-0125\","
- + " \"choices\": [{"
- + " \"index\": 0,"
- + " \"message\": {"
- + " \"role\": \"assistant\","
- + " \"content\": \""
- + contentBuilder
- + "\""
- + " },"
- + " \"finish_reason\": \"stop\""
- + " }],"
+ + " \"choices\": "
+ + choices
+ + ","
+ " \"usage\": {"
+ " \"prompt_tokens\": 9,"
+ " \"completion_tokens\": "