[FLINK-6436] [table] Fix code-gen bug when using a scalar UDF in a UDTF join condition.
This closes #3815. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e2cb2215 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e2cb2215 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e2cb2215 Branch: refs/heads/master Commit: e2cb2215917e33d35fff5b07ed6a64c05e14abce Parents: f26a911 Author: godfreyhe <godfre...@163.com> Authored: Wed May 3 20:59:34 2017 +0800 Committer: Fabian Hueske <fhue...@apache.org> Committed: Tue May 9 18:50:20 2017 +0200 ---------------------------------------------------------------------- .../table/plan/nodes/CommonCorrelate.scala | 13 ++++++++++- .../utils/UserDefinedScalarFunctions.scala | 9 ++++++-- .../DataSetUserDefinedFunctionITCase.scala | 24 +++++++++++++++++--- .../DataStreamUserDefinedFunctionITCase.scala | 22 ++++++++++++++++-- 4 files changed, 60 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/e2cb2215/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala index 44a109e3..c95f2f7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.plan.nodes import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rex.{RexCall, RexNode} +import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode, RexShuttle} import org.apache.calcite.sql.SemiJoinType import org.apache.flink.api.common.functions.FlatMapFunction import org.apache.flink.api.common.typeinfo.TypeInformation @@ -143,6 +143,17 @@ trait CommonCorrelate[T] { |getCollector().collect(${crossResultExpr.resultTerm}); |""".stripMargin } else { + + // adjust indicies of InputRefs to adhere to schema expected by generator + val changeInputRefIndexShuttle = new RexShuttle { + override def visitInputRef(inputRef: RexInputRef): RexNode = { + new RexInputRef(inputSchema.physicalArity + inputRef.getIndex, inputRef.getType) + } + } + // Run generateExpression to add init statements (ScalarFunctions) of condition to generator. + // The generated expression is discarded. + generator.generateExpression(condition.get.accept(changeInputRefIndexShuttle)) + val filterGenerator = new CodeGenerator(config, false, udtfTypeInfo, None, pojoFieldMapping) filterGenerator.input1Term = filterGenerator.input2Term val filterCondition = filterGenerator.generateExpression(condition.get) http://git-wip-us.apache.org/repos/asf/flink/blob/e2cb2215/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala index 8972a77..5285569 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala @@ -25,11 +25,10 @@ import org.apache.flink.table.api.Types import org.apache.flink.table.functions.{ScalarFunction, FunctionContext} import org.junit.Assert +import scala.annotation.varargs import scala.collection.mutable import scala.io.Source -import scala.annotation.varargs - case class SimplePojo(name: String, age: Int) object Func0 extends ScalarFunction { @@ -263,3 +262,9 @@ object Func17 extends ScalarFunction { a.mkString(", ") } } + +object Func18 extends ScalarFunction { + def eval(str: String, prefix: String): Boolean = { + str.startsWith(prefix) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e2cb2215/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala index 20bbf8b..b69dd49 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala @@ -24,9 +24,9 @@ import org.apache.flink.api.scala.util.CollectionDataSets import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.api.java.utils.UserDefinedTableFunctions.JavaTableFunc0 import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase -import org.apache.flink.table.expressions.utils.{Func13, RichFunc2} +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.expressions.utils.{Func1, Func13, Func18, RichFunc2} import org.apache.flink.table.utils._ import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode import org.apache.flink.test.util.TestBaseUtils @@ -143,7 +143,7 @@ class DataSetUserDefinedFunctionITCase( val pojo = new PojoTableFunc() val result = in .join(pojo('c)) - .where(('age > 20)) + .where('age > 20) .select('c, 'name, 'age) .toDataSet[Row] @@ -171,6 +171,24 @@ class DataSetUserDefinedFunctionITCase( } @Test + def testUserDefinedTableFunctionWithScalarFunctionInCondition(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + val func0 = new TableFunc0 + + val result = in + .join(func0('c)) + .where(Func18('name, "J") && (Func1('a) < 3) && Func1('age) > 20) + .select('c, 'name, 'age) + .toDataSet[Row] + + val results = result.collect() + val expected = "Jack#22,Jack,22" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test def testLongAndTemporalTypes(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tableEnv = TableEnvironment.getTableEnvironment(env, config) http://git-wip-us.apache.org/repos/asf/flink/blob/e2cb2215/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala index 2e8a065..b3d9c6f 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamUserDefinedFunctionITCase.scala @@ -23,7 +23,7 @@ import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.api.scala._ import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData} -import org.apache.flink.table.expressions.utils.{Func13, RichFunc2} +import org.apache.flink.table.expressions.utils.{Func13, Func18, RichFunc2} import org.apache.flink.table.utils._ import org.apache.flink.types.Row import org.junit.Assert._ @@ -51,7 +51,7 @@ class DataStreamUserDefinedFunctionITCase extends StreamingMultipleProgramsTestB .join(func0('c) as('d, 'e)) .select('c, 'd, 'e) .join(pojoFunc0('c)) - .where(('age > 20)) + .where('age > 20) .select('c, 'name, 'age) .toDataStream[Row] @@ -82,6 +82,24 @@ class DataStreamUserDefinedFunctionITCase extends StreamingMultipleProgramsTestB } @Test + def testUserDefinedTableFunctionWithScalarFunction(): Unit = { + val t = testData(env).toTable(tEnv).as('a, 'b, 'c) + val func0 = new TableFunc0 + + val result = t + .join(func0('c) as('d, 'e)) + .where(Func18('d, "J")) + .select('c, 'd, 'e) + .toDataStream[Row] + + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("Jack#22,Jack,22", "John#19,John,19") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test def testUserDefinedTableFunctionWithParameter(): Unit = { val tableFunc1 = new RichTableFunc1 tEnv.registerFunction("RichTableFunc1", tableFunc1)