This is an automated email from the ASF dual-hosted git repository.

jchan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new b957480112c [FLINK-33941][table-planner] Use field reference index to 
compute window aggregate time attribute column
b957480112c is described below

commit b957480112c00d9d777247fc48b602e9908652a2
Author: Xuyang <xyzhong...@163.com>
AuthorDate: Thu Jan 4 10:47:00 2024 +0800

    [FLINK-33941][table-planner] Use field reference index to compute window 
aggregate time attribute column
    
    This closes #23991
---
 .../stream/StreamExecGroupWindowAggregate.java     |  7 +---
 .../StreamExecPythonGroupWindowAggregate.java      |  7 +---
 .../logical/FlinkAggregateProjectMergeRule.java    | 48 +++++++++++++++++-----
 .../BatchPhysicalPythonWindowAggregateRule.java    |  4 +-
 .../table/planner/plan/logical/groupWindows.scala  | 14 +++++++
 .../nodes/calcite/LogicalWindowAggregate.scala     | 18 ++++++++
 .../batch/BatchPhysicalWindowAggregateRule.scala   |  3 +-
 .../table/planner/plan/utils/AggregateUtil.scala   |  9 ----
 8 files changed, 74 insertions(+), 36 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java
index 40471878046..d1a48c64077 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java
@@ -84,7 +84,6 @@ import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.hasTimeInt
 import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.isProctimeAttribute;
 import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.isRowtimeAttribute;
 import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.isTableAggregate;
-import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.timeFieldIndex;
 import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.toDuration;
 import static org.apache.flink.table.planner.plan.utils.AggregateUtil.toLong;
 import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToStreamAggregateInfoList;
@@ -211,11 +210,7 @@ public class StreamExecGroupWindowAggregate extends 
StreamExecAggregateBase {
 
         final int inputTimeFieldIndex;
         if (isRowtimeAttribute(window.timeAttribute())) {
-            inputTimeFieldIndex =
-                    timeFieldIndex(
-                            
planner.getTypeFactory().buildRelNodeRowType(inputRowType),
-                            planner.createRelBuilder(),
-                            window.timeAttribute());
+            inputTimeFieldIndex = window.timeAttribute().getFieldIndex();
             if (inputTimeFieldIndex < 0) {
                 throw new TableException(
                         "Group window must defined on a time attribute, "
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java
index ec908fb1181..d6fc11fee73 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java
@@ -93,7 +93,6 @@ import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.hasRowInte
 import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.hasTimeIntervalType;
 import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.isProctimeAttribute;
 import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.isRowtimeAttribute;
-import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.timeFieldIndex;
 import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.toDuration;
 import static org.apache.flink.table.planner.plan.utils.AggregateUtil.toLong;
 import static 
org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToStreamAggregateInfoList;
@@ -234,11 +233,7 @@ public class StreamExecPythonGroupWindowAggregate extends 
StreamExecAggregateBas
 
         final int inputTimeFieldIndex;
         if (isRowtimeAttribute(window.timeAttribute())) {
-            inputTimeFieldIndex =
-                    timeFieldIndex(
-                            
planner.getTypeFactory().buildRelNodeRowType(inputRowType),
-                            planner.createRelBuilder(),
-                            window.timeAttribute());
+            inputTimeFieldIndex = window.timeAttribute().getFieldIndex();
             if (inputTimeFieldIndex < 0) {
                 throw new TableException(
                         "Group window must defined on a time attribute, "
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateProjectMergeRule.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateProjectMergeRule.java
index 8f39f2f001b..dc5514872cc 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateProjectMergeRule.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateProjectMergeRule.java
@@ -18,6 +18,8 @@
 
 package org.apache.flink.table.planner.plan.rules.logical;
 
+import org.apache.flink.table.expressions.FieldReferenceExpression;
+import org.apache.flink.table.planner.plan.logical.LogicalWindow;
 import 
org.apache.flink.table.planner.plan.nodes.calcite.LogicalWindowAggregate;
 import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
 
@@ -54,7 +56,8 @@ import static java.util.Objects.requireNonNull;
  * <p>FLINK modifications are at lines
  *
  * <ol>
- *   <li>Should be removed after legacy groupWindowAggregate was removed: 
Lines 83 ~ 101
+ *   <li>Should be removed after legacy groupWindowAggregate was removed: 
Lines 85 ~ 105, Lines 136
+ *       ~ 156
  * </ol>
  */
 public class FlinkAggregateProjectMergeRule extends AggregateProjectMergeRule {
@@ -79,18 +82,19 @@ public class FlinkAggregateProjectMergeRule extends 
AggregateProjectMergeRule {
             RelOptRuleCall call, Aggregate aggregate, Project project) {
         // Find all fields which we need to be straightforward field 
projections.
         final Set<Integer> interestingFields = 
RelOptUtil.getAllFields(aggregate);
+        boolean isProctimeWindowAgg = false;
 
         // Should add the field of timeAttribute in a LogicalWindowAggregate 
node which uses rowTime
         if (aggregate instanceof LogicalWindowAggregate) {
             LogicalWindowAggregate winAgg = (LogicalWindowAggregate) aggregate;
             // isRowtimeAttribute can't be used here because the 
time_indicator phase comes later
-            boolean isProcTime =
+            isProctimeWindowAgg =
                     LogicalTypeChecks.isProctimeAttribute(
                             winAgg.getWindow()
                                     .timeAttribute()
                                     .getOutputDataType()
                                     .getLogicalType());
-            if (!isProcTime) {
+            if (!isProctimeWindowAgg) {
                 // no need to consider the inputIndex because 
LogicalWindowAggregate is single input
                 interestingFields.add(
                         ((LogicalWindowAggregate) aggregate)
@@ -127,13 +131,37 @@ public class FlinkAggregateProjectMergeRule extends 
AggregateProjectMergeRule {
             aggCalls.add(aggregateCall.transform(targetMapping));
         }
 
-        final Aggregate newAggregate =
-                aggregate.copy(
-                        aggregate.getTraitSet(),
-                        project.getInput(),
-                        newGroupSet,
-                        newGroupingSets,
-                        aggCalls.build());
+        final Aggregate newAggregate;
+
+        if (aggregate instanceof LogicalWindowAggregate && 
!isProctimeWindowAgg) {
+            // update the index of the time field in window
+            LogicalWindowAggregate winAgg = (LogicalWindowAggregate) aggregate;
+            LogicalWindow window = winAgg.getWindow();
+            int newTimeIndex = map.get(window.timeAttribute().getFieldIndex());
+            LogicalWindow newWindow =
+                    window.copy(
+                            new FieldReferenceExpression(
+                                    window.timeAttribute().getName(),
+                                    window.timeAttribute().getOutputDataType(),
+                                    window.timeAttribute().getInputIndex(),
+                                    newTimeIndex));
+            newAggregate =
+                    winAgg.copy(
+                            aggregate.getTraitSet(),
+                            project.getInput(),
+                            newGroupSet,
+                            newGroupingSets,
+                            aggCalls.build(),
+                            newWindow);
+        } else {
+            newAggregate =
+                    aggregate.copy(
+                            aggregate.getTraitSet(),
+                            project.getInput(),
+                            newGroupSet,
+                            newGroupingSets,
+                            aggCalls.build());
+        }
 
         // Add a project if the group set is not in the same order or
         // contains duplicates.
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalPythonWindowAggregateRule.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalPythonWindowAggregateRule.java
index 5e35384c490..1e6d680390b 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalPythonWindowAggregateRule.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalPythonWindowAggregateRule.java
@@ -130,9 +130,7 @@ public class BatchPhysicalPythonWindowAggregateRule extends 
RelOptRule {
                         null);
         UserDefinedFunction[] aggFunctions = aggBufferTypesAndFunctions._3();
 
-        int inputTimeFieldIndex =
-                AggregateUtil.timeFieldIndex(
-                        input.getRowType(), call.builder(), 
window.timeAttribute());
+        int inputTimeFieldIndex = window.timeAttribute().getFieldIndex();
         RelDataType inputTimeFieldType =
                 
input.getRowType().getFieldList().get(inputTimeFieldIndex).getType();
         boolean inputTimeIsDate = inputTimeFieldType.getSqlTypeName() == 
SqlTypeName.DATE;
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/logical/groupWindows.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/logical/groupWindows.scala
index 35ef24122a4..a36837710bd 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/logical/groupWindows.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/logical/groupWindows.scala
@@ -45,6 +45,8 @@ abstract class LogicalWindow(
     Objects.equals(timeAttribute, that.timeAttribute)
   }
 
+  def copy(newTimeAttribute: FieldReferenceExpression): LogicalWindow
+
   protected def isValueLiteralExpressionEqual(
       l1: ValueLiteralExpression,
       l2: ValueLiteralExpression): Boolean = {
@@ -90,6 +92,10 @@ case class TumblingGroupWindow(
     }
   }
 
+  override def copy(newTimeField: FieldReferenceExpression): LogicalWindow = {
+    TumblingGroupWindow(alias, newTimeField, size)
+  }
+
   override def toString: String = s"TumblingGroupWindow($alias, $timeField, 
$size)"
 }
 
@@ -113,6 +119,10 @@ case class SlidingGroupWindow(
     }
   }
 
+  override def copy(newTimeField: FieldReferenceExpression): LogicalWindow = {
+    SlidingGroupWindow(alias, newTimeField, size, slide)
+  }
+
   override def toString: String = s"SlidingGroupWindow($alias, $timeField, 
$size, $slide)"
 }
 
@@ -134,5 +144,9 @@ case class SessionGroupWindow(
     }
   }
 
+  override def copy(newTimeField: FieldReferenceExpression): LogicalWindow = {
+    SessionGroupWindow(alias, newTimeField, gap)
+  }
+
   override def toString: String = s"SessionGroupWindow($alias, $timeField, 
$gap)"
 }
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/LogicalWindowAggregate.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/LogicalWindowAggregate.scala
index f4acea9036e..0d9e223b268 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/LogicalWindowAggregate.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/LogicalWindowAggregate.scala
@@ -64,6 +64,24 @@ final class LogicalWindowAggregate(
       window,
       namedProperties)
   }
+
+  def copy(
+      traitSet: RelTraitSet,
+      input: RelNode,
+      groupSet: ImmutableBitSet,
+      // retain this to follow "Aggregate#copy"
+      groupSets: util.List[ImmutableBitSet],
+      aggCalls: util.List[AggregateCall],
+      window: LogicalWindow): Aggregate = {
+    new LogicalWindowAggregate(
+      cluster,
+      traitSet,
+      input,
+      groupSet,
+      aggCalls,
+      window,
+      namedProperties)
+  }
 }
 
 object LogicalWindowAggregate {
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalWindowAggregateRule.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalWindowAggregateRule.scala
index 0402b0bd861..4b23f840431 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalWindowAggregateRule.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalWindowAggregateRule.scala
@@ -160,8 +160,7 @@ class BatchPhysicalWindowAggregateRule
     // TODO aggregate include projection now, so do not provide new trait will 
be safe
     val aggProvidedTraitSet = 
input.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL)
 
-    val inputTimeFieldIndex =
-      AggregateUtil.timeFieldIndex(input.getRowType, call.builder(), 
window.timeAttribute)
+    val inputTimeFieldIndex = window.timeAttribute.getFieldIndex
     val inputTimeFieldType = 
agg.getInput.getRowType.getFieldList.get(inputTimeFieldIndex).getType
     val inputTimeIsDate = inputTimeFieldType.getSqlTypeName == SqlTypeName.DATE
     // local-agg output order: groupSet | assignTs | auxGroupSet | aggCalls
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
index 73c5f2a09da..dbb85d2b72a 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
@@ -63,7 +63,6 @@ import org.apache.calcite.sql.`type`.{SqlTypeName, 
SqlTypeUtil}
 import org.apache.calcite.sql.{SqlAggFunction, SqlKind, SqlRankFunction}
 import org.apache.calcite.sql.fun._
 import org.apache.calcite.sql.validate.SqlMonotonicity
-import org.apache.calcite.tools.RelBuilder
 
 import java.time.Duration
 import java.util
@@ -1116,14 +1115,6 @@ object AggregateUtil extends Enumeration {
     new CountBundleTrigger[RowData](size)
   }
 
-  /** Compute field index of given timeField expression. */
-  def timeFieldIndex(
-      inputType: RelDataType,
-      relBuilder: RelBuilder,
-      timeField: FieldReferenceExpression): Int = {
-    relBuilder.values(inputType).field(timeField.getName).getIndex
-  }
-
   /** Computes the positions of (window start, window end, row time). */
   private[flink] def computeWindowPropertyPos(
       properties: Seq[NamedWindowProperty]): (Option[Int], Option[Int], 
Option[Int]) = {

Reply via email to