fsk119 commented on code in PR #26667:
URL: https://github.com/apache/flink/pull/26667#discussion_r2375143319
##########
flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLEvaluateTableFunctionTest.java:
##########
@@ -112,15 +112,15 @@ public void testOptionalNamedArgumentsWithTaskAndConfig()
{
+ "ARGS => DESCRIPTOR(a, b), "
+ "TASK => 'classification', "
+ "CONFIG => MAP['metrics', 'accuracy,f1']))";
- assertReachOptimizer(sql);
+ util.verifyRelPlan(sql);
Review Comment:
What options we can specify in the ML_EVALUTE?
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/MLEvaluationAggregationFunction.java:
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.sql.ml;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.inference.TypeInference;
+import org.apache.flink.types.Row;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/** Aggregation function for evaluating models based on task type. */
+@Internal
+public class MLEvaluationAggregationFunction extends AggregateFunction<Row,
Object> {
Review Comment:
private static final long serialVersionUID = 1L;
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ExpandMLEvaluateTableFunctionRule.java:
##########
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.catalog.ContextResolvedFunction;
+import org.apache.flink.table.functions.FunctionIdentifier;
+import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
+import org.apache.flink.table.ml.PredictRuntimeProvider;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.planner.calcite.FlinkContext;
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.planner.calcite.RexModelCall;
+import
org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.MLEvaluationAggregationFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLPredictTableFunction;
+import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCallBinding;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.NlsString;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Rule that expands ML evaluation table function calls.
+ *
+ * <p>This rule matches {@link FlinkLogicalTableFunctionScan} with a {@link
+ * SqlMLEvaluateTableFunction} call and expands it into ml predict table
function and an aggregation
+ * function following it.
+ */
+@Internal
[email protected]
+public class ExpandMLEvaluateTableFunctionRule
+ extends RelRule<ExpandMLEvaluateTableFunctionRule.Config> {
+
+ public static final RelOptRule INSTANCE = new
ExpandMLEvaluateTableFunctionRule(Config.DEFAULT);
+
+ public ExpandMLEvaluateTableFunctionRule(Config config) {
+ super(config);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ final LogicalTableFunctionScan scan = call.rel(0);
+ final RelDataType resultType = scan.getRowType();
+ final RelBuilder relBuilder = call.builder().push(scan.getInput(0));
+
+ final RexCall rexCall = (RexCall) scan.getCall();
+
+ RelDataType predictOutputType = addPredictTableFunction(relBuilder,
rexCall);
+ addProjection(relBuilder, rexCall, predictOutputType);
+ addAggregate(relBuilder, rexCall, resultType);
+
+ call.transformTo(relBuilder.build());
+ }
+
+ private void addAggregate(RelBuilder relBuilder, RexCall rexCall,
RelDataType resultType) {
+ final String task = getTask(rexCall);
+ final MLEvaluationAggregationFunction aggregationFunction =
+ new MLEvaluationAggregationFunction(task);
+ final FlinkContext context =
ShortcutUtils.unwrapContext(relBuilder.getCluster());
+ final FlinkTypeFactory typeFactory =
+ ShortcutUtils.unwrapTypeFactory(relBuilder.getCluster());
+ relBuilder.aggregate(
+ relBuilder.groupKey(),
+ List.of(
+ AggregateCall.create(
+ BridgingSqlAggFunction.of(
+ context,
+ typeFactory,
+ ContextResolvedFunction.temporary(
+
FunctionIdentifier.of("ml_evaluate"),
+ aggregationFunction)),
+ false,
+ false,
+ false,
+ List.of(0, 1),
+ -1,
+ null,
+ RelCollations.EMPTY,
+ resultType.getFieldList().get(0).getType(),
+ "result")));
+ }
+
+ private void addProjection(
+ RelBuilder relBuilder, RexCall rexCall, RelDataType
predictOutputType) {
+ final RexCall labelDescriptor = (RexCall) rexCall.getOperands().get(2);
+ final String labelCol =
+ ((RexLiteral) labelDescriptor.getOperands().get(0))
+ .getValueAs(NlsString.class)
+ .getValue();
+
+ // Project the label column and the last column (prediction). Only one
label and prediction
+ // column is expected. Validation is done in
SqlMLEvaluateTableFunction.
+ final List<RexNode> projectFields =
+ predictOutputType.getFieldList().stream()
+ .filter(
+ field ->
+ field.getName().equals(labelCol)
+ || field.getIndex()
+ ==
predictOutputType.getFieldCount() - 1)
+ .map(field -> relBuilder.field(field.getIndex()))
+ .collect(Collectors.toList());
+ relBuilder.project(projectFields);
+ }
+
+ private RelDataType addPredictTableFunction(RelBuilder relBuilder, RexCall
rexCall) {
+ final RexCall tableArg = (RexCall) rexCall.getOperands().get(0);
+ final RexCall modelCall = (RexCall) rexCall.getOperands().get(1);
+ final RexCall featuresDescriptor = (RexCall)
rexCall.getOperands().get(3);
+
+ // Get optional config map if present
+ final RexCall configMap = getConfigMap(rexCall);
+
+ final List<RexNode> predictOperands = new ArrayList<>();
+ predictOperands.add(tableArg);
+ predictOperands.add(modelCall);
+ predictOperands.add(featuresDescriptor);
+ if (configMap != null) {
+ predictOperands.add(configMap);
+ }
+
+ final RexCall predictCall =
+ (RexCall)
+ relBuilder
+ .getRexBuilder()
+ .makeCall(new SqlMLPredictTableFunction(),
predictOperands);
Review Comment:
nit: use
org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable#ML_PREDICT
instead
##########
flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/MLEvaluateTableFunctionTest.java:
##########
@@ -54,7 +54,7 @@ public void setup() {
+ " b BIGINT,\n"
+ " c STRING,\n"
+ " d DECIMAL(10, 3),\n"
- + " label STRING,\n"
+ + " label FLOAT,\n"
Review Comment:
I find I can not run the following sql
```
SELECT 1, result
FROM TABLE(ML_EVALUATE(TABLE MyTable, MODEL MyModel, DESCRIPTOR(label),
DESCRIPTOR(a, b), 'classification'))
```
##########
flink-table/flink-table-common/src/main/java/org/apache/flink/table/ml/TaskType.java:
##########
@@ -19,8 +19,11 @@
package org.apache.flink.table.ml;
import org.apache.flink.annotation.Experimental;
+import org.apache.flink.table.api.ValidationException;
import java.util.Arrays;
+import java.util.Optional;
+import java.util.stream.Collectors;
Review Comment:
Why TaskType needs a string value, it seems we can just add a no-parameter
constructor?
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/MLEvaluationAggregationFunction.java:
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.sql.ml;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.inference.TypeInference;
+import org.apache.flink.types.Row;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/** Aggregation function for evaluating models based on task type. */
+@Internal
+public class MLEvaluationAggregationFunction extends AggregateFunction<Row,
Object> {
Review Comment:
Why don't extend BuiltInAggregateFunction? Take a look at ArrayAggFunction
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/catalog/CatalogSchemaModel.java:
##########
@@ -96,7 +96,7 @@ public RexNode toRex(SqlRexContext rexContext) {
FlinkContext context = ShortcutUtils.unwrapContext(cluster);
ModelProvider modelProvider = createModelProvider(context,
contextResolvedModel);
return new RexModelCall(
- getInputRowType(validator.getTypeFactory()),
contextResolvedModel, modelProvider);
+ getOutputRowType(validator.getTypeFactory()),
contextResolvedModel, modelProvider);
Review Comment:
we need open a new PR to fix this part. Do you mind opening an PR for
release-2.1?
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/MLEvaluationAggregationFunction.java:
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.sql.ml;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.inference.TypeInference;
+import org.apache.flink.types.Row;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/** Aggregation function for evaluating models based on task type. */
+@Internal
+public class MLEvaluationAggregationFunction extends AggregateFunction<Row,
Object> {
Review Comment:
I think we shoud move this class to table-runtime module because this class
is releated to execution.
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/MLEvaluationAggregationFunction.java:
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.sql.ml;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.inference.TypeInference;
+import org.apache.flink.types.Row;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/** Aggregation function for evaluating models based on task type. */
+@Internal
+public class MLEvaluationAggregationFunction extends AggregateFunction<Row,
Object> {
+
+ public static final Map<String, DataType> TASK_TYPE_MAP =
Review Comment:
Why not `Map<TaskType, DataType>`
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ExpandMLEvaluateTableFunctionRule.java:
##########
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.catalog.ContextResolvedFunction;
+import org.apache.flink.table.functions.FunctionIdentifier;
+import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
+import org.apache.flink.table.ml.PredictRuntimeProvider;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.planner.calcite.FlinkContext;
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.planner.calcite.RexModelCall;
+import
org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.MLEvaluationAggregationFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLPredictTableFunction;
+import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCallBinding;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.NlsString;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Rule that expands ML evaluation table function calls.
+ *
+ * <p>This rule matches {@link FlinkLogicalTableFunctionScan} with a {@link
+ * SqlMLEvaluateTableFunction} call and expands it into ml predict table
function and an aggregation
+ * function following it.
+ */
+@Internal
[email protected]
+public class ExpandMLEvaluateTableFunctionRule
+ extends RelRule<ExpandMLEvaluateTableFunctionRule.Config> {
+
+ public static final RelOptRule INSTANCE = new
ExpandMLEvaluateTableFunctionRule(Config.DEFAULT);
+
+ public ExpandMLEvaluateTableFunctionRule(Config config) {
Review Comment:
nit: private constructor
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/MLEvaluationAggregationFunction.java:
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.sql.ml;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.inference.TypeInference;
+import org.apache.flink.types.Row;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/** Aggregation function for evaluating models based on task type. */
+@Internal
+public class MLEvaluationAggregationFunction extends AggregateFunction<Row,
Object> {
+
+ public static final Map<String, DataType> TASK_TYPE_MAP =
+ Map.of(
+ TaskType.TEXT_GENERATION.getName(),
+ DataTypes.STRING(),
+ TaskType.CLUSTERING.getName(),
+ DataTypes.DOUBLE(),
+ TaskType.EMBEDDING.getName(),
+ DataTypes.ARRAY(DataTypes.FLOAT()),
+ TaskType.CLASSIFICATION.getName(),
+ DataTypes.DOUBLE(),
+ TaskType.REGRESSION.getName(),
+ DataTypes.DOUBLE());
+
+ private final String task;
+
+ public MLEvaluationAggregationFunction(String task) {
+ TaskType.throwOrReturnInvalidTaskType(task, true);
+ this.task = task;
+ }
+
+ private TypeInference typeInference() {
+ return TypeInference.newBuilder()
+ .inputTypeStrategy(
+ new InputTypeStrategy() {
+ @Override
+ public ArgumentCount getArgumentCount() {
+ return new ArgumentCount() {
+ @Override
+ public boolean isValidCount(int count) {
+ return count == 2;
+ }
+
+ @Override
+ public Optional<Integer> getMinCount() {
+ return Optional.of(2);
+ }
+
+ @Override
+ public Optional<Integer> getMaxCount() {
+ return Optional.of(2);
+ }
+ };
+ }
+
+ @Override
+ public Optional<List<DataType>> inferInputTypes(
+ CallContext callContext, boolean
throwOnFailure) {
+ DataType argumentType =
TASK_TYPE_MAP.get(task.toLowerCase());
+ final List<DataType> args =
List.of(argumentType, argumentType);
+ return Optional.of(args);
+ }
+
+ @Override
+ public List<Signature> getExpectedSignatures(
+ FunctionDefinition definition) {
+ final List<Signature.Argument> arguments = new
ArrayList<>();
+ arguments.add(Signature.Argument.of("label"));
+
arguments.add(Signature.Argument.of("prediction"));
+ return
Collections.singletonList(Signature.of(arguments));
+ }
+ })
+ .outputTypeStrategy(
+ callContext ->
+ Optional.of(
+ DataTypes.MAP(
+
DataTypes.STRING().notNull(),
+
DataTypes.DOUBLE().notNull())
+ .notNull()
+ .bridgedTo(Map.class)))
+ .accumulatorTypeStrategy(callContext ->
Optional.of(DataTypes.DOUBLE()))
Review Comment:
acc is double? It seems we need a map to place the results.
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/MLEvaluationAggregationFunction.java:
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.sql.ml;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.inference.TypeInference;
+import org.apache.flink.types.Row;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/** Aggregation function for evaluating models based on task type. */
+@Internal
+public class MLEvaluationAggregationFunction extends AggregateFunction<Row,
Object> {
Review Comment:
Rename to MLEvaluateAggFunction?
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/MLEvaluationAggregationFunction.java:
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.sql.ml;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.ArgumentCount;
+import org.apache.flink.table.types.inference.CallContext;
+import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.Signature;
+import org.apache.flink.table.types.inference.TypeInference;
+import org.apache.flink.types.Row;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/** Aggregation function for evaluating models based on task type. */
+@Internal
+public class MLEvaluationAggregationFunction extends AggregateFunction<Row,
Object> {
+
+ public static final Map<String, DataType> TASK_TYPE_MAP =
Review Comment:
How about split the `MLEvaluationAggregationFunction` into seperate classes,
e.g. ClassificationEvaluateAggFunction, ClusteringEvaluteAggFunction?
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ExpandMLEvaluateTableFunctionRule.java:
##########
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.catalog.ContextResolvedFunction;
+import org.apache.flink.table.functions.FunctionIdentifier;
+import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
+import org.apache.flink.table.ml.PredictRuntimeProvider;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.planner.calcite.FlinkContext;
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.planner.calcite.RexModelCall;
+import
org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.MLEvaluationAggregationFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLPredictTableFunction;
+import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCallBinding;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.NlsString;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Rule that expands ML evaluation table function calls.
+ *
+ * <p>This rule matches {@link FlinkLogicalTableFunctionScan} with a {@link
+ * SqlMLEvaluateTableFunction} call and expands it into ml predict table
function and an aggregation
+ * function following it.
+ */
+@Internal
[email protected]
+public class ExpandMLEvaluateTableFunctionRule
+ extends RelRule<ExpandMLEvaluateTableFunctionRule.Config> {
+
+ public static final RelOptRule INSTANCE = new
ExpandMLEvaluateTableFunctionRule(Config.DEFAULT);
+
+ public ExpandMLEvaluateTableFunctionRule(Config config) {
+ super(config);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ final LogicalTableFunctionScan scan = call.rel(0);
+ final RelDataType resultType = scan.getRowType();
+ final RelBuilder relBuilder = call.builder().push(scan.getInput(0));
+
+ final RexCall rexCall = (RexCall) scan.getCall();
+
+ RelDataType predictOutputType = addPredictTableFunction(relBuilder,
rexCall);
+ addProjection(relBuilder, rexCall, predictOutputType);
+ addAggregate(relBuilder, rexCall, resultType);
+
+ call.transformTo(relBuilder.build());
+ }
+
+ private void addAggregate(RelBuilder relBuilder, RexCall rexCall,
RelDataType resultType) {
+ final String task = getTask(rexCall);
+ final MLEvaluationAggregationFunction aggregationFunction =
+ new MLEvaluationAggregationFunction(task);
+ final FlinkContext context =
ShortcutUtils.unwrapContext(relBuilder.getCluster());
+ final FlinkTypeFactory typeFactory =
+ ShortcutUtils.unwrapTypeFactory(relBuilder.getCluster());
+ relBuilder.aggregate(
+ relBuilder.groupKey(),
+ List.of(
+ AggregateCall.create(
+ BridgingSqlAggFunction.of(
+ context,
+ typeFactory,
+ ContextResolvedFunction.temporary(
+
FunctionIdentifier.of("ml_evaluate"),
+ aggregationFunction)),
+ false,
+ false,
+ false,
+ List.of(0, 1),
+ -1,
+ null,
+ RelCollations.EMPTY,
+ resultType.getFieldList().get(0).getType(),
+ "result")));
+ }
+
+ private void addProjection(
+ RelBuilder relBuilder, RexCall rexCall, RelDataType
predictOutputType) {
+ final RexCall labelDescriptor = (RexCall) rexCall.getOperands().get(2);
+ final String labelCol =
+ ((RexLiteral) labelDescriptor.getOperands().get(0))
+ .getValueAs(NlsString.class)
+ .getValue();
+
+ // Project the label column and the last column (prediction). Only one
label and prediction
+ // column is expected. Validation is done in
SqlMLEvaluateTableFunction.
+ final List<RexNode> projectFields =
+ predictOutputType.getFieldList().stream()
+ .filter(
+ field ->
+ field.getName().equals(labelCol)
+ || field.getIndex()
+ ==
predictOutputType.getFieldCount() - 1)
+ .map(field -> relBuilder.field(field.getIndex()))
+ .collect(Collectors.toList());
+ relBuilder.project(projectFields);
+ }
+
+ private RelDataType addPredictTableFunction(RelBuilder relBuilder, RexCall
rexCall) {
+ final RexCall tableArg = (RexCall) rexCall.getOperands().get(0);
+ final RexCall modelCall = (RexCall) rexCall.getOperands().get(1);
+ final RexCall featuresDescriptor = (RexCall)
rexCall.getOperands().get(3);
+
+ // Get optional config map if present
+ final RexCall configMap = getConfigMap(rexCall);
+
+ final List<RexNode> predictOperands = new ArrayList<>();
+ predictOperands.add(tableArg);
+ predictOperands.add(modelCall);
+ predictOperands.add(featuresDescriptor);
+ if (configMap != null) {
+ predictOperands.add(configMap);
+ }
+
+ final RexCall predictCall =
+ (RexCall)
+ relBuilder
+ .getRexBuilder()
+ .makeCall(new SqlMLPredictTableFunction(),
predictOperands);
+
+ RexCallBinding callBinding =
Review Comment:
RelBuilder supports to infer type. You can use
```
relBuilder.functionScan(ML_PREDICT, 1, predictOperands)
```
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ExpandMLEvaluateTableFunctionRule.java:
##########
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.catalog.ContextResolvedFunction;
+import org.apache.flink.table.functions.FunctionIdentifier;
+import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
+import org.apache.flink.table.ml.PredictRuntimeProvider;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.planner.calcite.FlinkContext;
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.planner.calcite.RexModelCall;
+import
org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.MLEvaluationAggregationFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLPredictTableFunction;
+import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCallBinding;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.NlsString;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Rule that expands ML evaluation table function calls.
+ *
+ * <p>This rule matches {@link FlinkLogicalTableFunctionScan} with a {@link
+ * SqlMLEvaluateTableFunction} call and expands it into ml predict table
function and an aggregation
+ * function following it.
+ */
+@Internal
[email protected]
+public class ExpandMLEvaluateTableFunctionRule
+ extends RelRule<ExpandMLEvaluateTableFunctionRule.Config> {
+
+ public static final RelOptRule INSTANCE = new
ExpandMLEvaluateTableFunctionRule(Config.DEFAULT);
+
+ public ExpandMLEvaluateTableFunctionRule(Config config) {
+ super(config);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ final LogicalTableFunctionScan scan = call.rel(0);
+ final RelDataType resultType = scan.getRowType();
+ final RelBuilder relBuilder = call.builder().push(scan.getInput(0));
+
+ final RexCall rexCall = (RexCall) scan.getCall();
+
+ RelDataType predictOutputType = addPredictTableFunction(relBuilder,
rexCall);
+ addProjection(relBuilder, rexCall, predictOutputType);
+ addAggregate(relBuilder, rexCall, resultType);
+
+ call.transformTo(relBuilder.build());
+ }
+
+ private void addAggregate(RelBuilder relBuilder, RexCall rexCall,
RelDataType resultType) {
+ final String task = getTask(rexCall);
+ final MLEvaluationAggregationFunction aggregationFunction =
+ new MLEvaluationAggregationFunction(task);
+ final FlinkContext context =
ShortcutUtils.unwrapContext(relBuilder.getCluster());
+ final FlinkTypeFactory typeFactory =
+ ShortcutUtils.unwrapTypeFactory(relBuilder.getCluster());
+ relBuilder.aggregate(
+ relBuilder.groupKey(),
+ List.of(
+ AggregateCall.create(
+ BridgingSqlAggFunction.of(
+ context,
+ typeFactory,
+ ContextResolvedFunction.temporary(
+
FunctionIdentifier.of("ml_evaluate"),
+ aggregationFunction)),
+ false,
+ false,
+ false,
+ List.of(0, 1),
+ -1,
+ null,
+ RelCollations.EMPTY,
+ resultType.getFieldList().get(0).getType(),
+ "result")));
+ }
+
+ private void addProjection(
+ RelBuilder relBuilder, RexCall rexCall, RelDataType
predictOutputType) {
+ final RexCall labelDescriptor = (RexCall) rexCall.getOperands().get(2);
+ final String labelCol =
+ ((RexLiteral) labelDescriptor.getOperands().get(0))
+ .getValueAs(NlsString.class)
+ .getValue();
+
+ // Project the label column and the last column (prediction). Only one
label and prediction
+ // column is expected. Validation is done in
SqlMLEvaluateTableFunction.
+ final List<RexNode> projectFields =
+ predictOutputType.getFieldList().stream()
+ .filter(
+ field ->
+ field.getName().equals(labelCol)
+ || field.getIndex()
+ ==
predictOutputType.getFieldCount() - 1)
+ .map(field -> relBuilder.field(field.getIndex()))
+ .collect(Collectors.toList());
+ relBuilder.project(projectFields);
+ }
+
+ private RelDataType addPredictTableFunction(RelBuilder relBuilder, RexCall
rexCall) {
+ final RexCall tableArg = (RexCall) rexCall.getOperands().get(0);
+ final RexCall modelCall = (RexCall) rexCall.getOperands().get(1);
+ final RexCall featuresDescriptor = (RexCall)
rexCall.getOperands().get(3);
+
+ // Get optional config map if present
+ final RexCall configMap = getConfigMap(rexCall);
+
+ final List<RexNode> predictOperands = new ArrayList<>();
+ predictOperands.add(tableArg);
+ predictOperands.add(modelCall);
+ predictOperands.add(featuresDescriptor);
+ if (configMap != null) {
+ predictOperands.add(configMap);
+ }
+
+ final RexCall predictCall =
+ (RexCall)
+ relBuilder
+ .getRexBuilder()
+ .makeCall(new SqlMLPredictTableFunction(),
predictOperands);
+
+ RexCallBinding callBinding =
+ new RexCallBinding(
+ relBuilder.getTypeFactory(),
+ predictCall.getOperator(),
+ predictOperands,
+ List.of());
+
+ RelDataType predictReturnType =
+ ((SqlMLPredictTableFunction) predictCall.getOperator())
+ .getRowTypeInference()
+ .inferReturnType(callBinding);
+
+ relBuilder.push(
+ LogicalTableFunctionScan.create(
+ relBuilder.getCluster(),
+ List.of(relBuilder.build()),
+ predictCall,
+ null,
+ predictReturnType,
+ Collections.emptySet()));
+
+ return predictReturnType;
+ }
+
+ private static RexCall getConfigMap(RexCall rexCall) {
+ if (rexCall.getOperands().size() > 5) {
+ return (RexCall) rexCall.getOperands().get(5);
+ }
+ if (rexCall.getOperands().size() > 4) {
+ RexNode node = rexCall.getOperands().get(4);
+ if (node instanceof RexCall
+ && ((RexCall) node).getOperator().getKind() ==
SqlKind.MAP_VALUE_CONSTRUCTOR) {
+ return (RexCall) node;
+ }
+ }
+ return null;
+ }
+
+ private static String getTask(RexCall rexCall) {
+ final RexNode taskNode = rexCall.getOperands().get(4);
+ String task = null;
+ if (taskNode instanceof RexLiteral) {
+ task = ((RexLiteral)
taskNode).getValueAs(NlsString.class).getValue();
+ if (task == null || task.isEmpty()) {
+ task = null;
+ }
+ }
+ if (task == null) {
+ throw new ValidationException(
+ "Task type must be specified as a parameter to the
ML_EVALUATE function.");
+ }
+ TaskType.throwOrReturnInvalidTaskType(task, true);
+ return task;
+ }
+
+ /** Rule configuration. */
+ @Value.Immutable(singleton = false)
+ public interface Config extends RelRule.Config {
+ Config DEFAULT =
+ ImmutableExpandMLEvaluateTableFunctionRule.Config.builder()
+ .build()
+ .withDescription("ExpandMLEvaluateTableFunctionRule")
+ .as(Config.class)
+ .onMLEvaluateFunction();
+
+ @Override
+ default RelOptRule toRule() {
+ return new ExpandMLEvaluateTableFunctionRule(this);
+ }
+
+ default Config onMLEvaluateFunction() {
+ final RelRule.OperandTransform scanTransform =
+ operandBuilder ->
+ operandBuilder
+ .operand(LogicalTableFunctionScan.class)
+ .predicate(
+ scan -> {
+ if (!(scan.getCall()
instanceof RexCall)) {
+ return false;
+ }
+ RexCall call = (RexCall)
scan.getCall();
+ if (!(call.getOperator()
+ instanceof
SqlMLEvaluateTableFunction)) {
+ return false;
+ }
+ final RexModelCall modelCall =
+ (RexModelCall)
call.getOperands().get(1);
+ return
modelCall.getModelProvider()
Review Comment:
Why we need to check model privder type?
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ExpandMLEvaluateTableFunctionRule.java:
##########
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.catalog.ContextResolvedFunction;
+import org.apache.flink.table.functions.FunctionIdentifier;
+import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
+import org.apache.flink.table.ml.PredictRuntimeProvider;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.planner.calcite.FlinkContext;
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.planner.calcite.RexModelCall;
+import
org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.MLEvaluationAggregationFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLPredictTableFunction;
+import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCallBinding;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.NlsString;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Rule that expands ML evaluation table function calls.
+ *
+ * <p>This rule matches {@link FlinkLogicalTableFunctionScan} with a {@link
+ * SqlMLEvaluateTableFunction} call and expands it into ml predict table
function and an aggregation
+ * function following it.
+ */
+@Internal
[email protected]
+public class ExpandMLEvaluateTableFunctionRule
+ extends RelRule<ExpandMLEvaluateTableFunctionRule.Config> {
+
+ public static final RelOptRule INSTANCE = new
ExpandMLEvaluateTableFunctionRule(Config.DEFAULT);
+
+ public ExpandMLEvaluateTableFunctionRule(Config config) {
+ super(config);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ final LogicalTableFunctionScan scan = call.rel(0);
+ final RelDataType resultType = scan.getRowType();
+ final RelBuilder relBuilder = call.builder().push(scan.getInput(0));
+
+ final RexCall rexCall = (RexCall) scan.getCall();
+
+ RelDataType predictOutputType = addPredictTableFunction(relBuilder,
rexCall);
+ addProjection(relBuilder, rexCall, predictOutputType);
+ addAggregate(relBuilder, rexCall, resultType);
+
+ call.transformTo(relBuilder.build());
+ }
+
+ private void addAggregate(RelBuilder relBuilder, RexCall rexCall,
RelDataType resultType) {
+ final String task = getTask(rexCall);
+ final MLEvaluationAggregationFunction aggregationFunction =
+ new MLEvaluationAggregationFunction(task);
+ final FlinkContext context =
ShortcutUtils.unwrapContext(relBuilder.getCluster());
+ final FlinkTypeFactory typeFactory =
+ ShortcutUtils.unwrapTypeFactory(relBuilder.getCluster());
+ relBuilder.aggregate(
+ relBuilder.groupKey(),
+ List.of(
+ AggregateCall.create(
+ BridgingSqlAggFunction.of(
+ context,
+ typeFactory,
+ ContextResolvedFunction.temporary(
Review Comment:
use
org.apache.flink.table.catalog.ContextResolvedFunction#permanent(org.apache.flink.table.functions.FunctionIdentifier,
org.apache.flink.table.functions.FunctionDefinition) to build. I don't think
it should be a builtin function.
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ExpandMLEvaluateTableFunctionRule.java:
##########
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.catalog.ContextResolvedFunction;
+import org.apache.flink.table.functions.FunctionIdentifier;
+import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
+import org.apache.flink.table.ml.PredictRuntimeProvider;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.planner.calcite.FlinkContext;
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.planner.calcite.RexModelCall;
+import
org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.MLEvaluationAggregationFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLPredictTableFunction;
+import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCallBinding;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.NlsString;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Rule that expands ML evaluation table function calls.
+ *
+ * <p>This rule matches {@link FlinkLogicalTableFunctionScan} with a {@link
+ * SqlMLEvaluateTableFunction} call and expands it into ml predict table
function and an aggregation
+ * function following it.
+ */
+@Internal
[email protected]
+public class ExpandMLEvaluateTableFunctionRule
+ extends RelRule<ExpandMLEvaluateTableFunctionRule.Config> {
+
+ public static final RelOptRule INSTANCE = new
ExpandMLEvaluateTableFunctionRule(Config.DEFAULT);
+
+ public ExpandMLEvaluateTableFunctionRule(Config config) {
+ super(config);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ final LogicalTableFunctionScan scan = call.rel(0);
+ final RelDataType resultType = scan.getRowType();
+ final RelBuilder relBuilder = call.builder().push(scan.getInput(0));
+
+ final RexCall rexCall = (RexCall) scan.getCall();
+
+ RelDataType predictOutputType = addPredictTableFunction(relBuilder,
rexCall);
+ addProjection(relBuilder, rexCall, predictOutputType);
+ addAggregate(relBuilder, rexCall, resultType);
+
+ call.transformTo(relBuilder.build());
+ }
+
+ private void addAggregate(RelBuilder relBuilder, RexCall rexCall,
RelDataType resultType) {
+ final String task = getTask(rexCall);
+ final MLEvaluationAggregationFunction aggregationFunction =
+ new MLEvaluationAggregationFunction(task);
+ final FlinkContext context =
ShortcutUtils.unwrapContext(relBuilder.getCluster());
+ final FlinkTypeFactory typeFactory =
+ ShortcutUtils.unwrapTypeFactory(relBuilder.getCluster());
+ relBuilder.aggregate(
+ relBuilder.groupKey(),
+ List.of(
+ AggregateCall.create(
+ BridgingSqlAggFunction.of(
+ context,
+ typeFactory,
+ ContextResolvedFunction.temporary(
+
FunctionIdentifier.of("ml_evaluate"),
+ aggregationFunction)),
+ false,
+ false,
+ false,
+ List.of(0, 1),
+ -1,
+ null,
+ RelCollations.EMPTY,
+ resultType.getFieldList().get(0).getType(),
+ "result")));
+ }
+
+ private void addProjection(
+ RelBuilder relBuilder, RexCall rexCall, RelDataType
predictOutputType) {
+ final RexCall labelDescriptor = (RexCall) rexCall.getOperands().get(2);
+ final String labelCol =
+ ((RexLiteral) labelDescriptor.getOperands().get(0))
+ .getValueAs(NlsString.class)
+ .getValue();
+
+ // Project the label column and the last column (prediction). Only one
label and prediction
+ // column is expected. Validation is done in
SqlMLEvaluateTableFunction.
+ final List<RexNode> projectFields =
+ predictOutputType.getFieldList().stream()
+ .filter(
+ field ->
+ field.getName().equals(labelCol)
+ || field.getIndex()
+ ==
predictOutputType.getFieldCount() - 1)
+ .map(field -> relBuilder.field(field.getIndex()))
+ .collect(Collectors.toList());
+ relBuilder.project(projectFields);
Review Comment:
```
relBuilder.project(
relBuilder.field(relBuilder.peek().getRowType().getFieldCount() - 1),
relBuilder.field(labelCol));
```
##########
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ExpandMLEvaluateTableFunctionRule.java:
##########
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.catalog.ContextResolvedFunction;
+import org.apache.flink.table.functions.FunctionIdentifier;
+import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
+import org.apache.flink.table.ml.PredictRuntimeProvider;
+import org.apache.flink.table.ml.TaskType;
+import org.apache.flink.table.planner.calcite.FlinkContext;
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.planner.calcite.RexModelCall;
+import
org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.MLEvaluationAggregationFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction;
+import
org.apache.flink.table.planner.functions.sql.ml.SqlMLPredictTableFunction;
+import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCallBinding;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.NlsString;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Rule that expands ML evaluation table function calls.
+ *
+ * <p>This rule matches {@link FlinkLogicalTableFunctionScan} with a {@link
+ * SqlMLEvaluateTableFunction} call and expands it into ml predict table
function and an aggregation
+ * function following it.
+ */
+@Internal
[email protected]
+public class ExpandMLEvaluateTableFunctionRule
+ extends RelRule<ExpandMLEvaluateTableFunctionRule.Config> {
+
+ public static final RelOptRule INSTANCE = new
ExpandMLEvaluateTableFunctionRule(Config.DEFAULT);
+
+ public ExpandMLEvaluateTableFunctionRule(Config config) {
+ super(config);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ final LogicalTableFunctionScan scan = call.rel(0);
+ final RelDataType resultType = scan.getRowType();
+ final RelBuilder relBuilder = call.builder().push(scan.getInput(0));
+
+ final RexCall rexCall = (RexCall) scan.getCall();
+
+ RelDataType predictOutputType = addPredictTableFunction(relBuilder,
rexCall);
+ addProjection(relBuilder, rexCall, predictOutputType);
+ addAggregate(relBuilder, rexCall, resultType);
+
+ call.transformTo(relBuilder.build());
+ }
+
+ private void addAggregate(RelBuilder relBuilder, RexCall rexCall,
RelDataType resultType) {
+ final String task = getTask(rexCall);
+ final MLEvaluationAggregationFunction aggregationFunction =
+ new MLEvaluationAggregationFunction(task);
+ final FlinkContext context =
ShortcutUtils.unwrapContext(relBuilder.getCluster());
+ final FlinkTypeFactory typeFactory =
+ ShortcutUtils.unwrapTypeFactory(relBuilder.getCluster());
+ relBuilder.aggregate(
+ relBuilder.groupKey(),
+ List.of(
+ AggregateCall.create(
+ BridgingSqlAggFunction.of(
+ context,
+ typeFactory,
+ ContextResolvedFunction.temporary(
+
FunctionIdentifier.of("ml_evaluate"),
+ aggregationFunction)),
+ false,
+ false,
+ false,
+ List.of(0, 1),
+ -1,
+ null,
+ RelCollations.EMPTY,
+ resultType.getFieldList().get(0).getType(),
+ "result")));
+ }
+
+ private void addProjection(
+ RelBuilder relBuilder, RexCall rexCall, RelDataType
predictOutputType) {
+ final RexCall labelDescriptor = (RexCall) rexCall.getOperands().get(2);
+ final String labelCol =
+ ((RexLiteral) labelDescriptor.getOperands().get(0))
+ .getValueAs(NlsString.class)
+ .getValue();
+
+ // Project the label column and the last column (prediction). Only one
label and prediction
+ // column is expected. Validation is done in
SqlMLEvaluateTableFunction.
+ final List<RexNode> projectFields =
+ predictOutputType.getFieldList().stream()
+ .filter(
+ field ->
+ field.getName().equals(labelCol)
+ || field.getIndex()
+ ==
predictOutputType.getFieldCount() - 1)
Review Comment:
It's better we can do the validation in the Sql2Rel phase.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]