This is an automated email from the ASF dual-hosted git repository.
shengkai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new f14fcd30ca3 [FLINK-38104][table] Add table api support for ML_PREDICT
(#27108)
f14fcd30ca3 is described below
commit f14fcd30ca3c100c4e7cb0660a7abd045a9947d4
Author: Hao Li <[email protected]>
AuthorDate: Sat Nov 8 19:37:18 2025 -0800
[FLINK-38104][table] Add table api support for ML_PREDICT (#27108)
---
.../table/tests/test_environment_completeness.py | 4 +-
.../tests/test_table_environment_completeness.py | 1 +
.../java/org/apache/flink/table/api/Model.java | 146 ++++++++++++++++
.../apache/flink/table/api/TableEnvironment.java | 43 +++++
.../apache/flink/table/api/internal/ModelImpl.java | 120 ++++++++++++++
.../table/api/internal/TableEnvironmentImpl.java | 42 +++++
.../flink/table/catalog/ContextResolvedModel.java | 60 ++++++-
.../table/expressions/ApiExpressionUtils.java | 7 +
.../table/expressions/ApiExpressionVisitor.java | 4 +
.../expressions/ModelReferenceExpression.java | 153 +++++++++++++++++
.../expressions/ResolvedExpressionVisitor.java | 4 +
.../resolver/rules/ResolveCallByArgumentsRule.java | 37 +++++
.../utils/ApiExpressionDefaultVisitor.java | 6 +
.../utils/ResolvedExpressionDefaultVisitor.java | 5 +
.../utils/OperationExpressionsUtils.java | 6 +
.../test/program/FailingTableApiTestStep.java | 140 ++++++++++++++++
.../flink/table/test/program/TableApiTestStep.java | 18 ++
.../flink/table/test/program/TableTestProgram.java | 17 ++
.../apache/flink/table/test/program/TestStep.java | 3 +-
.../java/org/apache/flink/types/ColumnList.java | 4 +
.../expressions/converter/ExpressionConverter.java | 57 +++++++
.../planner/plan/QueryOperationConverter.java | 1 -
.../StreamPhysicalMLPredictTableFunction.java | 4 +
.../StreamNonDeterministicUpdatePlanVisitor.java | 14 ++
.../table/planner/plan/utils/FunctionCallUtil.java | 44 ++++-
.../table/api/QueryOperationSqlSemanticTest.java | 12 +-
.../api/QueryOperationSqlSerializationTest.java | 7 +-
.../table/api/QueryOperationTestPrograms.java | 114 +++++++++++++
.../nodes/exec/stream/MLPredictSemanticTests.java | 46 ++++++
.../nodes/exec/stream/MLPredictTestPrograms.java | 184 +++++++++++++++++++--
.../nodes/exec/testutils/SemanticTestBase.java | 23 ++-
.../stream/sql/MLPredictTableFunctionTest.java | 2 +-
.../runtime/stream/table/MLPredictITCase.java | 20 +++
.../flink/table/api/TableEnvironmentTest.scala | 24 ++-
.../flink/table/planner/utils/TableTestBase.scala | 2 +-
35 files changed, 1340 insertions(+), 34 deletions(-)
diff --git a/flink-python/pyflink/table/tests/test_environment_completeness.py
b/flink-python/pyflink/table/tests/test_environment_completeness.py
index d1305c7ef6a..ab9c69a0e0a 100644
--- a/flink-python/pyflink/table/tests/test_environment_completeness.py
+++ b/flink-python/pyflink/table/tests/test_environment_completeness.py
@@ -16,8 +16,9 @@
# limitations under the License.
################################################################################
-from pyflink.testing.test_case_utils import PythonAPICompletenessTestCase,
PyFlinkTestCase
from pyflink.table import TableEnvironment
+from pyflink.testing.test_case_utils import PythonAPICompletenessTestCase, \
+ PyFlinkTestCase
class EnvironmentAPICompletenessTests(PythonAPICompletenessTestCase,
PyFlinkTestCase):
@@ -40,6 +41,7 @@ class
EnvironmentAPICompletenessTests(PythonAPICompletenessTestCase, PyFlinkTest
'getCompletionHints',
'fromValues',
'fromCall',
+ 'fromModel',
# See FLINK-25986
'loadPlan',
'compilePlanSql',
diff --git
a/flink-python/pyflink/table/tests/test_table_environment_completeness.py
b/flink-python/pyflink/table/tests/test_table_environment_completeness.py
index 48c6369bae6..2e0cc5c2b39 100644
--- a/flink-python/pyflink/table/tests/test_table_environment_completeness.py
+++ b/flink-python/pyflink/table/tests/test_table_environment_completeness.py
@@ -44,6 +44,7 @@ class
TableEnvironmentAPICompletenessTests(PythonAPICompletenessTestCase, PyFlin
"from",
"registerFunction",
"fromCall",
+ "fromModel",
}
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Model.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Model.java
new file mode 100644
index 00000000000..dd27940519f
--- /dev/null
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Model.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.table.api.config.MLPredictRuntimeConfigOptions;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.types.ColumnList;
+
+import java.util.Map;
+
+/**
+ * The {@link Model} object is the core abstraction for ML model resources in
the Table API.
+ *
+ * <p>A {@link Model} object describes a machine learning model resource that
can be used for
+ * inference operations. It provides methods to perform prediction on data
tables.
+ *
+ * <p>The {@link Model} interface offers main operations:
+ *
+ * <ul>
+ * <li>{@link #predict(Table, ColumnList)} - Applies the model to make
predictions on input data
+ * </ul>
+ *
+ * <p>{@code ml_predict} operation supports runtime options for configuring
execution parameters
+ * such as asynchronous execution mode.
+ *
+ * <p>Every {@link Model} object has input and output schemas that describe
the expected data
+ * structure for model operations, available through {@link
#getResolvedInputSchema()} and {@link
+ * #getResolvedOutputSchema()}.
+ *
+ * <p>Example usage:
+ *
+ * <pre>{@code
+ * Model model = tableEnv.fromModel("my_model");
+ *
+ * // Simple prediction
+ * Table predictions = model.predict(inputTable, ColumnList.of("feature1",
"feature2"));
+ *
+ * // Prediction with options
+ * Map<String, String> options = Map.of("max-concurrent-operations", "100",
"timeout", "30s", "async", "true");
+ * Table predictions = model.predict(inputTable, ColumnList.of("feature1",
"feature2"), options);
+ * }</pre>
+ */
+@PublicEvolving
+public interface Model {
+
+ /**
+ * Returns the resolved input schema of this model.
+ *
+ * <p>The input schema describes the structure and data types of the input
columns that the
+ * model expects for inference operations.
+ *
+ * @return the resolved input schema.
+ */
+ ResolvedSchema getResolvedInputSchema();
+
+ /**
+ * Returns the resolved output schema of this model.
+ *
+ * <p>The output schema describes the structure and data types of the
output columns that the
+ * model produces during inference operations.
+ *
+ * @return the resolved output schema.
+ */
+ ResolvedSchema getResolvedOutputSchema();
+
+ /**
+ * Performs prediction on the given table using specified input columns.
+ *
+ * <p>This method applies the model to the input data to generate
predictions. The input columns
+ * must match the model's expected input schema.
+ *
+ * <p>Example:
+ *
+ * <pre>{@code
+ * Table predictions = model.predict(inputTable, ColumnList.of("feature1",
"feature2"));
+ * }</pre>
+ *
+ * @param table the input table containing data for prediction
+ * @param inputColumns the columns from the input table to use as model
input
+ * @return a table containing the input data along with prediction results
+ */
+ Table predict(Table table, ColumnList inputColumns);
+
+ /**
+ * Performs prediction on the given table using specified input columns
with runtime options.
+ *
+ * <p>This method applies the model to the input data to generate
predictions with additional
+ * runtime configuration options such as max-concurrent-operations,
timeout, and execution mode
+ * settings.
+ *
+ * <p>For Common runtime options, see {@link
MLPredictRuntimeConfigOptions}.
+ *
+ * <p>Example:
+ *
+ * <pre>{@code
+ * Map<String, String> options = Map.of("max-concurrent-operations",
"100", "timeout", "30s", "async", "true");
+ * Table predictions = model.predict(inputTable,
+ * ColumnList.of("feature1", "feature2"), options);
+ * }</pre>
+ *
+ * @param table the input table containing data for prediction
+ * @param inputColumns the columns from the input table to use as model
input
+ * @param options runtime options for configuring the prediction operation
+ * @return a table containing the input data along with prediction results
+ */
+ Table predict(Table table, ColumnList inputColumns, Map<String, String>
options);
+
+ /**
+ * Converts this model object into a named argument.
+ *
+ * <p>This method is intended for use in function calls that accept model
arguments,
+ * particularly in process table functions (PTFs) or other operations that
work with models.
+ *
+ * <p>Example:
+ *
+ * <pre>{@code
+ * env.fromCall(
+ * "ML_PREDICT",
+ * inputTable.asArgument("INPUT"),
+ * model.asArgument("MODEL"),
+ * Expressions.descriptor(ColumnList.of("feature1",
"feature2")).asArgument("ARGS")
+ * )
+ * }</pre>
+ *
+ * @param name the name to assign to this model argument
+ * @return an expression that can be passed to functions expecting model
arguments
+ */
+ ApiExpression asArgument(String name);
+}
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/TableEnvironment.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/TableEnvironment.java
index dada96ae619..5f6c584cced 100644
---
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/TableEnvironment.java
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/TableEnvironment.java
@@ -1175,6 +1175,49 @@ public interface TableEnvironment {
*/
Table fromCall(Class<? extends UserDefinedFunction> function, Object...
arguments);
+ /**
+ * Returns a {@link Model} object that is backed by the specified model
path.
+ *
+ * <p>This method creates a {@link Model} object from a given model path
in the catalog. The
+ * model path can be fully or partially qualified (e.g.,
"catalog.db.model" or just "model"),
+ * depending on the current catalog and database context.
+ *
+ * <p>The returned {@link Model} object can be used for further
transformations or as input to
+ * other operations in the Table API.
+ *
+ * <p>Example:
+ *
+ * <pre>{@code
+ * Model model = tableEnv.fromModel("my_model");
+ * }</pre>
+ *
+ * @param modelPath The path of the model in the catalog.
+ * @return The {@link Model} object describing the model resource.
+ */
+ Model fromModel(String modelPath);
+
+ /**
+ * Returns a {@link Model} object that is backed by the specified {@link
ModelDescriptor}.
+ *
+ * <p>This method creates a {@link Model} object using the provided {@link
ModelDescriptor},
+ * which contains the necessary information to identify and configure the
model resource in the
+ * catalog.
+ *
+ * <p>The returned {@link Model} object can be used for further
transformations or as input to
+ * other operations in the Table API.
+ *
+ * <p>Example:
+ *
+ * <pre>{@code
+ * ModelDescriptor descriptor = ...;
+ * Model model = tableEnv.from(descriptor);
+ * }</pre>
+ *
+ * @param descriptor The {@link ModelDescriptor} describing the model
resource.
+ * @return The {@link Model} object representing the model resource.
+ */
+ Model fromModel(ModelDescriptor descriptor);
+
/**
* Gets the names of all catalogs registered in this environment.
*
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/ModelImpl.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/ModelImpl.java
new file mode 100644
index 00000000000..5ff566a326d
--- /dev/null
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/ModelImpl.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.ApiExpression;
+import org.apache.flink.table.api.Expressions;
+import org.apache.flink.table.api.Model;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.TableEnvironment;
+import org.apache.flink.table.catalog.ContextResolvedModel;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.expressions.ApiExpressionUtils;
+import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
+import org.apache.flink.types.ColumnList;
+
+import java.util.ArrayList;
+import java.util.Map;
+
+import static org.apache.flink.table.api.Expressions.lit;
+import static
org.apache.flink.table.expressions.ApiExpressionUtils.valueLiteral;
+
+/** Implementation of {@link Model} that works with the Table API. */
+@Internal
+public class ModelImpl implements Model {
+
+ private final TableEnvironmentInternal tableEnvironment;
+ private final ContextResolvedModel model;
+
+ private ModelImpl(TableEnvironmentInternal tableEnvironment,
ContextResolvedModel model) {
+ this.tableEnvironment = tableEnvironment;
+ this.model = model;
+ }
+
+ public static ModelImpl createModel(
+ TableEnvironmentInternal tableEnvironment, ContextResolvedModel
model) {
+ return new ModelImpl(tableEnvironment, model);
+ }
+
+ public ContextResolvedModel getModel() {
+ return model;
+ }
+
+ @Override
+ public ResolvedSchema getResolvedInputSchema() {
+ return model.getResolvedModel().getResolvedInputSchema();
+ }
+
+ @Override
+ public ResolvedSchema getResolvedOutputSchema() {
+ return model.getResolvedModel().getResolvedOutputSchema();
+ }
+
+ public TableEnvironment getTableEnv() {
+ return tableEnvironment;
+ }
+
+ @Override
+ public Table predict(Table table, ColumnList inputColumns) {
+ return predict(table, inputColumns, Map.of());
+ }
+
+ @Override
+ public Table predict(Table table, ColumnList inputColumns, Map<String,
String> options) {
+ // Use Expressions.map() instead of Expressions.lit() to create a MAP
literal since
+ // lit() is not serializable to sql.
+ if (options.isEmpty()) {
+ return tableEnvironment.fromCall(
+ BuiltInFunctionDefinitions.ML_PREDICT.getName(),
+ table.asArgument("INPUT"),
+ this.asArgument("MODEL"),
+ new
ApiExpression(valueLiteral(inputColumns)).asArgument("ARGS"));
+ }
+ ArrayList<String> configKVs = new ArrayList<>();
+ options.forEach(
+ (k, v) -> {
+ configKVs.add(k);
+ configKVs.add(v);
+ });
+ return tableEnvironment.fromCall(
+ BuiltInFunctionDefinitions.ML_PREDICT.getName(),
+ table.asArgument("INPUT"),
+ this.asArgument("MODEL"),
+ new
ApiExpression(valueLiteral(inputColumns)).asArgument("ARGS"),
+ Expressions.map(
+ configKVs.get(0),
+ configKVs.get(1),
+ configKVs.subList(2,
configKVs.size()).toArray())
+ .asArgument("CONFIG"));
+ }
+
+ @Override
+ public ApiExpression asArgument(String name) {
+ return new ApiExpression(
+ ApiExpressionUtils.unresolvedCall(
+ BuiltInFunctionDefinitions.ASSIGNMENT,
+ lit(name),
+ ApiExpressionUtils.modelRef(name, this)));
+ }
+
+ public TableEnvironment getTableEnvironment() {
+ return tableEnvironment;
+ }
+}
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableEnvironmentImpl.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableEnvironmentImpl.java
index 058f7a92be9..6a40fa9c238 100644
---
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableEnvironmentImpl.java
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableEnvironmentImpl.java
@@ -33,6 +33,7 @@ import org.apache.flink.table.api.ExplainDetail;
import org.apache.flink.table.api.ExplainFormat;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.FunctionDescriptor;
+import org.apache.flink.table.api.Model;
import org.apache.flink.table.api.ModelDescriptor;
import org.apache.flink.table.api.PlanReference;
import org.apache.flink.table.api.ResultKind;
@@ -55,12 +56,14 @@ import org.apache.flink.table.catalog.CatalogManager;
import org.apache.flink.table.catalog.CatalogStore;
import org.apache.flink.table.catalog.CatalogStoreHolder;
import org.apache.flink.table.catalog.Column;
+import org.apache.flink.table.catalog.ContextResolvedModel;
import org.apache.flink.table.catalog.ContextResolvedTable;
import org.apache.flink.table.catalog.FunctionCatalog;
import org.apache.flink.table.catalog.FunctionLanguage;
import org.apache.flink.table.catalog.GenericInMemoryCatalog;
import org.apache.flink.table.catalog.ObjectIdentifier;
import org.apache.flink.table.catalog.QueryOperationCatalogView;
+import org.apache.flink.table.catalog.ResolvedCatalogModel;
import org.apache.flink.table.catalog.ResolvedCatalogTable;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.table.catalog.StagedTable;
@@ -77,6 +80,7 @@ import
org.apache.flink.table.execution.StagingSinkJobStatusHook;
import org.apache.flink.table.expressions.ApiExpressionUtils;
import org.apache.flink.table.expressions.DefaultSqlFactory;
import org.apache.flink.table.expressions.Expression;
+import org.apache.flink.table.expressions.ModelReferenceExpression;
import org.apache.flink.table.expressions.TableReferenceExpression;
import org.apache.flink.table.expressions.utils.ApiExpressionDefaultVisitor;
import org.apache.flink.table.factories.CatalogStoreFactory;
@@ -682,6 +686,29 @@ public class TableEnvironmentImpl implements
TableEnvironmentInternal {
operationTreeBuilder.tableFunction(Expressions.call(function,
arguments)));
}
+ @Override
+ public Model fromModel(String modelPath) {
+ UnresolvedIdentifier unresolvedIdentifier =
getParser().parseIdentifier(modelPath);
+ ObjectIdentifier modelIdentifier =
catalogManager.qualifyIdentifier(unresolvedIdentifier);
+ return catalogManager
+ .getModel(modelIdentifier)
+ .map(this::createModel)
+ .orElseThrow(
+ () ->
+ new ValidationException(
+ String.format(
+ "Model %s was not found.",
unresolvedIdentifier)));
+ }
+
+ @Override
+ public Model fromModel(ModelDescriptor descriptor) {
+ Preconditions.checkNotNull(descriptor, "Model descriptor must not be
null.");
+
+ final ResolvedCatalogModel resolvedCatalogModel =
+
catalogManager.resolveCatalogModel(descriptor.toCatalogModel());
+ return
createModel(ContextResolvedModel.anonymous(resolvedCatalogModel));
+ }
+
private Optional<SourceQueryOperation> scanInternal(UnresolvedIdentifier
identifier) {
ObjectIdentifier tableIdentifier =
catalogManager.qualifyIdentifier(identifier);
@@ -1487,6 +1514,10 @@ public class TableEnvironmentImpl implements
TableEnvironmentInternal {
functionCatalog.asLookup(getParser()::parseIdentifier));
}
+ public ModelImpl createModel(ContextResolvedModel model) {
+ return ModelImpl.createModel(this, model);
+ }
+
@Override
public String explainPlan(InternalPlan compiledPlan, ExplainDetail...
extraDetails) {
return planner.explainPlan(compiledPlan, extraDetails);
@@ -1531,5 +1562,16 @@ public class TableEnvironmentImpl implements
TableEnvironmentInternal {
}
return null;
}
+
+ @Override
+ public Void visit(ModelReferenceExpression modelRef) {
+ super.visit(modelRef);
+ if (modelRef.getTableEnvironment() != null
+ && modelRef.getTableEnvironment() !=
TableEnvironmentImpl.this) {
+ throw new ValidationException(
+ "All model references must use the same
TableEnvironment.");
+ }
+ return null;
+ }
}
}
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ContextResolvedModel.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ContextResolvedModel.java
index dbfba0b7a58..74854135de7 100644
---
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ContextResolvedModel.java
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ContextResolvedModel.java
@@ -19,12 +19,14 @@
package org.apache.flink.table.catalog;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.factories.FactoryUtil;
import org.apache.flink.util.Preconditions;
import javax.annotation.Nullable;
import java.util.Objects;
import java.util.Optional;
+import java.util.concurrent.atomic.AtomicInteger;
/**
* This class contains information about a model and its relationship with a
{@link Catalog}, if
@@ -48,28 +50,46 @@ import java.util.Optional;
@Internal
public final class ContextResolvedModel {
+ private static final AtomicInteger uniqueId = new AtomicInteger(0);
private final ObjectIdentifier objectIdentifier;
private final @Nullable Catalog catalog;
private final ResolvedCatalogModel resolvedModel;
+ private final boolean anonymous;
public static ContextResolvedModel permanent(
ObjectIdentifier identifier, Catalog catalog, ResolvedCatalogModel
resolvedModel) {
return new ContextResolvedModel(
- identifier, Preconditions.checkNotNull(catalog),
resolvedModel);
+ identifier, Preconditions.checkNotNull(catalog),
resolvedModel, false);
}
public static ContextResolvedModel temporary(
ObjectIdentifier identifier, ResolvedCatalogModel resolvedModel) {
- return new ContextResolvedModel(identifier, null, resolvedModel);
+ return new ContextResolvedModel(identifier, null, resolvedModel,
false);
+ }
+
+ public static ContextResolvedModel anonymous(ResolvedCatalogModel
resolvedModel) {
+ return anonymous(null, resolvedModel);
+ }
+
+ public static ContextResolvedModel anonymous(
+ @Nullable String hint, ResolvedCatalogModel resolvedModel) {
+ return new ContextResolvedModel(
+ ObjectIdentifier.ofAnonymous(
+ generateAnonymousStringIdentifier(hint,
resolvedModel)),
+ null,
+ resolvedModel,
+ true);
}
private ContextResolvedModel(
ObjectIdentifier objectIdentifier,
@Nullable Catalog catalog,
- ResolvedCatalogModel resolvedModel) {
+ ResolvedCatalogModel resolvedModel,
+ boolean anonymous) {
this.objectIdentifier = Preconditions.checkNotNull(objectIdentifier);
this.catalog = catalog;
this.resolvedModel = Preconditions.checkNotNull(resolvedModel);
+ this.anonymous = anonymous;
}
/**
@@ -83,6 +103,10 @@ public final class ContextResolvedModel {
return !isTemporary();
}
+ public boolean isAnonymous() {
+ return anonymous;
+ }
+
public ObjectIdentifier getIdentifier() {
return objectIdentifier;
}
@@ -116,13 +140,39 @@ public final class ContextResolvedModel {
return false;
}
ContextResolvedModel that = (ContextResolvedModel) o;
- return Objects.equals(objectIdentifier, that.objectIdentifier)
+ return anonymous == that.anonymous
+ && Objects.equals(objectIdentifier, that.objectIdentifier)
&& Objects.equals(catalog, that.catalog)
&& Objects.equals(resolvedModel, that.resolvedModel);
}
@Override
public int hashCode() {
- return Objects.hash(objectIdentifier, catalog, resolvedModel);
+ return Objects.hash(objectIdentifier, catalog, resolvedModel,
anonymous);
+ }
+
+ /**
+ * This method tries to return the provider name of the model, trying to
provide a bit more
+ * helpful toString for anonymous models. It's only to help users to
debug, and its return value
+ * should not be relied on.
+ */
+ private static String generateAnonymousStringIdentifier(
+ @Nullable String hint, ResolvedCatalogModel resolvedModel) {
+ // Planner can do some fancy optimizations' logic squashing two
sources together in the same
+ // operator. Because this logic is string based, anonymous models
still need some kind of
+ // unique string based identifier that can be used later by the
planner.
+ if (hint == null) {
+ try {
+ hint =
resolvedModel.getOptions().get(FactoryUtil.PROVIDER.key());
+ } catch (Exception ignored) {
+ }
+ }
+
+ int id = uniqueId.incrementAndGet();
+ if (hint == null) {
+ return "*anonymous$" + id + "*";
+ }
+
+ return "*anonymous_" + hint + "$" + id + "*";
}
}
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java
index 8748fafceb7..bcad8cccfce 100644
---
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionUtils.java
@@ -21,9 +21,11 @@ package org.apache.flink.table.expressions;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.ApiExpression;
import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Model;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.api.internal.ModelImpl;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.catalog.ContextResolvedFunction;
import org.apache.flink.table.functions.BuiltInFunctionDefinition;
@@ -301,6 +303,11 @@ public final class ApiExpressionUtils {
return new TableReferenceExpression(name, queryOperation, env);
}
+ public static ModelReferenceExpression modelRef(String name, Model model) {
+ return new ModelReferenceExpression(
+ name, ((ModelImpl) model).getModel(), ((ModelImpl)
model).getTableEnvironment());
+ }
+
public static LookupCallExpression lookupCall(String name, Expression...
args) {
return new LookupCallExpression(
name,
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionVisitor.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionVisitor.java
index f0b34bbddfa..b64b77d24fb 100644
---
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionVisitor.java
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ApiExpressionVisitor.java
@@ -29,6 +29,8 @@ public abstract class ApiExpressionVisitor<R> implements
ExpressionVisitor<R> {
return visit((UnresolvedReferenceExpression) other);
} else if (other instanceof TableReferenceExpression) {
return visit((TableReferenceExpression) other);
+ } else if (other instanceof ModelReferenceExpression) {
+ return visit((ModelReferenceExpression) other);
} else if (other instanceof LocalReferenceExpression) {
return visit((LocalReferenceExpression) other);
} else if (other instanceof LookupCallExpression) {
@@ -49,6 +51,8 @@ public abstract class ApiExpressionVisitor<R> implements
ExpressionVisitor<R> {
public abstract R visit(TableReferenceExpression tableReference);
+ public abstract R visit(ModelReferenceExpression modelReferenceExpression);
+
public abstract R visit(LocalReferenceExpression localReference);
/** For resolved expressions created by the planner. */
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ModelReferenceExpression.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ModelReferenceExpression.java
new file mode 100644
index 00000000000..01f88cbcfb2
--- /dev/null
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ModelReferenceExpression.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.expressions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.Model;
+import org.apache.flink.table.api.TableEnvironment;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.catalog.ContextResolvedModel;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * A reference to a {@link Model} in an expression context.
+ *
+ * <p>This expression is used when a model needs to be passed as an argument
to functions or
+ * operations that accept model references. It wraps a model object and
provides the necessary
+ * expression interface for use in the Table API expression system.
+ *
+ * <p>The expression carries a string representation of the model and uses a
special data type to
+ * indicate that this is a model reference rather than a regular data value.
+ */
+@Internal
+public final class ModelReferenceExpression implements ResolvedExpression {
+
+ private final String name;
+ private final ContextResolvedModel model;
+ // The environment is optional but serves validation purposes
+ // to ensure that all referenced tables belong to the same
+ // environment.
+ private final TableEnvironment env;
+
+ public ModelReferenceExpression(String name, ContextResolvedModel model,
TableEnvironment env) {
+ this.name = Preconditions.checkNotNull(name);
+ this.model = Preconditions.checkNotNull(model);
+ this.env = Preconditions.checkNotNull(env);
+ }
+
+ /**
+ * Returns the name of this model reference.
+ *
+ * @return the model reference name
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * Returns the ContextResolvedModel associated with this model reference.
+ *
+ * @return the query context resolved model
+ */
+ public ContextResolvedModel getModel() {
+ return model;
+ }
+
+ public @Nullable TableEnvironment getTableEnvironment() {
+ return env;
+ }
+
+ /**
+ * Returns the input data type expected by this model reference.
+ *
+ * <p>This method extracts the input data type from the model's input
schema, which describes
+ * the structure and data types that the model expects for inference
operations.
+ *
+ * @return the input data type expected by the model
+ */
+ public DataType getInputDataType() {
+ return
model.getResolvedModel().getResolvedInputSchema().toPhysicalRowDataType();
+ }
+
+ @Override
+ public DataType getOutputDataType() {
+ return
model.getResolvedModel().getResolvedOutputSchema().toPhysicalRowDataType();
+ }
+
+ @Override
+ public List<ResolvedExpression> getResolvedChildren() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public String asSerializableString(SqlFactory sqlFactory) {
+ if (model.isAnonymous()) {
+ throw new ValidationException("Anonymous models cannot be
serialized.");
+ }
+
+ return "MODEL " + model.getIdentifier().asSerializableString();
+ }
+
+ @Override
+ public String asSummaryString() {
+ return name;
+ }
+
+ @Override
+ public List<Expression> getChildren() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public <R> R accept(ExpressionVisitor<R> visitor) {
+ return visitor.visit(this);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ModelReferenceExpression that = (ModelReferenceExpression) o;
+ return Objects.equals(name, that.name)
+ && Objects.equals(model, that.model)
+ // Effectively means reference equality
+ && Objects.equals(env, that.env);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(name, model, env);
+ }
+
+ @Override
+ public String toString() {
+ return asSummaryString();
+ }
+}
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ResolvedExpressionVisitor.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ResolvedExpressionVisitor.java
index 42ee52d98d5..88a15e3929d 100644
---
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ResolvedExpressionVisitor.java
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/ResolvedExpressionVisitor.java
@@ -32,6 +32,8 @@ public abstract class ResolvedExpressionVisitor<R> implements
ExpressionVisitor<
public final R visit(Expression other) {
if (other instanceof TableReferenceExpression) {
return visit((TableReferenceExpression) other);
+ } else if (other instanceof ModelReferenceExpression) {
+ return visit((ModelReferenceExpression) other);
} else if (other instanceof LocalReferenceExpression) {
return visit((LocalReferenceExpression) other);
} else if (other instanceof ResolvedExpression) {
@@ -42,6 +44,8 @@ public abstract class ResolvedExpressionVisitor<R> implements
ExpressionVisitor<
public abstract R visit(TableReferenceExpression tableReference);
+ public abstract R visit(ModelReferenceExpression modelReferenceExpression);
+
public abstract R visit(LocalReferenceExpression localReference);
/** For resolved expressions created by the planner. */
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java
index f5c585087dd..e298953f69e 100644
---
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java
@@ -28,6 +28,7 @@ import org.apache.flink.table.connector.ChangelogMode;
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.ExpressionUtils;
+import org.apache.flink.table.expressions.ModelReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.TableReferenceExpression;
import org.apache.flink.table.expressions.TypeLiteralExpression;
@@ -38,6 +39,7 @@ import
org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.FunctionIdentifier;
import org.apache.flink.table.functions.FunctionKind;
+import org.apache.flink.table.functions.ModelSemantics;
import org.apache.flink.table.functions.ScalarFunctionDefinition;
import org.apache.flink.table.functions.TableAggregateFunctionDefinition;
import org.apache.flink.table.functions.TableFunctionDefinition;
@@ -651,6 +653,22 @@ final class ResolveCallByArgumentsRule implements
ResolverRule {
return Optional.of(semantics);
}
+ @Override
+ public Optional<ModelSemantics> getModelSemantics(int pos) {
+ final StaticArgument staticArg =
+ Optional.ofNullable(staticArguments).map(args ->
args.get(pos)).orElse(null);
+ if (staticArg == null || !staticArg.is(StaticArgumentTrait.MODEL))
{
+ return Optional.empty();
+ }
+ final ResolvedExpression arg = getArgument(pos);
+ if (!(arg instanceof ModelReferenceExpression)) {
+ return Optional.empty();
+ }
+ final ModelReferenceExpression modelRef =
(ModelReferenceExpression) arg;
+ final ModelSemantics semantics = new
TableApiModelSemantics(modelRef);
+ return Optional.of(semantics);
+ }
+
@Override
public String getName() {
return functionName;
@@ -732,4 +750,23 @@ final class ResolveCallByArgumentsRule implements
ResolverRule {
return Optional.empty();
}
}
+
+ private static class TableApiModelSemantics implements ModelSemantics {
+
+ private final ModelReferenceExpression modelRef;
+
+ private TableApiModelSemantics(ModelReferenceExpression modelRef) {
+ this.modelRef = modelRef;
+ }
+
+ @Override
+ public DataType inputDataType() {
+ return modelRef.getInputDataType();
+ }
+
+ @Override
+ public DataType outputDataType() {
+ return modelRef.getOutputDataType();
+ }
+ }
}
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java
index 9797f53a331..6e030345409 100644
---
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java
@@ -25,6 +25,7 @@ import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.LocalReferenceExpression;
import org.apache.flink.table.expressions.LookupCallExpression;
+import org.apache.flink.table.expressions.ModelReferenceExpression;
import org.apache.flink.table.expressions.NestedFieldReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.SqlCallExpression;
@@ -76,6 +77,11 @@ public abstract class ApiExpressionDefaultVisitor<T> extends
ApiExpressionVisito
return defaultMethod(tableReference);
}
+ @Override
+ public T visit(ModelReferenceExpression modelReference) {
+ return defaultMethod(modelReference);
+ }
+
@Override
public T visit(LocalReferenceExpression localReference) {
return defaultMethod(localReference);
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java
index 3bf93880d7d..841ff5a0339 100644
---
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal;
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.LocalReferenceExpression;
+import org.apache.flink.table.expressions.ModelReferenceExpression;
import org.apache.flink.table.expressions.NestedFieldReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.ResolvedExpressionVisitor;
@@ -41,6 +42,10 @@ public abstract class ResolvedExpressionDefaultVisitor<T>
extends ResolvedExpres
return defaultMethod(tableReference);
}
+ public T visit(ModelReferenceExpression modelReferenceExpression) {
+ return defaultMethod(modelReferenceExpression);
+ }
+
@Override
public T visit(LocalReferenceExpression localReference) {
return defaultMethod(localReference);
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationExpressionsUtils.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationExpressionsUtils.java
index 4d7af03afa8..4f39f3ecb84 100644
---
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationExpressionsUtils.java
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/OperationExpressionsUtils.java
@@ -24,6 +24,7 @@ import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.LocalReferenceExpression;
import org.apache.flink.table.expressions.LookupCallExpression;
+import org.apache.flink.table.expressions.ModelReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.TableReferenceExpression;
import org.apache.flink.table.expressions.UnresolvedCallExpression;
@@ -278,6 +279,11 @@ public class OperationExpressionsUtils {
return Optional.of(tableReference.getName());
}
+ @Override
+ public Optional<String> visit(ModelReferenceExpression modelReference)
{
+ return Optional.of(modelReference.getName());
+ }
+
@Override
public Optional<String> visit(FieldReferenceExpression fieldReference)
{
return Optional.of(fieldReference.getName());
diff --git
a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/FailingTableApiTestStep.java
b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/FailingTableApiTestStep.java
new file mode 100644
index 00000000000..ab91f955025
--- /dev/null
+++
b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/FailingTableApiTestStep.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.test.program;
+
+import org.apache.flink.table.api.Model;
+import org.apache.flink.table.api.ModelDescriptor;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.TableEnvironment;
+import org.apache.flink.table.api.TableRuntimeException;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.expressions.DefaultSqlFactory;
+import org.apache.flink.table.functions.UserDefinedFunction;
+import org.apache.flink.table.test.program.TableApiTestStep.TableEnvAccessor;
+import org.apache.flink.table.types.AbstractDataType;
+import org.apache.flink.util.Preconditions;
+
+import java.util.function.Function;
+
+import static org.apache.flink.core.testutils.FlinkAssertions.anyCauseMatches;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Test step for executing Table API query that will fail eventually with
either {@link
+ * ValidationException} (during planning time) or {@link
TableRuntimeException} (during execution
+ * time).
+ *
+ * <p>Similar to {@link FailingSqlTestStep} but uses Table API instead of SQL.
+ */
+public final class FailingTableApiTestStep implements TestStep {
+
+ private final Function<TableEnvAccessor, Table> tableQuery;
+ private final String sinkName;
+ public final Class<? extends Exception> expectedException;
+ public final String expectedErrorMessage;
+
+ FailingTableApiTestStep(
+ Function<TableEnvAccessor, Table> tableQuery,
+ String sinkName,
+ Class<? extends Exception> expectedException,
+ String expectedErrorMessage) {
+ Preconditions.checkArgument(
+ expectedException == ValidationException.class
+ || expectedException == TableRuntimeException.class,
+ "Usually a Table API query should fail with either validation
or runtime exception. "
+ + "Otherwise this might require an update to the
exception design.");
+ this.tableQuery = tableQuery;
+ this.sinkName = sinkName;
+ this.expectedException = expectedException;
+ this.expectedErrorMessage = expectedErrorMessage;
+ }
+
+ @Override
+ public TestKind getKind() {
+ return TestKind.FAILING_TABLE_API;
+ }
+
+ public Table toTable(TableEnvironment env) {
+ return tableQuery.apply(
+ new TableEnvAccessor() {
+ @Override
+ public Table from(String path) {
+ return env.from(path);
+ }
+
+ @Override
+ public Table fromCall(String path, Object... arguments) {
+ return env.fromCall(path, arguments);
+ }
+
+ @Override
+ public Table fromCall(
+ Class<? extends UserDefinedFunction> function,
Object... arguments) {
+ return env.fromCall(function, arguments);
+ }
+
+ @Override
+ public Table fromValues(Object... values) {
+ return env.fromValues(values);
+ }
+
+ @Override
+ public Table fromValues(AbstractDataType<?> dataType,
Object... values) {
+ return env.fromValues(dataType, values);
+ }
+
+ @Override
+ public Table sqlQuery(String query) {
+ return env.sqlQuery(query);
+ }
+
+ @Override
+ public Model fromModel(String modelPath) {
+ return env.fromModel(modelPath);
+ }
+
+ @Override
+ public Model from(ModelDescriptor modelDescriptor) {
+ return env.fromModel(modelDescriptor);
+ }
+ });
+ }
+
+ public void apply(TableEnvironment env) {
+ assertThatThrownBy(
+ () -> {
+ final Table table = toTable(env);
+ table.executeInsert(sinkName).await();
+ })
+ .satisfies(anyCauseMatches(expectedException,
expectedErrorMessage));
+ }
+
+ public void applyAsSql(TableEnvironment env) {
+ assertThatThrownBy(
+ () -> {
+ final Table table = toTable(env);
+ final String query =
+ table.getQueryOperation()
+
.asSerializableString(DefaultSqlFactory.INSTANCE);
+ env.executeSql(String.format("INSERT INTO %s %s",
sinkName, query))
+ .await();
+ })
+ .satisfies(anyCauseMatches(expectedException,
expectedErrorMessage));
+ }
+}
diff --git
a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableApiTestStep.java
b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableApiTestStep.java
index 4a375ce4f59..4968c86aa8c 100644
---
a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableApiTestStep.java
+++
b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableApiTestStep.java
@@ -18,6 +18,8 @@
package org.apache.flink.table.test.program;
+import org.apache.flink.table.api.Model;
+import org.apache.flink.table.api.ModelDescriptor;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.TableResult;
@@ -75,6 +77,16 @@ public class TableApiTestStep implements TestStep {
public Table sqlQuery(String query) {
return env.sqlQuery(query);
}
+
+ @Override
+ public Model fromModel(String modelPath) {
+ return env.fromModel(modelPath);
+ }
+
+ @Override
+ public Model from(ModelDescriptor modelDescriptor) {
+ return env.fromModel(modelDescriptor);
+ }
});
}
@@ -111,5 +123,11 @@ public class TableApiTestStep implements TestStep {
/** See {@link TableEnvironment#sqlQuery(String)}. */
Table sqlQuery(String query);
+
+ /** See {@link TableEnvironment#fromModel(String)}. */
+ Model fromModel(String modelPath);
+
+ /** See {@link TableEnvironment#fromModel(ModelDescriptor)}. */
+ Model from(ModelDescriptor modelDescriptor);
}
}
diff --git
a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableTestProgram.java
b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableTestProgram.java
index bb8ca296cae..d37c4c8dc95 100644
---
a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableTestProgram.java
+++
b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TableTestProgram.java
@@ -21,6 +21,7 @@ package org.apache.flink.table.test.program;
import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableRuntimeException;
+import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.test.program.FunctionTestStep.FunctionBehavior;
@@ -355,6 +356,22 @@ public class TableTestProgram {
return this;
}
+ /**
+ * Run step for executing a Table API query that will fail eventually
with either {@link
+ * ValidationException} (during planning time) or {@link
TableRuntimeException} (during
+ * execution time).
+ */
+ public Builder runFailingTableApi(
+ Function<TableEnvAccessor, Table> toTable,
+ String sinkName,
+ Class<? extends Exception> expectedException,
+ String expectedErrorMessage) {
+ this.runSteps.add(
+ new FailingTableApiTestStep(
+ toTable, sinkName, expectedException,
expectedErrorMessage));
+ return this;
+ }
+
public Builder runTableApi(Function<TableEnvAccessor, Table> toTable,
String sinkName) {
this.runSteps.add(new TableApiTestStep(toTable, sinkName));
return this;
diff --git
a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TestStep.java
b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TestStep.java
index fc4245df79f..db2fe754f36 100644
---
a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TestStep.java
+++
b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/test/program/TestStep.java
@@ -51,7 +51,8 @@ public interface TestStep {
SINK_WITHOUT_DATA,
SINK_WITH_DATA,
SINK_WITH_RESTORE_DATA,
- FAILING_SQL
+ FAILING_SQL,
+ FAILING_TABLE_API
}
TestKind getKind();
diff --git
a/flink-table/flink-table-common/src/main/java/org/apache/flink/types/ColumnList.java
b/flink-table/flink-table-common/src/main/java/org/apache/flink/types/ColumnList.java
index d2d02bc498f..bc70a660662 100644
---
a/flink-table/flink-table-common/src/main/java/org/apache/flink/types/ColumnList.java
+++
b/flink-table/flink-table-common/src/main/java/org/apache/flink/types/ColumnList.java
@@ -61,6 +61,10 @@ public final class ColumnList implements Serializable {
return of(names, List.of());
}
+ public static ColumnList of(String... names) {
+ return of(List.of(names));
+ }
+
/**
* Returns a list of column names.
*
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java
index b4c0cf39626..3a532d6f15f 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java
@@ -19,6 +19,8 @@
package org.apache.flink.table.planner.expressions.converter;
import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.catalog.Catalog;
+import org.apache.flink.table.catalog.ContextResolvedModel;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.data.DecimalData;
import org.apache.flink.table.expressions.CallExpression;
@@ -26,18 +28,27 @@ import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.ExpressionVisitor;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.LocalReferenceExpression;
+import org.apache.flink.table.expressions.ModelReferenceExpression;
import org.apache.flink.table.expressions.NestedFieldReferenceExpression;
import org.apache.flink.table.expressions.TimeIntervalUnit;
import org.apache.flink.table.expressions.TimePointUnit;
import org.apache.flink.table.expressions.TypeLiteralExpression;
import org.apache.flink.table.expressions.ValueLiteralExpression;
+import org.apache.flink.table.factories.FactoryUtil;
+import org.apache.flink.table.factories.ModelProviderFactory;
+import org.apache.flink.table.ml.ModelProvider;
+import org.apache.flink.table.module.Module;
+import org.apache.flink.table.planner.calcite.FlinkContext;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.calcite.RexFieldVariable;
+import org.apache.flink.table.planner.calcite.RexModelCall;
import org.apache.flink.table.planner.expressions.RexNodeExpression;
import
org.apache.flink.table.planner.expressions.converter.CallExpressionConvertRule.ConvertContext;
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.TimeType;
import org.apache.flink.types.ColumnList;
@@ -69,6 +80,7 @@ import static
org.apache.flink.table.planner.typeutils.SymbolUtil.commonToCalcit
import static org.apache.flink.table.planner.utils.ShortcutUtils.unwrapContext;
import static
org.apache.flink.table.planner.utils.TimestampStringUtils.fromLocalDateTime;
import static
org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType;
+import static org.apache.flink.util.OptionalUtils.firstPresent;
/** Visit expression to generator {@link RexNode}. */
public class ExpressionConverter implements ExpressionVisitor<RexNode> {
@@ -256,12 +268,57 @@ public class ExpressionConverter implements
ExpressionVisitor<RexNode> {
local.getName(),
typeFactory.createFieldTypeFromLogicalType(
fromDataTypeToLogicalType(local.getOutputDataType())));
+ } else if (other instanceof ModelReferenceExpression) {
+ return visit((ModelReferenceExpression) other);
} else {
throw new UnsupportedOperationException(
other.getClass().getSimpleName() + ":" + other.toString());
}
}
+ public RexNode visit(ModelReferenceExpression modelRef) {
+ final ContextResolvedModel contextResolvedModel = modelRef.getModel();
+ final FlinkContext flinkContext =
ShortcutUtils.unwrapContext(relBuilder);
+
+ final Optional<ModelProviderFactory> factoryFromCatalog =
+ contextResolvedModel
+ .getCatalog()
+ .flatMap(Catalog::getFactory)
+ .map(
+ f ->
+ f instanceof ModelProviderFactory
+ ? (ModelProviderFactory) f
+ : null);
+
+ final Optional<ModelProviderFactory> factoryFromModule =
+
flinkContext.getModuleManager().getFactory(Module::getModelProviderFactory);
+
+ // Since the catalog is more specific, we give it
+ // precedence over a factory provided by any
+ // modules.
+ final ModelProviderFactory factory =
+ firstPresent(factoryFromCatalog,
factoryFromModule).orElse(null);
+
+ final ModelProvider modelProvider =
+ FactoryUtil.createModelProvider(
+ factory,
+ contextResolvedModel.getIdentifier(),
+ contextResolvedModel.getResolvedModel(),
+ flinkContext.getTableConfig(),
+ flinkContext.getClassLoader(),
+ contextResolvedModel.isTemporary());
+ final LogicalType modelOutputType =
+ contextResolvedModel
+ .getResolvedModel()
+ .getResolvedOutputSchema()
+ .toPhysicalRowDataType()
+ .getLogicalType();
+ final RelDataType modelOutputRelDataType =
+ typeFactory.buildRelNodeRowType((RowType) modelOutputType);
+
+ return new RexModelCall(modelOutputRelDataType, contextResolvedModel,
modelProvider);
+ }
+
public static List<RexNode> toRexNodes(ConvertContext context,
List<Expression> expr) {
return
expr.stream().map(context::toRexNode).collect(Collectors.toList());
}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java
index 08ebe72f9e4..3120e23359a 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java
@@ -547,7 +547,6 @@ public class QueryOperationConverter extends
QueryOperationDefaultVisitor<RelNod
dataStreamQueryOperation.getResolvedSchema(),
dataStreamQueryOperation.getIdentifier());
}
-
throw new TableException("Unknown table operation: " + other);
}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java
index d0b294ed5ea..0016ec9893d 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java
@@ -114,6 +114,10 @@ public class StreamPhysicalMLPredictTableFunction extends
SingleRel implements S
.item("rowType", getRowType());
}
+ public RexNode getMLPredictCall() {
+ return scan.getCall();
+ }
+
private MLPredictSpec buildMLPredictSpec(Map<String, String>
runtimeConfig) {
RexTableArgCall tableCall = extractOperand(operand -> operand
instanceof RexTableArgCall);
RexCall descriptorCall =
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java
index 875a088ec59..c34dda9ea6b 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/optimize/StreamNonDeterministicUpdatePlanVisitor.java
@@ -43,6 +43,7 @@ import
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalL
import
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalLegacyTableSourceScan;
import
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalLimit;
import
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalLookupJoin;
+import
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMLPredictTableFunction;
import
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMatch;
import
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMiniBatchAssigner;
import
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMultiJoin;
@@ -198,6 +199,9 @@ public class StreamNonDeterministicUpdatePlanVisitor {
(StreamPhysicalWindowTableFunction) rel,
requireDeterminism);
} else if (rel instanceof StreamPhysicalDeltaJoin) {
return visitDeltaJoin((StreamPhysicalDeltaJoin) rel,
requireDeterminism);
+ } else if (rel instanceof StreamPhysicalMLPredictTableFunction) {
+ return visitMLPredictTableFunction(
+ (StreamPhysicalMLPredictTableFunction) rel,
requireDeterminism);
} else if (rel instanceof StreamPhysicalChangelogNormalize
|| rel instanceof StreamPhysicalDropUpdateBefore
|| rel instanceof StreamPhysicalMiniBatchAssigner
@@ -328,6 +332,16 @@ public class StreamNonDeterministicUpdatePlanVisitor {
}
}
+ private StreamPhysicalRel visitMLPredictTableFunction(
+ final StreamPhysicalMLPredictTableFunction predictTableFunction,
+ final ImmutableBitSet requireDeterminism) {
+ if (!inputInsertOnly(predictTableFunction) &&
!requireDeterminism.isEmpty()) {
+ throwNonDeterministicConditionError(
+ "ML_PREDICT", predictTableFunction.getMLPredictCall(),
predictTableFunction);
+ }
+ return transmitDeterminismRequirement(predictTableFunction,
NO_REQUIRED_DETERMINISM);
+ }
+
private StreamPhysicalRel visitCorrelate(
final StreamPhysicalCorrelateBase correlate, final ImmutableBitSet
requireDeterminism) {
if (inputInsertOnly(correlate) || requireDeterminism.isEmpty()) {
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FunctionCallUtil.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FunctionCallUtil.java
index 18091c4ae13..7369210019d 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FunctionCallUtil.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FunctionCallUtil.java
@@ -35,9 +35,11 @@ import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonSub
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeInfo;
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeName;
+import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
import java.util.HashMap;
import java.util.List;
@@ -45,10 +47,15 @@ import java.util.Map;
import java.util.Objects;
import static org.apache.calcite.sql.SqlKind.MAP_VALUE_CONSTRUCTOR;
+import static
org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;
+import static
org.apache.flink.table.types.logical.LogicalTypeFamily.CHARACTER_STRING;
/** Common utils for function call, e.g. ML_PREDICT and Lookup Join. */
public abstract class FunctionCallUtil {
+ private static final String CONFIG_ERROR_MESSAGE =
+ "Config parameter should be a MAP data type consisting of String
literals.";
+
/** A field used as an equal condition when querying content from a
dimension table. */
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include =
JsonTypeInfo.As.PROPERTY, property = "type")
@JsonSubTypes({
@@ -225,18 +232,41 @@ public abstract class FunctionCallUtil {
for (int i = 0; i < mapConstructor.getOperands().size(); i += 2) {
RexNode keyNode = mapConstructor.getOperands().get(i);
RexNode valueNode = mapConstructor.getOperands().get(i + 1);
- // Both key and value should be string literals
- if (!(keyNode instanceof RexLiteral) || !(valueNode instanceof
RexLiteral)) {
- throw new ValidationException(
- "Config parameter should be a MAP data type consisting
String literals.");
- }
- String key = RexLiteral.stringValue(keyNode);
- String value = RexLiteral.stringValue(valueNode);
+ String key = getStringLiteral(keyNode);
+ String value = getStringLiteral(valueNode);
reducedConfig.put(key, value);
}
return reducedConfig;
}
+ private static String getStringLiteral(RexNode node) {
+ // Cast from string to string is used when Expressions.lit(Map(...))
is used as config map
+ // from table api
+ if (node instanceof RexCall && node.getKind() == SqlKind.CAST) {
+ final RexCall castCall = (RexCall) node;
+ // Unwrap CAST if present
+ final RexNode castOperand = castCall.getOperands().get(0);
+ if (!(castOperand instanceof RexLiteral)) {
+ throw new ValidationException(CONFIG_ERROR_MESSAGE);
+ }
+ final RelDataType operandType = castOperand.getType();
+ if (!toLogicalType(operandType).is(CHARACTER_STRING)) {
+ throw new ValidationException(CONFIG_ERROR_MESSAGE);
+ }
+ final RelDataType castType = castCall.getType();
+ if (!toLogicalType(castType).is(CHARACTER_STRING)) {
+ throw new ValidationException(CONFIG_ERROR_MESSAGE);
+ }
+ return RexLiteral.stringValue(castOperand);
+ }
+ // Both key and value should be string literals
+ if (!(node instanceof RexLiteral)) {
+ throw new ValidationException(CONFIG_ERROR_MESSAGE);
+ }
+
+ return RexLiteral.stringValue(node);
+ }
+
public static String explainFunctionParam(FunctionParam param,
List<String> fieldNames) {
if (param instanceof Constant) {
return RelExplainUtil.literalToString(((Constant) param).literal);
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSemanticTest.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSemanticTest.java
index 73ef260d56b..aeb8c7028a8 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSemanticTest.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSemanticTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.table.api;
import org.apache.flink.table.operations.QueryOperation;
import
org.apache.flink.table.planner.plan.nodes.exec.testutils.SemanticTestBase;
+import org.apache.flink.table.test.program.FailingTableApiTestStep;
import org.apache.flink.table.test.program.TableApiTestStep;
import org.apache.flink.table.test.program.TableTestProgram;
import org.apache.flink.table.test.program.TestStep;
@@ -57,7 +58,11 @@ public class QueryOperationSqlSemanticTest extends
SemanticTestBase {
QueryOperationTestPrograms.OVER_WINDOW_LAG,
QueryOperationTestPrograms.ACCESSING_NESTED_COLUMN,
QueryOperationTestPrograms.ROW_SEMANTIC_TABLE_PTF,
- QueryOperationTestPrograms.SET_SEMANTIC_TABLE_PTF);
+ QueryOperationTestPrograms.SET_SEMANTIC_TABLE_PTF,
+ QueryOperationTestPrograms.ML_PREDICT_MODEL_API,
+
QueryOperationTestPrograms.ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG,
+ QueryOperationTestPrograms.ASYNC_ML_PREDICT_MODEL_API,
+ QueryOperationTestPrograms.ML_PREDICT_ANON_MODEL_API);
}
@Override
@@ -65,6 +70,9 @@ public class QueryOperationSqlSemanticTest extends
SemanticTestBase {
if (testStep instanceof TableApiTestStep) {
final TableApiTestStep tableApiStep = (TableApiTestStep) testStep;
tableApiStep.applyAsSql(env).await();
+ } else if (testStep instanceof FailingTableApiTestStep) {
+ final FailingTableApiTestStep failingTableApiStep =
(FailingTableApiTestStep) testStep;
+ failingTableApiStep.applyAsSql(env);
} else {
super.runStep(testStep, env);
}
@@ -72,6 +80,6 @@ public class QueryOperationSqlSemanticTest extends
SemanticTestBase {
@Override
public EnumSet<TestKind> supportedRunSteps() {
- return EnumSet.of(TestKind.TABLE_API, TestKind.SQL);
+ return EnumSet.of(TestKind.TABLE_API, TestKind.SQL,
TestKind.FAILING_TABLE_API);
}
}
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSerializationTest.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSerializationTest.java
index fb35916dc24..639a3eb4e5a 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSerializationTest.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSerializationTest.java
@@ -70,7 +70,10 @@ public class QueryOperationSqlSerializationTest implements
TableTestProgramRunne
QueryOperationTestPrograms.ACCESSING_NESTED_COLUMN,
QueryOperationTestPrograms.ROW_SEMANTIC_TABLE_PTF,
QueryOperationTestPrograms.SET_SEMANTIC_TABLE_PTF,
- QueryOperationTestPrograms.INLINE_FUNCTION_SERIALIZATION);
+ QueryOperationTestPrograms.INLINE_FUNCTION_SERIALIZATION,
+ QueryOperationTestPrograms.ML_PREDICT_MODEL_API,
+ QueryOperationTestPrograms.ASYNC_ML_PREDICT_MODEL_API,
+
QueryOperationTestPrograms.ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG);
}
@ParameterizedTest
@@ -139,6 +142,7 @@ public class QueryOperationSqlSerializationTest implements
TableTestProgramRunne
final Map<String, String> connectorOptions = new HashMap<>();
connectorOptions.put("connector", "values");
program.getSetupSourceTestSteps().forEach(s -> s.apply(env,
connectorOptions));
+ program.getSetupModelTestSteps().forEach(s -> s.apply(env,
Map.of("provider", "values")));
program.getSetupSinkTestSteps().forEach(s -> s.apply(env,
connectorOptions));
program.getSetupFunctionTestSteps().forEach(f -> f.apply(env));
program.getSetupSqlTestSteps().forEach(s -> s.apply(env));
@@ -149,6 +153,7 @@ public class QueryOperationSqlSerializationTest implements
TableTestProgramRunne
public EnumSet<TestKind> supportedSetupSteps() {
return EnumSet.of(
TestKind.CONFIG,
+ TestKind.MODEL,
TestKind.SQL,
TestKind.FUNCTION,
TestKind.SOURCE_WITH_DATA,
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
index 53315269c2f..6640b4da778 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
@@ -19,9 +19,12 @@
package org.apache.flink.table.api;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.config.ExecutionConfigOptions;
+import
org.apache.flink.table.api.config.ExecutionConfigOptions.AsyncOutputMode;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.operations.QueryOperation;
+import org.apache.flink.table.planner.factories.TestValuesModelFactory;
import
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.ChainedReceivingFunction;
import
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.ChainedSendingFunction;
import
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.RowSemanticTableFunction;
@@ -30,6 +33,7 @@ import
org.apache.flink.table.planner.runtime.utils.JavaUserDefinedTableFunction
import org.apache.flink.table.test.program.SinkTestStep;
import org.apache.flink.table.test.program.SourceTestStep;
import org.apache.flink.table.test.program.TableTestProgram;
+import org.apache.flink.types.ColumnList;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;
@@ -39,6 +43,7 @@ import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.ZoneId;
import java.util.Collections;
+import java.util.Map;
import static org.apache.flink.table.api.Expressions.$;
import static org.apache.flink.table.api.Expressions.UNBOUNDED_ROW;
@@ -49,6 +54,10 @@ import static org.apache.flink.table.api.Expressions.lag;
import static org.apache.flink.table.api.Expressions.lit;
import static org.apache.flink.table.api.Expressions.nullOf;
import static org.apache.flink.table.api.Expressions.row;
+import static
org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_MODEL;
+import static
org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SIMPLE_FEATURES_SOURCE;
+import static
org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SIMPLE_SINK;
+import static
org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SYNC_MODEL;
import static
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.BASE_SINK_SCHEMA;
import static
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.BASIC_VALUES;
import static
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.KEYED_TIMED_BASE_SINK_SCHEMA;
@@ -1117,6 +1126,111 @@ public class QueryOperationTestPrograms {
"sink")
.build();
+ public static final TableTestProgram ML_PREDICT_MODEL_API =
+ TableTestProgram.of("ml-predict-model-api", "ml-predict using
model API")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(SYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .runSql(
+ "SELECT `$$T_FUNC`.`id`, `$$T_FUNC`.`feature`,
`$$T_FUNC`.`category` FROM TABLE(\n"
+ + " ML_PREDICT((\n"
+ + " SELECT `$$T_SOURCE`.`id`,
`$$T_SOURCE`.`feature` FROM `default_catalog`.`default_database`.`features`
$$T_SOURCE\n"
+ + " ), MODEL
`default_catalog`.`default_database`.`chatgpt`, DESCRIPTOR(`feature`),
DEFAULT)\n"
+ + ") $$T_FUNC")
+ .runTableApi(
+ env ->
+ env.fromModel("chatgpt")
+ .predict(
+ env.from("features"),
ColumnList.of("feature")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ASYNC_ML_PREDICT_MODEL_API =
+ TableTestProgram.of("async-ml-predict-model-api", "async
ml-predict using model API")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(ASYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .setupConfig(
+
ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
+ AsyncOutputMode.ALLOW_UNORDERED)
+ .runSql(
+ "SELECT `$$T_FUNC`.`id`, `$$T_FUNC`.`feature`,
`$$T_FUNC`.`category` FROM TABLE(\n"
+ + " ML_PREDICT((\n"
+ + " SELECT `$$T_SOURCE`.`id`,
`$$T_SOURCE`.`feature` FROM `default_catalog`.`default_database`.`features`
$$T_SOURCE\n"
+ + " ), MODEL
`default_catalog`.`default_database`.`chatgpt`, DESCRIPTOR(`feature`),
MAP['async', 'true'])\n"
+ + ") $$T_FUNC")
+ .runTableApi(
+ env ->
+ env.fromModel("chatgpt")
+ .predict(
+ env.from("features"),
+ ColumnList.of("feature"),
+ Map.of("async", "true")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ML_PREDICT_ANON_MODEL_API =
+ TableTestProgram.of(
+ "ml-predict-anonymous-model-api",
+ "ml-predict using anonymous model API")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .runFailingTableApi(
+ env ->
+ env.from(
+
ModelDescriptor.forProvider("values")
+ .inputSchema(
+
Schema.newBuilder()
+
.column(
+
"feature",
+
"STRING")
+
.build())
+ .outputSchema(
+
Schema.newBuilder()
+
.column(
+
"category",
+
"STRING")
+
.build())
+ .option(
+ "data-id",
+
TestValuesModelFactory
+
.registerData(
+
SYNC_MODEL
+
.data))
+ .build())
+ .predict(
+ env.from("features"),
ColumnList.of("feature")),
+ "sink",
+ ValidationException.class,
+ "Anonymous models cannot be serialized.")
+ .build();
+
+ public static final TableTestProgram
ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG =
+ TableTestProgram.of(
+ "async-ml-predict-table-api-map-expression-config",
+ "ml-predict in async mode using Table API and map
expression.")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(ASYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .setupConfig(
+
ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
+
ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED)
+ .runSql(
+ "SELECT `$$T_FUNC`.`id`, `$$T_FUNC`.`feature`,
`$$T_FUNC`.`category` FROM TABLE(\n"
+ + " ML_PREDICT((\n"
+ + " SELECT `$$T_SOURCE`.`id`,
`$$T_SOURCE`.`feature` FROM `default_catalog`.`default_database`.`features`
$$T_SOURCE\n"
+ + " ), MODEL
`default_catalog`.`default_database`.`chatgpt`, DESCRIPTOR(`feature`),
MAP['async', 'true'])\n"
+ + ") $$T_FUNC")
+ .runTableApi(
+ env ->
+ env.fromCall(
+ "ML_PREDICT",
+
env.from("features").asArgument("INPUT"),
+
env.fromModel("chatgpt").asArgument("MODEL"),
+
descriptor("feature").asArgument("ARGS"),
+ Expressions.map("async",
"true").asArgument("CONFIG")),
+ "sink")
+ .build();
+
/**
* A function that will be used as an inline function in {@link
#INLINE_FUNCTION_SERIALIZATION}.
*/
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictSemanticTests.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictSemanticTests.java
new file mode 100644
index 00000000000..cb15ff4533d
--- /dev/null
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictSemanticTests.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.nodes.exec.stream;
+
+import
org.apache.flink.table.planner.plan.nodes.exec.testutils.SemanticTestBase;
+import org.apache.flink.table.test.program.TableTestProgram;
+
+import java.util.List;
+
+import static
org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_ML_PREDICT_MODEL_API;
+import static
org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_ML_PREDICT_TABLE_API;
+import static
org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG;
+import static
org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ML_PREDICT_ANON_MODEL_API;
+import static
org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.ML_PREDICT_MODEL_API;
+import static
org.apache.flink.table.planner.plan.nodes.exec.stream.MLPredictTestPrograms.SYNC_ML_PREDICT_TABLE_API;
+
+/** Semantic tests for {@link StreamExecMLPredictTableFunction} using Table
API. */
+public class MLPredictSemanticTests extends SemanticTestBase {
+
+ @Override
+ public List<TableTestProgram> programs() {
+ return List.of(
+ SYNC_ML_PREDICT_TABLE_API,
+ ASYNC_ML_PREDICT_TABLE_API,
+ ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG,
+ ML_PREDICT_MODEL_API,
+ ASYNC_ML_PREDICT_MODEL_API,
+ ML_PREDICT_ANON_MODEL_API);
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictTestPrograms.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictTestPrograms.java
index 7d430e143b2..26d4903a3e4 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictTestPrograms.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MLPredictTestPrograms.java
@@ -18,11 +18,17 @@
package org.apache.flink.table.planner.plan.nodes.exec.stream;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Expressions;
+import org.apache.flink.table.api.ModelDescriptor;
+import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.config.ExecutionConfigOptions;
+import org.apache.flink.table.planner.factories.TestValuesModelFactory;
import org.apache.flink.table.test.program.ModelTestStep;
import org.apache.flink.table.test.program.SinkTestStep;
import org.apache.flink.table.test.program.SourceTestStep;
import org.apache.flink.table.test.program.TableTestProgram;
+import org.apache.flink.types.ColumnList;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;
@@ -30,6 +36,8 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import static org.apache.flink.table.api.Expressions.descriptor;
+
/** Programs for verifying {@link StreamExecMLPredictTableFunction}. */
public class MLPredictTestPrograms {
@@ -48,7 +56,13 @@ public class MLPredictTestPrograms {
Row.ofKind(RowKind.INSERT, 4, "Mysql"),
Row.ofKind(RowKind.INSERT, 5, "Postgres")
};
- static final SourceTestStep FEATURES_TABLE =
+ public static final SourceTestStep SIMPLE_FEATURES_SOURCE =
+ SourceTestStep.newBuilder("features")
+ .addSchema(FEATURES_SCHEMA)
+ .producedValues(FEATURES_BEFORE_DATA)
+ .build();
+
+ static final SourceTestStep RESTORE_FEATURES_TABLE =
SourceTestStep.newBuilder("features")
.addSchema(FEATURES_SCHEMA)
.producedBeforeRestore(FEATURES_BEFORE_DATA)
@@ -61,7 +75,7 @@ public class MLPredictTestPrograms {
static final String[] MODEL_OUTPUT_SCHEMA = new String[] {"category
STRING"};
static final Map<Row, List<Row>> MODEL_DATA =
- new HashMap<Row, List<Row>>() {
+ new HashMap<>() {
{
put(
Row.ofKind(RowKind.INSERT, "Flink"),
@@ -82,14 +96,14 @@ public class MLPredictTestPrograms {
}
};
- static final ModelTestStep SYNC_MODEL =
+ public static final ModelTestStep SYNC_MODEL =
ModelTestStep.newBuilder("chatgpt")
.addInputSchema(MODEL_INPUT_SCHEMA)
.addOutputSchema(MODEL_OUTPUT_SCHEMA)
.data(MODEL_DATA)
.build();
- static final ModelTestStep ASYNC_MODEL =
+ public static final ModelTestStep ASYNC_MODEL =
ModelTestStep.newBuilder("chatgpt")
.addInputSchema(MODEL_INPUT_SCHEMA)
.addOutputSchema(MODEL_OUTPUT_SCHEMA)
@@ -102,7 +116,7 @@ public class MLPredictTestPrograms {
static final String[] SINK_SCHEMA =
new String[] {"id INT PRIMARY KEY NOT ENFORCED", "feature STRING",
"category STRING"};
- static final SinkTestStep SINK_TABLE =
+ static final SinkTestStep RESTORE_SINK_TABLE =
SinkTestStep.newBuilder("sink_t")
.addSchema(SINK_SCHEMA)
.consumedBeforeRestore(
@@ -112,22 +126,31 @@ public class MLPredictTestPrograms {
.consumedAfterRestore("+I[4, Mysql, Database]", "+I[5,
Postgres, Database]")
.build();
+ public static final SinkTestStep SIMPLE_SINK =
+ SinkTestStep.newBuilder("sink")
+ .addSchema(SINK_SCHEMA)
+ .consumedValues(
+ "+I[1, Flink, Big Data]",
+ "+I[2, Spark, Big Data]",
+ "+I[3, Hive, Big Data]")
+ .build();
+
//
-------------------------------------------------------------------------------------------
public static final TableTestProgram SYNC_ML_PREDICT =
TableTestProgram.of("sync-ml-predict", "ml-predict in sync mode.")
- .setupTableSource(FEATURES_TABLE)
+ .setupTableSource(RESTORE_FEATURES_TABLE)
.setupModel(SYNC_MODEL)
- .setupTableSink(SINK_TABLE)
+ .setupTableSink(RESTORE_SINK_TABLE)
.runSql(
"INSERT INTO sink_t SELECT * FROM ML_PREDICT(TABLE
features, MODEL chatgpt, DESCRIPTOR(feature))")
.build();
public static final TableTestProgram ASYNC_UNORDERED_ML_PREDICT =
TableTestProgram.of("async-unordered-ml-predict", "ml-predict in
async unordered mode.")
- .setupTableSource(FEATURES_TABLE)
+ .setupTableSource(RESTORE_FEATURES_TABLE)
.setupModel(ASYNC_MODEL)
- .setupTableSink(SINK_TABLE)
+ .setupTableSink(RESTORE_SINK_TABLE)
.setupConfig(
ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED)
@@ -139,11 +162,148 @@ public class MLPredictTestPrograms {
TableTestProgram.of(
"sync-ml-predict-with-runtime-options",
"ml-predict in sync mode with runtime config.")
- .setupTableSource(FEATURES_TABLE)
+ .setupTableSource(RESTORE_FEATURES_TABLE)
.setupModel(ASYNC_MODEL)
- .setupTableSink(SINK_TABLE)
+ .setupTableSink(RESTORE_SINK_TABLE)
.runSql(
"INSERT INTO sink_t SELECT * FROM ML_PREDICT(TABLE
features, MODEL chatgpt, DESCRIPTOR(feature), MAP['async', 'false'])")
.build();
- ;
+
+ public static final TableTestProgram SYNC_ML_PREDICT_TABLE_API =
+ TableTestProgram.of(
+ "sync-ml-predict-table-api", "ml-predict in sync
mode using Table API.")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(SYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .runTableApi(
+ env ->
+ env.fromCall(
+ "ML_PREDICT",
+
env.from("features").asArgument("INPUT"),
+
env.fromModel("chatgpt").asArgument("MODEL"),
+
descriptor("feature").asArgument("ARGS")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ASYNC_ML_PREDICT_TABLE_API =
+ TableTestProgram.of(
+ "async-ml-predict-table-api",
+ "ml-predict in async mode using Table API.")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(ASYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .setupConfig(
+
ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
+
ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED)
+ .runTableApi(
+ env ->
+ env.fromCall(
+ "ML_PREDICT",
+
env.from("features").asArgument("INPUT"),
+
env.fromModel("chatgpt").asArgument("MODEL"),
+
descriptor("feature").asArgument("ARGS"),
+ Expressions.lit(
+ Map.of("async",
"true"),
+ DataTypes.MAP(
+
DataTypes.STRING(),
+
DataTypes.STRING())
+ .notNull())
+ .asArgument("CONFIG")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram
ASYNC_ML_PREDICT_TABLE_API_MAP_EXPRESSION_CONFIG =
+ TableTestProgram.of(
+ "async-ml-predict-table-api-map-expression-config",
+ "ml-predict in async mode using Table API and map
expression.")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(ASYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .setupConfig(
+
ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
+
ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED)
+ .runTableApi(
+ env ->
+ env.fromCall(
+ "ML_PREDICT",
+
env.from("features").asArgument("INPUT"),
+
env.fromModel("chatgpt").asArgument("MODEL"),
+
descriptor("feature").asArgument("ARGS"),
+ Expressions.map(
+ "async",
+ "true",
+
"max-concurrent-operations",
+ "10")
+ .asArgument("CONFIG")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ML_PREDICT_MODEL_API =
+ TableTestProgram.of("ml-predict-model-api", "ml-predict using
model API")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(SYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .runTableApi(
+ env ->
+ env.fromModel("chatgpt")
+ .predict(
+ env.from("features"),
ColumnList.of("feature")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ASYNC_ML_PREDICT_MODEL_API =
+ TableTestProgram.of("async-ml-predict-model-api", "async
ml-predict using model API")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupModel(ASYNC_MODEL)
+ .setupTableSink(SIMPLE_SINK)
+ .setupConfig(
+
ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
+
ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED)
+ .runTableApi(
+ env ->
+ env.fromModel("chatgpt")
+ .predict(
+ env.from("features"),
+ ColumnList.of("feature"),
+ Map.of(
+ "async",
+ "true",
+
"max-concurrent-operations",
+ "10")),
+ "sink")
+ .build();
+
+ public static final TableTestProgram ML_PREDICT_ANON_MODEL_API =
+ TableTestProgram.of(
+ "ml-predict-anonymous-model-api",
+ "ml-predict using anonymous model API")
+ .setupTableSource(SIMPLE_FEATURES_SOURCE)
+ .setupTableSink(SIMPLE_SINK)
+ .runTableApi(
+ env ->
+ env.from(
+
ModelDescriptor.forProvider("values")
+ .inputSchema(
+
Schema.newBuilder()
+
.column(
+
"feature",
+
"STRING")
+
.build())
+ .outputSchema(
+
Schema.newBuilder()
+
.column(
+
"category",
+
"STRING")
+
.build())
+ .option(
+ "data-id",
+
TestValuesModelFactory
+
.registerData(
+
SYNC_MODEL
+
.data))
+ .build())
+ .predict(
+ env.from("features"),
ColumnList.of("feature")),
+ "sink")
+ .build();
}
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/SemanticTestBase.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/SemanticTestBase.java
index 9c65cec88c8..4ba4c3cc865 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/SemanticTestBase.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/testutils/SemanticTestBase.java
@@ -22,10 +22,13 @@ import org.apache.flink.table.api.EnvironmentSettings;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import org.apache.flink.table.planner.factories.TestValuesModelFactory;
import org.apache.flink.table.planner.factories.TestValuesTableFactory;
import org.apache.flink.table.test.program.ConfigOptionTestStep;
import org.apache.flink.table.test.program.FailingSqlTestStep;
+import org.apache.flink.table.test.program.FailingTableApiTestStep;
import org.apache.flink.table.test.program.FunctionTestStep;
+import org.apache.flink.table.test.program.ModelTestStep;
import org.apache.flink.table.test.program.SinkTestStep;
import org.apache.flink.table.test.program.SourceTestStep;
import org.apache.flink.table.test.program.SqlTestStep;
@@ -62,6 +65,7 @@ public abstract class SemanticTestBase implements
TableTestProgramRunner {
public EnumSet<TestKind> supportedSetupSteps() {
return EnumSet.of(
TestKind.CONFIG,
+ TestKind.MODEL,
TestKind.SOURCE_WITH_DATA,
TestKind.SINK_WITH_DATA,
TestKind.FUNCTION,
@@ -70,7 +74,8 @@ public abstract class SemanticTestBase implements
TableTestProgramRunner {
@Override
public EnumSet<TestKind> supportedRunSteps() {
- return EnumSet.of(TestKind.SQL, TestKind.FAILING_SQL,
TestKind.TABLE_API);
+ return EnumSet.of(
+ TestKind.SQL, TestKind.FAILING_SQL, TestKind.TABLE_API,
TestKind.FAILING_TABLE_API);
}
@AfterEach
@@ -145,6 +150,22 @@ public abstract class SemanticTestBase implements
TableTestProgramRunner {
sqlTestStep.apply(env);
}
break;
+ case FAILING_TABLE_API:
+ {
+ final FailingTableApiTestStep tableApiTestStep =
+ (FailingTableApiTestStep) testStep;
+ tableApiTestStep.apply(env);
+ }
+ break;
+ case MODEL:
+ {
+ final ModelTestStep modelTestStep = (ModelTestStep)
testStep;
+ final Map<String, String> options = new HashMap<>();
+ options.put("provider", "values");
+ options.put("data-id",
TestValuesModelFactory.registerData(modelTestStep.data));
+ modelTestStep.apply(env, options);
+ }
+ break;
case TABLE_API:
{
final TableApiTestStep apiTestStep = (TableApiTestStep)
testStep;
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLPredictTableFunctionTest.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLPredictTableFunctionTest.java
index 526b92413a9..1c20fdb7ed9 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLPredictTableFunctionTest.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLPredictTableFunctionTest.java
@@ -314,7 +314,7 @@ public class MLPredictTableFunctionTest extends
TableTestBase {
+ "FROM TABLE(ML_PREDICT(TABLE
MyTable, MODEL MyModel, DESCRIPTOR(a, b), MAP['async', 'true', 'capacity',
CAST(-1 AS STRING)]))"))
.hasCauseInstanceOf(ValidationException.class)
.hasStackTraceContaining(
- "Config parameter should be a MAP data type consisting
String literals.");
+ "Config parameter should be a MAP data type consisting
of String literals.");
assertThatThrownBy(
() ->
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/MLPredictITCase.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/MLPredictITCase.java
index cb6fdfaf3f7..4bf0f44f465 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/MLPredictITCase.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/MLPredictITCase.java
@@ -18,10 +18,13 @@
package org.apache.flink.table.planner.runtime.stream.table;
+import org.apache.flink.table.api.Model;
+import org.apache.flink.table.api.Table;
import org.apache.flink.table.planner.factories.TestValuesModelFactory;
import org.apache.flink.table.planner.factories.TestValuesTableFactory;
import
org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecMLPredictTableFunction;
import org.apache.flink.table.planner.runtime.utils.StreamingTestBase;
+import org.apache.flink.types.ColumnList;
import org.apache.flink.types.Row;
import org.apache.flink.util.CollectionUtil;
@@ -156,6 +159,23 @@ public class MLPredictITCase extends StreamingTestBase {
.containsExactlyInAnyOrder(Row.of(1L, "x1", 1, "z1"),
Row.of(2L, "x2", 2, "z2"));
}
+ @Test
+ public void testPredictTableApiWithView() {
+ Model model = tEnv().fromModel("m1");
+ Table table = tEnv().from("src");
+ tEnv().createView("view_src", model.predict(table,
ColumnList.of("id")));
+ List<Row> results =
+ CollectionUtil.iteratorToList(
+ tEnv().executeSql("select * from view_src").collect());
+ assertThatList(results)
+ .containsExactlyInAnyOrder(
+ Row.of(1L, 12, "Julian", "x1", 1, "z1"),
+ Row.of(2L, 15, "Hello", "x2", 2, "z2"),
+ Row.of(3L, 15, "Fabian", "x3", 3, "z3"),
+ Row.of(8L, 11, "Hello world", "x8", 8, "z8"),
+ Row.of(9L, 12, "Hello world!", "x9", 9, "z9"));
+ }
+
private void createScanTable(String tableName, List<Row> data) {
String dataId = TestValuesTableFactory.registerData(data);
tEnv().executeSql(
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableEnvironmentTest.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableEnvironmentTest.scala
index 6e9b9285bec..27b02c67898 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableEnvironmentTest.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/TableEnvironmentTest.scala
@@ -24,7 +24,7 @@ import
org.apache.flink.core.testutils.FlinkAssertions.anyCauseMatches
import org.apache.flink.sql.parser.error.SqlValidateException
import org.apache.flink.streaming.api.environment.{LocalStreamEnvironment,
StreamExecutionEnvironment}
import org.apache.flink.table.api.bridge.scala._
-import org.apache.flink.table.api.internal.TableEnvironmentInternal
+import org.apache.flink.table.api.internal.{ModelImpl,
TableEnvironmentInternal}
import org.apache.flink.table.catalog._
import org.apache.flink.table.factories.{TableFactoryUtil,
TableSourceFactoryContextImpl}
import org.apache.flink.table.functions.TestGenericUDF
@@ -3250,6 +3250,28 @@ class TableEnvironmentTest {
checkData(util.Arrays.asList(Row.of("your_model")).iterator(),
tableResult3.collect())
}
+ @Test
+ def testGetNonExistModel(): Unit = {
+ assertThatThrownBy(() => tableEnv.fromModel("MyModel"))
+ .hasMessageContaining("Model `MyModel` was not found")
+ .isInstanceOf[ValidationException]
+ }
+
+ @Test
+ def testGetModel(): Unit = {
+ val inputSchema = Schema.newBuilder().column("feature",
DataTypes.STRING()).build()
+
+ val outputSchema = Schema.newBuilder().column("response",
DataTypes.DOUBLE()).build()
+ tableEnv.createModel(
+ "MyModel",
+ ModelDescriptor
+ .forProvider("openai")
+ .inputSchema(inputSchema)
+ .outputSchema(outputSchema)
+ .build())
+ assertThat(tableEnv.fromModel("MyModel")).isInstanceOf(classOf[ModelImpl])
+ }
+
@Test
def testTemporaryOperationListener(): Unit = {
val listener = new ListenerCatalog("listener_cat")
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
index be3d14227e7..b194d08637c 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
@@ -37,7 +37,7 @@ import
org.apache.flink.table.api.bridge.java.{StreamTableEnvironment => JavaStr
import org.apache.flink.table.api.bridge.scala.{StreamTableEnvironment =>
ScalaStreamTableEnv}
import org.apache.flink.table.api.config.{ExecutionConfigOptions,
OptimizerConfigOptions}
import
org.apache.flink.table.api.config.OptimizerConfigOptions.AdaptiveBroadcastJoinStrategy.NONE
-import org.apache.flink.table.api.internal.{StatementSetImpl,
TableEnvironmentImpl, TableEnvironmentInternal, TableImpl}
+import org.apache.flink.table.api.internal._
import org.apache.flink.table.api.typeutils.CaseClassTypeInfo
import org.apache.flink.table.catalog._
import org.apache.flink.table.connector.ChangelogMode