This is an automated email from the ASF dual-hosted git repository.
shengkai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 7d14d2a476d [FLINK-37792][table] Reuse LookupJoin operator to run
ML_PREDICT
7d14d2a476d is described below
commit 7d14d2a476d4e75db04663732f61542a389524fe
Author: Shengkai <[email protected]>
AuthorDate: Wed Jun 4 09:26:53 2025 +0800
[FLINK-37792][table] Reuse LookupJoin operator to run ML_PREDICT
This closes #26630.
---
.../generated/execution_config_configuration.html | 19 ++
.../shortcodes/generated/sink_configuration.html | 18 ++
.../table/api/config/ExecutionConfigOptions.java | 31 ++
.../flink/table/planner/calcite/RexModelCall.java | 4 +
.../functions/inference/LookupCallContext.java | 16 +-
.../nodes/exec/common/CommonExecLookupJoin.java | 19 +-
.../plan/nodes/exec/spec/MLPredictSpec.java | 46 +++
.../planner/plan/nodes/exec/spec/ModelSpec.java | 45 +++
.../stream/StreamExecMLPredictTableFunction.java | 340 +++++++++++++++++++++
.../StreamPhysicalMLPredictTableFunction.java | 147 ++++++++-
.../StreamPhysicalMLPredictTableFunctionRule.java | 6 -
.../planner/plan/utils/ExecNodeMetadataUtil.java | 7 +-
.../table/planner/plan/utils/LookupJoinUtil.java | 24 ++
.../planner/codegen/LookupJoinCodeGenerator.scala | 54 ++--
.../codegen/calls/BridgingFunctionGenUtil.scala | 40 ++-
.../planner/factories/TestValuesModelFactory.java | 260 ++++++++++++++++
.../runtime/stream/table/AsyncMLPredictITCase.java | 302 ++++++++++++++++++
.../runtime/stream/table/MLPredictITCase.java | 173 +++++++++++
.../org.apache.flink.table.factories.Factory | 3 +-
.../ml/ModelPredictRuntimeProviderContext.java | 46 +++
20 files changed, 1531 insertions(+), 69 deletions(-)
diff --git
a/docs/layouts/shortcodes/generated/execution_config_configuration.html
b/docs/layouts/shortcodes/generated/execution_config_configuration.html
index 770f2f5b524..5de80d29a10 100644
--- a/docs/layouts/shortcodes/generated/execution_config_configuration.html
+++ b/docs/layouts/shortcodes/generated/execution_config_configuration.html
@@ -26,6 +26,25 @@
<td>Duration</td>
<td>The async timeout for the asynchronous operation to
complete.</td>
</tr>
+ <tr>
+ <td><h5>table.exec.async-ml-predict.buffer-capacity</h5><br> <span
class="label label-primary">Batch</span> <span class="label
label-primary">Streaming</span></td>
+ <td style="word-wrap: break-word;">10</td>
+ <td>Integer</td>
+ <td>The max number of async i/o operation that the async ml
predict can trigger.</td>
+ </tr>
+ <tr>
+ <td><h5>table.exec.async-ml-predict.output-mode</h5><br> <span
class="label label-primary">Batch</span> <span class="label
label-primary">Streaming</span></td>
+ <td style="word-wrap: break-word;">ORDERED</td>
+ <td><p>Enum</p></td>
+ <td>Output mode for async ML predict, which describes whether or
not the the output should attempt to be ordered or not. The supported options
are: ALLOW_UNORDERED means the operator emit the result when execution
finishes. The planner will attempt use ALLOW_UNORDERED whn it doesn't affect
the correctness of the results.
+ORDERED ensures that the operator emits the result in the same order as the
data enters it. This is the default.<br /><br />Possible
values:<ul><li>"ORDERED"</li><li>"ALLOW_UNORDERED"</li></ul></td>
+ </tr>
+ <tr>
+ <td><h5>table.exec.async-ml-predict.timeout</h5><br> <span
class="label label-primary">Batch</span> <span class="label
label-primary">Streaming</span></td>
+ <td style="word-wrap: break-word;">3 min</td>
+ <td>Duration</td>
+ <td>The async timeout for the asynchronous operation to complete.
If the deadline fails, a timeout exception will be thrown to indicate the
error.</td>
+ </tr>
<tr>
<td><h5>table.exec.async-scalar.buffer-capacity</h5><br> <span
class="label label-primary">Streaming</span></td>
<td style="word-wrap: break-word;">10</td>
diff --git a/docs/layouts/shortcodes/generated/sink_configuration.html
b/docs/layouts/shortcodes/generated/sink_configuration.html
new file mode 100644
index 00000000000..3ecdb2d4e56
--- /dev/null
+++ b/docs/layouts/shortcodes/generated/sink_configuration.html
@@ -0,0 +1,18 @@
+<table class="configuration table table-bordered">
+ <thead>
+ <tr>
+ <th class="text-left" style="width: 20%">Key</th>
+ <th class="text-left" style="width: 15%">Default</th>
+ <th class="text-left" style="width: 10%">Type</th>
+ <th class="text-left" style="width: 55%">Description</th>
+ </tr>
+ </thead>
+ <tbody>
+ <tr>
+ <td><h5>sink.committer.retries</h5></td>
+ <td style="word-wrap: break-word;">10</td>
+ <td>Integer</td>
+ <td>The number of retries a Flink application attempts for
committable operations (such as transactions) on retriable errors, as specified
by the sink connector, before Flink fails and potentially restarts.</td>
+ </tr>
+ </tbody>
+</table>
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/ExecutionConfigOptions.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/ExecutionConfigOptions.java
index a1847e93b9e..9c32c7bb5a2 100644
---
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/ExecutionConfigOptions.java
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/ExecutionConfigOptions.java
@@ -443,6 +443,37 @@ public class ExecutionConfigOptions {
"The max number of async retry attempts to make
before task "
+ "execution is failed.");
+ // ------------------------------------------------------------------------
+ // Async ML_PREDICT Options
+ // ------------------------------------------------------------------------
+
+ @Documentation.TableOption(execMode =
Documentation.ExecMode.BATCH_STREAMING)
+ public static final ConfigOption<Integer>
TABLE_EXEC_ASYNC_ML_PREDICT_BUFFER_CAPACITY =
+ key("table.exec.async-ml-predict.buffer-capacity")
+ .intType()
+ .defaultValue(10)
+ .withDescription(
+ "The max number of async i/o operation that the
async ml predict can trigger.");
+
+ @Documentation.TableOption(execMode =
Documentation.ExecMode.BATCH_STREAMING)
+ public static final ConfigOption<Duration>
TABLE_EXEC_ASYNC_ML_PREDICT_TIMEOUT =
+ key("table.exec.async-ml-predict.timeout")
+ .durationType()
+ .defaultValue(Duration.ofMinutes(3))
+ .withDescription(
+ "The async timeout for the asynchronous operation
to complete. If the deadline fails, a timeout exception will be thrown to
indicate the error.");
+
+ @Documentation.TableOption(execMode =
Documentation.ExecMode.BATCH_STREAMING)
+ public static final ConfigOption<AsyncOutputMode>
TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE =
+ key("table.exec.async-ml-predict.output-mode")
+ .enumType(AsyncOutputMode.class)
+ .defaultValue(AsyncOutputMode.ORDERED)
+ .withDescription(
+ "Output mode for async ML predict, which describes
whether or not the the output should attempt to be ordered or not. The
supported options are: "
+ + "ALLOW_UNORDERED means the operator emit
the result when execution finishes. The planner will attempt use
ALLOW_UNORDERED whn it doesn't affect "
+ + "the correctness of the results.\n"
+ + "ORDERED ensures that the operator emits
the result in the same order as the data enters it. This is the default.");
+
// ------------------------------------------------------------------------
// MiniBatch Options
// ------------------------------------------------------------------------
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/RexModelCall.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/RexModelCall.java
index 2e71d217f70..dc348c25a06 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/RexModelCall.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/RexModelCall.java
@@ -47,6 +47,10 @@ public class RexModelCall extends RexCall {
this.modelProvider = modelProvider;
}
+ public ContextResolvedModel getContextResolvedModel() {
+ return contextResolvedModel;
+ }
+
public ModelProvider getModelProvider() {
return modelProvider;
}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java
index e96f10e9e39..2a56ac670c6 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java
@@ -32,7 +32,6 @@ import org.apache.calcite.rex.RexLiteral;
import java.util.AbstractList;
import java.util.List;
-import java.util.Map;
import java.util.Optional;
import static
org.apache.flink.table.functions.UserDefinedFunctionHelper.generateInlineFunctionName;
@@ -43,9 +42,7 @@ import static
org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDa
@Internal
public class LookupCallContext extends AbstractSqlCallContext {
- private final Map<Integer, LookupKey> lookupKeys;
-
- private final int[] lookupKeyOrder;
+ private final List<LookupKey> lookupKeys;
private final List<DataType> argumentDataTypes;
@@ -55,14 +52,12 @@ public class LookupCallContext extends
AbstractSqlCallContext {
DataTypeFactory dataTypeFactory,
UserDefinedFunction function,
LogicalType inputType,
- Map<Integer, LookupKey> lookupKeys,
- int[] lookupKeyOrder,
+ List<LookupKey> lookupKeys,
LogicalType lookupType) {
super(dataTypeFactory, function, generateInlineFunctionName(function),
false);
this.lookupKeys = lookupKeys;
- this.lookupKeyOrder = lookupKeyOrder;
this.argumentDataTypes =
- new AbstractList<DataType>() {
+ new AbstractList<>() {
@Override
public DataType get(int index) {
final LookupKey key = getKey(index);
@@ -79,7 +74,7 @@ public class LookupCallContext extends AbstractSqlCallContext
{
@Override
public int size() {
- return lookupKeyOrder.length;
+ return lookupKeys.size();
}
};
this.outputDataType = fromLogicalToDataType(lookupType);
@@ -123,7 +118,6 @@ public class LookupCallContext extends
AbstractSqlCallContext {
//
--------------------------------------------------------------------------------------------
private LookupKey getKey(int pos) {
- final int index = lookupKeyOrder[pos];
- return lookupKeys.get(index);
+ return lookupKeys.get(pos);
}
}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
index f30268510fb..8b3e0504c50 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
@@ -92,11 +92,13 @@ import org.apache.commons.lang3.StringUtils;
import javax.annotation.Nullable;
+import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.stream.Collectors;
import static
org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;
import static
org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTypeFactory;
@@ -415,6 +417,11 @@ public abstract class CommonExecLookupJoin extends
ExecNodeBase<RowData> {
DataTypeFactory dataTypeFactory =
ShortcutUtils.unwrapContext(relBuilder).getCatalogManager().getDataTypeFactory();
+ List<LookupJoinUtil.LookupKey> convertedKeys =
+
Arrays.stream(LookupJoinUtil.getOrderedLookupKeys(allLookupKeys.keySet()))
+ .mapToObj(allLookupKeys::get)
+ .collect(Collectors.toList());
+
LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
Object>>
generatedFuncWithType =
LookupJoinCodeGenerator.generateAsyncLookupFunction(
@@ -424,8 +431,7 @@ public abstract class CommonExecLookupJoin extends
ExecNodeBase<RowData> {
inputRowType,
tableSourceRowType,
resultRowType,
- allLookupKeys,
-
LookupJoinUtil.getOrderedLookupKeys(allLookupKeys.keySet()),
+ convertedKeys,
asyncLookupFunction,
StringUtils.join(temporalTable.getQualifiedName(), "."));
@@ -551,8 +557,10 @@ public abstract class CommonExecLookupJoin extends
ExecNodeBase<RowData> {
DataTypeFactory dataTypeFactory =
ShortcutUtils.unwrapContext(relBuilder).getCatalogManager().getDataTypeFactory();
- int[] orderedLookupKeys =
LookupJoinUtil.getOrderedLookupKeys(allLookupKeys.keySet());
-
+ List<LookupJoinUtil.LookupKey> convertedKeys =
+
Arrays.stream(LookupJoinUtil.getOrderedLookupKeys(allLookupKeys.keySet()))
+ .mapToObj(allLookupKeys::get)
+ .collect(Collectors.toList());
GeneratedFunction<FlatMapFunction<RowData, RowData>> generatedFetcher =
LookupJoinCodeGenerator.generateSyncLookupFunction(
config,
@@ -561,8 +569,7 @@ public abstract class CommonExecLookupJoin extends
ExecNodeBase<RowData> {
inputRowType,
tableSourceRowType,
resultRowType,
- allLookupKeys,
- orderedLookupKeys,
+ convertedKeys,
syncLookupFunction,
StringUtils.join(temporalTable.getQualifiedName(),
"."),
isObjectReuseEnabled);
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/MLPredictSpec.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/MLPredictSpec.java
new file mode 100644
index 00000000000..50dd2b7cb33
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/MLPredictSpec.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.nodes.exec.spec;
+
+import org.apache.flink.table.planner.plan.utils.LookupJoinUtil;
+
+import java.util.List;
+import java.util.Map;
+
+/** Spec to describe {@code ML_PREDICT}. */
+public class MLPredictSpec {
+
+ private final List<LookupJoinUtil.LookupKey> features;
+
+ private final Map<String, String> runtimeConfig;
+
+ public MLPredictSpec(
+ List<LookupJoinUtil.LookupKey> features, Map<String, String>
runtimeConfig) {
+ this.features = features;
+ this.runtimeConfig = runtimeConfig;
+ }
+
+ public List<LookupJoinUtil.LookupKey> getFeatures() {
+ return features;
+ }
+
+ public Map<String, String> getRuntimeConfig() {
+ return runtimeConfig;
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/ModelSpec.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/ModelSpec.java
new file mode 100644
index 00000000000..31b78b9b39d
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/ModelSpec.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.nodes.exec.spec;
+
+import org.apache.flink.table.catalog.ContextResolvedModel;
+import org.apache.flink.table.ml.ModelProvider;
+
+/** Spec to describe model. */
+public class ModelSpec {
+
+ private final ContextResolvedModel contextResolvedModel;
+ private ModelProvider provider;
+
+ public ModelSpec(ContextResolvedModel contextResolvedModel) {
+ this.contextResolvedModel = contextResolvedModel;
+ }
+
+ public ContextResolvedModel getContextResolvedModel() {
+ return contextResolvedModel;
+ }
+
+ public void setModelProvider(ModelProvider provider) {
+ this.provider = provider;
+ }
+
+ public ModelProvider getModelProvider() {
+ return provider;
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java
new file mode 100644
index 00000000000..e0cb0fc5a2d
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java
@@ -0,0 +1,340 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.nodes.exec.stream;
+
+import org.apache.flink.FlinkVersion;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.dag.Transformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.datastream.AsyncDataStream;
+import org.apache.flink.streaming.api.functions.async.AsyncFunction;
+import org.apache.flink.streaming.api.operators.ProcessOperator;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.async.AsyncWaitOperatorFactory;
+import org.apache.flink.streaming.api.transformations.OneInputTransformation;
+import org.apache.flink.streaming.api.transformations.PartitionTransformation;
+import
org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner;
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.conversion.DataStructureConverter;
+import org.apache.flink.table.data.conversion.DataStructureConverters;
+import org.apache.flink.table.functions.AsyncPredictFunction;
+import org.apache.flink.table.functions.PredictFunction;
+import org.apache.flink.table.functions.UserDefinedFunction;
+import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
+import org.apache.flink.table.ml.ModelProvider;
+import org.apache.flink.table.ml.PredictRuntimeProvider;
+import org.apache.flink.table.planner.calcite.FlinkContext;
+import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
+import org.apache.flink.table.planner.codegen.FilterCodeGenerator;
+import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
+import org.apache.flink.table.planner.delegation.PlannerBase;
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig;
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext;
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeMetadata;
+import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
+import
org.apache.flink.table.planner.plan.nodes.exec.MultipleTransformationTranslator;
+import org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec;
+import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec;
+import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
+import org.apache.flink.table.planner.plan.utils.KeySelectorUtil;
+import org.apache.flink.table.planner.plan.utils.LookupJoinUtil;
+import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
+import org.apache.flink.table.runtime.collector.ListenableCollector;
+import org.apache.flink.table.runtime.collector.TableFunctionResultFuture;
+import
org.apache.flink.table.runtime.functions.ml.ModelPredictRuntimeProviderContext;
+import org.apache.flink.table.runtime.generated.GeneratedCollector;
+import org.apache.flink.table.runtime.generated.GeneratedFunction;
+import org.apache.flink.table.runtime.generated.GeneratedResultFuture;
+import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
+import
org.apache.flink.table.runtime.operators.join.lookup.AsyncLookupJoinRunner;
+import org.apache.flink.table.runtime.operators.join.lookup.LookupJoinRunner;
+import org.apache.flink.table.runtime.typeutils.InternalSerializers;
+import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
+import org.apache.flink.table.types.logical.RowType;
+
+import javax.annotation.Nullable;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Optional;
+import java.util.stream.IntStream;
+
+/** Stream {@link ExecNode} for {@code ML_PREDICT}. */
+@ExecNodeMetadata(
+ name = "stream-exec-ml-predict-table-function",
+ version = 1,
+ producedTransformations =
StreamExecMLPredictTableFunction.ML_PREDICT_TRANSFORMATION,
+ minPlanVersion = FlinkVersion.V2_1,
+ minStateVersion = FlinkVersion.V2_1)
+public class StreamExecMLPredictTableFunction extends ExecNodeBase<RowData>
+ implements MultipleTransformationTranslator<RowData>,
StreamExecNode<RowData> {
+
+ public static final String PARTITIONER_TRANSFORMATION = "partitioner";
+
+ public static final String ML_PREDICT_TRANSFORMATION =
"ml-predict-table-function";
+
+ private final MLPredictSpec mlPredictSpec;
+ private final ModelSpec modelSpec;
+ private final @Nullable LookupJoinUtil.AsyncLookupOptions
asyncLookupOptions;
+ private final @Nullable int[] inputUpsertKeys;
+
+ public StreamExecMLPredictTableFunction(
+ ReadableConfig persistedConfig,
+ MLPredictSpec mlPredictSpec,
+ ModelSpec modelSpec,
+ @Nullable LookupJoinUtil.AsyncLookupOptions asyncLookupOptions,
+ @Nullable int[] inputUpsertKeys,
+ InputProperty inputProperty,
+ RowType outputType,
+ String description) {
+ super(
+ ExecNodeContext.newNodeId(),
+
ExecNodeContext.newContext(StreamExecMLPredictTableFunction.class),
+ persistedConfig,
+ Collections.singletonList(inputProperty),
+ outputType,
+ description);
+ this.mlPredictSpec = mlPredictSpec;
+ this.modelSpec = modelSpec;
+ this.asyncLookupOptions = asyncLookupOptions;
+ this.inputUpsertKeys = inputUpsertKeys;
+ }
+
+ @Override
+ protected Transformation<RowData> translateToPlanInternal(
+ PlannerBase planner, ExecNodeConfig config) {
+ Transformation<RowData> inputTransformation =
+ (Transformation<RowData>)
getInputEdges().get(0).translateToPlan(planner);
+
+ ModelProvider provider = modelSpec.getModelProvider();
+ boolean async = asyncLookupOptions != null;
+ UserDefinedFunction predictFunction = findModelFunction(provider,
async);
+ FlinkContext context = planner.getFlinkContext();
+ DataTypeFactory dataTypeFactory =
context.getCatalogManager().getDataTypeFactory();
+
+ RowType inputType = (RowType) getInputEdges().get(0).getOutputType();
+ RowType modelOutputType =
+ (RowType)
+ modelSpec
+ .getContextResolvedModel()
+ .getResolvedModel()
+ .getResolvedOutputSchema()
+ .toPhysicalRowDataType()
+ .getLogicalType();
+ return async
+ ? createAsyncModelPredict(
+ inputTransformation,
+ config,
+ planner.getFlinkContext().getClassLoader(),
+ dataTypeFactory,
+ inputType,
+ modelOutputType,
+ (RowType) getOutputType(),
+ (AsyncPredictFunction) predictFunction)
+ : createModelPredict(
+ inputTransformation,
+ config,
+ planner.getFlinkContext().getClassLoader(),
+ dataTypeFactory,
+ inputType,
+ modelOutputType,
+ (RowType) getOutputType(),
+ (PredictFunction) predictFunction);
+ }
+
+ private Transformation<RowData> createModelPredict(
+ Transformation<RowData> inputTransformation,
+ ExecNodeConfig config,
+ ClassLoader classLoader,
+ DataTypeFactory dataTypeFactory,
+ RowType inputRowType,
+ RowType modelOutputType,
+ RowType resultRowType,
+ PredictFunction predictFunction) {
+ GeneratedFunction<FlatMapFunction<RowData, RowData>> generatedFetcher =
+ LookupJoinCodeGenerator.generateSyncLookupFunction(
+ config,
+ classLoader,
+ dataTypeFactory,
+ inputRowType,
+ modelOutputType,
+ resultRowType,
+ mlPredictSpec.getFeatures(),
+ predictFunction,
+ "MLPredict",
+ true);
+ GeneratedCollector<ListenableCollector<RowData>> generatedCollector =
+ LookupJoinCodeGenerator.generateCollector(
+ new CodeGeneratorContext(config, classLoader),
+ inputRowType,
+ modelOutputType,
+ (RowType) getOutputType(),
+ JavaScalaConversionUtil.toScala(Optional.empty()),
+ JavaScalaConversionUtil.toScala(Optional.empty()),
+ true);
+ LookupJoinRunner mlPredictRunner =
+ new LookupJoinRunner(
+ generatedFetcher,
+ generatedCollector,
+ FilterCodeGenerator.generateFilterCondition(
+ config, classLoader, null, inputRowType),
+ false,
+ modelOutputType.getFieldCount());
+ SimpleOperatorFactory<RowData> operatorFactory =
+ SimpleOperatorFactory.of(new
ProcessOperator<>(mlPredictRunner));
+ return ExecNodeUtil.createOneInputTransformation(
+ inputTransformation,
+ createTransformationMeta(ML_PREDICT_TRANSFORMATION, config),
+ operatorFactory,
+ InternalTypeInfo.of(getOutputType()),
+ inputTransformation.getParallelism(),
+ false);
+ }
+
+ @SuppressWarnings("unchecked")
+ private Transformation<RowData> createAsyncModelPredict(
+ Transformation<RowData> inputTransformation,
+ ExecNodeConfig config,
+ ClassLoader classLoader,
+ DataTypeFactory dataTypeFactory,
+ RowType inputRowType,
+ RowType modelOutputType,
+ RowType resultRowType,
+ AsyncPredictFunction asyncPredictFunction) {
+
LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
Object>>
+ generatedFuncWithType =
+ LookupJoinCodeGenerator.generateAsyncLookupFunction(
+ config,
+ classLoader,
+ dataTypeFactory,
+ inputRowType,
+ modelOutputType,
+ resultRowType,
+ mlPredictSpec.getFeatures(),
+ asyncPredictFunction,
+ "AsyncMLPredict");
+
+ GeneratedResultFuture<TableFunctionResultFuture<RowData>>
generatedResultFuture =
+ LookupJoinCodeGenerator.generateTableAsyncCollector(
+ config,
+ classLoader,
+ "TableFunctionResultFuture",
+ inputRowType,
+ modelOutputType,
+ JavaScalaConversionUtil.toScala(Optional.empty()));
+
+ DataStructureConverter<?, ?> fetcherConverter =
+
DataStructureConverters.getConverter(generatedFuncWithType.dataType());
+ AsyncFunction<RowData, RowData> asyncFunc =
+ new AsyncLookupJoinRunner(
+ generatedFuncWithType.tableFunc(),
+ (DataStructureConverter<RowData, Object>)
fetcherConverter,
+ generatedResultFuture,
+ FilterCodeGenerator.generateFilterCondition(
+ config, classLoader, null, inputRowType),
+ InternalSerializers.create(modelOutputType),
+ false,
+ asyncLookupOptions.asyncBufferCapacity);
+ if (asyncLookupOptions.asyncOutputMode ==
AsyncDataStream.OutputMode.UNORDERED) {
+ // The input stream is insert-only.
+ return ExecNodeUtil.createOneInputTransformation(
+ inputTransformation,
+ createTransformationMeta(ML_PREDICT_TRANSFORMATION,
config),
+ new AsyncWaitOperatorFactory<>(
+ asyncFunc,
+ asyncLookupOptions.asyncTimeout,
+ asyncLookupOptions.asyncBufferCapacity,
+ asyncLookupOptions.asyncOutputMode),
+ InternalTypeInfo.of(getOutputType()),
+ inputTransformation.getParallelism(),
+ false);
+ } else if (asyncLookupOptions.asyncOutputMode ==
AsyncDataStream.OutputMode.ORDERED) {
+ // The input stream is cdc-stream.
+ int[] shuffleKeys = inputUpsertKeys;
+ // If no upset key is specified, use the whole row
+ if (shuffleKeys == null || shuffleKeys.length == 0) {
+ shuffleKeys = IntStream.range(0,
inputRowType.getFieldCount()).toArray();
+ }
+ Arrays.sort(shuffleKeys);
+
+ // Shuffle the data
+ RowDataKeySelector keySelector =
+ KeySelectorUtil.getRowDataSelector(
+ classLoader, shuffleKeys,
InternalTypeInfo.of(inputRowType));
+ final KeyGroupStreamPartitioner<RowData, RowData> partitioner =
+ new KeyGroupStreamPartitioner<>(
+ keySelector,
+
KeyGroupRangeAssignment.DEFAULT_LOWER_BOUND_MAX_PARALLELISM);
+ Transformation<RowData> partitionedTransform =
+ new PartitionTransformation<>(inputTransformation,
partitioner);
+ createTransformationMeta(
+ PARTITIONER_TRANSFORMATION, "Partitioner",
"Partitioner", config)
+ .fill(partitionedTransform);
+
+ // Add the operator. AsyncOperator emit data order is same as the
data enter the
+ // operator order.
+ OneInputTransformation<RowData, RowData> transformation =
+ ExecNodeUtil.createOneInputTransformation(
+ partitionedTransform,
+
createTransformationMeta(ML_PREDICT_TRANSFORMATION, config),
+ new AsyncWaitOperatorFactory<>(
+ asyncFunc,
+ asyncLookupOptions.asyncTimeout,
+ asyncLookupOptions.asyncBufferCapacity,
+ asyncLookupOptions.asyncOutputMode),
+ InternalTypeInfo.of(getOutputType()),
+ inputTransformation.getParallelism(),
+ false);
+ transformation.setStateKeySelector(keySelector);
+ transformation.setStateKeyType(keySelector.getProducedType());
+ return transformation;
+ } else {
+ throw new TableException(
+ String.format("Unknown output mode: %s.",
asyncLookupOptions.asyncOutputMode));
+ }
+ }
+
+ private UserDefinedFunction findModelFunction(ModelProvider provider,
boolean async) {
+ ModelPredictRuntimeProviderContext context =
+ new ModelPredictRuntimeProviderContext(
+ modelSpec.getContextResolvedModel().getResolvedModel(),
+
Configuration.fromMap(mlPredictSpec.getRuntimeConfig()));
+ if (async) {
+ if (provider instanceof AsyncPredictRuntimeProvider) {
+ return ((AsyncPredictRuntimeProvider)
provider).createAsyncPredictFunction(context);
+ }
+ } else {
+ if (provider instanceof PredictRuntimeProvider) {
+ return ((PredictRuntimeProvider)
provider).createPredictFunction(context);
+ }
+ }
+
+ throw new TableException(
+ "Required "
+ + (async ? "async" : "sync")
+ + " model function by planner, but ModelProvider "
+ + "does not offer a valid model function.");
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java
index faaf96f5dfe..39bb83aa167 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunction.java
@@ -18,17 +18,45 @@
package org.apache.flink.table.planner.plan.nodes.physical.stream;
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.connector.ChangelogMode;
+import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
+import org.apache.flink.table.ml.ModelProvider;
+import org.apache.flink.table.ml.PredictRuntimeProvider;
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.planner.calcite.RexModelCall;
+import org.apache.flink.table.planner.calcite.RexTableArgCall;
+import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
+import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
+import org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec;
+import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec;
+import
org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecMLPredictTableFunction;
import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import org.apache.flink.table.planner.plan.utils.ChangelogPlanUtils;
+import org.apache.flink.table.planner.plan.utils.LookupJoinUtil;
+import org.apache.flink.table.planner.plan.utils.UpsertKeyUtil;
+import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
+import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlDescriptorOperator;
+import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
/** Stream physical RelNode for ml predict table function. */
public class StreamPhysicalMLPredictTableFunction extends SingleRel implements
StreamPhysicalRel {
@@ -60,7 +88,21 @@ public class StreamPhysicalMLPredictTableFunction extends
SingleRel implements S
@Override
public ExecNode<?> translateToExecNode() {
- return null;
+ int[] upsertKeys =
+ UpsertKeyUtil.smallestKey(
+
FlinkRelMetadataQuery.reuseOrCreate(getCluster().getMetadataQuery())
+ .getUpsertKeys(getInput()))
+ .orElse(null);
+ RexModelCall modelCall = extractOperand(operand -> operand instanceof
RexModelCall);
+ return new StreamExecMLPredictTableFunction(
+ ShortcutUtils.unwrapTableConfig(this),
+ buildMLPredictSpec(),
+ buildModelSpec(modelCall),
+ buildAsyncOptions(modelCall),
+ upsertKeys,
+ InputProperty.DEFAULT,
+ FlinkTypeFactory.toLogicalRowType(getRowType()),
+ getRelDetailedDescription());
}
@Override
@@ -74,4 +116,107 @@ public class StreamPhysicalMLPredictTableFunction extends
SingleRel implements S
.item("invocation", scan.getCall())
.item("rowType", getRowType());
}
+
+ private MLPredictSpec buildMLPredictSpec() {
+ RexTableArgCall tableCall = extractOperand(operand -> operand
instanceof RexTableArgCall);
+ RexCall descriptorCall =
+ extractOperand(
+ operand ->
+ operand instanceof RexCall
+ && ((RexCall) operand).getOperator()
+ instanceof
SqlDescriptorOperator);
+ Map<String, Integer> column2Index = new HashMap<>();
+ List<String> fieldNames = tableCall.getType().getFieldNames();
+ for (int i = 0; i < fieldNames.size(); i++) {
+ column2Index.put(fieldNames.get(i), i);
+ }
+ List<LookupJoinUtil.LookupKey> features =
+ descriptorCall.getOperands().stream()
+ .map(
+ operand -> {
+ if (operand instanceof RexLiteral) {
+ RexLiteral literal = (RexLiteral)
operand;
+ String fieldName =
RexLiteral.stringValue(literal);
+ Integer index =
column2Index.get(fieldName);
+ if (index == null) {
+ throw new TableException(
+ String.format(
+ "Field %s is not
found in input schema: %s.",
+ fieldName,
tableCall.getType()));
+ }
+ return new
LookupJoinUtil.FieldRefLookupKey(index);
+ } else {
+ throw new TableException(
+ String.format(
+ "Unknown operand for
descriptor operator: %s.",
+ operand));
+ }
+ })
+ .collect(Collectors.toList());
+ return new MLPredictSpec(features, Collections.emptyMap());
+ }
+
+ private ModelSpec buildModelSpec(RexModelCall modelCall) {
+ ModelSpec modelSpec = new
ModelSpec(modelCall.getContextResolvedModel());
+ modelSpec.setModelProvider(modelCall.getModelProvider());
+ return modelSpec;
+ }
+
+ private LookupJoinUtil.AsyncLookupOptions buildAsyncOptions(RexModelCall
modelCall) {
+ boolean isAsyncEnabled =
isAsyncMLPredict(modelCall.getModelProvider());
+ if (isAsyncEnabled) {
+ return LookupJoinUtil.getMergedMLPredictAsyncOptions(
+ // TODO: extract runtime config
+ Collections.emptyMap(),
+ ShortcutUtils.unwrapTableConfig(getCluster()),
+ getInputChangelogMode(getInput()));
+ } else {
+ return null;
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ private <T> T extractOperand(Predicate<RexNode> predicate) {
+ return (T)
+ ((RexCall) scan.getCall())
+ .getOperands().stream()
+ .filter(predicate)
+ .findFirst()
+ .orElseThrow(
+ () ->
+ new TableException(
+ String.format(
+ "MLPredict
doesn't contain specified operand: %s",
+
scan.getCall().toString())));
+ }
+
+ private boolean isAsyncMLPredict(ModelProvider provider) {
+ boolean syncFound = false;
+ boolean asyncFound = false;
+ if (provider instanceof PredictRuntimeProvider) {
+ syncFound = true;
+ }
+ if (provider instanceof AsyncPredictRuntimeProvider) {
+ asyncFound = true;
+ }
+
+ if (!syncFound && !asyncFound) {
+ throw new TableException(
+ String.format(
+ "Unknown model provider found: %s.",
provider.getClass().getName()));
+ }
+ return asyncFound;
+ }
+
+ private ChangelogMode getInputChangelogMode(RelNode rel) {
+ if (rel instanceof StreamPhysicalRel) {
+ return JavaScalaConversionUtil.toJava(
+
ChangelogPlanUtils.getChangelogMode((StreamPhysicalRel) rel))
+ .orElse(ChangelogMode.insertOnly());
+ } else if (rel instanceof HepRelVertex) {
+ return getInputChangelogMode(((HepRelVertex) rel).getCurrentRel());
+ } else {
+ return ChangelogMode.insertOnly();
+ }
+ }
}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunctionRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunctionRule.java
index 640c01e495d..69d1afb4d95 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunctionRule.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalMLPredictTableFunctionRule.java
@@ -72,12 +72,6 @@ public class StreamPhysicalMLPredictTableFunctionRule
extends ConverterRule {
final RelTraitSet providedTraitSet =
rel.getTraitSet().replace(FlinkConventions.STREAM_PHYSICAL());
-
- // TODO:
- // Get model provider and context resolved model from RexModelCall
- // Get table input from descriptor
- // Get config from map
- // Create ModelProviderSpec similar to DynamicTableSourceSpec and
TemporalTableSourceSpec
return new StreamPhysicalMLPredictTableFunction(
scan.getCluster(), providedTraitSet, newInput, scan,
scan.getRowType());
}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/ExecNodeMetadataUtil.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/ExecNodeMetadataUtil.java
index 484d0883be1..49adb7afc78 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/ExecNodeMetadataUtil.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/ExecNodeMetadataUtil.java
@@ -71,6 +71,7 @@ import
org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecLimit;
import
org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecLocalGroupAggregate;
import
org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecLocalWindowAggregate;
import
org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecLookupJoin;
+import
org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecMLPredictTableFunction;
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecMatch;
import
org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecMiniBatchAssigner;
import
org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecMultipleInput;
@@ -170,6 +171,7 @@ public final class ExecNodeMetadataUtil {
add(StreamExecPythonGroupAggregate.class);
add(StreamExecPythonGroupWindowAggregate.class);
add(StreamExecPythonOverAggregate.class);
+ add(StreamExecMLPredictTableFunction.class);
// Batch execution mode
add(BatchExecSink.class);
add(BatchExecTableSourceScan.class);
@@ -213,6 +215,7 @@ public final class ExecNodeMetadataUtil {
add(StreamExecGroupTableAggregate.class);
add(StreamExecPythonGroupTableAggregate.class);
add(StreamExecMultipleInput.class);
+ add(StreamExecMLPredictTableFunction.class);
}
};
@@ -283,7 +286,9 @@ public final class ExecNodeMetadataUtil {
}
private static void addToLookupMap(Class<? extends ExecNode<?>>
execNodeClass) {
- if (!hasJsonCreatorAnnotation(execNodeClass)) {
+ // TODO: remove the logic when StreamExecMLPredictTableFunction
supports serde.
+ if (!hasJsonCreatorAnnotation(execNodeClass)
+ && execNodeClass != StreamExecMLPredictTableFunction.class) {
throw new IllegalStateException(
String.format(
"ExecNode: %s does not implement @JsonCreator
annotation on "
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/LookupJoinUtil.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/LookupJoinUtil.java
index 79f25be2dd0..4506b41a9a7 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/LookupJoinUtil.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/LookupJoinUtil.java
@@ -392,6 +392,30 @@ public final class LookupJoinUtil {
.TABLE_EXEC_ASYNC_LOOKUP_OUTPUT_MODE))));
}
+ public static AsyncLookupOptions getMergedMLPredictAsyncOptions(
+ Map<String, String> runtimeConfig,
+ TableConfig config,
+ ChangelogMode inputChangelogMode) {
+ Configuration queryConf = Configuration.fromMap(runtimeConfig);
+ ExecutionConfigOptions.AsyncOutputMode asyncOutputMode =
+ coalesce(
+ queryConf.get(ASYNC_OUTPUT_MODE),
+
config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE));
+
+ return new AsyncLookupOptions(
+ coalesce(
+ queryConf.get(ASYNC_CAPACITY),
+ config.get(
+ ExecutionConfigOptions
+
.TABLE_EXEC_ASYNC_ML_PREDICT_BUFFER_CAPACITY)),
+ coalesce(
+ queryConf.get(ASYNC_TIMEOUT),
+ config.get(
+
ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_TIMEOUT))
+ .toMillis(),
+ convert(inputChangelogMode, asyncOutputMode));
+ }
+
/**
* This method determines whether async lookup is enabled according to the
given lookup keys
* with considering lookup {@link RelHint} and required upsertMaterialize.
Note: it will not
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
index f18186ae3c1..59547130561 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
@@ -18,14 +18,14 @@
package org.apache.flink.table.planner.codegen
import org.apache.flink.api.common.functions.{FlatMapFunction, Function,
OpenContext}
-import org.apache.flink.configuration.{Configuration, ReadableConfig}
+import org.apache.flink.configuration.ReadableConfig
import org.apache.flink.streaming.api.functions.async.AsyncFunction
import org.apache.flink.table.api.ValidationException
import org.apache.flink.table.catalog.DataTypeFactory
import org.apache.flink.table.connector.source.{LookupTableSource,
ScanTableSource}
import org.apache.flink.table.data.{GenericRowData, RowData}
import org.apache.flink.table.data.utils.JoinedRowData
-import org.apache.flink.table.functions.{AsyncLookupFunction,
AsyncTableFunction, LookupFunction, TableFunction, UserDefinedFunction,
UserDefinedFunctionHelper}
+import org.apache.flink.table.functions.{AsyncLookupFunction,
AsyncPredictFunction, AsyncTableFunction, LookupFunction, PredictFunction,
TableFunction, UserDefinedFunction, UserDefinedFunctionHelper}
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.CodeGenUtils._
import org.apache.flink.table.planner.codegen.GenerateUtils._
@@ -39,7 +39,7 @@ import
org.apache.flink.table.planner.plan.utils.RexLiteralUtil
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
import org.apache.flink.table.runtime.collector.{ListenableCollector,
TableFunctionResultFuture}
import
org.apache.flink.table.runtime.collector.ListenableCollector.CollectListener
-import org.apache.flink.table.runtime.generated.{GeneratedCollector,
GeneratedFilterCondition, GeneratedFunction, GeneratedResultFuture}
+import org.apache.flink.table.runtime.generated.{GeneratedCollector,
GeneratedFunction, GeneratedResultFuture}
import org.apache.flink.table.types.DataType
import
org.apache.flink.table.types.extraction.ExtractionUtils.extractSimpleGeneric
import org.apache.flink.table.types.inference.{TypeInference, TypeStrategies,
TypeTransformations}
@@ -71,8 +71,7 @@ object LookupJoinCodeGenerator {
inputType: LogicalType,
tableSourceType: LogicalType,
returnType: LogicalType,
- lookupKeys: util.Map[Integer, LookupKey],
- lookupKeyOrder: Array[Int],
+ lookupKeys: util.List[LookupKey],
syncLookupFunction: TableFunction[_],
functionName: String,
fieldCopy: Boolean): GeneratedFunction[FlatMapFunction[RowData,
RowData]] = {
@@ -94,7 +93,6 @@ object LookupJoinCodeGenerator {
tableSourceType,
returnType,
lookupKeys,
- lookupKeyOrder,
classOf[TableFunction[_]],
syncLookupFunction,
functionName,
@@ -111,8 +109,7 @@ object LookupJoinCodeGenerator {
inputType: LogicalType,
tableSourceType: LogicalType,
returnType: LogicalType,
- lookupKeys: util.Map[Integer, LookupKey],
- lookupKeyOrder: Array[Int],
+ lookupKeys: util.List[LookupKey],
asyncLookupFunction: AsyncTableFunction[_],
functionName: String):
GeneratedTableFunctionWithDataType[AsyncFunction[RowData, AnyRef]] = {
@@ -125,7 +122,6 @@ object LookupJoinCodeGenerator {
tableSourceType,
returnType,
lookupKeys,
- lookupKeyOrder,
classOf[AsyncTableFunction[_]],
asyncLookupFunction,
functionName,
@@ -142,21 +138,15 @@ object LookupJoinCodeGenerator {
inputType: LogicalType,
tableSourceType: LogicalType,
returnType: LogicalType,
- lookupKeys: util.Map[Integer, LookupKey],
- lookupKeyOrder: Array[Int],
+ lookupKeys: util.List[LookupKey],
lookupFunctionBase: Class[_],
lookupFunction: UserDefinedFunction,
functionName: String,
fieldCopy: Boolean,
bodyCode: GeneratedExpression => String):
GeneratedTableFunctionWithDataType[F] = {
- val callContext = new LookupCallContext(
- dataTypeFactory,
- lookupFunction,
- inputType,
- lookupKeys,
- lookupKeyOrder,
- tableSourceType)
+ val callContext =
+ new LookupCallContext(dataTypeFactory, lookupFunction, inputType,
lookupKeys, tableSourceType)
// create the final UDF for runtime
val udf = UserDefinedFunctionHelper.createSpecializedFunction(
@@ -172,7 +162,14 @@ object LookupJoinCodeGenerator {
createLookupTypeInference(dataTypeFactory, callContext,
lookupFunctionBase, udf, functionName)
val ctx = new CodeGeneratorContext(tableConfig, classLoader)
- val operands = prepareOperands(ctx, inputType, lookupKeys, lookupKeyOrder,
fieldCopy)
+ val operands = prepareOperands(ctx, inputType, lookupKeys, fieldCopy)
+
+ // TODO: filter all records when there are any nulls on the join key,
because
+ // "IS NOT DISTINCT FROM" is not supported yet.
+ // Note: AsyncPredictFunction or PredictFunction does not use Lookup
Syntax.
+ val skipIfArgsNull = !lookupFunction.isInstanceOf[PredictFunction] &&
!lookupFunction
+ .isInstanceOf[AsyncPredictFunction]
+
val callWithDataType =
BridgingFunctionGenUtil.generateFunctionAwareCallWithDataType(
ctx,
operands,
@@ -181,9 +178,7 @@ object LookupJoinCodeGenerator {
callContext,
udf,
functionName,
- // TODO: filter all records when there is any nulls on the join key,
because
- // "IS NOT DISTINCT FROM" is not supported yet.
- skipIfArgsNull = true
+ skipIfArgsNull = skipIfArgsNull
)
val function = FunctionCodeGenerator.generateFunction(
@@ -200,13 +195,10 @@ object LookupJoinCodeGenerator {
private def prepareOperands(
ctx: CodeGeneratorContext,
inputType: LogicalType,
- lookupKeys: util.Map[Integer, LookupKey],
- lookupKeyOrder: Array[Int],
+ lookupKeys: util.List[LookupKey],
fieldCopy: Boolean): Seq[GeneratedExpression] = {
- lookupKeyOrder
- .map(Integer.valueOf)
- .map(lookupKeys.get)
+ lookupKeys.asScala
.map {
case constantKey: ConstantLookupKey =>
val res = RexLiteralUtil.toFlinkInternalValue(constantKey.literal)
@@ -250,7 +242,10 @@ object LookupJoinCodeGenerator {
val defaultOutputDataType = callContext.getOutputDataType.get()
val outputClass =
- if (udf.isInstanceOf[LookupFunction] ||
udf.isInstanceOf[AsyncLookupFunction]) {
+ if (
+ udf.isInstanceOf[LookupFunction] ||
udf.isInstanceOf[AsyncLookupFunction] || udf
+ .isInstanceOf[PredictFunction] ||
udf.isInstanceOf[AsyncPredictFunction]
+ ) {
Some(classOf[RowData])
} else {
toScala(extractSimpleGeneric(baseClass, udf.getClass, 0))
@@ -396,7 +391,8 @@ object LookupJoinCodeGenerator {
// TODO we should update code splitter's grammar file to accept
lambda expressions.
if (getCollectListener().isPresent()) {
- ((${classOf[CollectListener[_]].getCanonicalName})
getCollectListener().get()).onCollect(record);
+ ((${classOf[CollectListener[_]].getCanonicalName})
getCollectListener().get())
+ .onCollect(record);
}
${ctx.reuseLocalVariableCode()}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala
index 465e5220bf2..c31b42ae141 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala
@@ -152,7 +152,7 @@ object BridgingFunctionGenUtil {
contextTerm
)
} else if (udf.getKind == FunctionKind.ASYNC_TABLE) {
- generateAsyncTableFunctionCall(functionTerm, externalOperands,
returnType)
+ generateAsyncTableFunctionCall(functionTerm, externalOperands,
returnType, skipIfArgsNull)
} else if (udf.getKind == FunctionKind.ASYNC_SCALAR) {
generateAsyncScalarFunctionCall(
ctx,
@@ -205,22 +205,34 @@ object BridgingFunctionGenUtil {
private def generateAsyncTableFunctionCall(
functionTerm: String,
externalOperands: Seq[GeneratedExpression],
- outputType: LogicalType): GeneratedExpression = {
+ outputType: LogicalType,
+ skipIfArgsNull: Boolean): GeneratedExpression = {
val DELEGATE = className[DelegatingResultFuture[_]]
- val functionCallCode =
- s"""
- |${externalOperands.map(_.code).mkString("\n")}
- |if (${externalOperands.map(_.nullTerm).mkString(" || ")}) {
- |
$DEFAULT_COLLECTOR_TERM.complete(java.util.Collections.emptyList());
- |} else {
- | $DELEGATE delegates = new $DELEGATE($DEFAULT_COLLECTOR_TERM);
- | $functionTerm.eval(
- | delegates.getCompletableFuture(),
- | ${externalOperands.map(_.resultTerm).mkString(", ")});
- |}
- |""".stripMargin
+ val functionCallCode = {
+ if (skipIfArgsNull) {
+ s"""
+ |${externalOperands.map(_.code).mkString("\n")}
+ |if (${externalOperands.map(_.nullTerm).mkString(" || ")}) {
+ |
$DEFAULT_COLLECTOR_TERM.complete(java.util.Collections.emptyList());
+ |} else {
+ | $DELEGATE delegates = new $DELEGATE($DEFAULT_COLLECTOR_TERM);
+ | $functionTerm.eval(
+ | delegates.getCompletableFuture(),
+ | ${externalOperands.map(_.resultTerm).mkString(", ")});
+ |}
+ |""".stripMargin
+ } else {
+ s"""
+ |${externalOperands.map(_.code).mkString("\n")}
+ |$DELEGATE delegates = new $DELEGATE($DEFAULT_COLLECTOR_TERM);
+ |$functionTerm.eval(
+ | delegates.getCompletableFuture(),
+ | ${externalOperands.map(_.resultTerm).mkString(", ")});
+ |""".stripMargin
+ }
+ }
// has no result
GeneratedExpression(NO_CODE, NEVER_NULL, functionCallCode, outputType)
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesModelFactory.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesModelFactory.java
new file mode 100644
index 00000000000..c25a8fb5ef3
--- /dev/null
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesModelFactory.java
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.factories;
+
+import org.apache.flink.configuration.ConfigOption;
+import org.apache.flink.configuration.ConfigOptions;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.conversion.RowRowConverter;
+import org.apache.flink.table.factories.FactoryUtil;
+import org.apache.flink.table.factories.ModelProviderFactory;
+import org.apache.flink.table.functions.AsyncPredictFunction;
+import org.apache.flink.table.functions.FunctionContext;
+import org.apache.flink.table.functions.PredictFunction;
+import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
+import org.apache.flink.table.ml.ModelProvider;
+import org.apache.flink.table.ml.PredictRuntimeProvider;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+
+/** Values model factory. */
+public class TestValuesModelFactory implements ModelProviderFactory {
+
+ private static final AtomicInteger idCounter = new AtomicInteger(0);
+ protected static final Map<String, Map<Row, List<Row>>> REGISTERED_DATA =
new HashMap<>();
+
+ public static final ConfigOption<String> DATA_ID =
+ ConfigOptions.key("data-id").stringType().noDefaultValue();
+
+ public static final ConfigOption<Boolean> ASYNC =
+ ConfigOptions.key("async").booleanType().defaultValue(false);
+
+ public static String registerData(Map<Row, List<Row>> data) {
+ String id = String.valueOf(idCounter.incrementAndGet());
+ REGISTERED_DATA.put(id, data);
+ return id;
+ }
+
+ public static void clearAllData() {
+ REGISTERED_DATA.clear();
+ }
+
+ @Override
+ public ModelProvider createModelProvider(Context context) {
+ FactoryUtil.ModelProviderFactoryHelper helper =
+ FactoryUtil.createModelProviderFactoryHelper(this, context);
+ helper.validate();
+ String dataId = helper.getOptions().get(DATA_ID);
+ Map<Row, List<Row>> rows = REGISTERED_DATA.getOrDefault(dataId,
Collections.emptyMap());
+
+ RowRowConverter inputConverter =
+ RowRowConverter.create(
+
context.getCatalogModel().getResolvedInputSchema().toPhysicalRowDataType());
+ RowRowConverter outputConverter =
+ RowRowConverter.create(
+ context.getCatalogModel()
+ .getResolvedOutputSchema()
+ .toPhysicalRowDataType());
+
+ if (helper.getOptions().get(ASYNC)) {
+ return new AsyncValuesModelProvider(
+ new ValuesPredictFunction(rows, inputConverter,
outputConverter),
+ new AsyncValuesPredictFunction(rows, inputConverter,
outputConverter));
+ } else {
+ return new ValuesModelProvider(
+ new ValuesPredictFunction(rows, inputConverter,
outputConverter));
+ }
+ }
+
+ @Override
+ public String factoryIdentifier() {
+ return "values";
+ }
+
+ @Override
+ public Set<ConfigOption<?>> requiredOptions() {
+ return Collections.emptySet();
+ }
+
+ @Override
+ public Set<ConfigOption<?>> optionalOptions() {
+ return new HashSet<>(Arrays.asList(ASYNC, DATA_ID));
+ }
+
+ private static Map<RowData, List<RowData>> toInternal(
+ Map<Row, List<Row>> data,
+ RowRowConverter keyConverter,
+ RowRowConverter valueConverter,
+ FunctionContext context) {
+ keyConverter.open(context.getUserCodeClassLoader());
+ valueConverter.open(context.getUserCodeClassLoader());
+
+ Map<RowData, List<RowData>> converted = new HashMap<>();
+ for (Map.Entry<Row, List<Row>> entry : data.entrySet()) {
+ RowData input = keyConverter.toInternal(entry.getKey());
+ converted.put(
+ input,
+ entry.getValue().stream()
+ .map(valueConverter::toInternal)
+ .collect(Collectors.toList()));
+ }
+ return converted;
+ }
+
+ public static class ValuesModelProvider implements PredictRuntimeProvider {
+
+ private final PredictFunction function;
+
+ public ValuesModelProvider(PredictFunction function) {
+ this.function = function;
+ }
+
+ @Override
+ public PredictFunction createPredictFunction(Context context) {
+ return function;
+ }
+
+ @Override
+ public ModelProvider copy() {
+ return new ValuesModelProvider(function);
+ }
+ }
+
+ public static class AsyncValuesModelProvider
+ implements AsyncPredictRuntimeProvider, PredictRuntimeProvider {
+
+ private final PredictFunction predictFunction;
+ private final AsyncPredictFunction asyncPredictFunction;
+
+ public AsyncValuesModelProvider(
+ PredictFunction predictFunction, AsyncPredictFunction
asyncPredictFunction) {
+ this.predictFunction = predictFunction;
+ this.asyncPredictFunction = asyncPredictFunction;
+ }
+
+ @Override
+ public AsyncPredictFunction createAsyncPredictFunction(Context
context) {
+ return asyncPredictFunction;
+ }
+
+ @Override
+ public PredictFunction createPredictFunction(Context context) {
+ return predictFunction;
+ }
+
+ @Override
+ public ModelProvider copy() {
+ return new AsyncValuesModelProvider(predictFunction,
asyncPredictFunction);
+ }
+ }
+
+ /** Values Predict function. */
+ public static class ValuesPredictFunction extends PredictFunction {
+
+ private final Map<Row, List<Row>> data;
+ private final RowRowConverter inputConverter;
+ private final RowRowConverter outputConverter;
+
+ private transient Map<RowData, List<RowData>> converted;
+
+ public ValuesPredictFunction(
+ Map<Row, List<Row>> data,
+ RowRowConverter inputConverter,
+ RowRowConverter outputConverter) {
+ this.data = data;
+ this.inputConverter = inputConverter;
+ this.outputConverter = outputConverter;
+ }
+
+ @Override
+ public void open(FunctionContext context) throws Exception {
+ super.open(context);
+ converted = toInternal(data, inputConverter, outputConverter,
context);
+ }
+
+ @Override
+ public Collection<RowData> predict(RowData features) {
+ return Preconditions.checkNotNull(converted.get(features));
+ }
+ }
+
+ /** Async values predict function. */
+ public static class AsyncValuesPredictFunction extends
AsyncPredictFunction {
+
+ private final Map<Row, List<Row>> data;
+ private final RowRowConverter inputConverter;
+ private final RowRowConverter outputConverter;
+
+ private transient Random random;
+ private transient Map<RowData, List<RowData>> converted;
+ private transient ExecutorService executors;
+
+ public AsyncValuesPredictFunction(
+ Map<Row, List<Row>> data,
+ RowRowConverter inputConverter,
+ RowRowConverter outputConverter) {
+ this.data = data;
+ this.inputConverter = inputConverter;
+ this.outputConverter = outputConverter;
+ }
+
+ @Override
+ public void open(FunctionContext context) throws Exception {
+ super.open(context);
+ random = new Random();
+ converted = toInternal(data, inputConverter, outputConverter,
context);
+ executors = Executors.newFixedThreadPool(5);
+ }
+
+ @Override
+ public CompletableFuture<Collection<RowData>> asyncPredict(RowData
features) {
+ return CompletableFuture.supplyAsync(
+ () -> {
+ try {
+ Thread.sleep(random.nextInt(1000));
+ return
Preconditions.checkNotNull(converted.get(features));
+ } catch (Exception e) {
+ throw new RuntimeException("Execution is
interrupted.", e);
+ }
+ },
+ executors);
+ }
+
+ @Override
+ public void close() throws Exception {
+ super.close();
+ executors.shutdownNow();
+ }
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncMLPredictITCase.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncMLPredictITCase.java
new file mode 100644
index 00000000000..d11f767124d
--- /dev/null
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncMLPredictITCase.java
@@ -0,0 +1,302 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.runtime.stream.table;
+
+import org.apache.flink.table.api.config.ExecutionConfigOptions;
+import org.apache.flink.table.planner.factories.TestValuesModelFactory;
+import org.apache.flink.table.planner.factories.TestValuesTableFactory;
+import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase;
+import
org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CollectionUtil;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.TestTemplate;
+import org.junit.jupiter.api.extension.ExtendWith;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static
org.apache.flink.table.planner.factories.TestValuesTableFactory.changelogRow;
+import static org.assertj.core.api.Assertions.assertThatList;
+
+/** ITCase for async ML_PREDICT. */
+@ExtendWith(ParameterizedTestExtension.class)
+public class AsyncMLPredictITCase extends StreamingWithStateTestBase {
+
+ private final Boolean objectReuse;
+ private final ExecutionConfigOptions.AsyncOutputMode asyncOutputMode;
+
+ public AsyncMLPredictITCase(
+ StateBackendMode backend,
+ Boolean objectReuse,
+ ExecutionConfigOptions.AsyncOutputMode asyncOutputMode) {
+ super(backend);
+
+ this.objectReuse = objectReuse;
+ this.asyncOutputMode = asyncOutputMode;
+ }
+
+ private final List<Row> data =
+ Arrays.asList(
+ Row.of(1L, 12, "Julian"),
+ Row.of(2L, 15, "Hello"),
+ Row.of(3L, 15, "Fabian"),
+ Row.of(8L, 11, "Hello world"),
+ Row.of(9L, 12, "Hello world!"));
+
+ private final List<Row> dataWithNull =
+ Arrays.asList(
+ Row.of(15L, null, "Hello"),
+ Row.of(3L, 15, "Fabian"),
+ Row.of(11L, null, "Hello world"),
+ Row.of(9L, 12, "Hello world!"));
+
+ private final List<Row> cdcRowData =
+ Arrays.asList(
+ changelogRow("+I", 1L, 12, "Julian"),
+ changelogRow("-U", 1L, 12, "Julian"),
+ changelogRow("+U", 1L, 13, "Julian"),
+ changelogRow("-D", 1L, 13, "Julian"),
+ changelogRow("+I", 1L, 14, "Julian"),
+ changelogRow("+I", 2L, 16, "Hello"),
+ changelogRow("-U", 2L, 16, "Hello"),
+ changelogRow("+U", 2L, 17, "Hello"),
+ changelogRow("+I", 3L, 19, "Fabian"),
+ changelogRow("-D", 3L, 19, "Fabian"));
+
+ private final Map<Row, List<Row>> id2features = new HashMap<>();
+
+ {
+ id2features.put(Row.of(1L), Collections.singletonList(Row.of("x1", 1,
"z1")));
+ id2features.put(Row.of(2L), Collections.singletonList(Row.of("x2", 2,
"z2")));
+ id2features.put(Row.of(3L), Collections.singletonList(Row.of("x3", 3,
"z3")));
+ id2features.put(Row.of(8L), Collections.singletonList(Row.of("x8", 8,
"z8")));
+ id2features.put(Row.of(9L), Collections.singletonList(Row.of("x9", 9,
"z9")));
+ }
+
+ private final Map<Row, List<Row>> idLen2features = new HashMap<>();
+
+ {
+ idLen2features.put(
+ Row.of(15L, null), Collections.singletonList(Row.of("x1", 1,
"zNull15")));
+ idLen2features.put(Row.of(15L, 15),
Collections.singletonList(Row.of("x1", 1, "z1515")));
+ idLen2features.put(Row.of(3L, 15),
Collections.singletonList(Row.of("x2", 2, "z315")));
+ idLen2features.put(
+ Row.of(11L, null), Collections.singletonList(Row.of("x3", 3,
"zNull11")));
+ idLen2features.put(Row.of(11L, 11),
Collections.singletonList(Row.of("x3", 3, "z1111")));
+ idLen2features.put(Row.of(9L, 12),
Collections.singletonList(Row.of("x8", 8, "z912")));
+ idLen2features.put(Row.of(12L, 12),
Collections.singletonList(Row.of("x8", 8, "z1212")));
+ }
+
+ private final Map<Row, List<Row>> content2vector = new HashMap<>();
+
+ {
+ content2vector.put(
+ Row.of("Julian"),
+ Collections.singletonList(Row.of((Object) new Float[] {1.0f,
2.0f, 3.0f})));
+ content2vector.put(
+ Row.of("Hello"),
+ Collections.singletonList(Row.of((Object) new Float[] {2.0f,
3.0f, 4.0f})));
+ content2vector.put(
+ Row.of("Fabian"),
+ Collections.singletonList(Row.of((Object) new Float[] {3.0f,
4.0f, 5.0f})));
+ }
+
+ @BeforeEach
+ public void before() {
+ super.before();
+ if (objectReuse) {
+ env().getConfig().enableObjectReuse();
+ } else {
+ env().getConfig().disableObjectReuse();
+ }
+ tEnv().getConfig()
+ .set(
+
ExecutionConfigOptions.TABLE_EXEC_ASYNC_ML_PREDICT_OUTPUT_MODE,
+ asyncOutputMode);
+
+ createScanTable("src", data);
+ createScanTable("nullable_src", dataWithNull);
+ createScanTable("cdc_src", cdcRowData);
+
+ tEnv().executeSql(
+ String.format(
+ "CREATE MODEL m1\n"
+ + "INPUT (a BIGINT)\n"
+ + "OUTPUT (x STRING, y INT, z
STRING)\n"
+ + "WITH (\n"
+ + " 'provider' = 'values',"
+ + " 'async' = 'true',"
+ + " 'data-id' = '%s'"
+ + ")",
+
TestValuesModelFactory.registerData(id2features)));
+ tEnv().executeSql(
+ String.format(
+ "CREATE MODEL m2\n"
+ + "INPUT (a BIGINT, b INT)\n"
+ + "OUTPUT (x STRING, y INT, z
STRING)\n"
+ + "WITH (\n"
+ + " 'provider' = 'values',"
+ + " 'async' = 'true',"
+ + " 'data-id' = '%s'"
+ + ")",
+
TestValuesModelFactory.registerData(idLen2features)));
+ tEnv().executeSql(
+ String.format(
+ "CREATE MODEL m3\n"
+ + "INPUT (content STRING)\n"
+ + "OUTPUT (vector ARRAY<FLOAT>)\n"
+ + "WITH (\n"
+ + " 'provider' = 'values',"
+ + " 'data-id' = '%s',"
+ + " 'async' = 'true'"
+ + ")",
+
TestValuesModelFactory.registerData(content2vector)));
+ }
+
+ @TestTemplate
+ public void testAsyncMLPredict() {
+ assertThatList(
+ CollectionUtil.iteratorToList(
+ tEnv().executeSql(
+ "SELECT id, z FROM
ML_PREDICT(TABLE src, MODEL m1, DESCRIPTOR(`id`))")
+ .collect()))
+ .containsExactlyInAnyOrder(
+ Row.of(1L, "z1"),
+ Row.of(2L, "z2"),
+ Row.of(3L, "z3"),
+ Row.of(8L, "z8"),
+ Row.of(9L, "z9"));
+ }
+
+ @TestTemplate
+ public void testAsyncMLPredictWithMultipleFields() {
+ assertThatList(
+ CollectionUtil.iteratorToList(
+ tEnv().executeSql(
+ "SELECT id, len, z FROM
ML_PREDICT(TABLE nullable_src, MODEL m2, DESCRIPTOR(`id`, `len`))")
+ .collect()))
+ .containsExactlyInAnyOrder(
+ Row.of(3L, 15, "z315"),
+ Row.of(9L, 12, "z912"),
+ Row.of(11L, null, "zNull11"),
+ Row.of(15L, null, "zNull15"));
+ }
+
+ @TestTemplate
+ public void testAsyncMLPredictWithConstantValues() {
+ assertThatList(
+ CollectionUtil.iteratorToList(
+ tEnv().executeSql(
+ "WITH v(id) AS (SELECT * FROM
(VALUES (CAST(1 AS BIGINT)), (CAST(2 AS BIGINT)))) "
+ + "SELECT * FROM
ML_PREDICT(INPUT => TABLE v, MODEL => MODEL `m1`, ARGS => DESCRIPTOR(`id`))")
+ .collect()))
+ .containsExactlyInAnyOrder(Row.of(1L, "x1", 1, "z1"),
Row.of(2L, "x2", 2, "z2"));
+ }
+
+ @TestTemplate
+ public void testAsyncPredictWithCDCData() throws Exception {
+ tEnv().executeSql(
+ "CREATE TEMPORARY TABLE sink("
+ + " id BIGINT,"
+ + " vector ARRAY<FLOAT>"
+ + ") WITH ("
+ + " 'connector' = 'values',"
+ + " 'sink-insert-only' = 'false'"
+ + ")");
+ tEnv().executeSql(
+ "INSERT INTO sink SELECT id, vector FROM
ML_PREDICT(TABLE cdc_src, MODEL m3, DESCRIPTOR(`content`))")
+ .await();
+
+ assertThatList(TestValuesTableFactory.getResults("sink"))
+ .containsExactlyInAnyOrder(
+ Row.of(2L, new Float[] {2.0f, 3.0f, 4.0f}),
+ Row.of(1L, new Float[] {1.0f, 2.0f, 3.0f}));
+ }
+
+ private void createScanTable(String tableName, List<Row> data) {
+ String dataId = TestValuesTableFactory.registerData(data);
+ tEnv().executeSql(
+ String.format(
+ "CREATE TABLE `%s`(\n"
+ + " id BIGINT,"
+ + " len INT,"
+ + " content STRING,"
+ + " PRIMARY KEY (`id`) NOT ENFORCED"
+ + ") WITH ("
+ + " 'connector' = 'values',"
+ + " 'data-id' = '%s',"
+ + " 'changelog-mode' = 'I,UA,UB,D'"
+ + ")",
+ tableName, dataId));
+ }
+
+ @Parameters(name = "backend = {0}, objectReuse = {1}, asyncOutputMode =
{2}")
+ public static Collection<Object[]> parameters() {
+ return Arrays.asList(
+ new Object[][] {
+ {
+ StreamingWithStateTestBase.HEAP_BACKEND(),
+ true,
+ ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED
+ },
+ {
+ StreamingWithStateTestBase.HEAP_BACKEND(),
+ true,
+ ExecutionConfigOptions.AsyncOutputMode.ORDERED
+ },
+ {
+ StreamingWithStateTestBase.HEAP_BACKEND(),
+ false,
+ ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED
+ },
+ {
+ StreamingWithStateTestBase.HEAP_BACKEND(),
+ false,
+ ExecutionConfigOptions.AsyncOutputMode.ORDERED
+ },
+ {
+ StreamingWithStateTestBase.ROCKSDB_BACKEND(),
+ true,
+ ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED
+ },
+ {
+ StreamingWithStateTestBase.ROCKSDB_BACKEND(),
+ true,
+ ExecutionConfigOptions.AsyncOutputMode.ORDERED
+ },
+ {
+ StreamingWithStateTestBase.ROCKSDB_BACKEND(),
+ false,
+ ExecutionConfigOptions.AsyncOutputMode.ALLOW_UNORDERED
+ },
+ {
+ StreamingWithStateTestBase.ROCKSDB_BACKEND(),
+ false,
+ ExecutionConfigOptions.AsyncOutputMode.ORDERED
+ }
+ });
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/MLPredictITCase.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/MLPredictITCase.java
new file mode 100644
index 00000000000..cb6fdfaf3f7
--- /dev/null
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/MLPredictITCase.java
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.runtime.stream.table;
+
+import org.apache.flink.table.planner.factories.TestValuesModelFactory;
+import org.apache.flink.table.planner.factories.TestValuesTableFactory;
+import
org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecMLPredictTableFunction;
+import org.apache.flink.table.planner.runtime.utils.StreamingTestBase;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CollectionUtil;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThatList;
+
+/** ITCase for {@link StreamExecMLPredictTableFunction}. */
+public class MLPredictITCase extends StreamingTestBase {
+
+ private final List<Row> data =
+ Arrays.asList(
+ Row.of(1L, 12, "Julian"),
+ Row.of(2L, 15, "Hello"),
+ Row.of(3L, 15, "Fabian"),
+ Row.of(8L, 11, "Hello world"),
+ Row.of(9L, 12, "Hello world!"));
+
+ private final List<Row> dataWithNull =
+ Arrays.asList(
+ Row.of(null, 15, "Hello"),
+ Row.of(3L, 15, "Fabian"),
+ Row.of(null, 11, "Hello world"),
+ Row.of(9L, 12, "Hello world!"));
+
+ private final Map<Row, List<Row>> id2features = new HashMap<>();
+
+ {
+ id2features.put(Row.of(1L), Collections.singletonList(Row.of("x1", 1,
"z1")));
+ id2features.put(Row.of(2L), Collections.singletonList(Row.of("x2", 2,
"z2")));
+ id2features.put(Row.of(3L), Collections.singletonList(Row.of("x3", 3,
"z3")));
+ id2features.put(Row.of(8L), Collections.singletonList(Row.of("x8", 8,
"z8")));
+ id2features.put(Row.of(9L), Collections.singletonList(Row.of("x9", 9,
"z9")));
+ }
+
+ private final Map<Row, List<Row>> idLen2features = new HashMap<>();
+
+ {
+ idLen2features.put(Row.of(null, 15),
Collections.singletonList(Row.of("x1", 1, "zNull15")));
+ idLen2features.put(Row.of(15L, 15),
Collections.singletonList(Row.of("x1", 1, "z1515")));
+ idLen2features.put(Row.of(3L, 15),
Collections.singletonList(Row.of("x2", 2, "z315")));
+ idLen2features.put(Row.of(null, 11),
Collections.singletonList(Row.of("x3", 3, "zNull11")));
+ idLen2features.put(Row.of(11L, 11),
Collections.singletonList(Row.of("x3", 3, "z1111")));
+ idLen2features.put(Row.of(9L, 12),
Collections.singletonList(Row.of("x8", 8, "z912")));
+ idLen2features.put(Row.of(12L, 12),
Collections.singletonList(Row.of("x8", 8, "z1212")));
+ }
+
+ @BeforeEach
+ public void before() throws Exception {
+ super.before();
+ createScanTable("src", data);
+ createScanTable("nullable_src", dataWithNull);
+
+ tEnv().executeSql(
+ String.format(
+ "CREATE MODEL m1\n"
+ + "INPUT (a BIGINT)\n"
+ + "OUTPUT (x STRING, y INT, z
STRING)\n"
+ + "WITH (\n"
+ + " 'provider' = 'values',"
+ + " 'data-id' = '%s'"
+ + ")",
+
TestValuesModelFactory.registerData(id2features)));
+ tEnv().executeSql(
+ String.format(
+ "CREATE MODEL m2\n"
+ + "INPUT (a BIGINT, b INT)\n"
+ + "OUTPUT (x STRING, y INT, z
STRING)\n"
+ + "WITH (\n"
+ + " 'provider' = 'values',"
+ + " 'data-id' = '%s'"
+ + ")",
+
TestValuesModelFactory.registerData(idLen2features)));
+ }
+
+ @Test
+ public void testMLPredict() {
+ List<Row> result =
+ CollectionUtil.iteratorToList(
+ tEnv().executeSql(
+ "SELECT id, z "
+ + "FROM ML_PREDICT(TABLE src,
MODEL m1, DESCRIPTOR(`id`)) ")
+ .collect());
+
+ assertThatList(result)
+ .containsExactlyInAnyOrder(
+ Row.of(1L, "z1"),
+ Row.of(2L, "z2"),
+ Row.of(3L, "z3"),
+ Row.of(8L, "z8"),
+ Row.of(9L, "z9"));
+ }
+
+ @Test
+ public void testMLPredictWithMultipleFields() {
+ List<Row> result =
+ CollectionUtil.iteratorToList(
+ tEnv().executeSql(
+ "SELECT id, len, z "
+ + "FROM ML_PREDICT(TABLE
nullable_src, MODEL m2, DESCRIPTOR(`id`, `len`)) ")
+ .collect());
+
+ assertThatList(result)
+ .containsExactlyInAnyOrder(
+ Row.of(3L, 15, "z315"),
+ Row.of(9L, 12, "z912"),
+ Row.of(null, 11, "zNull11"),
+ Row.of(null, 15, "zNull15"));
+ }
+
+ @Test
+ public void testPredictWithConstantValues() {
+ List<Row> result =
+ CollectionUtil.iteratorToList(
+ tEnv().executeSql(
+ "WITH v(id) AS (SELECT * FROM (VALUES
(CAST(1 AS BIGINT)), (CAST(2 AS BIGINT)))) "
+ + "SELECT * FROM ML_PREDICT( "
+ + " INPUT => TABLE v, "
+ + " MODEL => MODEL `m1`, "
+ + " ARGS => DESCRIPTOR(`id`) "
+ + ")")
+ .collect());
+
+ assertThatList(result)
+ .containsExactlyInAnyOrder(Row.of(1L, "x1", 1, "z1"),
Row.of(2L, "x2", 2, "z2"));
+ }
+
+ private void createScanTable(String tableName, List<Row> data) {
+ String dataId = TestValuesTableFactory.registerData(data);
+ tEnv().executeSql(
+ String.format(
+ "CREATE TABLE `%s`(\n"
+ + " id BIGINT,\n"
+ + " len INT,\n"
+ + " content STRING\n"
+ + ") WITH (\n"
+ + " 'connector' = 'values',\n"
+ + " 'data-id' = '%s'\n"
+ + ")",
+ tableName, dataId));
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/test/resources/META-INF/services/org.apache.flink.table.factories.Factory
b/flink-table/flink-table-planner/src/test/resources/META-INF/services/org.apache.flink.table.factories.Factory
index 4a0e2319b3a..a7640c70f42 100644
---
a/flink-table/flink-table-planner/src/test/resources/META-INF/services/org.apache.flink.table.factories.Factory
+++
b/flink-table/flink-table-planner/src/test/resources/META-INF/services/org.apache.flink.table.factories.Factory
@@ -21,4 +21,5 @@
org.apache.flink.table.planner.plan.stream.sql.TestTableFactory
org.apache.flink.table.planner.factories.TestUpdateDeleteTableFactory
org.apache.flink.table.planner.factories.TestSupportsStagingTableFactory
org.apache.flink.table.planner.factories.TestProcedureCatalogFactory
-org.apache.flink.table.planner.utils.TestSimpleDynamicTableSourceFactory
\ No newline at end of file
+org.apache.flink.table.planner.utils.TestSimpleDynamicTableSourceFactory
+org.apache.flink.table.planner.factories.TestValuesModelFactory
\ No newline at end of file
diff --git
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/ml/ModelPredictRuntimeProviderContext.java
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/ml/ModelPredictRuntimeProviderContext.java
new file mode 100644
index 00000000000..69681d13a9b
--- /dev/null
+++
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/ml/ModelPredictRuntimeProviderContext.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions.ml;
+
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.table.catalog.ResolvedCatalogModel;
+import org.apache.flink.table.ml.ModelProvider;
+
+/** Context to provide the query information. */
+public class ModelPredictRuntimeProviderContext implements
ModelProvider.Context {
+
+ private final ResolvedCatalogModel catalogModel;
+ private final ReadableConfig runtimeConfig;
+
+ public ModelPredictRuntimeProviderContext(
+ ResolvedCatalogModel catalogModel, ReadableConfig runtimeConfig) {
+ this.catalogModel = catalogModel;
+ this.runtimeConfig = runtimeConfig;
+ }
+
+ @Override
+ public ResolvedCatalogModel getCatalogModel() {
+ return catalogModel;
+ }
+
+ @Override
+ public ReadableConfig runtimeConfig() {
+ return runtimeConfig;
+ }
+}