This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 683a2a7af5c248e998c0606ecccbfa2960d847cd Author: godfreyhe <[email protected]> AuthorDate: Thu Dec 24 18:31:58 2020 +0800 [FLINK-20737][table-planner-blink] Introduce StreamPhysicalGroupAggregate, and make StreamExecGroupAggregate only extended from ExecNode This closes #14478 --- .../exec/stream/StreamExecGroupAggregate.java | 202 +++++++++++++++++++++ .../plan/metadata/FlinkRelMdColumnInterval.scala | 6 +- .../plan/metadata/FlinkRelMdColumnUniqueness.scala | 2 +- .../FlinkRelMdFilteredColumnInterval.scala | 4 +- .../metadata/FlinkRelMdModifiedMonotonicity.scala | 2 +- .../plan/metadata/FlinkRelMdUniqueKeys.scala | 2 +- .../stream/StreamExecGlobalGroupAggregate.scala | 4 +- .../physical/stream/StreamExecGroupAggregate.scala | 192 -------------------- .../StreamExecIncrementalGroupAggregate.scala | 4 +- .../stream/StreamExecLocalGroupAggregate.scala | 4 +- .../stream/StreamExecPythonGroupAggregate.scala | 4 +- .../stream/StreamPhysicalGroupAggregate.scala | 97 ++++++++++ ...cala => StreamPhysicalGroupAggregateBase.scala} | 2 +- .../FlinkChangelogModeInferenceProgram.scala | 4 +- .../planner/plan/rules/FlinkStreamRuleSets.scala | 2 +- ...cala => StreamPhysicalGroupAggregateRule.scala} | 19 +- .../stream/TwoStageOptimizedAggregateRule.scala | 14 +- .../table/planner/plan/utils/AggregateUtil.scala | 55 ++++-- .../plan/metadata/FlinkRelMdHandlerTestBase.scala | 4 +- 19 files changed, 376 insertions(+), 247 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java new file mode 100644 index 0000000..1331812 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.exec.stream; + +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.streaming.api.operators.KeyedProcessOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.transformations.OneInputTransformation; +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.api.config.ExecutionConfigOptions; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.planner.codegen.CodeGeneratorContext; +import org.apache.flink.table.planner.codegen.EqualiserCodeGenerator; +import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator; +import org.apache.flink.table.planner.delegation.PlannerBase; +import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase; +import org.apache.flink.table.planner.plan.utils.AggregateInfoList; +import org.apache.flink.table.planner.plan.utils.AggregateUtil; +import org.apache.flink.table.planner.plan.utils.KeySelectorUtil; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; +import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; +import org.apache.flink.table.runtime.operators.aggregate.GroupAggFunction; +import org.apache.flink.table.runtime.operators.aggregate.MiniBatchGroupAggFunction; +import org.apache.flink.table.runtime.operators.bundle.KeyedMapBundleOperator; +import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.util.Preconditions; + +import org.apache.calcite.rel.core.AggregateCall; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.Collections; + +/** + * Stream {@link ExecNode} for unbounded group aggregate. + * + * <p>This node does support un-splittable aggregate function (e.g. STDDEV_POP). + */ +public class StreamExecGroupAggregate extends ExecNodeBase<RowData> + implements StreamExecNode<RowData> { + private static final Logger LOG = LoggerFactory.getLogger(StreamExecGroupAggregate.class); + + private final int[] grouping; + private final AggregateCall[] aggCalls; + /** Each element indicates whether the corresponding agg call needs `retract` method. */ + private final boolean[] aggCallNeedRetractions; + /** Whether this node will generate UPDATE_BEFORE messages. */ + private final boolean generateUpdateBefore; + /** Whether this node consumes retraction messages. */ + private final boolean needRetraction; + + public StreamExecGroupAggregate( + int[] grouping, + AggregateCall[] aggCalls, + boolean[] aggCallNeedRetractions, + boolean generateUpdateBefore, + boolean needRetraction, + ExecEdge inputEdge, + RowType outputType, + String description) { + super(Collections.singletonList(inputEdge), outputType, description); + Preconditions.checkArgument(aggCalls.length == aggCallNeedRetractions.length); + this.grouping = grouping; + this.aggCalls = aggCalls; + this.aggCallNeedRetractions = aggCallNeedRetractions; + this.generateUpdateBefore = generateUpdateBefore; + this.needRetraction = needRetraction; + } + + @SuppressWarnings("unchecked") + @Override + protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) { + final TableConfig tableConfig = planner.getTableConfig(); + if (grouping.length > 0 && tableConfig.getMinIdleStateRetentionTime() < 0) { + LOG.warn( + "No state retention interval configured for a query which accumulates state. " + + "Please provide a query configuration with valid retention interval to prevent excessive " + + "state size. You may specify a retention time of 0 to not clean up the state."); + } + + final ExecNode<RowData> inputNode = (ExecNode<RowData>) getInputNodes().get(0); + final Transformation<RowData> inputTransform = inputNode.translateToPlan(planner); + final RowType inputRowType = (RowType) inputNode.getOutputType(); + + final AggsHandlerCodeGenerator generator = + new AggsHandlerCodeGenerator( + new CodeGeneratorContext(tableConfig), + planner.getRelBuilder(), + JavaScalaConversionUtil.toScala(inputRowType.getChildren()), + // TODO: heap state backend do not copy key currently, + // we have to copy input field + // TODO: copy is not need when state backend is rocksdb, + // improve this in future + // TODO: but other operators do not copy this input field..... + true) + .needAccumulate(); + + if (needRetraction) { + generator.needRetract(); + } + + final AggregateInfoList aggInfoList = + AggregateUtil.transformToStreamAggregateInfoList( + inputRowType, + JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), + aggCallNeedRetractions, + needRetraction, + true, + true); + final GeneratedAggsHandleFunction aggsHandler = + generator.generateAggsHandler("GroupAggsHandler", aggInfoList); + + final LogicalType[] accTypes = + Arrays.stream(aggInfoList.getAccTypes()) + .map(LogicalTypeDataTypeConverter::fromDataTypeToLogicalType) + .toArray(LogicalType[]::new); + final LogicalType[] aggValueTypes = + Arrays.stream(aggInfoList.getActualValueTypes()) + .map(LogicalTypeDataTypeConverter::fromDataTypeToLogicalType) + .toArray(LogicalType[]::new); + final GeneratedRecordEqualiser recordEqualiser = + new EqualiserCodeGenerator(aggValueTypes) + .generateRecordEqualiser("GroupAggValueEqualiser"); + final int inputCountIndex = aggInfoList.getIndexOfCountStar(); + final boolean isMiniBatchEnabled = + tableConfig + .getConfiguration() + .getBoolean(ExecutionConfigOptions.TABLE_EXEC_MINIBATCH_ENABLED); + + final OneInputStreamOperator<RowData, RowData> operator; + if (isMiniBatchEnabled) { + MiniBatchGroupAggFunction aggFunction = + new MiniBatchGroupAggFunction( + aggsHandler, + recordEqualiser, + accTypes, + inputRowType, + inputCountIndex, + generateUpdateBefore, + tableConfig.getIdleStateRetention().toMillis()); + operator = + new KeyedMapBundleOperator<>( + aggFunction, AggregateUtil.createMiniBatchTrigger(tableConfig)); + } else { + GroupAggFunction aggFunction = + new GroupAggFunction( + aggsHandler, + recordEqualiser, + accTypes, + inputCountIndex, + generateUpdateBefore, + tableConfig.getIdleStateRetention().toMillis()); + operator = new KeyedProcessOperator<>(aggFunction); + } + + // partitioned aggregation + final OneInputTransformation<RowData, RowData> transform = + new OneInputTransformation<>( + inputTransform, + getDesc(), + operator, + InternalTypeInfo.of(getOutputType()), + inputTransform.getParallelism()); + + if (inputsContainSingleton()) { + transform.setParallelism(1); + transform.setMaxParallelism(1); + } + + // set KeyType and Selector for state + final RowDataKeySelector selector = + KeySelectorUtil.getRowDataSelector(grouping, InternalTypeInfo.of(inputRowType)); + transform.setStateKeySelector(selector); + transform.setStateKeyType(selector.getProducedType()); + + return transform; + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala index 36251fa..2e33587 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala @@ -433,7 +433,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { * @return interval of the given column on stream group Aggregate */ def getColumnInterval( - aggregate: StreamExecGroupAggregate, + aggregate: StreamPhysicalGroupAggregate, mq: RelMetadataQuery, index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index) @@ -535,7 +535,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { val input = aggregate.getInput val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq) val groupSet = aggregate match { - case agg: StreamExecGroupAggregate => agg.grouping + case agg: StreamPhysicalGroupAggregate => agg.grouping case agg: StreamExecLocalGroupAggregate => agg.grouping case agg: StreamExecGlobalGroupAggregate => agg.grouping case agg: StreamExecIncrementalGroupAggregate => agg.partialAggGrouping @@ -595,7 +595,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } else { val aggCallIndex = index - groupSet.length val aggCall = aggregate match { - case agg: StreamExecGroupAggregate if agg.aggCalls.length > aggCallIndex => + case agg: StreamPhysicalGroupAggregate if agg.aggCalls.length > aggCallIndex => agg.aggCalls(aggCallIndex) case agg: StreamExecGlobalGroupAggregate if agg.globalAggInfoList.getActualAggregateCalls.length > aggCallIndex => diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala index ce091ff..5799ab0 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala @@ -341,7 +341,7 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata } def areColumnsUnique( - rel: StreamExecGroupAggregate, + rel: StreamPhysicalGroupAggregate, mq: RelMetadataQuery, columns: ImmutableBitSet, ignoreNulls: Boolean): JBoolean = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala index 11c8383..c851c96 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala @@ -20,7 +20,7 @@ package org.apache.flink.table.planner.plan.metadata import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.FilteredColumnInterval import org.apache.flink.table.planner.plan.nodes.calcite.TableAggregate import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecGroupAggregateBase -import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamExecGroupAggregate, StreamExecGroupTableAggregate, StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamExecLocalGroupAggregate} +import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGlobalGroupAggregate, StreamPhysicalGroupAggregate, StreamExecGroupTableAggregate, StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamExecLocalGroupAggregate} import org.apache.flink.table.planner.plan.stats.ValueInterval import org.apache.flink.table.planner.plan.utils.ColumnIntervalUtil import org.apache.flink.util.Preconditions.checkArgument @@ -184,7 +184,7 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC } def getFilteredColumnInterval( - aggregate: StreamExecGroupAggregate, + aggregate: StreamPhysicalGroupAggregate, mq: RelMetadataQuery, columnIndex: Int, filterArg: Int): ValueInterval = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala index a17775b..d46ac17 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala @@ -276,7 +276,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon mq: RelMetadataQuery): RelModifiedMonotonicity = null def getRelModifiedMonotonicity( - rel: StreamExecGroupAggregate, + rel: StreamPhysicalGroupAggregate, mq: RelMetadataQuery): RelModifiedMonotonicity = { getRelModifiedMonotonicityOnAggregate(rel.getInput, mq, rel.aggCalls.toList, rel.grouping) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala index ecdf6d1..bf9c1eb 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala @@ -345,7 +345,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu } def getUniqueKeys( - rel: StreamExecGroupAggregate, + rel: StreamPhysicalGroupAggregate, mq: RelMetadataQuery, ignoreNulls: Boolean): JSet[ImmutableBitSet] = { getUniqueKeysOnAggregate(rel.grouping, mq, ignoreNulls) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGlobalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGlobalGroupAggregate.scala index f2842f8..6c44b33 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGlobalGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGlobalGroupAggregate.scala @@ -46,7 +46,7 @@ import scala.collection.JavaConversions._ /** * Stream physical RelNode for unbounded global group aggregate. * - * @see [[StreamExecGroupAggregateBase]] for more info. + * @see [[StreamPhysicalGroupAggregateBase]] for more info. */ class StreamExecGlobalGroupAggregate( cluster: RelOptCluster, @@ -58,7 +58,7 @@ class StreamExecGlobalGroupAggregate( val localAggInfoList: AggregateInfoList, val globalAggInfoList: AggregateInfoList, val partialFinalType: PartialFinalType) - extends StreamExecGroupAggregateBase(cluster, traitSet, inputRel) + extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) with LegacyStreamExecNode[RowData] { override def requireWatermark: Boolean = false diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregate.scala deleted file mode 100644 index f7ae1d2..0000000 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregate.scala +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.planner.plan.nodes.physical.stream - -import org.apache.flink.api.dag.Transformation -import org.apache.flink.streaming.api.operators.KeyedProcessOperator -import org.apache.flink.streaming.api.transformations.OneInputTransformation -import org.apache.flink.table.api.config.ExecutionConfigOptions -import org.apache.flink.table.data.RowData -import org.apache.flink.table.planner.calcite.FlinkTypeFactory -import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator -import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, EqualiserCodeGenerator} -import org.apache.flink.table.planner.delegation.StreamPlanner -import org.apache.flink.table.planner.plan.PartialFinalType -import org.apache.flink.table.planner.plan.nodes.exec.LegacyStreamExecNode -import org.apache.flink.table.planner.plan.utils.{AggregateInfoList, AggregateUtil, ChangelogPlanUtils, KeySelectorUtil, RelExplainUtil} -import org.apache.flink.table.runtime.operators.aggregate.{GroupAggFunction, MiniBatchGroupAggFunction} -import org.apache.flink.table.runtime.operators.bundle.KeyedMapBundleOperator -import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType -import org.apache.flink.table.runtime.typeutils.InternalTypeInfo - -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rel.core.AggregateCall -import org.apache.calcite.rel.{RelNode, RelWriter} - -import java.util - -import scala.collection.JavaConversions._ - -/** - * Stream physical RelNode for unbounded group aggregate. - * - * This node does support un-splittable aggregate function (e.g. STDDEV_POP). - * - * @see [[StreamExecGroupAggregateBase]] for more info. - */ -class StreamExecGroupAggregate( - cluster: RelOptCluster, - traitSet: RelTraitSet, - inputRel: RelNode, - outputRowType: RelDataType, - val grouping: Array[Int], - val aggCalls: Seq[AggregateCall], - var partialFinalType: PartialFinalType = PartialFinalType.NONE) - extends StreamExecGroupAggregateBase(cluster, traitSet, inputRel) - with LegacyStreamExecNode[RowData] { - - val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList( - this, - grouping.length, - aggCalls) - - override def requireWatermark: Boolean = false - - override def deriveRowType(): RelDataType = outputRowType - - override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { - new StreamExecGroupAggregate( - cluster, - traitSet, - inputs.get(0), - outputRowType, - grouping, - aggCalls, - partialFinalType) - } - - override def explainTerms(pw: RelWriter): RelWriter = { - val inputRowType = getInput.getRowType - super.explainTerms(pw) - .itemIf("groupBy", - RelExplainUtil.fieldToString(grouping, inputRowType), grouping.nonEmpty) - .itemIf("partialFinalType", partialFinalType, partialFinalType != PartialFinalType.NONE) - .item("select", RelExplainUtil.streamGroupAggregationToString( - inputRowType, - getRowType, - aggInfoList, - grouping)) - } - - //~ ExecNode methods ----------------------------------------------------------- - - override protected def translateToPlanInternal( - planner: StreamPlanner): Transformation[RowData] = { - - val tableConfig = planner.getTableConfig - - if (grouping.length > 0 && tableConfig.getMinIdleStateRetentionTime < 0) { - LOG.warn("No state retention interval configured for a query which accumulates state. " + - "Please provide a query configuration with valid retention interval to prevent excessive " + - "state size. You may specify a retention time of 0 to not clean up the state.") - } - - val inputTransformation = getInputNodes.get(0).translateToPlan(planner) - .asInstanceOf[Transformation[RowData]] - - val outRowType = FlinkTypeFactory.toLogicalRowType(outputRowType) - val inputRowType = FlinkTypeFactory.toLogicalRowType(getInput.getRowType) - - val generateUpdateBefore = ChangelogPlanUtils.generateUpdateBefore(this) - val needRetraction = !ChangelogPlanUtils.inputInsertOnly(this) - - val generator = new AggsHandlerCodeGenerator( - CodeGeneratorContext(tableConfig), - planner.getRelBuilder, - inputRowType.getChildren, - // TODO: heap state backend do not copy key currently, we have to copy input field - // TODO: copy is not need when state backend is rocksdb, improve this in future - // TODO: but other operators do not copy this input field..... - copyInputField = true) - - if (needRetraction) { - generator.needRetract() - } - - val aggsHandler = generator - .needAccumulate() - .generateAggsHandler("GroupAggsHandler", aggInfoList) - val accTypes = aggInfoList.getAccTypes.map(fromDataTypeToLogicalType) - val aggValueTypes = aggInfoList.getActualValueTypes.map(fromDataTypeToLogicalType) - val recordEqualiser = new EqualiserCodeGenerator(aggValueTypes) - .generateRecordEqualiser("GroupAggValueEqualiser") - val inputCountIndex = aggInfoList.getIndexOfCountStar - - val isMiniBatchEnabled = tableConfig.getConfiguration.getBoolean( - ExecutionConfigOptions.TABLE_EXEC_MINIBATCH_ENABLED) - - val operator = if (isMiniBatchEnabled) { - val aggFunction = new MiniBatchGroupAggFunction( - aggsHandler, - recordEqualiser, - accTypes, - inputRowType, - inputCountIndex, - generateUpdateBefore, - tableConfig.getIdleStateRetention.toMillis) - - new KeyedMapBundleOperator( - aggFunction, - AggregateUtil.createMiniBatchTrigger(tableConfig)) - } else { - val aggFunction = new GroupAggFunction( - aggsHandler, - recordEqualiser, - accTypes, - inputCountIndex, - generateUpdateBefore, - tableConfig.getIdleStateRetention.toMillis) - - val operator = new KeyedProcessOperator[RowData, RowData, RowData](aggFunction) - operator - } - - val selector = KeySelectorUtil.getRowDataSelector( - grouping, - InternalTypeInfo.of(inputRowType)) - - // partitioned aggregation - val ret = new OneInputTransformation( - inputTransformation, - getRelDetailedDescription, - operator, - InternalTypeInfo.of(outRowType), - inputTransformation.getParallelism) - - if (inputsContainSingleton()) { - ret.setParallelism(1) - ret.setMaxParallelism(1) - } - - // set KeyType and Selector for state - ret.setStateKeySelector(selector) - ret.setStateKeyType(selector.getProducedType) - ret - } -} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala index 9ec2b8d..9706253 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala @@ -67,7 +67,7 @@ import scala.collection.JavaConversions._ * +- StreamExecLocalGroupAggregate (partial-local-aggregate) * }}} * - * @see [[StreamExecGroupAggregateBase]] for more info. + * @see [[StreamPhysicalGroupAggregateBase]] for more info. */ class StreamExecIncrementalGroupAggregate( cluster: RelOptCluster, @@ -80,7 +80,7 @@ class StreamExecIncrementalGroupAggregate( val finalAggCalls: Seq[AggregateCall], val finalAggGrouping: Array[Int], val partialAggGrouping: Array[Int]) - extends StreamExecGroupAggregateBase(cluster, traitSet, inputRel) + extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) with LegacyStreamExecNode[RowData] { override def deriveRowType(): RelDataType = outputRowType diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecLocalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecLocalGroupAggregate.scala index f7198bc..730695c 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecLocalGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecLocalGroupAggregate.scala @@ -44,7 +44,7 @@ import scala.collection.JavaConversions._ /** * Stream physical RelNode for unbounded local group aggregate. * - * @see [[StreamExecGroupAggregateBase]] for more info. + * @see [[StreamPhysicalGroupAggregateBase]] for more info. */ class StreamExecLocalGroupAggregate( cluster: RelOptCluster, @@ -55,7 +55,7 @@ class StreamExecLocalGroupAggregate( val aggCalls: Seq[AggregateCall], val aggInfoList: AggregateInfoList, val partialFinalType: PartialFinalType) - extends StreamExecGroupAggregateBase(cluster, traitSet, inputRel) + extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) with LegacyStreamExecNode[RowData] { override def requireWatermark: Boolean = false diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupAggregate.scala index aabf675..3570ca1 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupAggregate.scala @@ -44,7 +44,7 @@ import java.util /** * Stream physical RelNode for Python unbounded group aggregate. * - * @see [[StreamExecGroupAggregateBase]] for more info. + * @see [[StreamPhysicalGroupAggregateBase]] for more info. */ class StreamExecPythonGroupAggregate( cluster: RelOptCluster, @@ -53,7 +53,7 @@ class StreamExecPythonGroupAggregate( outputRowType: RelDataType, val grouping: Array[Int], val aggCalls: Seq[AggregateCall]) - extends StreamExecGroupAggregateBase(cluster, traitSet, inputRel) + extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) with LegacyStreamExecNode[RowData] with CommonPythonAggregate { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGroupAggregate.scala new file mode 100644 index 0000000..dcac5ec --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGroupAggregate.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.planner.plan.nodes.physical.stream + +import org.apache.flink.table.planner.calcite.FlinkTypeFactory +import org.apache.flink.table.planner.plan.PartialFinalType +import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecGroupAggregate +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode} +import org.apache.flink.table.planner.plan.utils.{AggregateUtil, ChangelogPlanUtils, RelExplainUtil} + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.{RelNode, RelWriter} + +import java.util + +/** + * Stream physical RelNode for unbounded group aggregate. + * + * This node does support un-splittable aggregate function (e.g. STDDEV_POP). + * + * @see [[StreamPhysicalGroupAggregateBase]] for more info. + */ +class StreamPhysicalGroupAggregate( + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputRel: RelNode, + outputRowType: RelDataType, + val grouping: Array[Int], + val aggCalls: Seq[AggregateCall], + var partialFinalType: PartialFinalType = PartialFinalType.NONE) + extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) { + + private val aggInfoList = + AggregateUtil.deriveAggregateInfoList(this, grouping.length, aggCalls) + + override def requireWatermark: Boolean = false + + override def deriveRowType(): RelDataType = outputRowType + + override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { + new StreamPhysicalGroupAggregate( + cluster, + traitSet, + inputs.get(0), + outputRowType, + grouping, + aggCalls, + partialFinalType) + } + + override def explainTerms(pw: RelWriter): RelWriter = { + val inputRowType = getInput.getRowType + super.explainTerms(pw) + .itemIf("groupBy", + RelExplainUtil.fieldToString(grouping, inputRowType), grouping.nonEmpty) + .itemIf("partialFinalType", partialFinalType, partialFinalType != PartialFinalType.NONE) + .item("select", RelExplainUtil.streamGroupAggregationToString( + inputRowType, + getRowType, + aggInfoList, + grouping)) + } + + override def translateToExecNode(): ExecNode[_] = { + val aggCallNeedRetractions = + AggregateUtil.deriveAggCallNeedRetractions(this, grouping.length, aggCalls) + val generateUpdateBefore = ChangelogPlanUtils.generateUpdateBefore(this) + val needRetraction = !ChangelogPlanUtils.inputInsertOnly(this) + new StreamExecGroupAggregate( + grouping, + aggCalls.toArray, + aggCallNeedRetractions, + generateUpdateBefore, + needRetraction, + ExecEdge.DEFAULT, + FlinkTypeFactory.toLogicalRowType(getRowType), + getRelDetailedDescription + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGroupAggregateBase.scala similarity index 98% rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregateBase.scala rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGroupAggregateBase.scala index 7e4ff91..6bbeacc 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecGroupAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGroupAggregateBase.scala @@ -44,7 +44,7 @@ import org.apache.calcite.rel.{RelNode, SingleRel} * * <p>NOTES: partial-aggregation supports local-global mode, so does final-aggregation. */ -abstract class StreamExecGroupAggregateBase( +abstract class StreamPhysicalGroupAggregateBase( cluster: RelOptCluster, traitSet: RelTraitSet, inputRel: RelNode) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala index ab6a76d..8d276ca 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala @@ -164,7 +164,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti } createNewNode(deduplicate, children, providedTrait, requiredTrait, requester) - case agg: StreamExecGroupAggregate => + case agg: StreamPhysicalGroupAggregate => // agg support all changes in input val children = visitChildren(agg, ModifyKindSetTrait.ALL_CHANGES) val inputModifyKindSet = getModifyKindSet(children.head) @@ -461,7 +461,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti } visitSink(sink, sinkRequiredTraits) - case _: StreamExecGroupAggregate | _: StreamExecGroupTableAggregate | + case _: StreamPhysicalGroupAggregate | _: StreamExecGroupTableAggregate | _: StreamPhysicalLimit | _: StreamExecPythonGroupAggregate | _: StreamExecPythonGroupTableAggregate => // Aggregate, TableAggregate and Limit requires update_before if there are updates diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala index 5e83109..d7c695d 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala @@ -420,7 +420,7 @@ object FlinkStreamRuleSets { // expand StreamPhysicalExpandRule.INSTANCE, // group agg - StreamExecGroupAggregateRule.INSTANCE, + StreamPhysicalGroupAggregateRule.INSTANCE, StreamExecGroupTableAggregateRule.INSTANCE, StreamExecPythonGroupAggregateRule.INSTANCE, StreamExecPythonGroupTableAggregateRule.INSTANCE, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecGroupAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalGroupAggregateRule.scala similarity index 88% rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecGroupAggregateRule.scala rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalGroupAggregateRule.scala index c38dc90..b6778f3 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecGroupAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalGroupAggregateRule.scala @@ -21,9 +21,10 @@ package org.apache.flink.table.planner.plan.rules.physical.stream import org.apache.flink.table.api.TableException import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.nodes.FlinkConventions -import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate -import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamExecGroupAggregate +import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalGroupAggregate +import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate + import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.convert.ConverterRule @@ -32,14 +33,14 @@ import org.apache.calcite.rel.core.Aggregate.Group import scala.collection.JavaConversions._ /** - * Rule to convert a [[FlinkLogicalAggregate]] into a [[StreamExecGroupAggregate]]. - */ -class StreamExecGroupAggregateRule + * Rule to convert a [[FlinkLogicalAggregate]] into a [[StreamPhysicalGroupAggregate]]. + */ +class StreamPhysicalGroupAggregateRule extends ConverterRule( classOf[FlinkLogicalAggregate], FlinkConventions.LOGICAL, FlinkConventions.STREAM_PHYSICAL, - "StreamExecGroupAggregateRule") { + "StreamPhysicalGroupAggregateRule") { override def matches(call: RelOptRuleCall): Boolean = { val agg: FlinkLogicalAggregate = call.rel(0) @@ -65,7 +66,7 @@ class StreamExecGroupAggregateRule val providedTraitSet = rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL) val newInput: RelNode = RelOptRule.convert(agg.getInput, requiredTraitSet) - new StreamExecGroupAggregate( + new StreamPhysicalGroupAggregate( rel.getCluster, providedTraitSet, newInput, @@ -76,6 +77,6 @@ class StreamExecGroupAggregateRule } } -object StreamExecGroupAggregateRule { - val INSTANCE: RelOptRule = new StreamExecGroupAggregateRule +object StreamPhysicalGroupAggregateRule { + val INSTANCE: RelOptRule = new StreamPhysicalGroupAggregateRule } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala index afb24fb..7f01e5c 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/TwoStageOptimizedAggregateRule.scala @@ -35,7 +35,7 @@ import org.apache.calcite.rel.RelNode import java.util /** - * Rule that matches [[StreamExecGroupAggregate]] on [[StreamPhysicalExchange]] + * Rule that matches [[StreamPhysicalGroupAggregate]] on [[StreamPhysicalExchange]] * with the following condition: * 1. mini-batch is enabled in given TableConfig, * 2. two-phase aggregation is enabled in given TableConfig, @@ -51,21 +51,21 @@ import java.util * }}} */ class TwoStageOptimizedAggregateRule extends RelOptRule( - operand(classOf[StreamExecGroupAggregate], + operand(classOf[StreamPhysicalGroupAggregate], operand(classOf[StreamPhysicalExchange], operand(classOf[RelNode], any))), "TwoStageOptimizedAggregateRule") { override def matches(call: RelOptRuleCall): Boolean = { val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig - val agg: StreamExecGroupAggregate = call.rel(0) + val agg: StreamPhysicalGroupAggregate = call.rel(0) val realInput: RelNode = call.rel(2) val needRetraction = !ChangelogPlanUtils.isInsertOnly( realInput.asInstanceOf[StreamPhysicalRel]) val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery) val monotonicity = fmq.getRelModifiedMonotonicity(agg) - val needRetractionArray = AggregateUtil.getNeedRetractions( + val needRetractionArray = AggregateUtil.deriveAggCallNeedRetractions( agg.grouping.length, agg.aggCalls, needRetraction, monotonicity) val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList( @@ -91,13 +91,13 @@ class TwoStageOptimizedAggregateRule extends RelOptRule( } override def onMatch(call: RelOptRuleCall): Unit = { - val agg: StreamExecGroupAggregate = call.rel(0) + val agg: StreamPhysicalGroupAggregate = call.rel(0) val realInput: RelNode = call.rel(2) val needRetraction = !ChangelogPlanUtils.isInsertOnly( realInput.asInstanceOf[StreamPhysicalRel]) val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery) val monotonicity = fmq.getRelModifiedMonotonicity(agg) - val needRetractionArray = AggregateUtil.getNeedRetractions( + val needRetractionArray = AggregateUtil.deriveAggCallNeedRetractions( agg.grouping.length, agg.aggCalls, needRetraction, monotonicity) val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList( @@ -124,7 +124,7 @@ class TwoStageOptimizedAggregateRule extends RelOptRule( input: RelNode, localAggInfoList: AggregateInfoList, globalAggInfoList: AggregateInfoList, - agg: StreamExecGroupAggregate): StreamExecGlobalGroupAggregate = { + agg: StreamPhysicalGroupAggregate): StreamExecGlobalGroupAggregate = { val localAggRowType = AggregateUtil.inferLocalAggRowType( localAggInfoList, input.getRowType, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index 01fd4d4..ea1205b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -180,22 +180,13 @@ object AggregateUtil extends Enumeration { groupCount: Int, aggCalls: Seq[AggregateCall]): AggregateInfoList = { val input = agg.getInput(0) - // need to call `retract()` if input contains update or delete - val modifyKindSetTrait = input.getTraitSet.getTrait(ModifyKindSetTraitDef.INSTANCE) - val needRetraction = if (modifyKindSetTrait == null) { - // FlinkChangelogModeInferenceProgram is not applied yet, false as default - false - } else { - !modifyKindSetTrait.modifyKindSet.isInsertOnly - } - val fmq = FlinkRelMetadataQuery.reuseOrCreate(agg.getCluster.getMetadataQuery) - val monotonicity = fmq.getRelModifiedMonotonicity(agg) - val needRetractionArray = getNeedRetractions(groupCount, aggCalls, needRetraction, monotonicity) + val aggCallNeedRetractions = deriveAggCallNeedRetractions(agg, groupCount, aggCalls) + val needInputCount = needRetraction(agg) transformToStreamAggregateInfoList( FlinkTypeFactory.toLogicalRowType(input.getRowType), aggCalls, - needRetractionArray, - needInputCount = needRetraction, + aggCallNeedRetractions, + needInputCount, isStateBackendDataViews = true) } @@ -782,10 +773,40 @@ object AggregateUtil extends Enumeration { } /** - * Optimize max or min with retraction agg. MaxWithRetract can be optimized to Max if input is - * update increasing. - */ - def getNeedRetractions( + * Return true if the given agg rel needs retraction message, else false. + */ + def needRetraction(agg: StreamPhysicalRel): Boolean = { + // need to call `retract()` if input contains update or delete + val modifyKindSetTrait = agg.getInput(0).getTraitSet.getTrait(ModifyKindSetTraitDef.INSTANCE) + if (modifyKindSetTrait == null) { + // FlinkChangelogModeInferenceProgram is not applied yet, false as default + false + } else { + !modifyKindSetTrait.modifyKindSet.isInsertOnly + } + } + + /** + * Return the retraction flags for each given agg calls, currently MAX and MIN are supported. + * MaxWithRetract can be optimized to Max if input is update increasing, + * MinWithRetract can be optimized to Min if input is update decreasing. + */ + def deriveAggCallNeedRetractions( + agg: StreamPhysicalRel, + groupCount: Int, + aggCalls: Seq[AggregateCall]): Array[Boolean] = { + val fmq = FlinkRelMetadataQuery.reuseOrCreate(agg.getCluster.getMetadataQuery) + val monotonicity = fmq.getRelModifiedMonotonicity(agg) + val needRetractionFlag = needRetraction(agg) + deriveAggCallNeedRetractions(groupCount, aggCalls, needRetractionFlag, monotonicity) + } + + /** + * Return the retraction flags for each given agg calls, currently max and min are supported. + * MaxWithRetract can be optimized to Max if input is update increasing, + * MinWithRetract can be optimized to Min if input is update decreasing. + */ + def deriveAggCallNeedRetractions( groupCount: Int, aggCalls: Seq[AggregateCall], needRetraction: Boolean, diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala index eaa38bf..f0e623f 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala @@ -1019,7 +1019,7 @@ class FlinkRelMdHandlerTestBase { aggCallToAggFunction, isMerge = false) - val needRetractionArray = AggregateUtil.getNeedRetractions( + val needRetractionArray = AggregateUtil.deriveAggCallNeedRetractions( 1, aggCalls, needRetraction = false, null) val localAggInfoList = transformToStreamAggregateInfoList( @@ -1059,7 +1059,7 @@ class FlinkRelMdHandlerTestBase { val streamExchange2 = new StreamPhysicalExchange(cluster, studentStreamScan.getTraitSet.replace(hash3), studentStreamScan, hash3) - val streamGlobalAggWithoutLocal = new StreamExecGroupAggregate( + val streamGlobalAggWithoutLocal = new StreamPhysicalGroupAggregate( cluster, streamPhysicalTraits, streamExchange2,
