This is an automated email from the ASF dual-hosted git repository. kurt pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit e08117f89ba012c37f70c4aad99d569d8a9ba2b6 Author: JingsongLi <lzljs3620...@aliyun.com> AuthorDate: Sun Jul 28 20:16:14 2019 +0800 [FLINK-13225][table-planner-blink] Fix type inference for hive udf --- .../catalog/FunctionCatalogOperatorTable.java | 14 +++- .../planner/functions/utils/HiveFunctionUtils.java | 80 ++++++++++++++++++++ .../functions/utils/HiveScalarSqlFunction.java | 85 ++++++++++++++++++++++ .../table/planner/codegen/ExprCodeGenerator.scala | 18 ++++- .../functions/utils/ScalarSqlFunction.scala | 9 ++- .../functions/utils/UserDefinedFunctionUtils.scala | 4 + 6 files changed, 205 insertions(+), 5 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java index 87a7bcb..ddf8f60 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/FunctionCatalogOperatorTable.java @@ -26,6 +26,7 @@ import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.ScalarFunctionDefinition; import org.apache.flink.table.functions.TableFunctionDefinition; import org.apache.flink.table.planner.calcite.FlinkTypeFactory; +import org.apache.flink.table.planner.functions.utils.HiveScalarSqlFunction; import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils; import org.apache.flink.table.types.utils.TypeConversions; @@ -40,6 +41,8 @@ import org.apache.calcite.sql.validate.SqlNameMatcher; import java.util.List; import java.util.Optional; +import static org.apache.flink.table.planner.functions.utils.HiveFunctionUtils.isHiveFunc; + /** * Thin adapter between {@link SqlOperatorTable} and {@link FunctionCatalog}. */ @@ -92,7 +95,16 @@ public class FunctionCatalogOperatorTable implements SqlOperatorTable { if (functionDefinition instanceof AggregateFunctionDefinition) { return convertAggregateFunction(name, (AggregateFunctionDefinition) functionDefinition); } else if (functionDefinition instanceof ScalarFunctionDefinition) { - return convertScalarFunction(name, (ScalarFunctionDefinition) functionDefinition); + ScalarFunctionDefinition def = (ScalarFunctionDefinition) functionDefinition; + if (isHiveFunc(def.getScalarFunction())) { + return Optional.of(new HiveScalarSqlFunction( + name, + name, + def.getScalarFunction(), + typeFactory)); + } else { + return convertScalarFunction(name, def); + } } else if (functionDefinition instanceof TableFunctionDefinition && category != null && category.isTableFunction()) { diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveFunctionUtils.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveFunctionUtils.java new file mode 100644 index 0000000..13a82cb --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveFunctionUtils.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.utils; + +import org.apache.flink.table.planner.calcite.FlinkTypeFactory; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.utils.TypeConversions; + +import org.apache.calcite.rel.type.RelDataType; + +import java.io.Serializable; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +import static org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType; + +/** + * Hack utils for hive function. + */ +public class HiveFunctionUtils { + + public static boolean isHiveFunc(Object function) { + try { + getSetArgsMethod(function); + return true; + } catch (NoSuchMethodException e) { + return false; + } + } + + private static Method getSetArgsMethod(Object function) throws NoSuchMethodException { + return function.getClass().getMethod( + "setArgumentTypesAndConstants", Object[].class, DataType[].class); + + } + + static Serializable invokeSetArgs( + Serializable function, Object[] constantArguments, LogicalType[] argTypes) { + try { + // See hive HiveFunction + Method method = getSetArgsMethod(function); + method.invoke(function, constantArguments, TypeConversions.fromLogicalToDataType(argTypes)); + return function; + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException(e); + } + } + + static RelDataType invokeGetResultType( + Object function, Object[] constantArguments, LogicalType[] argTypes, + FlinkTypeFactory typeFactory) { + try { + // See hive HiveFunction + Method method = function.getClass() + .getMethod("getHiveResultType", Object[].class, DataType[].class); + DataType resultType = (DataType) method.invoke( + function, constantArguments, TypeConversions.fromLogicalToDataType(argTypes)); + return typeFactory.createFieldTypeFromLogicalType(fromDataTypeToLogicalType(resultType)); + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException(e); + } + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveScalarSqlFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveScalarSqlFunction.java new file mode 100644 index 0000000..a44576a --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/utils/HiveScalarSqlFunction.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.utils; + +import org.apache.flink.table.functions.ScalarFunction; +import org.apache.flink.table.planner.calcite.FlinkTypeFactory; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.util.InstantiationUtil; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.type.SqlReturnTypeInference; + +import java.io.IOException; +import java.util.List; + +import scala.Some; + +import static org.apache.flink.table.planner.functions.utils.HiveFunctionUtils.invokeGetResultType; +import static org.apache.flink.table.planner.functions.utils.HiveFunctionUtils.invokeSetArgs; +import static org.apache.flink.table.runtime.types.ClassLogicalTypeConverter.getDefaultExternalClassForType; + +/** + * Hive {@link ScalarSqlFunction}. + * Override getFunction to clone function and invoke {@code HiveScalarFunction#setArgumentTypesAndConstants}. + * Override SqlReturnTypeInference to invoke {@code HiveScalarFunction#getHiveResultType} instead of + * {@code HiveScalarFunction#getResultType(Class[])}. + * + * @deprecated TODO hack code, its logical should be integrated to ScalarSqlFunction + */ +@Deprecated +public class HiveScalarSqlFunction extends ScalarSqlFunction { + + private final ScalarFunction function; + + public HiveScalarSqlFunction( + String name, String displayName, + ScalarFunction function, FlinkTypeFactory typeFactory) { + super(name, displayName, function, typeFactory, new Some<>(createReturnTypeInference(function, typeFactory))); + this.function = function; + } + + @Override + public ScalarFunction makeFunction(Object[] constantArguments, LogicalType[] argTypes) { + ScalarFunction clone; + try { + clone = InstantiationUtil.clone(function); + } catch (IOException | ClassNotFoundException e) { + throw new RuntimeException(e); + } + return (ScalarFunction) invokeSetArgs(clone, constantArguments, argTypes); + } + + private static SqlReturnTypeInference createReturnTypeInference( + ScalarFunction function, FlinkTypeFactory typeFactory) { + return opBinding -> { + List<RelDataType> sqlTypes = opBinding.collectOperandTypes(); + LogicalType[] parameters = UserDefinedFunctionUtils.getOperandTypeArray(opBinding); + + Object[] constantArguments = new Object[sqlTypes.size()]; + for (int i = 0; i < sqlTypes.size(); i++) { + if (!opBinding.isOperandNull(i, false) && opBinding.isOperandLiteral(i, false)) { + constantArguments[i] = opBinding.getOperandLiteralValue( + i, getDefaultExternalClassForType(parameters[i])); + } + } + return invokeGetResultType(function, constantArguments, parameters, typeFactory); + }; + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala index e641708..7c55d73 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala @@ -20,6 +20,7 @@ package org.apache.flink.table.planner.codegen import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.table.api.TableException +import org.apache.flink.table.dataformat.DataFormatConverters.{DataFormatConverter, getConverterForDataType} import org.apache.flink.table.dataformat._ import org.apache.flink.table.planner.calcite.{FlinkTypeFactory, RexAggLocalVariable, RexDistinctKeyVariable} import org.apache.flink.table.planner.codegen.CodeGenUtils.{requireTemporal, requireTimeInterval, _} @@ -30,6 +31,7 @@ import org.apache.flink.table.planner.codegen.calls.{FunctionGenerator, ScalarFu import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable._ import org.apache.flink.table.planner.functions.sql.SqlThrowExceptionFunction import org.apache.flink.table.planner.functions.utils.{ScalarSqlFunction, TableSqlFunction} +import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable import org.apache.flink.table.runtime.typeutils.TypeCheckUtils import org.apache.flink.table.runtime.typeutils.TypeCheckUtils.{isNumeric, isTemporal, isTimeInterval} @@ -730,7 +732,9 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) GeneratedExpression(nullValue.resultTerm, nullValue.nullTerm, code, resultType) case ssf: ScalarSqlFunction => - new ScalarFunctionCallGen(ssf.getScalarFunction).generate(ctx, operands, resultType) + new ScalarFunctionCallGen( + ssf.makeFunction(getOperandLiterals(operands), operands.map(_.resultType).toArray)) + .generate(ctx, operands, resultType) case tsf: TableSqlFunction => new TableFunctionCallGen(tsf.getTableFunction).generate(ctx, operands, resultType) @@ -757,4 +761,16 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) throw new CodeGenException(s"Unsupported call: $explainCall") } } + + def getOperandLiterals(operands: Seq[GeneratedExpression]): Array[AnyRef] = { + operands.map { expr => + expr.literalValue match { + case None => null + case Some(literal) => + getConverterForDataType(fromLogicalTypeToDataType(expr.resultType)) + .asInstanceOf[DataFormatConverter[AnyRef, AnyRef] + ].toExternal(literal.asInstanceOf[AnyRef]) + } + }.toArray + } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/ScalarSqlFunction.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/ScalarSqlFunction.scala index 35b5b5d..159d4f1 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/ScalarSqlFunction.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/ScalarSqlFunction.scala @@ -26,6 +26,7 @@ import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils.{ import org.apache.flink.table.runtime.types.ClassLogicalTypeConverter.getDefaultExternalClassForType import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType import org.apache.flink.table.runtime.types.TypeInfoLogicalTypeConverter.fromTypeInfoToLogicalType +import org.apache.flink.table.types.logical.LogicalType import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.sql._ @@ -47,16 +48,18 @@ class ScalarSqlFunction( name: String, displayName: String, scalarFunction: ScalarFunction, - typeFactory: FlinkTypeFactory) + typeFactory: FlinkTypeFactory, + returnTypeInfer: Option[SqlReturnTypeInference] = None) extends SqlFunction( new SqlIdentifier(name, SqlParserPos.ZERO), - createReturnTypeInference(name, scalarFunction, typeFactory), + returnTypeInfer.getOrElse(createReturnTypeInference(name, scalarFunction, typeFactory)), createOperandTypeInference(name, scalarFunction, typeFactory), createOperandTypeChecker(name, scalarFunction), null, SqlFunctionCategory.USER_DEFINED_FUNCTION) { - def getScalarFunction: ScalarFunction = scalarFunction + def makeFunction(constants: Array[AnyRef], argTypes: Array[LogicalType]): ScalarFunction = + scalarFunction override def isDeterministic: Boolean = scalarFunction.isDeterministic diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/UserDefinedFunctionUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/UserDefinedFunctionUtils.scala index 1de25dd..e565a6e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/functions/utils/UserDefinedFunctionUtils.scala @@ -754,6 +754,10 @@ object UserDefinedFunctionUtils { } } + def getOperandTypeArray(callBinding: SqlOperatorBinding): Array[LogicalType] = { + getOperandType(callBinding).toArray + } + def getOperandType(callBinding: SqlOperatorBinding): Seq[LogicalType] = { val operandTypes = for (i <- 0 until callBinding.getOperandCount) yield callBinding.getOperandType(i)