[ https://issues.apache.org/jira/browse/FLINK-5832?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16605885#comment-16605885 ]
ASF GitHub Bot commented on FLINK-5832: --------------------------------------- twalthr closed pull request #3456: [FLINK-5832] [table] Support for simple hive UDF URL: https://github.com/apache/flink/pull/3456 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/flink-connectors/flink-hcatalog/pom.xml b/flink-connectors/flink-hcatalog/pom.xml index ba0e142692b..afb6914281b 100644 --- a/flink-connectors/flink-hcatalog/pom.xml +++ b/flink-connectors/flink-hcatalog/pom.xml @@ -42,6 +42,12 @@ under the License. <scope>provided</scope> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-table_2.10</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-hadoop-compatibility_2.10</artifactId> @@ -50,8 +56,8 @@ under the License. <dependency> <groupId>org.apache.hive.hcatalog</groupId> - <artifactId>hcatalog-core</artifactId> - <version>0.12.0</version> + <artifactId>hive-hcatalog-core</artifactId> + <version>0.13.0</version> <exclusions> <exclusion> <groupId>org.json</groupId> diff --git a/flink-connectors/flink-hcatalog/src/main/scala/org/apache/flink/table/hive/functions/HiveFunctionWrapper.scala b/flink-connectors/flink-hcatalog/src/main/scala/org/apache/flink/table/hive/functions/HiveFunctionWrapper.scala new file mode 100644 index 00000000000..27f042036d9 --- /dev/null +++ b/flink-connectors/flink-hcatalog/src/main/scala/org/apache/flink/table/hive/functions/HiveFunctionWrapper.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.hive.functions + +import org.apache.hadoop.hive.ql.exec.UDF + +private[hive] case class HiveFunctionWrapper( + var functionClassName: String, + private var instance: AnyRef = null) { + + def createFunction[UDFType <: AnyRef](): UDFType = { + if (instance != null) { + instance.asInstanceOf[UDFType] + } else { + val func = getClassLoader.loadClass(functionClassName).newInstance().asInstanceOf[UDFType] + if (!func.isInstanceOf[UDF]) { + instance = func + } + func + } + } + + def getClassLoader: ClassLoader = { + Thread.currentThread.getContextClassLoader + } + +} diff --git a/flink-connectors/flink-hcatalog/src/main/scala/org/apache/flink/table/hive/functions/HiveSimpleUDF.scala b/flink-connectors/flink-hcatalog/src/main/scala/org/apache/flink/table/hive/functions/HiveSimpleUDF.scala new file mode 100644 index 00000000000..00018bb47d0 --- /dev/null +++ b/flink-connectors/flink-hcatalog/src/main/scala/org/apache/flink/table/hive/functions/HiveSimpleUDF.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.hive.functions + +import java.lang.reflect.Method +import java.math.BigDecimal +import java.util + +import org.apache.flink.table.functions.ScalarFunction +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, UDF} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.typeinfo.{PrimitiveTypeInfo, TypeInfo, TypeInfoFactory} + +import scala.annotation.varargs + +/** + * A Hive UDF Wrapper which behaves as a Flink-table ScalarFunction. + * + * This class has to have a method with @varargs annotation. For scala will compile + * <code> eval(args: Any*) </code> to <code>eval(args: Seq)</code>. + * This will cause an exception in Janino compiler. + */ +class HiveSimpleUDF(className: String) extends ScalarFunction { + + @transient + private lazy val functionWrapper = HiveFunctionWrapper(className) + + @transient + private lazy val function = functionWrapper.createFunction[UDF]() + + @transient + private var typeInfos: util.List[TypeInfo] = _ + + @transient + private var objectInspectors: Array[ObjectInspector] = _ + + @transient + private var conversionHelper: ConversionHelper = _ + + @transient + private var method: Method = _ + + @varargs + def eval(args: AnyRef*) : Any = { + if (null == typeInfos) { + typeInfos = new util.ArrayList[TypeInfo]() + args.foreach(arg => { + typeInfos.add(TypeInfoFactory.getPrimitiveTypeInfoFromJavaPrimitive(arg.getClass)) + }) + method = function.getResolver.getEvalMethod(typeInfos) + + objectInspectors = new Array[ObjectInspector](typeInfos.size()) + args.zipWithIndex.foreach { case (_, i) => + objectInspectors(i) = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + typeInfos.get(i).asInstanceOf[PrimitiveTypeInfo]) + } + conversionHelper = new ConversionHelper(method, objectInspectors) + } + + val mappedArgs = args.map { + case arg: BigDecimal => + arg.asInstanceOf[BigDecimal].doubleValue().asInstanceOf[AnyRef] + case arg: AnyRef => + arg + } + + FunctionRegistry.invoke(method, function, + conversionHelper.convertIfNecessary(mappedArgs: _*): _*) + } +} diff --git a/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/HiveScalarFunctionTest.scala b/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/HiveScalarFunctionTest.scala new file mode 100644 index 00000000000..98dd3666e37 --- /dev/null +++ b/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/HiveScalarFunctionTest.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.hive.functions + +import java.sql.{Date, Time, Timestamp} + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.table.api.Types +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.functions.ScalarFunction +import org.apache.flink.table.hive.functions.utils.{ExpressionTestBase, SimplePojo} +import org.apache.flink.types.Row +import org.junit.Test + +class HiveScalarFunctionTest extends ExpressionTestBase { + + @Test + def testHiveSimpleFunctions(): Unit = { + val HiveUDFAcos = new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFAcos") + testAllApis( + HiveUDFAcos(1.0), + "HiveUDFAcos(1.0)", + "HiveUDFAcos(1.0)", + "0.0" + ) + + val HiveUDFAscii = new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFAscii") + testAllApis( + HiveUDFAscii("0"), + "HiveUDFAscii('0')", + "HiveUDFAscii('0')", + "48" + ) + + val HiveUDFAsin = new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFAsin") + testAllApis( + HiveUDFAsin("0"), + "HiveUDFAsin('0')", + "HiveUDFAsin('0')", + "0.0" + ) + + val HiveUDFBin = new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFBin") + testAllApis( + HiveUDFBin(13), + "HiveUDFBin(13)", + "HiveUDFBin(13)", + "1101" + ) + + val HiveUDFConv = new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFConv") + testAllApis( + HiveUDFConv("100", 2, 10), + "HiveUDFConv('100', 2, 10)", + "HiveUDFConv('100', 2, 10)", + "4" + ) + testAllApis( + HiveUDFConv(-10, 16, -10), + "HiveUDFConv(-10, 16, -10)", + "HiveUDFConv(-10, 16, -10)", + "-16" + ) + + val HiveUDFCos = new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFCos") + testAllApis( + HiveUDFCos(0.0), + "HiveUDFCos(0.0)", + "HiveUDFCos(0.0)", + "1.0" + ) + + val HiveUDFDayOfMonth = new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFDayOfMonth") + testAllApis( + HiveUDFDayOfMonth("2009-07-30"), + "HiveUDFDayOfMonth('2009-07-30')", + "HiveUDFDayOfMonth('2009-07-30')", + "30" + ) + } + + // ---------------------------------------------------------------------------------------------- + + override def testData: Any = { + val testData = new Row(9) + testData.setField(0, 42) + testData.setField(1, "Test") + testData.setField(2, null) + testData.setField(3, SimplePojo("Bob", 36)) + testData.setField(4, Date.valueOf("1990-10-14")) + testData.setField(5, Time.valueOf("12:10:10")) + testData.setField(6, Timestamp.valueOf("1990-10-14 12:10:10")) + testData.setField(7, 12) + testData.setField(8, 1000L) + testData + } + + override def typeInfo: TypeInformation[Any] = { + new RowTypeInfo( + Types.INT, + Types.STRING, + Types.BOOLEAN, + TypeInformation.of(classOf[SimplePojo]), + Types.DATE, + Types.TIME, + Types.TIMESTAMP, + Types.INTERVAL_MONTHS, + Types.INTERVAL_MILLIS + ).asInstanceOf[TypeInformation[Any]] + } + + override def functions: Map[String, ScalarFunction] = Map( + "HiveUDFAcos" -> new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFAcos"), + "HiveUDFAscii" -> new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFAscii"), + "HiveUDFAsin" -> new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFAsin"), + "HiveUDFBin" -> new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFBin"), + "HiveUDFConv" -> new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFConv"), + "HiveUDFCos" -> new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFCos"), + "HiveUDFDayOfMonth" -> new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFDayOfMonth") + ) +} diff --git a/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/utils/ExpressionTestBase.scala b/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/utils/ExpressionTestBase.scala new file mode 100644 index 00000000000..f9eb1bf445f --- /dev/null +++ b/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/utils/ExpressionTestBase.scala @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.hive.functions.utils + +import java.util +import java.util.concurrent.Future + +import org.apache.calcite.plan.hep.{HepMatchOrder, HepPlanner, HepProgramBuilder} +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.`type`.SqlTypeName._ +import org.apache.calcite.sql2rel.RelDecorrelator +import org.apache.calcite.tools.{Programs, RelBuilder} +import org.apache.flink.api.common.TaskInfo +import org.apache.flink.api.common.accumulators.Accumulator +import org.apache.flink.api.common.functions._ +import org.apache.flink.api.common.functions.util.RuntimeUDFContext +import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.java.{DataSet => JDataSet} +import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment} +import org.apache.flink.configuration.Configuration +import org.apache.flink.core.fs.Path +import org.apache.flink.table.api.{BatchTableEnvironment, TableConfig, TableEnvironment} +import org.apache.flink.table.calcite.FlinkPlannerImpl +import org.apache.flink.table.codegen.{CodeGenerator, Compiler, GeneratedFunction} +import org.apache.flink.table.expressions.{Expression, ExpressionParser} +import org.apache.flink.table.functions.ScalarFunction +import org.apache.flink.table.plan.nodes.dataset.{DataSetCalc, DataSetConvention} +import org.apache.flink.table.plan.rules.FlinkRuleSets +import org.apache.flink.types.Row +import org.junit.Assert._ +import org.junit.{After, Before} +import org.mockito.Mockito._ + +import scala.collection.mutable + +/** + * Base test class for expression tests. + */ +abstract class ExpressionTestBase { + + private val testExprs = mutable.ArrayBuffer[(RexNode, String)]() + + // setup test utils + private val tableName = "testTable" + private val context = prepareContext(typeInfo) + private val planner = new FlinkPlannerImpl( + context._2.getFrameworkConfig, + context._2.getPlanner, + context._2.getTypeFactory) + private val optProgram = Programs.ofRules(FlinkRuleSets.DATASET_OPT_RULES) + + private def hepPlanner = { + val builder = new HepProgramBuilder + builder.addMatchOrder(HepMatchOrder.BOTTOM_UP) + val it = FlinkRuleSets.DATASET_NORM_RULES.iterator() + while (it.hasNext) { + builder.addRuleInstance(it.next()) + } + new HepPlanner(builder.build, context._2.getFrameworkConfig.getContext) + } + + private def prepareContext(typeInfo: TypeInformation[Any]) + : (RelBuilder, TableEnvironment, ExecutionEnvironment) = { + // create DataSetTable + val dataSetMock = mock(classOf[DataSet[Any]]) + val jDataSetMock = mock(classOf[JDataSet[Any]]) + when(dataSetMock.javaSet).thenReturn(jDataSetMock) + when(jDataSetMock.getType).thenReturn(typeInfo) + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + tEnv.registerDataSet(tableName, dataSetMock) + functions.foreach(f => tEnv.registerFunction(f._1, f._2)) + + // prepare RelBuilder + val relBuilder = tEnv.getRelBuilder + relBuilder.scan(tableName) + + (relBuilder, tEnv, env) + } + + def testData: Any + + def typeInfo: TypeInformation[Any] + + def functions: Map[String, ScalarFunction] = Map() + + @Before + def resetTestExprs() = { + testExprs.clear() + } + + @After + def evaluateExprs() = { + val relBuilder = context._1 + val config = new TableConfig() + val generator = new CodeGenerator(config, false, typeInfo) + + // cast expressions to String + val stringTestExprs = testExprs.map(expr => relBuilder.cast(expr._1, VARCHAR)) + + // generate code + val resultType = new RowTypeInfo(Seq.fill(testExprs.size)(STRING_TYPE_INFO): _*) + val genExpr = generator.generateResultExpression( + resultType, + resultType.getFieldNames, + stringTestExprs) + + val bodyCode = + s""" + |${genExpr.code} + |return ${genExpr.resultTerm}; + |""".stripMargin + + val genFunc = generator.generateFunction[MapFunction[Any, Row], Row]( + "TestFunction", + classOf[MapFunction[Any, Row]], + bodyCode, + resultType) + + // compile and evaluate + val clazz = new TestCompiler[MapFunction[Any, Row], Row]().compile(genFunc) + val mapper = clazz.newInstance() + + val isRichFunction = mapper.isInstanceOf[RichFunction] + + // call setRuntimeContext method and open method for RichFunction + if (isRichFunction) { + val richMapper = mapper.asInstanceOf[RichMapFunction[_, _]] + val t = new RuntimeUDFContext( + new TaskInfo("ExpressionTest", 1, 0, 1, 1), + null, + context._3.getConfig, + new util.HashMap[String, Future[Path]](), + new util.HashMap[String, Accumulator[_, _]](), + null) + richMapper.setRuntimeContext(t) + richMapper.open(new Configuration()) + } + + val result = mapper.map(testData) + + // call close method for RichFunction + if (isRichFunction) { + mapper.asInstanceOf[RichMapFunction[_, _]].close() + } + + // compare + testExprs + .zipWithIndex + .foreach { + case ((expr, expected), index) => + val actual = result.getField(index) + assertEquals( + s"Wrong result for: $expr", + expected, + if (actual == null) "null" else actual) + } + } + + private def addSqlTestExpr(sqlExpr: String, expected: String): Unit = { + // create RelNode from SQL expression + val parsed = planner.parse(s"SELECT $sqlExpr FROM $tableName") + val validated = planner.validate(parsed) + val converted = planner.rel(validated).rel + + val decorPlan = RelDecorrelator.decorrelateQuery(converted) + + // normalize + val normalizedPlan = if (FlinkRuleSets.DATASET_NORM_RULES.iterator().hasNext) { + val planner = hepPlanner + planner.setRoot(decorPlan) + planner.findBestExp + } else { + decorPlan + } + + // create DataSetCalc + val flinkOutputProps = converted.getTraitSet.replace(DataSetConvention.INSTANCE).simplify() + val dataSetCalc = optProgram.run(context._2.getPlanner, normalizedPlan, flinkOutputProps) + + // extract RexNode + val calcProgram = dataSetCalc + .asInstanceOf[DataSetCalc] + .calcProgram + val expanded = calcProgram.expandLocalRef(calcProgram.getProjectList.get(0)) + + testExprs += ((expanded, expected)) + } + + private def addTableApiTestExpr(tableApiExpr: Expression, expected: String): Unit = { + // create RelNode from Table API expression + val env = context._2 + val converted = env + .asInstanceOf[BatchTableEnvironment] + .scan(tableName) + .select(tableApiExpr) + .getRelNode + + // create DataSetCalc + val decorPlan = RelDecorrelator.decorrelateQuery(converted) + val flinkOutputProps = converted.getTraitSet.replace(DataSetConvention.INSTANCE).simplify() + val dataSetCalc = optProgram.run(context._2.getPlanner, decorPlan, flinkOutputProps) + + // extract RexNode + val calcProgram = dataSetCalc + .asInstanceOf[DataSetCalc] + .calcProgram + val expanded = calcProgram.expandLocalRef(calcProgram.getProjectList.get(0)) + + testExprs += ((expanded, expected)) + } + + private def addTableApiTestExpr(tableApiString: String, expected: String): Unit = { + addTableApiTestExpr(ExpressionParser.parseExpression(tableApiString), expected) + } + + def testAllApis( + expr: Expression, + exprString: String, + sqlExpr: String, + expected: String) + : Unit = { + addTableApiTestExpr(expr, expected) + addTableApiTestExpr(exprString, expected) + addSqlTestExpr(sqlExpr, expected) + } + + def testTableApi( + expr: Expression, + exprString: String, + expected: String) + : Unit = { + addTableApiTestExpr(expr, expected) + addTableApiTestExpr(exprString, expected) + } + + def testSqlApi( + sqlExpr: String, + expected: String) + : Unit = { + addSqlTestExpr(sqlExpr, expected) + } + + // ---------------------------------------------------------------------------------------------- + + // TestCompiler that uses current class loader + class TestCompiler[F <: Function, T <: Any] extends Compiler[F] { + def compile(genFunc: GeneratedFunction[F, T]): Class[F] = + compile(getClass.getClassLoader, genFunc.name, genFunc.code) + } +} diff --git a/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/utils/SimplePojo.scala b/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/utils/SimplePojo.scala new file mode 100644 index 00000000000..1ebe8bcad72 --- /dev/null +++ b/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/utils/SimplePojo.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.hive.functions.utils + +case class SimplePojo(name: String, age: Int) + diff --git a/flink-libraries/flink-table/pom.xml b/flink-libraries/flink-table/pom.xml index c6071b06162..2d9f6a38d73 100644 --- a/flink-libraries/flink-table/pom.xml +++ b/flink-libraries/flink-table/pom.xml @@ -92,7 +92,6 @@ under the License. </exclusions> </dependency> - <!-- test dependencies --> <dependency> diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala index 7ff18eb6332..2a8ba28c3ca 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala @@ -44,14 +44,22 @@ class ScalarFunctionCallGen( operands: Seq[GeneratedExpression]) : GeneratedExpression = { // determine function signature and result class - val matchingSignature = getSignature(scalarFunction, signature) + val matchingMethod = getEvalMethod(scalarFunction, signature) .getOrElse(throw new CodeGenException("No matching signature found.")) + val matchingSignature = matchingMethod.getParameterTypes val resultClass = getResultTypeClass(scalarFunction, matchingSignature) + // zip for variable signatures + var paramToOperands = matchingSignature.zip(operands) + if (operands.length > matchingSignature.length) { + operands.drop(matchingSignature.length).foreach(op => + paramToOperands = paramToOperands :+ + (matchingSignature.last.getComponentType, op) + ) + } + // convert parameters for function (output boxing) - val parameters = matchingSignature - .zip(operands) - .map { case (paramClass, operandExpr) => + val parameters = paramToOperands.map { case (paramClass, operandExpr) => if (paramClass.isPrimitive) { operandExpr } else if (ClassUtils.isPrimitiveWrapper(paramClass) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala index da652e043d1..11021203498 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala @@ -112,9 +112,16 @@ object ScalarSqlFunction { .getParameterTypes(foundSignature) .map(typeFactory.createTypeFromTypeInfo) - inferredTypes.zipWithIndex.foreach { - case (inferredType, i) => - operandTypes(i) = inferredType + operandTypes.zipWithIndex.foreach { + case (_, i) => + if (i < inferredTypes.length - 1) { + operandTypes(i) = inferredTypes(i) + } else if (null != inferredTypes.last.getComponentType) { + // last arguments is a collection, the array type + operandTypes(i) = inferredTypes.last.getComponentType + } else { + operandTypes(i) = inferredTypes.last + } } } } @@ -136,8 +143,18 @@ object ScalarSqlFunction { } override def getOperandCountRange: SqlOperandCountRange = { - val signatureLengths = signatures.map(_.length) - SqlOperandCountRanges.between(signatureLengths.min, signatureLengths.max) + var min = 255 + var max = -1 + signatures.foreach(sig => { + var len = sig.length + if (len > 0 && sig(sig.length - 1).isArray) { + max = 254 // according to JVM spec 4.3.3 + len = sig.length - 1 + } + max = Math.max(len, max) + min = Math.min(len, min) + }) + SqlOperandCountRanges.between(min, max) } override def checkOperandTypes( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala index 21d28b5e591..2f0756b9afc 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala @@ -78,20 +78,7 @@ object UserDefinedFunctionUtils { function: UserDefinedFunction, signature: Seq[TypeInformation[_]]) : Option[Array[Class[_]]] = { - // We compare the raw Java classes not the TypeInformation. - // TypeInformation does not matter during runtime (e.g. within a MapFunction). - val actualSignature = typeInfoToClass(signature) - val signatures = getSignatures(function) - - signatures - // go over all signatures and find one matching actual signature - .find { curSig => - // match parameters of signature to actual parameters - actualSignature.length == curSig.length && - curSig.zipWithIndex.forall { case (clazz, i) => - parameterTypeEquals(actualSignature(i), clazz) - } - } + getEvalMethod(function, signature).map(_.getParameterTypes) } /** @@ -106,16 +93,53 @@ object UserDefinedFunctionUtils { val actualSignature = typeInfoToClass(signature) val evalMethods = checkAndExtractEvalMethods(function) - evalMethods - // go over all eval methods and find one matching - .find { cur => - val signatures = cur.getParameterTypes - // match parameters of signature to actual parameters - actualSignature.length == signatures.length && - signatures.zipWithIndex.forall { case (clazz, i) => - parameterTypeEquals(actualSignature(i), clazz) + val filtered = evalMethods + // go over all eval methods and filter out matching methods + .filter { + case cur if !cur.isVarArgs => + val signatures = cur.getParameterTypes + // match parameters of signature to actual par(ameters + actualSignature.length == signatures.length && + signatures.zipWithIndex.forall { case (clazz, i) => + parameterTypeEquals(actualSignature(i), clazz) + } + case cur if cur.isVarArgs => + val signatures = cur.getParameterTypes + actualSignature.zipWithIndex.forall { + case (clazz, i) if i < signatures.length - 1 => + parameterTypeEquals(clazz, signatures(i)) + case (clazz, i) if i >= signatures.length - 1 => + parameterTypeEquals(clazz, signatures.last.getComponentType) + } || + (actualSignature.isEmpty && signatures.length == 1) + } + + // if there is a fixed method, compiler will call the method preferentially + val fixedMethods = filtered.count{!_.isVarArgs} + val found = filtered.filter { cur => + fixedMethods > 0 && !cur.isVarArgs || + fixedMethods == 0 && cur.isVarArgs + } + + if (found.isEmpty && + // does there exist scala type variable arguments + evalMethods.exists{ evalMethod => + val signatures = evalMethod.getParameterTypes + signatures.zipWithIndex.forall { + case (clazz, i) if i < signatures.length - 1 => + parameterTypeEquals(actualSignature(i), clazz) + case (clazz, i) if i == signatures.length - 1 => + clazz.getName.equals("scala.collection.Seq") } + }) { + throw new ValidationException("The 'eval' method do not support Scala type of " + + "variable args eg. Type*, please add a @scala.annotation.varargs annotation " + + "to your 'eval' method") + } else if (found.length > 1) { + throw new ValidationException("Found multiple 'eval' methods which " + + "match the signature.") } + found.headOption } /** @@ -133,7 +157,7 @@ object UserDefinedFunctionUtils { /** * Extracts "eval" methods and throws a [[ValidationException]] if no implementation - * can be found. + * can be found, or implementation does not match the requirements */ def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] = { val methods = function @@ -152,9 +176,9 @@ object UserDefinedFunctionUtils { s"Function class '${function.getClass.getCanonicalName}' does not implement at least " + s"one method named 'eval' which is public, not abstract and " + s"(in case of table functions) not static.") - } else { - methods } + + methods } def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = { @@ -317,6 +341,7 @@ object UserDefinedFunctionUtils { private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean = candidate == null || candidate == expected || + expected == classOf[Object] || expected.isPrimitive && Primitives.wrap(expected) == candidate || candidate == classOf[Date] && (expected == classOf[Int] || expected == classOf[JInt]) || candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) || diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java index e817f06b4e1..56f866d2b11 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java @@ -33,4 +33,24 @@ public String eval(Integer a, int b, Long c) { } } + public static class JavaFunc2 extends ScalarFunction { + public String eval(String s, Integer... a) { + int m = 1; + for (int n : a) { + m *= n; + } + return s + m; + } + } + + public static class JavaFunc3 extends ScalarFunction { + public int eval(String a, int... b) { + return b.length; + } + + public String eval(String c) { + return c; + } + } + } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala index a6c1760c9b8..4985e410eee 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala @@ -24,8 +24,8 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.types.Row -import org.apache.flink.table.api.Types -import org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, JavaFunc1} +import org.apache.flink.table.api.{Types, ValidationException} +import org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, JavaFunc1, JavaFunc2, JavaFunc3} import org.apache.flink.table.api.scala._ import org.apache.flink.table.expressions.utils._ import org.apache.flink.table.functions.ScalarFunction @@ -180,6 +180,85 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { "+0 00:00:01.000") } + @Test + def testVariableArgs(): Unit = { + testAllApis( + Func14(1, 2, 3, 4), + "Func14(1, 2, 3, 4)", + "Func14(1, 2, 3, 4)", + "10") + + // Test for empty arguments + testAllApis( + Func14(), + "Func14()", + "Func14()", + "0") + + // Test for override + testAllApis( + Func15("Hello"), + "Func15('Hello')", + "Func15('Hello')", + "Hello" + ) + + testAllApis( + Func15('f1), + "Func15(f1)", + "Func15(f1)", + "Test" + ) + + testAllApis( + Func15("Hello", 1, 2, 3), + "Func15('Hello', 1, 2, 3)", + "Func15('Hello', 1, 2, 3)", + "Hello3" + ) + + testAllApis( + Func16('f9), + "Func16(f9)", + "Func16(f9)", + "Hello, World" + ) + + try { + testAllApis( + Func17("Hello", "World"), + "Func17('Hello', 'World')", + "Func17('Hello', 'World')", + "Hello, World" + ) + throw new RuntimeException("Shouldn't be reached here!") + } catch { + case ex: ValidationException => + // It's normal + } + + val JavaFunc2 = new JavaFunc2 + testAllApis( + JavaFunc2("Hi", 1, 3, 5, 7), + "JavaFunc2('Hi', 1, 3, 5, 7)", + "JavaFunc2('Hi', 1, 3, 5, 7)", + "Hi105") + + // Test for override + val JavaFunc3 = new JavaFunc3 + testAllApis( + JavaFunc3("Hi"), + "JavaFunc3('Hi')", + "JavaFunc3('Hi')", + "Hi") + + testAllApis( + JavaFunc3('f1), + "JavaFunc3(f1)", + "JavaFunc3(f1)", + "Test") + } + @Test def testJavaBoxedPrimitives(): Unit = { val JavaFunc0 = new JavaFunc0() @@ -235,10 +314,11 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { "#Test") } + // ---------------------------------------------------------------------------------------------- override def testData: Any = { - val testData = new Row(9) + val testData = new Row(10) testData.setField(0, 42) testData.setField(1, "Test") testData.setField(2, null) @@ -248,6 +328,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { testData.setField(6, Timestamp.valueOf("1990-10-14 12:10:10")) testData.setField(7, 12) testData.setField(8, 1000L) + testData.setField(9, Seq("Hello", "World")) testData } @@ -261,7 +342,8 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { Types.TIME, Types.TIMESTAMP, Types.INTERVAL_MONTHS, - Types.INTERVAL_MILLIS + Types.INTERVAL_MILLIS, + TypeInformation.of(classOf[Seq[String]]) ).asInstanceOf[TypeInformation[Any]] } @@ -279,8 +361,14 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { "Func10" -> Func10, "Func11" -> Func11, "Func12" -> Func12, + "Func14" -> Func14, + "Func15" -> Func15, + "Func16" -> Func16, + "Func17" -> Func17, "JavaFunc0" -> new JavaFunc0, "JavaFunc1" -> new JavaFunc1, + "JavaFunc2" -> new JavaFunc2, + "JavaFunc3" -> new JavaFunc3, "RichFunc0" -> new RichFunc0, "RichFunc1" -> new RichFunc1, "RichFunc2" -> new RichFunc2 diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala index 1258137df7e..982a1d6625c 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala @@ -28,6 +28,8 @@ import org.junit.Assert import scala.collection.mutable import scala.io.Source +import scala.annotation.varargs + case class SimplePojo(name: String, age: Int) object Func0 extends ScalarFunction { @@ -227,3 +229,37 @@ class Func13(prefix: String) extends ScalarFunction { } } +object Func14 extends ScalarFunction { + + @varargs + def eval(a: Int*): Int = { + a.sum + } +} + +object Func15 extends ScalarFunction { + + @varargs + def eval(a: String, b: Int*): String = { + a + b.length + } + + def eval(a: String): String = { + a + } +} + +object Func16 extends ScalarFunction { + + def eval(a: Seq[String]): String = { + a.mkString(", ") + } +} + +object Func17 extends ScalarFunction { + + // Without @varargs, it will throw exception + def eval(a: String*): String = { + a.mkString(", ") + } +} ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > Support for simple hive UDF > --------------------------- > > Key: FLINK-5832 > URL: https://issues.apache.org/jira/browse/FLINK-5832 > Project: Flink > Issue Type: Sub-task > Components: Table API & SQL > Reporter: Zhuoluo Yang > Assignee: Zhuoluo Yang > Priority: Major > Labels: pull-request-available > > The first step of FLINK-5802 is to support simple Hive UDF. -- This message was sent by Atlassian JIRA (v7.6.3#76005)