This is an automated email from the ASF dual-hosted git repository. snuyanzin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 40fb49dd17b [FLINK-27741][table-planner] Fix NPE when use dense_rank() and rank() 40fb49dd17b is described below commit 40fb49dd17b3e1b6c5aa0249514273730ebe9226 Author: chenzihao <chenzih...@xiaomi.com> AuthorDate: Tue May 14 22:18:05 2024 +0200 [FLINK-27741][table-planner] Fix NPE when use dense_rank() and rank() Co-authored-by: Sergey Nuyanzin <snuyan...@gmail.com> This closes apache#19797 --- .../aggfunctions/RankLikeAggFunctionBase.java | 2 +- .../planner/plan/utils/AggFunctionFactory.scala | 17 +++--- .../plan/batch/sql/agg/OverAggregateTest.xml | 44 ++++++++++++++++ .../plan/batch/sql/agg/OverAggregateTest.scala | 13 +++++ .../runtime/stream/sql/OverAggregateITCase.scala | 60 ++++++++++++++++++++++ 5 files changed, 129 insertions(+), 7 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/RankLikeAggFunctionBase.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/RankLikeAggFunctionBase.java index 2a556d7b741..898939aedb9 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/RankLikeAggFunctionBase.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/RankLikeAggFunctionBase.java @@ -99,7 +99,7 @@ public abstract class RankLikeAggFunctionBase extends DeclarativeAggregateFuncti equalTo(lasValue, operand(i))); } Optional<Expression> ret = Arrays.stream(orderKeyEquals).reduce(ExpressionBuilder::and); - return ret.orElseGet(() -> literal(true)); + return ret.orElseGet(() -> literal(false)); } protected Expression generateInitLiteral(LogicalType orderType) { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala index 4ecd4363863..6ca84314fc7 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala @@ -532,18 +532,23 @@ class AggFunctionFactory( } private def createRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = { - val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_)) - new RankAggFunction(argTypes) + new RankAggFunction(getArgTypesOrEmpty()) } private def createDenseRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = { - val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_)) - new DenseRankAggFunction(argTypes) + new DenseRankAggFunction(getArgTypesOrEmpty()) } private def createPercentRankAggFunction(argTypes: Array[LogicalType]): UserDefinedFunction = { - val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_)) - new PercentRankAggFunction(argTypes) + new PercentRankAggFunction(getArgTypesOrEmpty()) + } + + private def getArgTypesOrEmpty(): Array[LogicalType] = { + if (orderKeyIndexes != null) { + orderKeyIndexes.map(inputRowType.getChildren.get(_)) + } else { + Array[LogicalType]() + } } private def createNTILEAggFUnction(argTypes: Array[LogicalType]): UserDefinedFunction = { diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml index 909efe170f7..0ca5ec28442 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml @@ -280,6 +280,50 @@ OverAggregate(partitionBy=[c], window#0=[COUNT(*) AS w0$o0 RANG BETWEEN UNBOUNDE ]]> </Resource> </TestCase> + <TestCase name="testDenseRankOnOrder"> + <Resource name="sql"> + <![CDATA[SELECT a, DENSE_RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTableWithProctime]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(a=[$0], EXPR$1=[DENSE_RANK() OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST)]) ++- LogicalTableScan(table=[[default_catalog, default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]]) +]]> + </Resource> + <Resource name="optimized exec plan"> + <![CDATA[ +Calc(select=[a, w0$o0 AS $1]) ++- OverAggregate(partitionBy=[a], orderBy=[proctime ASC], window#0=[DENSE_RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, proctime, w0$o0]) + +- Exchange(distribution=[forward]) + +- Sort(orderBy=[a ASC, proctime ASC]) + +- Exchange(distribution=[hash[a]]) + +- Calc(select=[a, proctime]) + +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]], fields=[a, b, c, proctime]) +]]> + </Resource> + </TestCase> + <TestCase name="testRankOnOver"> + <Resource name="sql"> + <![CDATA[SELECT a, RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTableWithProctime]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(a=[$0], EXPR$1=[RANK() OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST)]) ++- LogicalTableScan(table=[[default_catalog, default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]]) +]]> + </Resource> + <Resource name="optimized exec plan"> + <![CDATA[ +Calc(select=[a, w0$o0 AS $1]) ++- OverAggregate(partitionBy=[a], orderBy=[proctime ASC], window#0=[RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, proctime, w0$o0]) + +- Exchange(distribution=[forward]) + +- Sort(orderBy=[a ASC, proctime ASC]) + +- Exchange(distribution=[hash[a]]) + +- Calc(select=[a, proctime]) + +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]], fields=[a, b, c, proctime]) +]]> + </Resource> + </TestCase> <TestCase name="testOverWindowWithoutPartitionBy"> <Resource name="sql"> <![CDATA[SELECT c, SUM(a) OVER (ORDER BY b) FROM MyTable]]> diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala index 1fb6ad9028a..f71325beb57 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala @@ -31,6 +31,7 @@ class OverAggregateTest extends TableTestBase { private val util = batchTestUtil() util.addTableSource[(Int, Long, String)]("MyTable", 'a, 'b, 'c) + util.addTableSource[(Int, Long, String, Long)]("MyTableWithProctime", 'a, 'b, 'c, 'proctime) @Test def testOverWindowWithoutPartitionByOrderBy(): Unit = { @@ -47,6 +48,18 @@ class OverAggregateTest extends TableTestBase { util.verifyExecPlan("SELECT c, SUM(a) OVER (ORDER BY b) FROM MyTable") } + @Test + def testDenseRankOnOrder(): Unit = { + util.verifyExecPlan( + "SELECT a, DENSE_RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTableWithProctime") + } + + @Test + def testRankOnOver(): Unit = { + util.verifyExecPlan( + "SELECT a, RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTableWithProctime") + } + @Test def testDiffPartitionKeysWithSameOrderKeys(): Unit = { val sqlQuery = diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala index f4897e8b14f..9bf39d8d0e2 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala @@ -165,6 +165,66 @@ class OverAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTest assertThat(sink.getAppendResults.sorted).isEqualTo(expected.sorted) } + @TestTemplate + def testDenseRankOnOver(): Unit = { + val t = failingDataSource(TestData.tupleData5) + .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime) + tEnv.createTemporaryView("MyTable", t) + val sqlQuery = "SELECT a, DENSE_RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTable" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sqlQuery).toDataStream.addSink(sink) + env.execute() + + val expected = List( + "1,1", + "2,1", + "2,2", + "3,1", + "3,2", + "3,3", + "4,1", + "4,2", + "4,3", + "4,4", + "5,1", + "5,2", + "5,3", + "5,4", + "5,5") + assertThat(expected.sorted).isEqualTo(sink.getAppendResults.sorted) + } + + @TestTemplate + def testRankOnOver(): Unit = { + val t = failingDataSource(TestData.tupleData5) + .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime) + tEnv.createTemporaryView("MyTable", t) + val sqlQuery = "SELECT a, RANK() OVER (PARTITION BY a ORDER BY proctime) FROM MyTable" + + val sink = new TestingAppendSink + tEnv.sqlQuery(sqlQuery).toDataStream.addSink(sink) + env.execute() + + val expected = List( + "1,1", + "2,1", + "2,2", + "3,1", + "3,2", + "3,3", + "4,1", + "4,2", + "4,3", + "4,4", + "5,1", + "5,2", + "5,3", + "5,4", + "5,5") + assertThat(expected.sorted).isEqualTo(sink.getAppendResults.sorted) + } + @TestTemplate def testProcTimeBoundedPartitionedRowsOver(): Unit = { val t = failingDataSource(TestData.tupleData5)