This is an automated email from the ASF dual-hosted git repository.

snuyanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 40fb49dd17b [FLINK-27741][table-planner] Fix NPE when use dense_rank() 
and rank()
40fb49dd17b is described below

commit 40fb49dd17b3e1b6c5aa0249514273730ebe9226
Author: chenzihao <chenzih...@xiaomi.com>
AuthorDate: Tue May 14 22:18:05 2024 +0200

    [FLINK-27741][table-planner] Fix NPE when use dense_rank() and rank()
    
    Co-authored-by: Sergey Nuyanzin <snuyan...@gmail.com>
    
    This closes apache#19797
---
 .../aggfunctions/RankLikeAggFunctionBase.java      |  2 +-
 .../planner/plan/utils/AggFunctionFactory.scala    | 17 +++---
 .../plan/batch/sql/agg/OverAggregateTest.xml       | 44 ++++++++++++++++
 .../plan/batch/sql/agg/OverAggregateTest.scala     | 13 +++++
 .../runtime/stream/sql/OverAggregateITCase.scala   | 60 ++++++++++++++++++++++
 5 files changed, 129 insertions(+), 7 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/RankLikeAggFunctionBase.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/RankLikeAggFunctionBase.java
index 2a556d7b741..898939aedb9 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/RankLikeAggFunctionBase.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/RankLikeAggFunctionBase.java
@@ -99,7 +99,7 @@ public abstract class RankLikeAggFunctionBase extends 
DeclarativeAggregateFuncti
                             equalTo(lasValue, operand(i)));
         }
         Optional<Expression> ret = 
Arrays.stream(orderKeyEquals).reduce(ExpressionBuilder::and);
-        return ret.orElseGet(() -> literal(true));
+        return ret.orElseGet(() -> literal(false));
     }
 
     protected Expression generateInitLiteral(LogicalType orderType) {
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
index 4ecd4363863..6ca84314fc7 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
@@ -532,18 +532,23 @@ class AggFunctionFactory(
   }
 
   private def createRankAggFunction(argTypes: Array[LogicalType]): 
UserDefinedFunction = {
-    val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
-    new RankAggFunction(argTypes)
+    new RankAggFunction(getArgTypesOrEmpty())
   }
 
   private def createDenseRankAggFunction(argTypes: Array[LogicalType]): 
UserDefinedFunction = {
-    val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
-    new DenseRankAggFunction(argTypes)
+    new DenseRankAggFunction(getArgTypesOrEmpty())
   }
 
   private def createPercentRankAggFunction(argTypes: Array[LogicalType]): 
UserDefinedFunction = {
-    val argTypes = orderKeyIndexes.map(inputRowType.getChildren.get(_))
-    new PercentRankAggFunction(argTypes)
+    new PercentRankAggFunction(getArgTypesOrEmpty())
+  }
+
+  private def getArgTypesOrEmpty(): Array[LogicalType] = {
+    if (orderKeyIndexes != null) {
+      orderKeyIndexes.map(inputRowType.getChildren.get(_))
+    } else {
+      Array[LogicalType]()
+    }
   }
 
   private def createNTILEAggFUnction(argTypes: Array[LogicalType]): 
UserDefinedFunction = {
diff --git 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
index 909efe170f7..0ca5ec28442 100644
--- 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
+++ 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
@@ -280,6 +280,50 @@ OverAggregate(partitionBy=[c], window#0=[COUNT(*) AS w0$o0 
RANG BETWEEN UNBOUNDE
 ]]>
     </Resource>
   </TestCase>
+       <TestCase name="testDenseRankOnOrder">
+               <Resource name="sql">
+                       <![CDATA[SELECT a, DENSE_RANK() OVER (PARTITION BY a 
ORDER BY proctime) FROM MyTableWithProctime]]>
+               </Resource>
+               <Resource name="ast">
+                       <![CDATA[
+LogicalProject(a=[$0], EXPR$1=[DENSE_RANK() OVER (PARTITION BY $0 ORDER BY $3 
NULLS FIRST)])
++- LogicalTableScan(table=[[default_catalog, default_database, 
MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]])
+]]>
+               </Resource>
+               <Resource name="optimized exec plan">
+                       <![CDATA[
+Calc(select=[a, w0$o0 AS $1])
++- OverAggregate(partitionBy=[a], orderBy=[proctime ASC], 
window#0=[DENSE_RANK(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT 
ROW], select=[a, proctime, w0$o0])
+   +- Exchange(distribution=[forward])
+      +- Sort(orderBy=[a ASC, proctime ASC])
+         +- Exchange(distribution=[hash[a]])
+            +- Calc(select=[a, proctime])
+               +- LegacyTableSourceScan(table=[[default_catalog, 
default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, 
proctime)]]], fields=[a, b, c, proctime])
+]]>
+               </Resource>
+       </TestCase>
+       <TestCase name="testRankOnOver">
+               <Resource name="sql">
+                       <![CDATA[SELECT a, RANK() OVER (PARTITION BY a ORDER BY 
proctime) FROM MyTableWithProctime]]>
+               </Resource>
+               <Resource name="ast">
+                       <![CDATA[
+LogicalProject(a=[$0], EXPR$1=[RANK() OVER (PARTITION BY $0 ORDER BY $3 NULLS 
FIRST)])
++- LogicalTableScan(table=[[default_catalog, default_database, 
MyTableWithProctime, source: [TestTableSource(a, b, c, proctime)]]])
+]]>
+               </Resource>
+               <Resource name="optimized exec plan">
+                       <![CDATA[
+Calc(select=[a, w0$o0 AS $1])
++- OverAggregate(partitionBy=[a], orderBy=[proctime ASC], window#0=[RANK(*) AS 
w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, proctime, 
w0$o0])
+   +- Exchange(distribution=[forward])
+      +- Sort(orderBy=[a ASC, proctime ASC])
+         +- Exchange(distribution=[hash[a]])
+            +- Calc(select=[a, proctime])
+               +- LegacyTableSourceScan(table=[[default_catalog, 
default_database, MyTableWithProctime, source: [TestTableSource(a, b, c, 
proctime)]]], fields=[a, b, c, proctime])
+]]>
+               </Resource>
+       </TestCase>
   <TestCase name="testOverWindowWithoutPartitionBy">
     <Resource name="sql">
       <![CDATA[SELECT c, SUM(a) OVER (ORDER BY b) FROM MyTable]]>
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
index 1fb6ad9028a..f71325beb57 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
@@ -31,6 +31,7 @@ class OverAggregateTest extends TableTestBase {
 
   private val util = batchTestUtil()
   util.addTableSource[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+  util.addTableSource[(Int, Long, String, Long)]("MyTableWithProctime", 'a, 
'b, 'c, 'proctime)
 
   @Test
   def testOverWindowWithoutPartitionByOrderBy(): Unit = {
@@ -47,6 +48,18 @@ class OverAggregateTest extends TableTestBase {
     util.verifyExecPlan("SELECT c, SUM(a) OVER (ORDER BY b) FROM MyTable")
   }
 
+  @Test
+  def testDenseRankOnOrder(): Unit = {
+    util.verifyExecPlan(
+      "SELECT a, DENSE_RANK() OVER (PARTITION BY a ORDER BY proctime) FROM 
MyTableWithProctime")
+  }
+
+  @Test
+  def testRankOnOver(): Unit = {
+    util.verifyExecPlan(
+      "SELECT a, RANK() OVER (PARTITION BY a ORDER BY proctime) FROM 
MyTableWithProctime")
+  }
+
   @Test
   def testDiffPartitionKeysWithSameOrderKeys(): Unit = {
     val sqlQuery =
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala
index f4897e8b14f..9bf39d8d0e2 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala
@@ -165,6 +165,66 @@ class OverAggregateITCase(mode: StateBackendMode) extends 
StreamingWithStateTest
     assertThat(sink.getAppendResults.sorted).isEqualTo(expected.sorted)
   }
 
+  @TestTemplate
+  def testDenseRankOnOver(): Unit = {
+    val t = failingDataSource(TestData.tupleData5)
+      .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
+    tEnv.createTemporaryView("MyTable", t)
+    val sqlQuery = "SELECT a, DENSE_RANK() OVER (PARTITION BY a ORDER BY 
proctime) FROM MyTable"
+
+    val sink = new TestingAppendSink
+    tEnv.sqlQuery(sqlQuery).toDataStream.addSink(sink)
+    env.execute()
+
+    val expected = List(
+      "1,1",
+      "2,1",
+      "2,2",
+      "3,1",
+      "3,2",
+      "3,3",
+      "4,1",
+      "4,2",
+      "4,3",
+      "4,4",
+      "5,1",
+      "5,2",
+      "5,3",
+      "5,4",
+      "5,5")
+    assertThat(expected.sorted).isEqualTo(sink.getAppendResults.sorted)
+  }
+
+  @TestTemplate
+  def testRankOnOver(): Unit = {
+    val t = failingDataSource(TestData.tupleData5)
+      .toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
+    tEnv.createTemporaryView("MyTable", t)
+    val sqlQuery = "SELECT a, RANK() OVER (PARTITION BY a ORDER BY proctime) 
FROM MyTable"
+
+    val sink = new TestingAppendSink
+    tEnv.sqlQuery(sqlQuery).toDataStream.addSink(sink)
+    env.execute()
+
+    val expected = List(
+      "1,1",
+      "2,1",
+      "2,2",
+      "3,1",
+      "3,2",
+      "3,3",
+      "4,1",
+      "4,2",
+      "4,3",
+      "4,4",
+      "5,1",
+      "5,2",
+      "5,3",
+      "5,4",
+      "5,5")
+    assertThat(expected.sorted).isEqualTo(sink.getAppendResults.sorted)
+  }
+
   @TestTemplate
   def testProcTimeBoundedPartitionedRowsOver(): Unit = {
     val t = failingDataSource(TestData.tupleData5)

Reply via email to