This is an automated email from the ASF dual-hosted git repository. jchan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new b957480112c [FLINK-33941][table-planner] Use field reference index to compute window aggregate time attribute column b957480112c is described below commit b957480112c00d9d777247fc48b602e9908652a2 Author: Xuyang <xyzhong...@163.com> AuthorDate: Thu Jan 4 10:47:00 2024 +0800 [FLINK-33941][table-planner] Use field reference index to compute window aggregate time attribute column This closes #23991 --- .../stream/StreamExecGroupWindowAggregate.java | 7 +--- .../StreamExecPythonGroupWindowAggregate.java | 7 +--- .../logical/FlinkAggregateProjectMergeRule.java | 48 +++++++++++++++++----- .../BatchPhysicalPythonWindowAggregateRule.java | 4 +- .../table/planner/plan/logical/groupWindows.scala | 14 +++++++ .../nodes/calcite/LogicalWindowAggregate.scala | 18 ++++++++ .../batch/BatchPhysicalWindowAggregateRule.scala | 3 +- .../table/planner/plan/utils/AggregateUtil.scala | 9 ---- 8 files changed, 74 insertions(+), 36 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java index 40471878046..d1a48c64077 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupWindowAggregate.java @@ -84,7 +84,6 @@ import static org.apache.flink.table.planner.plan.utils.AggregateUtil.hasTimeInt import static org.apache.flink.table.planner.plan.utils.AggregateUtil.isProctimeAttribute; import static org.apache.flink.table.planner.plan.utils.AggregateUtil.isRowtimeAttribute; import static org.apache.flink.table.planner.plan.utils.AggregateUtil.isTableAggregate; -import static org.apache.flink.table.planner.plan.utils.AggregateUtil.timeFieldIndex; import static org.apache.flink.table.planner.plan.utils.AggregateUtil.toDuration; import static org.apache.flink.table.planner.plan.utils.AggregateUtil.toLong; import static org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToStreamAggregateInfoList; @@ -211,11 +210,7 @@ public class StreamExecGroupWindowAggregate extends StreamExecAggregateBase { final int inputTimeFieldIndex; if (isRowtimeAttribute(window.timeAttribute())) { - inputTimeFieldIndex = - timeFieldIndex( - planner.getTypeFactory().buildRelNodeRowType(inputRowType), - planner.createRelBuilder(), - window.timeAttribute()); + inputTimeFieldIndex = window.timeAttribute().getFieldIndex(); if (inputTimeFieldIndex < 0) { throw new TableException( "Group window must defined on a time attribute, " diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java index ec908fb1181..d6fc11fee73 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java @@ -93,7 +93,6 @@ import static org.apache.flink.table.planner.plan.utils.AggregateUtil.hasRowInte import static org.apache.flink.table.planner.plan.utils.AggregateUtil.hasTimeIntervalType; import static org.apache.flink.table.planner.plan.utils.AggregateUtil.isProctimeAttribute; import static org.apache.flink.table.planner.plan.utils.AggregateUtil.isRowtimeAttribute; -import static org.apache.flink.table.planner.plan.utils.AggregateUtil.timeFieldIndex; import static org.apache.flink.table.planner.plan.utils.AggregateUtil.toDuration; import static org.apache.flink.table.planner.plan.utils.AggregateUtil.toLong; import static org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToStreamAggregateInfoList; @@ -234,11 +233,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas final int inputTimeFieldIndex; if (isRowtimeAttribute(window.timeAttribute())) { - inputTimeFieldIndex = - timeFieldIndex( - planner.getTypeFactory().buildRelNodeRowType(inputRowType), - planner.createRelBuilder(), - window.timeAttribute()); + inputTimeFieldIndex = window.timeAttribute().getFieldIndex(); if (inputTimeFieldIndex < 0) { throw new TableException( "Group window must defined on a time attribute, " diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateProjectMergeRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateProjectMergeRule.java index 8f39f2f001b..dc5514872cc 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateProjectMergeRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateProjectMergeRule.java @@ -18,6 +18,8 @@ package org.apache.flink.table.planner.plan.rules.logical; +import org.apache.flink.table.expressions.FieldReferenceExpression; +import org.apache.flink.table.planner.plan.logical.LogicalWindow; import org.apache.flink.table.planner.plan.nodes.calcite.LogicalWindowAggregate; import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; @@ -54,7 +56,8 @@ import static java.util.Objects.requireNonNull; * <p>FLINK modifications are at lines * * <ol> - * <li>Should be removed after legacy groupWindowAggregate was removed: Lines 83 ~ 101 + * <li>Should be removed after legacy groupWindowAggregate was removed: Lines 85 ~ 105, Lines 136 + * ~ 156 * </ol> */ public class FlinkAggregateProjectMergeRule extends AggregateProjectMergeRule { @@ -79,18 +82,19 @@ public class FlinkAggregateProjectMergeRule extends AggregateProjectMergeRule { RelOptRuleCall call, Aggregate aggregate, Project project) { // Find all fields which we need to be straightforward field projections. final Set<Integer> interestingFields = RelOptUtil.getAllFields(aggregate); + boolean isProctimeWindowAgg = false; // Should add the field of timeAttribute in a LogicalWindowAggregate node which uses rowTime if (aggregate instanceof LogicalWindowAggregate) { LogicalWindowAggregate winAgg = (LogicalWindowAggregate) aggregate; // isRowtimeAttribute can't be used here because the time_indicator phase comes later - boolean isProcTime = + isProctimeWindowAgg = LogicalTypeChecks.isProctimeAttribute( winAgg.getWindow() .timeAttribute() .getOutputDataType() .getLogicalType()); - if (!isProcTime) { + if (!isProctimeWindowAgg) { // no need to consider the inputIndex because LogicalWindowAggregate is single input interestingFields.add( ((LogicalWindowAggregate) aggregate) @@ -127,13 +131,37 @@ public class FlinkAggregateProjectMergeRule extends AggregateProjectMergeRule { aggCalls.add(aggregateCall.transform(targetMapping)); } - final Aggregate newAggregate = - aggregate.copy( - aggregate.getTraitSet(), - project.getInput(), - newGroupSet, - newGroupingSets, - aggCalls.build()); + final Aggregate newAggregate; + + if (aggregate instanceof LogicalWindowAggregate && !isProctimeWindowAgg) { + // update the index of the time field in window + LogicalWindowAggregate winAgg = (LogicalWindowAggregate) aggregate; + LogicalWindow window = winAgg.getWindow(); + int newTimeIndex = map.get(window.timeAttribute().getFieldIndex()); + LogicalWindow newWindow = + window.copy( + new FieldReferenceExpression( + window.timeAttribute().getName(), + window.timeAttribute().getOutputDataType(), + window.timeAttribute().getInputIndex(), + newTimeIndex)); + newAggregate = + winAgg.copy( + aggregate.getTraitSet(), + project.getInput(), + newGroupSet, + newGroupingSets, + aggCalls.build(), + newWindow); + } else { + newAggregate = + aggregate.copy( + aggregate.getTraitSet(), + project.getInput(), + newGroupSet, + newGroupingSets, + aggCalls.build()); + } // Add a project if the group set is not in the same order or // contains duplicates. diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalPythonWindowAggregateRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalPythonWindowAggregateRule.java index 5e35384c490..1e6d680390b 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalPythonWindowAggregateRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalPythonWindowAggregateRule.java @@ -130,9 +130,7 @@ public class BatchPhysicalPythonWindowAggregateRule extends RelOptRule { null); UserDefinedFunction[] aggFunctions = aggBufferTypesAndFunctions._3(); - int inputTimeFieldIndex = - AggregateUtil.timeFieldIndex( - input.getRowType(), call.builder(), window.timeAttribute()); + int inputTimeFieldIndex = window.timeAttribute().getFieldIndex(); RelDataType inputTimeFieldType = input.getRowType().getFieldList().get(inputTimeFieldIndex).getType(); boolean inputTimeIsDate = inputTimeFieldType.getSqlTypeName() == SqlTypeName.DATE; diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/logical/groupWindows.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/logical/groupWindows.scala index 35ef24122a4..a36837710bd 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/logical/groupWindows.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/logical/groupWindows.scala @@ -45,6 +45,8 @@ abstract class LogicalWindow( Objects.equals(timeAttribute, that.timeAttribute) } + def copy(newTimeAttribute: FieldReferenceExpression): LogicalWindow + protected def isValueLiteralExpressionEqual( l1: ValueLiteralExpression, l2: ValueLiteralExpression): Boolean = { @@ -90,6 +92,10 @@ case class TumblingGroupWindow( } } + override def copy(newTimeField: FieldReferenceExpression): LogicalWindow = { + TumblingGroupWindow(alias, newTimeField, size) + } + override def toString: String = s"TumblingGroupWindow($alias, $timeField, $size)" } @@ -113,6 +119,10 @@ case class SlidingGroupWindow( } } + override def copy(newTimeField: FieldReferenceExpression): LogicalWindow = { + SlidingGroupWindow(alias, newTimeField, size, slide) + } + override def toString: String = s"SlidingGroupWindow($alias, $timeField, $size, $slide)" } @@ -134,5 +144,9 @@ case class SessionGroupWindow( } } + override def copy(newTimeField: FieldReferenceExpression): LogicalWindow = { + SessionGroupWindow(alias, newTimeField, gap) + } + override def toString: String = s"SessionGroupWindow($alias, $timeField, $gap)" } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/LogicalWindowAggregate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/LogicalWindowAggregate.scala index f4acea9036e..0d9e223b268 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/LogicalWindowAggregate.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/calcite/LogicalWindowAggregate.scala @@ -64,6 +64,24 @@ final class LogicalWindowAggregate( window, namedProperties) } + + def copy( + traitSet: RelTraitSet, + input: RelNode, + groupSet: ImmutableBitSet, + // retain this to follow "Aggregate#copy" + groupSets: util.List[ImmutableBitSet], + aggCalls: util.List[AggregateCall], + window: LogicalWindow): Aggregate = { + new LogicalWindowAggregate( + cluster, + traitSet, + input, + groupSet, + aggCalls, + window, + namedProperties) + } } object LogicalWindowAggregate { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalWindowAggregateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalWindowAggregateRule.scala index 0402b0bd861..4b23f840431 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalWindowAggregateRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalWindowAggregateRule.scala @@ -160,8 +160,7 @@ class BatchPhysicalWindowAggregateRule // TODO aggregate include projection now, so do not provide new trait will be safe val aggProvidedTraitSet = input.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL) - val inputTimeFieldIndex = - AggregateUtil.timeFieldIndex(input.getRowType, call.builder(), window.timeAttribute) + val inputTimeFieldIndex = window.timeAttribute.getFieldIndex val inputTimeFieldType = agg.getInput.getRowType.getFieldList.get(inputTimeFieldIndex).getType val inputTimeIsDate = inputTimeFieldType.getSqlTypeName == SqlTypeName.DATE // local-agg output order: groupSet | assignTs | auxGroupSet | aggCalls diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index 73c5f2a09da..dbb85d2b72a 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -63,7 +63,6 @@ import org.apache.calcite.sql.`type`.{SqlTypeName, SqlTypeUtil} import org.apache.calcite.sql.{SqlAggFunction, SqlKind, SqlRankFunction} import org.apache.calcite.sql.fun._ import org.apache.calcite.sql.validate.SqlMonotonicity -import org.apache.calcite.tools.RelBuilder import java.time.Duration import java.util @@ -1116,14 +1115,6 @@ object AggregateUtil extends Enumeration { new CountBundleTrigger[RowData](size) } - /** Compute field index of given timeField expression. */ - def timeFieldIndex( - inputType: RelDataType, - relBuilder: RelBuilder, - timeField: FieldReferenceExpression): Int = { - relBuilder.values(inputType).field(timeField.getName).getIndex - } - /** Computes the positions of (window start, window end, row time). */ private[flink] def computeWindowPropertyPos( properties: Seq[NamedWindowProperty]): (Option[Int], Option[Int], Option[Int]) = {