libenchao commented on code in PR #24638:
URL: https://github.com/apache/flink/pull/24638#discussion_r1607579249


##########
flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala:
##########
@@ -188,7 +188,7 @@ class FlinkRelMdUpsertKeysTest extends 
FlinkRelMdHandlerTestBase {
       rank => assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(rank).toSet)
     }
 
-    Array(logicalRowNumber, flinkLogicalRowNumber, streamRowNumber)
+    Array(logicalWindow, logicalRowNumber, flinkLogicalRowNumber, 
streamRowNumber)

Review Comment:
   I'm wondering why the existing tests seem already have some ability to infer 
`row_number` as a part of unique key?



##########
flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeys.scala:
##########
@@ -186,28 +186,34 @@ class FlinkRelMdUpsertKeys private extends 
MetadataHandler[UpsertKeys] {
   }
 
   def getUpsertKeys(rel: Window, mq: RelMetadataQuery): JSet[ImmutableBitSet] 
= {
-    getUpsertKeysOnOver(rel, mq, rel.groups.map(_.keys): _*)
+    getUpsertKeysOnOver(rel, mq)
   }
 
   def getUpsertKeys(
       rel: BatchPhysicalOverAggregate,
       mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
-    getUpsertKeysOnOver(rel, mq, ImmutableBitSet.of(rel.partitionKeyIndices: 
_*))
+    getUpsertKeysOnOver(rel, mq)
   }
 
   def getUpsertKeys(
       rel: StreamPhysicalOverAggregate,
       mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
-    getUpsertKeysOnOver(rel, mq, rel.logicWindow.groups.map(_.keys): _*)
+    getUpsertKeysOnOver(rel, mq)
   }
 
   private def getUpsertKeysOnOver(
-      rel: SingleRel,
-      mq: RelMetadataQuery,
-      distributionKeys: ImmutableBitSet*): JSet[ImmutableBitSet] = {
-    var inputKeys = 
FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(rel.getInput)
-    for (distributionKey <- distributionKeys) {
+      window: SingleRel,
+      mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
+    var (groups, aggStartPos) = 
FlinkRelMdUniqueKeys.INSTANCE.getGroupsAndStartPos(window)
+    var inputKeys = 
FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(window.getInput)
+    for (group <- groups) {
+      val distributionKey = group.keys
       inputKeys = filterKeys(inputKeys, distributionKey)
+      FlinkRelMdUniqueKeys.INSTANCE.getUniqueKeysOfWindowGroup(group, 
aggStartPos) match {
+        case Some(upsertKeys) => inputKeys.addAll(upsertKeys)
+        case _ =>
+      }
+      aggStartPos = aggStartPos + group.aggCalls.length

Review Comment:
   filtering the input keys, and adding the keys inferred via 'partition key 
and row_number' is orthogonal, can we just use two different for loops, which 
is more clear?



##########
flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala:
##########
@@ -439,6 +439,18 @@ class FlinkRelMdColumnUniqueness private extends 
MetadataHandler[BuiltInMetadata
       mq: RelMetadataQuery,
       columns: ImmutableBitSet,
       ignoreNulls: Boolean): JBoolean = {
+    var (groups, aggStartPos) = 
FlinkRelMdUniqueKeys.INSTANCE.getGroupsAndStartPos(overAgg)
+    for (group <- groups) {
+      FlinkRelMdUniqueKeys.INSTANCE.getUniqueKeysOfWindowGroup(group, 
aggStartPos) match {
+        case Some(upsertKeys) =>

Review Comment:
   the name should be `uniqueKeys`?



##########
flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala:
##########
@@ -455,11 +455,56 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
     getUniqueKeysOfOverAgg(rel, mq, ignoreNulls)
   }
 
+  def getUniqueKeysOfWindowGroup(
+      group: Window.Group,
+      startPos: Int): Option[JSet[ImmutableBitSet]] = {
+    val retSet = new JHashSet[ImmutableBitSet]
+    val aggCalls = group.aggCalls
+    for ((aggCall, offset) <- aggCalls.zipWithIndex) {
+      // If it's a ROW_NUMBER window, then the unique keys are partition by 
key and row number.
+      if (aggCall.getOperator.equals(SqlStdOperatorTable.ROW_NUMBER)) {
+        val rowNumberColumnIndex = startPos + offset
+        retSet.add(group.keys.union(ImmutableBitSet.of(rowNumberColumnIndex)))
+      }
+    }
+    if (retSet.isEmpty) {
+      None
+    } else {
+      Some(retSet)
+    }
+  }
+
+  def getGroupsAndStartPos(window: SingleRel): Tuple2[JList[Window.Group], 
Int] = {
+    val groups: JList[Window.Group] = window match {
+      case window: Window => window.groups
+      case streamOverAggregate: StreamPhysicalOverAggregate =>
+        streamOverAggregate.logicWindow.groups
+      case batchOverAggregate: BatchPhysicalOverAggregate => 
batchOverAggregate.windowGroups
+      case _ => throw new IllegalArgumentException("Illegal window type.")
+    }
+    val aggCounts = groups.map(_.aggCalls.length).sum
+    val aggStartIndex = window.getRowType.getFieldCount - aggCounts
+    (groups, aggStartIndex)
+  }
+
   private def getUniqueKeysOfOverAgg(
       window: SingleRel,
       mq: RelMetadataQuery,
       ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
-    mq.getUniqueKeys(window.getInput, ignoreNulls)
+    val retSet = new JHashSet[ImmutableBitSet]
+    var (groups, aggStartPos) = getGroupsAndStartPos(window)
+    for (group <- groups) {
+      getUniqueKeysOfWindowGroup(group, aggStartPos) match {
+        case Some(uniqueKeys) => retSet.addAll(uniqueKeys)
+        case _ =>
+      }
+      aggStartPos = aggStartPos + group.aggCalls.length
+    }
+    val inputKeys = mq.getUniqueKeys(window.getInput, ignoreNulls)
+    if (inputKeys != null && inputKeys.nonEmpty) {
+      retSet.addAll(inputKeys)
+    }
+    retSet

Review Comment:
   If the `retSet` is empty, I would suggest to return `inputKeys`, since 
'null' and 'empty set' have different meaning: 
https://github.com/apache/calcite/blob/327bfcc7799c4413a84a5ebe0849ff64853a7dd1/core/src/main/java/org/apache/calcite/rel/metadata/RelMetadataQuery.java#L478-L479



##########
flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala:
##########
@@ -2540,7 +2571,7 @@ class FlinkRelMdHandlerTestBase {
   private lazy val overAggGroups = {
     ImmutableList.of(
       new Window.Group(
-        ImmutableBitSet.of(5),
+        ImmutableBitSet.of(4),

Review Comment:
   Why must this be modified?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to