This is an automated email from the ASF dual-hosted git repository. kurt pushed a commit to branch release-1.9 in repository https://gitbox.apache.org/repos/asf/flink.git
commit 7b5c1f87d2a730c691c44bc21bb06f5ebb50235f Author: JingsongLi <lzljs3620...@aliyun.com> AuthorDate: Sun Jul 28 20:23:56 2019 +0800 [FLINK-13225][table-planner-blink] Fix type inference for hive udaf (cherry picked from commit 3d502e3069faa0e898b9b0a1059622eac2b1c2f0) --- .../catalog/FunctionCatalogOperatorTable.java | 12 +++- .../planner/expressions/SqlAggFunctionVisitor.java | 3 +- .../functions/utils/HiveAggSqlFunction.java | 83 ++++++++++++++++++++++ .../planner/functions/utils/AggSqlFunction.scala | 10 ++- .../planner/plan/utils/AggFunctionFactory.scala | 11 ++- 5 files changed, 113 insertions(+), 6 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java index bc60f27..3c875dd 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java @@ -27,6 +27,7 @@ import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.ScalarFunctionDefinition; import org.apache.flink.table.functions.TableFunctionDefinition; import org.apache.flink.table.planner.calcite.FlinkTypeFactory; +import org.apache.flink.table.planner.functions.utils.HiveAggSqlFunction; import org.apache.flink.table.planner.functions.utils.HiveScalarSqlFunction; import org.apache.flink.table.planner.functions.utils.HiveTableSqlFunction; import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils; @@ -99,7 +100,16 @@ public class FunctionCatalogOperatorTable implements SqlOperatorTable { String name, FunctionDefinition functionDefinition) { if (functionDefinition instanceof AggregateFunctionDefinition) { - return convertAggregateFunction(name, (AggregateFunctionDefinition) functionDefinition); + AggregateFunctionDefinition def = (AggregateFunctionDefinition) functionDefinition; + if (isHiveFunc(def.getAggregateFunction())) { + return Optional.of(new HiveAggSqlFunction( + name, + name, + def.getAggregateFunction(), + typeFactory)); + } else { + return convertAggregateFunction(name, (AggregateFunctionDefinition) functionDefinition); + } } else if (functionDefinition instanceof ScalarFunctionDefinition) { ScalarFunctionDefinition def = (ScalarFunctionDefinition) functionDefinition; if (isHiveFunc(def.getScalarFunction())) { diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java index 8221a78..4b3b9a7 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/SqlAggFunctionVisitor.java @@ -90,7 +90,8 @@ public class SqlAggFunctionVisitor extends ExpressionDefaultVisitor<SqlAggFuncti fromLegacyInfoToDataType(aggDef.getResultTypeInfo()), fromLegacyInfoToDataType(aggDef.getAccumulatorTypeInfo()), typeFactory, - aggFunc.requiresOver()); + aggFunc.requiresOver(), + scala.Option.empty()); } else { throw new UnsupportedOperationException("TableAggregateFunction is not supported yet!"); } diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveAggSqlFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveAggSqlFunction.java new file mode 100644 index 0000000..54bde22 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveAggSqlFunction.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.utils; + +import org.apache.flink.api.java.typeutils.GenericTypeInfo; +import org.apache.flink.table.functions.AggregateFunction; +import org.apache.flink.table.planner.calcite.FlinkTypeFactory; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.util.InstantiationUtil; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.type.SqlReturnTypeInference; + +import java.io.IOException; +import java.util.List; + +import scala.Some; + +import static org.apache.flink.table.planner.functions.utils.HiveFunctionUtils.invokeGetResultType; +import static org.apache.flink.table.planner.functions.utils.HiveFunctionUtils.invokeSetArgs; +import static org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataType; + +/** + * Hive {@link AggSqlFunction}. + * Override getFunction to clone function and invoke {@code HiveUDAF#setArgumentTypesAndConstants}. + * Override SqlReturnTypeInference to invoke {@code HiveUDAF#getHiveResultType} instead of + * {@code HiveUDAF#getResultType}. + * + * @deprecated TODO hack code, its logical should be integrated to AggSqlFunction + */ +@Deprecated +public class HiveAggSqlFunction extends AggSqlFunction { + + private final AggregateFunction aggregateFunction; + + public HiveAggSqlFunction(String name, String displayName, + AggregateFunction aggregateFunction, FlinkTypeFactory typeFactory) { + super(name, displayName, aggregateFunction, fromLegacyInfoToDataType(new GenericTypeInfo<>(Object.class)), + fromLegacyInfoToDataType(new GenericTypeInfo<>(Object.class)), typeFactory, false, + new Some<>(createReturnTypeInference(aggregateFunction, typeFactory))); + this.aggregateFunction = aggregateFunction; + } + + @Override + public AggregateFunction makeFunction(Object[] constantArguments, LogicalType[] argTypes) { + AggregateFunction clone; + try { + clone = InstantiationUtil.clone(aggregateFunction); + } catch (IOException | ClassNotFoundException e) { + throw new RuntimeException(e); + } + return (AggregateFunction) invokeSetArgs(clone, constantArguments, argTypes); + } + + private static SqlReturnTypeInference createReturnTypeInference( + AggregateFunction function, FlinkTypeFactory typeFactory) { + return opBinding -> { + List<RelDataType> sqlTypes = opBinding.collectOperandTypes(); + LogicalType[] parameters = UserDefinedFunctionUtils.getOperandTypeArray(opBinding); + + Object[] constantArguments = new Object[sqlTypes.size()]; + // Can not touch the literals, Calcite make them in previous RelNode. + // In here, all inputs are input refs. + return invokeGetResultType(function, constantArguments, parameters, typeFactory); + }; + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/AggSqlFunction.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/AggSqlFunction.scala index b2b14fc..5c5f744 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/AggSqlFunction.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/AggSqlFunction.scala @@ -54,10 +54,12 @@ class AggSqlFunction( val externalResultType: DataType, val externalAccType: DataType, typeFactory: FlinkTypeFactory, - requiresOver: Boolean) + requiresOver: Boolean, + returnTypeInfer: Option[SqlReturnTypeInference] = None) extends SqlUserDefinedAggFunction( new SqlIdentifier(name, SqlParserPos.ZERO), - createReturnTypeInference(fromDataTypeToLogicalType(externalResultType), typeFactory), + returnTypeInfer.getOrElse(createReturnTypeInference( + fromDataTypeToLogicalType(externalResultType), typeFactory)), createOperandTypeInference(name, aggregateFunction, typeFactory), createOperandTypeChecker(name, aggregateFunction), // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function @@ -69,7 +71,9 @@ class AggSqlFunction( typeFactory ) { - def getFunction: AggregateFunction[_, _] = aggregateFunction + def makeFunction( + constants: Array[AnyRef], argTypes: Array[LogicalType]): AggregateFunction[_, _] = + aggregateFunction override def isDeterministic: Boolean = aggregateFunction.isDeterministic diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala index e8707f2..d5f9e56 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala @@ -43,6 +43,8 @@ import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.sql.fun._ import org.apache.calcite.sql.{SqlAggFunction, SqlKind, SqlRankFunction} +import java.util + import scala.collection.JavaConversions._ /** @@ -122,7 +124,14 @@ class AggFunctionFactory( case a: SqlAggFunction if a.getKind == SqlKind.COLLECT => createCollectAggFunction(argTypes) - case udagg: AggSqlFunction => udagg.getFunction + case udagg: AggSqlFunction => + // Can not touch the literals, Calcite make them in previous RelNode. + // In here, all inputs are input refs. + val constants = new util.ArrayList[AnyRef]() + argTypes.foreach(t => constants.add(null)) + udagg.makeFunction( + constants.toArray, + argTypes) case unSupported: SqlAggFunction => throw new TableException(s"Unsupported Function: '${unSupported.getName}'")