This is an automated email from the ASF dual-hosted git repository.

starocean999 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new f3fdb306e4f [feat](nereids) add merge aggregate rule (#31811)
f3fdb306e4f is described below

commit f3fdb306e4fd7d405250d478463276cdc55b4e25
Author: feiniaofeiafei <53502832+feiniaofeia...@users.noreply.github.com>
AuthorDate: Wed Mar 13 12:05:47 2024 +0800

    [feat](nereids) add merge aggregate rule (#31811)
---
 .../doris/nereids/jobs/executor/Rewriter.java      |   4 +-
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../doris/nereids/rules/rewrite/ColumnPruning.java |  46 +++-
 .../nereids/rules/rewrite/MergeAggregate.java      | 211 ++++++++++++++++++
 .../nereids/trees/plans/logical/LogicalUnion.java  |   6 +
 .../org/apache/doris/nereids/util/PlanUtils.java   |   6 +
 .../merge_aggregate/merge_aggregate.out            | 248 +++++++++++++++++++++
 .../merge_aggregate/merge_aggregate.groovy         | 177 +++++++++++++++
 8 files changed, 695 insertions(+), 4 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index db463354120..c4fa3abd0b9 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -81,6 +81,7 @@ import org.apache.doris.nereids.rules.rewrite.InferPredicates;
 import org.apache.doris.nereids.rules.rewrite.InferSetOperatorDistinct;
 import org.apache.doris.nereids.rules.rewrite.InlineLogicalView;
 import org.apache.doris.nereids.rules.rewrite.LimitSortToTopN;
+import org.apache.doris.nereids.rules.rewrite.MergeAggregate;
 import org.apache.doris.nereids.rules.rewrite.MergeFilters;
 import org.apache.doris.nereids.rules.rewrite.MergeOneRowRelationIntoUnion;
 import org.apache.doris.nereids.rules.rewrite.MergeProjects;
@@ -341,7 +342,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
             ),
 
             topic("Eliminate GroupBy",
-                    topDown(new EliminateGroupBy())
+                    topDown(new EliminateGroupBy(),
+                            new MergeAggregate())
             ),
 
             topic("Eager aggregation",
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 90f2c222091..f7ca87a844b 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -197,6 +197,7 @@ public enum RuleType {
     MERGE_LIMITS(RuleTypeClass.REWRITE),
     MERGE_GENERATES(RuleTypeClass.REWRITE),
     // Eliminate plan
+    MERGE_AGGREGATE(RuleTypeClass.REWRITE),
     ELIMINATE_AGGREGATE(RuleTypeClass.REWRITE),
     ELIMINATE_LIMIT(RuleTypeClass.REWRITE),
     ELIMINATE_LIMIT_ON_ONE_ROW_RELATION(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java
index cf94caa25c8..4d0c9be368d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java
@@ -151,9 +151,7 @@ public class ColumnPruning extends 
DefaultPlanRewriter<PruneContext> implements
         if (union.getQualifier() == Qualifier.DISTINCT) {
             return skipPruneThisAndFirstLevelChildren(union);
         }
-
-        LogicalUnion prunedOutputUnion = pruneOutput(union, 
union.getOutputs(), union::pruneOutputs, context);
-
+        LogicalUnion prunedOutputUnion = pruneUnionOutput(union, context);
         // start prune children of union
         List<Slot> originOutput = union.getOutput();
         Set<Slot> prunedOutput = prunedOutputUnion.getOutputSet();
@@ -303,6 +301,48 @@ public class ColumnPruning extends 
DefaultPlanRewriter<PruneContext> implements
         }
     }
 
+    private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext 
context) {
+        List<NamedExpression> originOutput = union.getOutputs();
+        if (originOutput.isEmpty()) {
+            return union;
+        }
+        List<NamedExpression> prunedOutputs = Lists.newArrayList();
+        List<List<NamedExpression>> constantExprsList = 
union.getConstantExprsList();
+        List<List<NamedExpression>> prunedConstantExprsList = 
Lists.newArrayList();
+        List<Integer> extractColumnIndex = Lists.newArrayList();
+        for (int i = 0; i < originOutput.size(); i++) {
+            NamedExpression output = originOutput.get(i);
+            if (context.requiredSlots.contains(output.toSlot())) {
+                prunedOutputs.add(output);
+                extractColumnIndex.add(i);
+            }
+        }
+        int len = extractColumnIndex.size();
+        for (List<NamedExpression> row : constantExprsList) {
+            ArrayList<NamedExpression> newRow = new ArrayList<>(len);
+            for (int idx : extractColumnIndex) {
+                newRow.add(row.get(idx));
+            }
+            prunedConstantExprsList.add(newRow);
+        }
+
+        if (prunedOutputs.isEmpty()) {
+            List<NamedExpression> candidates = 
Lists.newArrayList(originOutput);
+            candidates.retainAll(keys);
+            if (candidates.isEmpty()) {
+                candidates = originOutput;
+            }
+            NamedExpression minimumColumn = 
ExpressionUtils.selectMinimumColumn(candidates);
+            prunedOutputs = ImmutableList.of(minimumColumn);
+        }
+
+        if (prunedOutputs.equals(originOutput)) {
+            return union;
+        } else {
+            return union.withNewOutputsAndConstExprsList(prunedOutputs, 
prunedConstantExprsList);
+        }
+    }
+
     private <P extends Plan> P pruneChildren(P plan) {
         return pruneChildren(plan, ImmutableSet.of());
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
new file mode 100644
index 00000000000..3bdfbc582ac
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
@@ -0,0 +1,211 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.PlanUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**MergeAggregate*/
+public class MergeAggregate implements RewriteRuleFactory {
+    private static final ImmutableSet<String> ALLOW_MERGE_AGGREGATE_FUNCTIONS =
+            ImmutableSet.of("min", "max", "sum", "any_value");
+    private Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc = new 
HashMap<>();
+
+    @Override
+    public List<Rule> buildRules() {
+        return ImmutableList.of(
+                
logicalAggregate(logicalAggregate()).when(this::canMergeAggregateWithoutProject)
+                        .then(this::mergeTwoAggregate)
+                        .toRule(RuleType.MERGE_AGGREGATE),
+                logicalAggregate(logicalProject(logicalAggregate()))
+                        .when(this::canMergeAggregateWithProject)
+                        .then(this::mergeAggProjectAgg)
+                        .toRule(RuleType.MERGE_AGGREGATE));
+    }
+
+    /**
+     * before:
+     * LogicalAggregate
+     *   +--LogicalAggregate
+     * after:
+     * LogicalAggregate
+     */
+    private Plan mergeTwoAggregate(LogicalAggregate<LogicalAggregate<Plan>> 
outerAgg) {
+        LogicalAggregate<Plan> innerAgg = outerAgg.child();
+
+        List<NamedExpression> newOutputExpressions = 
outerAgg.getOutputExpressions().stream()
+                .map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
+                .collect(Collectors.toList());
+        return 
outerAgg.withAggOutput(newOutputExpressions).withChildren(innerAgg.children());
+    }
+
+    /**
+     * before:
+     * LogicalAggregate (outputExpressions = [col2, sum(col1)], groupByKeys = 
[col2])
+     *   +--LogicalProject (projects = [a as col2, col1])
+     *     +--LogicalAggregate (outputExpressions = [a, b, sum(c) as col1], 
groupByKeys = [a,b])
+     * after:
+     * LogicalProject (projects = [a as col2, sum(col1) as sum(col1)]
+     *   +--LogicalAggregate (outputExpression = [a, sum(c) as sum(col1)], 
groupByKeys = [a])
+     */
+    private Plan 
mergeAggProjectAgg(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> 
outerAgg) {
+        LogicalProject<LogicalAggregate<Plan>> project = outerAgg.child();
+        LogicalAggregate<Plan> innerAgg = project.child();
+
+        // rewrite agg function. e.g. max(max)
+        List<NamedExpression> aggFunc = 
outerAgg.getOutputExpressions().stream()
+                .filter(expr -> (expr instanceof Alias) && (expr.child(0) 
instanceof AggregateFunction))
+                .map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
+                .collect(Collectors.toList());
+        // rewrite agg function directly refer to the slot below the project
+        List<Expression> replacedAggFunc = 
PlanUtils.replaceExpressionByProjections(project.getProjects(),
+                (List) aggFunc);
+        // replace groupByKeys directly refer to the slot below the project
+        List<Expression> replacedGroupBy = 
PlanUtils.replaceExpressionByProjections(project.getProjects(),
+                outerAgg.getGroupByExpressions());
+        List<NamedExpression> newOutputExpressions = 
ImmutableList.<NamedExpression>builder()
+                .addAll(replacedGroupBy.stream().map(slot -> (NamedExpression) 
slot).iterator())
+                .addAll(replacedAggFunc.stream().map(alias -> 
(NamedExpression) alias).iterator()).build();
+        // construct agg
+        LogicalAggregate<Plan> resAgg = 
outerAgg.withGroupByAndOutput(replacedGroupBy, newOutputExpressions)
+                .withChildren(innerAgg.children());
+
+        // construct upper project
+        Map<SlotReference, Alias> childToAlias = project.getProjects().stream()
+                .filter(expr -> (expr instanceof Alias) && (expr.child(0) 
instanceof SlotReference))
+                .collect(Collectors.toMap(alias -> (SlotReference) 
alias.child(0), alias -> (Alias) alias));
+        List<Expression> projectGroupBy = 
ExpressionUtils.replace(replacedGroupBy, childToAlias);
+        List<NamedExpression> upperProjects = 
ImmutableList.<NamedExpression>builder()
+                .addAll(projectGroupBy.stream().map(namedExpr -> 
(NamedExpression) namedExpr).iterator())
+                .addAll(replacedAggFunc.stream().map(expr -> 
((NamedExpression) expr).toSlot()).iterator())
+                .build();
+        return new LogicalProject<Plan>(upperProjects, resAgg);
+    }
+
+    private NamedExpression rewriteAggregateFunction(NamedExpression e,
+            Map<ExprId, AggregateFunction> innerAggExprIdToAggFunc) {
+        return (NamedExpression) e.rewriteDownShortCircuit(expr -> {
+            if (expr instanceof Alias && ((Alias) expr).child() instanceof 
AggregateFunction) {
+                Alias alias = (Alias) expr;
+                AggregateFunction aggFunc = (AggregateFunction) alias.child();
+                ExprId childExprId = ((SlotReference) 
aggFunc.child(0)).getExprId();
+                if (innerAggExprIdToAggFunc.containsKey(childExprId)) {
+                    return new Alias(alias.getExprId(), 
innerAggExprIdToAggFunc.get(childExprId),
+                            alias.getName());
+                } else {
+                    return expr;
+                }
+            } else {
+                return expr;
+            }
+        });
+    }
+
+    boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, 
LogicalAggregate<Plan> innerAgg,
+            boolean sameGroupBy) {
+        innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream()
+                .filter(expr -> (expr instanceof Alias) && (expr.child(0) 
instanceof AggregateFunction))
+                .collect(Collectors.toMap(NamedExpression::getExprId, value -> 
(AggregateFunction) value.child(0),
+                        (existValue, newValue) -> existValue));
+        Set<AggregateFunction> aggregateFunctions = 
outerAgg.getAggregateFunctions();
+        for (AggregateFunction outerFunc : aggregateFunctions) {
+            if 
(!(ALLOW_MERGE_AGGREGATE_FUNCTIONS.contains(outerFunc.getName()))) {
+                return false;
+            }
+            if (outerFunc.isDistinct() && !sameGroupBy) {
+                return false;
+            }
+            // not support outerAggFunc: sum(a+1),sum(a+b)
+            if (!(outerFunc.child(0) instanceof SlotReference)) {
+                return false;
+            }
+            ExprId childExprId = ((SlotReference) 
outerFunc.child(0)).getExprId();
+            if (innerAggExprIdToAggFunc.containsKey(childExprId)) {
+                AggregateFunction innerFunc = 
innerAggExprIdToAggFunc.get(childExprId);
+                if (innerFunc.isDistinct() && !sameGroupBy) {
+                    return false;
+                }
+                // support 
sum(sum),min(min),max(max),any_value(any_value),sum(count)
+                // sum(count) -> count() need outerAgg having group by keys 
(reason: nullable)
+                if (!(outerFunc.getName().equals("sum") && 
innerFunc.getName().equals("count")
+                        && !outerAgg.getGroupByExpressions().isEmpty())
+                        && !innerFunc.getName().equals(outerFunc.getName())) {
+                    return false;
+                }
+            } else {
+                // select a, max(b), min(b), any_value(b) from (select a,b 
from t1 group by a, b) group by a;
+                // equals select a, max(b), min(b), any_value(b) from t1 group 
by a;
+                if (!outerFunc.getName().equals("max")
+                        && !outerFunc.getName().equals("min")
+                        && !outerFunc.getName().equals("any_value")) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
+
+    private boolean 
canMergeAggregateWithoutProject(LogicalAggregate<LogicalAggregate<Plan>> 
outerAgg) {
+        LogicalAggregate<Plan> innerAgg = outerAgg.child();
+        if (!new 
HashSet<>(innerAgg.getGroupByExpressions()).containsAll(outerAgg.getGroupByExpressions()))
 {
+            return false;
+        }
+        boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == 
outerAgg.getGroupByExpressions().size());
+
+        return commonCheck(outerAgg, innerAgg, sameGroupBy);
+    }
+
+    private boolean 
canMergeAggregateWithProject(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>>
 outerAgg) {
+        LogicalProject<LogicalAggregate<Plan>> project = outerAgg.child();
+        LogicalAggregate<Plan> innerAgg = project.child();
+
+        List<Expression> outerAggGroupByKeys = 
PlanUtils.replaceExpressionByProjections(project.getProjects(),
+                outerAgg.getGroupByExpressions());
+        if (!new 
HashSet<>(innerAgg.getGroupByExpressions()).containsAll(outerAggGroupByKeys)) {
+            return false;
+        }
+        // project cannot have expressions like a+1
+        if (ExpressionUtils.anyMatch(project.getProjects(),
+                expr -> !(expr instanceof SlotReference) && !(expr instanceof 
Alias))) {
+            return false;
+        }
+        boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == 
outerAgg.getGroupByExpressions().size());
+        return commonCheck(outerAgg, innerAgg, sameGroupBy);
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java
index 3a88020ac94..dac6996c0ca 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java
@@ -165,6 +165,12 @@ public class LogicalUnion extends LogicalSetOperation 
implements Union, OutputPr
                 hasPushedFilter, Optional.empty(), Optional.empty(), children);
     }
 
+    public LogicalUnion withNewOutputsAndConstExprsList(List<NamedExpression> 
newOutputs,
+            List<List<NamedExpression>> constantExprsList) {
+        return new LogicalUnion(qualifier, newOutputs, regularChildrenOutputs, 
constantExprsList,
+                hasPushedFilter, Optional.empty(), Optional.empty(), children);
+    }
+
     public LogicalUnion withChildrenAndConstExprsList(List<Plan> children,
             List<List<SlotReference>> childrenOutputs, 
List<List<NamedExpression>> constantExprsList) {
         return new LogicalUnion(qualifier, outputs, childrenOutputs, 
constantExprsList, hasPushedFilter, children);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
index 3e8d5cd1d9f..a4e25e21418 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
@@ -100,6 +100,12 @@ public class PlanUtils {
         return ExpressionUtils.replaceNamedExpressions(parentProjects, 
replaceMap);
     }
 
+    public static List<Expression> 
replaceExpressionByProjections(List<NamedExpression> childProjects,
+            List<Expression> targetExpression) {
+        Map<Slot, Expression> replaceMap = 
ExpressionUtils.generateReplaceMap(childProjects);
+        return ExpressionUtils.replace(targetExpression, replaceMap);
+    }
+
     public static Plan skipProjectFilterLimit(Plan plan) {
         if (plan instanceof LogicalProject && ((LogicalProject<?>) 
plan).isAllSlots()
                 || plan instanceof LogicalFilter || plan instanceof 
LogicalLimit) {
diff --git 
a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out 
b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out
new file mode 100644
index 00000000000..ba5b127a56f
--- /dev/null
+++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out
@@ -0,0 +1,248 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !sumCount_empty_table --
+\N
+
+-- !maxMax_minMin_sumSum_sumCount --
+1      1       1       1
+2      2       2       1
+6      6       6       1
+7      2       20      5
+8      6       26      4
+8      8       8       1
+9      5       20      3
+
+-- !maxGroupKey_minGroupKey --
+\N     \N      6       6
+1      1       2       1
+2      2       3       3
+3      3       2       1
+4      4       2       2
+5      5       4       3
+7      7       6       6
+
+-- !agg_project_agg --
+\N     \N      \N      6       1
+1      1       1       20      5
+2      2       2       8       1
+3      3       3       20      3
+4      4       4       2       1
+5      5       5       26      4
+7      7       7       1       1
+
+-- !upper_plan_can_use_name --
+2
+3
+7
+8
+9
+9
+10
+
+-- !outer_agg_has_distinct_same_keys --
+1      1       1       1
+2      2       2       1
+4      2       6       2
+6      6       6       1
+6      6       6       1
+6      6       6       1
+7      3       14      3
+8      6       20      3
+8      8       8       1
+9      5       14      2
+
+-- !inner_agg_has_distinct_same_keys --
+1      1       1       1
+2      2       2       1
+4      2       6       2
+6      6       6       1
+6      6       6       1
+6      6       6       1
+7      3       14      3
+8      6       14      3
+8      8       8       1
+9      5       14      2
+
+-- !sumCount_empty_table_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----PhysicalDistribute[DistributionSpecGather]
+------hashAgg[LOCAL]
+--------PhysicalProject
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------PhysicalOlapScan[mal_test2]
+
+-- !agg_project_agg_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalDistribute[DistributionSpecGather]
+------PhysicalQuickSort[LOCAL_SORT]
+--------PhysicalProject
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------PhysicalOlapScan[mal_test1]
+
+-- !maxMax_minMin_sumSum_sumCount_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalDistribute[DistributionSpecGather]
+------PhysicalQuickSort[LOCAL_SORT]
+--------PhysicalProject
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------PhysicalOlapScan[mal_test1]
+
+-- !maxGroupKey_minGroupKey_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalDistribute[DistributionSpecGather]
+------PhysicalQuickSort[LOCAL_SORT]
+--------PhysicalProject
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------PhysicalOlapScan[mal_test1]
+
+-- !outer_agg_has_distinct_same_keys_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalDistribute[DistributionSpecGather]
+------PhysicalQuickSort[LOCAL_SORT]
+--------PhysicalProject
+----------hashAgg[LOCAL]
+------------PhysicalOlapScan[mal_test1]
+
+-- !inner_agg_has_distinct_same_keys_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalDistribute[DistributionSpecGather]
+------PhysicalQuickSort[LOCAL_SORT]
+--------PhysicalProject
+----------hashAgg[DISTINCT_LOCAL]
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------PhysicalOlapScan[mal_test1]
+
+-- !middle_project_has_expression_cannot_merge_shape1 --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalDistribute[DistributionSpecGather]
+------PhysicalQuickSort[LOCAL_SORT]
+--------PhysicalProject
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------hashAgg[LOCAL]
+--------------------PhysicalProject
+----------------------PhysicalOlapScan[mal_test1]
+
+-- !middle_project_has_expression_cannot_merge_shape2 --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalDistribute[DistributionSpecGather]
+------PhysicalQuickSort[LOCAL_SORT]
+--------PhysicalProject
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------hashAgg[LOCAL]
+--------------------PhysicalOlapScan[mal_test1]
+
+-- !maxGroupKey_minGroupKey_sumGroupKey_cannot_merge_shape --
+PhysicalResultSink
+--PhysicalDistribute[DistributionSpecGather]
+----PhysicalProject
+------hashAgg[GLOBAL]
+--------PhysicalDistribute[DistributionSpecHash]
+----------hashAgg[LOCAL]
+------------hashAgg[LOCAL]
+--------------PhysicalProject
+----------------PhysicalOlapScan[mal_test1]
+
+-- !maxMin_cannot_merge_shape --
+PhysicalResultSink
+--PhysicalDistribute[DistributionSpecGather]
+----PhysicalProject
+------hashAgg[GLOBAL]
+--------PhysicalDistribute[DistributionSpecHash]
+----------hashAgg[LOCAL]
+------------PhysicalProject
+--------------hashAgg[LOCAL]
+----------------PhysicalOlapScan[mal_test1]
+
+-- !group_key_not_contain_cannot_merge_shape --
+PhysicalResultSink
+--PhysicalDistribute[DistributionSpecGather]
+----PhysicalProject
+------hashAgg[GLOBAL]
+--------PhysicalDistribute[DistributionSpecHash]
+----------hashAgg[LOCAL]
+------------PhysicalProject
+--------------hashAgg[LOCAL]
+----------------PhysicalOlapScan[mal_test1]
+
+-- !outer_agg_has_distinct_cannot_merge_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalDistribute[DistributionSpecGather]
+------PhysicalQuickSort[LOCAL_SORT]
+--------PhysicalProject
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------hashAgg[LOCAL]
+--------------------PhysicalOlapScan[mal_test1]
+
+-- !inner_agg_has_distinct_cannot_merge_shape --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalDistribute[DistributionSpecGather]
+------PhysicalQuickSort[LOCAL_SORT]
+--------PhysicalProject
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------hashAgg[DISTINCT_LOCAL]
+--------------------hashAgg[GLOBAL]
+----------------------hashAgg[LOCAL]
+------------------------PhysicalOlapScan[mal_test1]
+
+-- !agg_with_expr_cannot_merge_shape1 --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalDistribute[DistributionSpecGather]
+------PhysicalQuickSort[LOCAL_SORT]
+--------PhysicalProject
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------hashAgg[LOCAL]
+--------------------PhysicalProject
+----------------------PhysicalOlapScan[mal_test1]
+
+-- !agg_with_expr_cannot_merge_shape2 --
+PhysicalResultSink
+--PhysicalQuickSort[MERGE_SORT]
+----PhysicalDistribute[DistributionSpecGather]
+------PhysicalQuickSort[LOCAL_SORT]
+--------PhysicalProject
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------hashAgg[LOCAL]
+--------------------PhysicalProject
+----------------------PhysicalOlapScan[mal_test1]
+
diff --git 
a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
 
b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
new file mode 100644
index 00000000000..44c256e2f57
--- /dev/null
+++ 
b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
@@ -0,0 +1,177 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+suite("merge_aggregate") {
+    sql "SET enable_nereids_planner=true"
+    sql "SET enable_fallback_to_original_planner=false"
+    sql """
+          DROP TABLE IF EXISTS mal_test1
+         """
+
+    sql """
+         create table mal_test1(pk int, a int, b int) distributed by hash(pk) 
buckets 10
+         properties('replication_num' = '1'); 
+         """
+
+    sql """
+         insert into mal_test1 
values(2,1,3),(1,1,2),(3,5,6),(6,null,6),(4,5,6),(2,1,4),(2,3,5),(1,1,4)
+        
,(3,5,6),(3,5,null),(6,7,1),(2,1,7),(2,4,2),(2,3,9),(1,3,6),(3,5,8),(3,2,8);
+      """
+    sql "drop table if exists mal_test2"
+    sql """
+        create table mal_test2(pk int, a int, b int) distributed by hash(pk) 
buckets 10
+        properties('replication_num' = '1');
+    """
+
+    sql "sync"
+
+
+    qt_sumCount_empty_table """
+        select sum(col) from (select count(a) col from mal_test2 group by a) t;
+    """
+
+    qt_maxMax_minMin_sumSum_sumCount """
+        select max(col1), min(col2), sum(col3), sum(col4) from 
+        (select pk,a,max(b) as col1, min(b) as col2, sum(b) as col3, count(b) 
as col4 
+        from mal_test1 group by pk,a) t group by a order by 1,2,3,4;
+     """
+
+    qt_maxGroupKey_minGroupKey """
+        select max(a),min(a),max(pk),min(pk) from 
+        (select pk,a from mal_test1 group by pk,a) t 
+        group by a order by 1,2,3,4;
+    """
+
+    qt_agg_project_agg """
+        select col2, max(col2),min(col2),sum(col3),sum(col4) from 
+        (select pk as col1,a as col2,sum(b) col3, count(b) col4 from mal_test1 
group by pk,a) t 
+        group by col2 order by 1,2,3,4;
+    """
+
+    qt_upper_plan_can_use_name """
+        select c1+1 from (
+        select max(col1) c1, min(col2) c2, sum(col3) c3, sum(col4) c4 from 
+        (select pk,a,max(b) as col1, min(b) as col2, sum(b) as col3, count(b) 
as col4 from mal_test1 group by pk,a) t 
+        group by a order by 1,2,3,4) outert order by 1;
+    """
+
+    qt_outer_agg_has_distinct_same_keys """
+        select max(col1), min(col2), sum(col3), sum(DISTINCT col4) from 
+        (select pk,a,max(b) as col1, min(b) as col2, sum(b) as col3, count(b) 
as col4 from mal_test1 group by pk,a) t 
+        group by pk,a order by 1,2,3,4;    
+    """
+
+    qt_inner_agg_has_distinct_same_keys """
+        select max(col1), min(col2), sum(col3), sum(col4) from 
+        (select pk,a,max(b) as col1, min(b) as col2, sum(distinct b) as col3, 
count(b) as col4 from mal_test1 group by pk,a) t 
+        group by a,pk order by 1,2,3,4;
+    """
+
+    qt_sumCount_empty_table_shape """
+        explain shape plan select sum(col) from (select count(a) col from 
mal_test2 group by a) t;
+    """
+
+    qt_agg_project_agg_shape """
+        explain shape plan select max(col2),min(col2),sum(col3),sum(col4) from 
+        (select pk as col1,a as col2,sum(b) col3, count(b) col4 from mal_test1 
group by pk,a) t 
+        group by col2 order by 1,2,3,4;
+    """
+
+    qt_maxMax_minMin_sumSum_sumCount_shape """
+        explain shape plan select max(col1), min(col2), sum(col3), sum(col4) 
from 
+        (select pk,a,max(b) as col1, min(b) as col2, sum(b) as col3, count(b) 
as col4 
+        from mal_test1 group by pk,a) t group by a order by 1,2,3,4;
+     """
+
+    qt_maxGroupKey_minGroupKey_shape """
+        explain shape plan select max(a),min(a),max(pk),min(pk) from 
+        (select pk,a from mal_test1 group by pk,a) t 
+        group by a order by 1,2,3,4;
+    """
+
+    qt_outer_agg_has_distinct_same_keys_shape """
+        explain shape plan
+        select max(col1), min(col2), sum(col3), sum(DISTINCT col4) from 
+        (select pk,a,max(b) as col1, min(b) as col2, sum(b) as col3, count(b) 
as col4 from mal_test1 group by pk,a) t 
+        group by pk,a order by 1,2,3,4;    
+    """
+
+    qt_inner_agg_has_distinct_same_keys_shape """
+        explain shape plan
+        select max(col1), min(col2), sum(col3), sum(col4) from 
+        (select pk,a,max(b) as col1, min(b) as col2, sum(distinct b) as col3, 
count(b) as col4 from mal_test1 group by pk,a) t 
+        group by a,pk order by 1,2,3,4;
+    """
+
+    qt_middle_project_has_expression_cannot_merge_shape1 """
+        explain shape plan 
+        select max(col1),min(col1) from 
+        (select pk+1 as col1,a from mal_test1 group by pk,a) t 
+        group by a order by 1,2;
+    """
+
+    qt_middle_project_has_expression_cannot_merge_shape2 """
+        explain shape plan
+        select max(col1), min(col2), sum(col3), sum(col4) from
+        (select pk,a,max(b)+1 as col1, min(b) as col2, sum(b) as col3, 
count(b) as col4 from mal_test1 group by pk,a) t
+        group by a order by 1,2,3,4;
+    """
+
+    qt_maxGroupKey_minGroupKey_sumGroupKey_cannot_merge_shape """
+        explain shape plan select max(a),min(a),max(pk),min(pk),sum(pk) from 
+        (select pk,a from mal_test1 group by pk,a) t 
+        group by a;
+    """
+
+    qt_maxMin_cannot_merge_shape """
+        explain shape plan select max(col), max(col2) from 
+        (select pk,a,min(b) col,max(b) col2 from mal_test1 group by pk,a) t 
+        group by a;
+    """
+
+    qt_group_key_not_contain_cannot_merge_shape """
+        explain shape plan select max(col2) from 
+        (select pk,a,max(b) col2 from mal_test1 group by pk,a) t 
+        group by a,col2;
+    """
+
+    qt_outer_agg_has_distinct_cannot_merge_shape """
+        explain shape plan
+        select max(col1), min(col2), sum(col3), sum(DISTINCT col4) from 
+        (select pk,a,max(b) as col1, min(b) as col2, sum(b) as col3, count(b) 
as col4 from mal_test1 group by pk,a) t 
+        group by a order by 1,2,3,4;    
+    """
+
+    qt_inner_agg_has_distinct_cannot_merge_shape """
+        explain shape plan
+        select max(col1), min(col2), sum(col3), sum(col4) from 
+        (select pk,a,max(b) as col1, min(b) as col2, sum(distinct b) as col3, 
count(b) as col4 from mal_test1 group by pk,a) t 
+        group by a order by 1,2,3,4;
+    """
+
+    qt_agg_with_expr_cannot_merge_shape1 """
+        explain shape plan select max(col1+a),min(col1) from 
+        (select pk as col1, a from mal_test1 group by pk,a) t 
+        group by a order by 1,2;
+    """
+
+    qt_agg_with_expr_cannot_merge_shape2 """
+        explain shape plan select max(col1+1),min(col1) from 
+        (select pk as col1, a from mal_test1 group by pk,a) t 
+        group by a order by 1,2;
+    """
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to