Repository: flink Updated Branches: refs/heads/master 9b179beae -> 04aee61d8
[FLINK-5882] [table] TableFunction (UDTF) should support variable types and variable arguments This closes #3407. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/04aee61d Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/04aee61d Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/04aee61d Branch: refs/heads/master Commit: 04aee61d86f9ba30715c133380560739282feb81 Parents: 9b179be Author: Zhuoluo Yang <zhuoluo....@alibaba-inc.com> Authored: Tue Mar 7 12:02:46 2017 +0800 Committer: twalthr <twal...@apache.org> Committed: Mon Mar 13 10:55:10 2017 +0100 ---------------------------------------------------------------------- .../codegen/calls/TableFunctionCallGen.scala | 17 ++++++--- .../DataSetUserDefinedFunctionITCase.scala | 37 ++++++++++++++++++++ .../DataStreamUserDefinedFunctionITCase.scala | 34 ++++++++++++++++-- .../table/utils/UserDefinedTableFunctions.scala | 11 +++++- 4 files changed, 91 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/04aee61d/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala index 890b6bd..ba90292 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala @@ -44,14 +44,21 @@ class TableFunctionCallGen( codeGenerator: CodeGenerator, operands: Seq[GeneratedExpression]) : GeneratedExpression = { - // determine function signature - val matchingSignature = getSignature(tableFunction, signature) + // determine function method + val matchingMethod = getEvalMethod(tableFunction, signature) .getOrElse(throw new CodeGenException("No matching signature found.")) + val matchingSignature = matchingMethod.getParameterTypes + + // zip for variable signatures + var paramToOperands = matchingSignature.zip(operands) + if (operands.length > matchingSignature.length) { + operands.drop(matchingSignature.length).foreach(op => + paramToOperands = paramToOperands :+ (matchingSignature.last.getComponentType, op) + ) + } // convert parameters for function (output boxing) - val parameters = matchingSignature - .zip(operands) - .map { case (paramClass, operandExpr) => + val parameters = paramToOperands.map { case (paramClass, operandExpr) => if (paramClass.isPrimitive) { operandExpr } else if (ClassUtils.isPrimitiveWrapper(paramClass) http://git-wip-us.apache.org/repos/asf/flink/blob/04aee61d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala index 33b2439..20bbf8b 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala @@ -31,6 +31,7 @@ import org.apache.flink.table.utils._ import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode import org.apache.flink.test.util.TestBaseUtils import org.apache.flink.types.Row +import org.junit.Assert._ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized @@ -277,6 +278,42 @@ class DataSetUserDefinedFunctionITCase( TestBaseUtils.compareResultAsText(results.asJava, expected) } + @Test + def testTableFunctionWithVariableArguments(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val varArgsFunc0 = new VarArgsFunc0 + tableEnv.registerFunction("VarArgsFunc0", varArgsFunc0) + + val result = testData(env) + .toTable(tableEnv, 'a, 'b, 'c) + .select('c) + .join(varArgsFunc0("1", "2", 'c)) + + val expected = "Anna#44,1\n" + + "Anna#44,2\n" + + "Anna#44,Anna#44\n" + + "Jack#22,1\n" + + "Jack#22,2\n" + + "Jack#22,Jack#22\n" + + "John#19,1\n" + + "John#19,2\n" + + "John#19,John#19\n" + + "nosharp,1\n" + + "nosharp,2\n" + + "nosharp,nosharp" + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + + // Test for empty cases + val result0 = testData(env) + .toTable(tableEnv, 'a, 'b, 'c) + .select('c) + .join(varArgsFunc0()) + val results0 = result0.toDataSet[Row].collect() + assertTrue(results0.isEmpty) + } + private def testData( env: ExecutionEnvironment) : DataSet[(Int, Long, String)] = { http://git-wip-us.apache.org/repos/asf/flink/blob/04aee61d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala index e7ce457..853c771 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala @@ -24,8 +24,7 @@ import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.api.scala._ import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData} import org.apache.flink.table.expressions.utils.{Func13, RichFunc2} -import org.apache.flink.table.utils.{RichTableFunc1, TableFunc0, TableFunc3, UserDefinedFunctionTestUtils} -import org.apache.flink.table.utils.PojoTableFunc +import org.apache.flink.table.utils._ import org.apache.flink.types.Row import org.junit.Assert._ import org.junit.Test @@ -196,6 +195,37 @@ class DataStreamUserDefinedFunctionITCase extends StreamingMultipleProgramsTestB assertEquals(expected.sorted, StreamITCase.testResults.sorted) } + @Test + def testTableFunctionWithVariableArguments(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env) + val varArgsFunc0 = new VarArgsFunc0 + tableEnv.registerFunction("VarArgsFunc0", varArgsFunc0) + + val result = testData(env) + .toTable(tableEnv, 'a, 'b, 'c) + .select('c) + .join(varArgsFunc0("1", "2", 'c)) + + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "Anna#44,1", + "Anna#44,2", + "Anna#44,Anna#44", + "Jack#22,1", + "Jack#22,2", + "Jack#22,Jack#22", + "John#19,1", + "John#19,2", + "John#19,John#19", + "nosharp,1", + "nosharp,2", + "nosharp,nosharp") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + private def testData( env: StreamExecutionEnvironment) : DataStream[(Int, Long, String)] = { http://git-wip-us.apache.org/repos/asf/flink/blob/04aee61d/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala index 88917a2..d0ffade 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala @@ -23,10 +23,12 @@ import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.Tuple3 import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.table.api.ValidationException -import org.apache.flink.table.functions.{TableFunction, FunctionContext} +import org.apache.flink.table.functions.{FunctionContext, TableFunction} import org.apache.flink.types.Row import org.junit.Assert +import scala.annotation.varargs + case class SimpleUser(name: String, age: Int) @@ -203,3 +205,10 @@ class RichTableFunc1 extends TableFunction[String] { separator = None } } + +class VarArgsFunc0 extends TableFunction[String] { + @varargs + def eval(str: String*): Unit = { + str.foreach(collect) + } +}