morrySnow commented on code in PR #12159:
URL: https://github.com/apache/doris/pull/12159#discussion_r957976831


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
-    @Override
-    public Rule build() {
-        return logicalAggregate().when(agg -> 
!agg.isDisassembled()).thenApply(ctx -> {
-            LogicalAggregate<GroupPlan> aggregate = ctx.root;
-            List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
-            List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
-
-            // 1. generate a map from local aggregate output to global 
aggregate expr substitution.
-            //    inputSubstitutionMap use for replacing expression in global 
aggregate
-            //    replace rule is:
-            //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
-            //        b. Expression is a group by key and is an expression. 
e.g. group by k1 + 1
-            //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | situation | origin expression   | local output expression 
| expression in global aggregate |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | a         | Ref(k1)#1           | Ref(k1)#1               
| Ref(k1)#1                      |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 
| Ref(key)#2                     |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     
| AF(af#3)                       |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, 
#x: ExprId x
-            // 2. collect local aggregate output expressions and local 
aggregate group by expression list
-            Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
-            List<Expression> localGroupByExprs = 
aggregate.getGroupByExpressions();
-            List<NamedExpression> localOutputExprs = Lists.newArrayList();
-            for (Expression originGroupByExpr : originGroupByExprs) {
-                if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+    private Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
+    private List<NamedExpression> distinctAggFunctionParams = new 
ArrayList<>();
+    private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+    private List<NamedExpression> distinctOriginOutputExprs = new 
ArrayList<>();

Review Comment:
   ```suggestion
       private final Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
       private final List<NamedExpression> distinctAggFunctionParams = 
Lists.newArrayList();
       private final List<AggregateFunction> distinctAggFunctions = 
Lists.newArrayList();
       private final List<NamedExpression> distinctOriginOutputExprs = 
Lists.newArrayList();
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
-    @Override
-    public Rule build() {
-        return logicalAggregate().when(agg -> 
!agg.isDisassembled()).thenApply(ctx -> {
-            LogicalAggregate<GroupPlan> aggregate = ctx.root;
-            List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
-            List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
-
-            // 1. generate a map from local aggregate output to global 
aggregate expr substitution.
-            //    inputSubstitutionMap use for replacing expression in global 
aggregate
-            //    replace rule is:
-            //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
-            //        b. Expression is a group by key and is an expression. 
e.g. group by k1 + 1
-            //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | situation | origin expression   | local output expression 
| expression in global aggregate |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | a         | Ref(k1)#1           | Ref(k1)#1               
| Ref(k1)#1                      |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 
| Ref(key)#2                     |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     
| AF(af#3)                       |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, 
#x: ExprId x
-            // 2. collect local aggregate output expressions and local 
aggregate group by expression list
-            Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
-            List<Expression> localGroupByExprs = 
aggregate.getGroupByExpressions();
-            List<NamedExpression> localOutputExprs = Lists.newArrayList();
-            for (Expression originGroupByExpr : originGroupByExprs) {
-                if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+    private Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
+    private List<NamedExpression> distinctAggFunctionParams = new 
ArrayList<>();
+    private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+    private List<NamedExpression> distinctOriginOutputExprs = new 
ArrayList<>();
+
+    // only support distinct function with group by
+    // TODO: support distinct function without group by. (add second global 
phase)
+    private LogicalAggregate secondDisassemble(LogicalAggregate<GroupPlan> 
aggregate) {
+        // origin sql: select count(distinct a) from t1 group by b;
+        // global agg: select a from t1 group by b, a;
+        // second local agg: select count(distinct a) from t1 group by b;
+        // In order to get the second local agg from global agg:
+        // 1. the distinct expression needs to be removed from the output and 
the group by of global agg
+        // 2. add distinct agg function back to output
+        List<NamedExpression> originFirstGlobalOutputExprs = 
aggregate.getOutputExpressions();
+        List<Expression> originFirstGlobalGroupByExprs = 
aggregate.getGroupByExpressions();
+        List<NamedExpression> secondLocalOutputExprs = new 
ArrayList<>(aggregate.getOutputExpressions());
+        // add origin distinct function back
+        secondLocalOutputExprs.addAll(distinctOriginOutputExprs);
+        List<Expression> secondLocalGroupByExprs = new 
ArrayList<>(aggregate.getGroupByExpressions());
+
+        List<Expression> distinctExprs = new ArrayList<>();
+        // remove distinct param exprs from secondLocalGroupByExprs and 
secondLocalOutputExprs
+        for (NamedExpression expression : distinctAggFunctionParams) {
+            
secondLocalGroupByExprs.remove(inputSubstitutionMap.get(expression));
+            
secondLocalOutputExprs.remove(inputSubstitutionMap.get(expression));
+            distinctExprs.add(inputSubstitutionMap.get(expression));
+        }
+
+        Map<Expression, Expression> secondSubstitutionMap = Maps.newHashMap();
+
+        // replace the original slot reference with the latest one
+        Expression distinctAgg = 
distinctAggFunctions.get(0).withChildren(distinctExprs);
+        secondSubstitutionMap.put(distinctAggFunctions.get(0), distinctAgg);
+        List<NamedExpression> secondLocalOutputNamedExprs = 
secondLocalOutputExprs.stream()
+                .map(e -> ExpressionReplacer.INSTANCE.visit(e, 
secondSubstitutionMap))
+                .map(NamedExpression.class::cast)
+                .collect(Collectors.toList());
+
+        // firstGlobalOutputExprs = originOutputExprs + originGroupByExprs
+        List<NamedExpression> firstGlobalOutputExprs = new 
ArrayList<>(originFirstGlobalOutputExprs);
+        for (Expression originGroupByExpr : originFirstGlobalGroupByExprs) {
+            if (firstGlobalOutputExprs.contains(originGroupByExpr)) {
+                continue;
+            }
+            if (originGroupByExpr instanceof SlotReference) {
+                firstGlobalOutputExprs.add((SlotReference) originGroupByExpr);
+            } else {
+                Preconditions.checkState(false);
+            }
+        }
+
+        // generate new plan
+        LogicalAggregate globalAggregate = new LogicalAggregate<>(
+                originFirstGlobalGroupByExprs,
+                firstGlobalOutputExprs,
+                true,
+                aggregate.isNormalized(),
+                AggPhase.GLOBAL,
+                aggregate.child()
+        );
+
+        return new LogicalAggregate<>(
+                secondLocalGroupByExprs,
+                secondLocalOutputNamedExprs,
+                true,
+                aggregate.isNormalized(),
+                AggPhase.DISTINCT_LOCAL,
+                globalAggregate
+        );
+    }
+
+    private LogicalAggregate firstDisassemble(LogicalAggregate<GroupPlan> 
aggregate) {
+        List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
+        List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
+        // 1. generate a map from local aggregate output to global aggregate 
expr substitution.
+        //    inputSubstitutionMap use for replacing expression in global 
aggregate
+        //    replace rule is:
+        //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
+        //        b. Expression is a group by key and is an expression. e.g. 
group by k1 + 1
+        //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | situation | origin expression   | local output expression | 
expression in global aggregate |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | a         | Ref(k1)#1           | Ref(k1)#1               | 
Ref(k1)#1                      |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 | 
Ref(key)#2                     |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     | 
AF(af#3)                       |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: 
ExprId x
+        // 2. collect local aggregate output expressions and local aggregate 
group by expression list
+        List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
+        List<NamedExpression> localOutputExprs = Lists.newArrayList();
+        for (Expression originGroupByExpr : originGroupByExprs) {
+            if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+                continue;
+            }
+            if (originGroupByExpr instanceof SlotReference) {
+                inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
+                localOutputExprs.add((SlotReference) originGroupByExpr);
+            } else {
+                NamedExpression localOutputExpr = new Alias(originGroupByExpr, 
originGroupByExpr.toSql());
+                inputSubstitutionMap.put(originGroupByExpr, 
localOutputExpr.toSlot());
+                localOutputExprs.add(localOutputExpr);
+            }
+        }
+        for (NamedExpression originOutputExpr : originOutputExprs) {
+            List<AggregateFunction> aggregateFunctions
+                    = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
+            for (AggregateFunction aggregateFunction : aggregateFunctions) {
+                if (inputSubstitutionMap.containsKey(aggregateFunction)) {
                     continue;
                 }
-                if (originGroupByExpr instanceof SlotReference) {
-                    inputSubstitutionMap.put(originGroupByExpr, 
originGroupByExpr);
-                    localOutputExprs.add((SlotReference) originGroupByExpr);
-                } else {
-                    NamedExpression localOutputExpr = new 
Alias(originGroupByExpr, originGroupByExpr.toSql());
-                    inputSubstitutionMap.put(originGroupByExpr, 
localOutputExpr.toSlot());
-                    localOutputExprs.add(localOutputExpr);
-                }
+                NamedExpression localOutputExpr = new Alias(aggregateFunction, 
aggregateFunction.toSql());
+                Expression substitutionValue = aggregateFunction.withChildren(
+                        Lists.newArrayList(localOutputExpr.toSlot()));
+                inputSubstitutionMap.put(aggregateFunction, substitutionValue);
+                localOutputExprs.add(localOutputExpr);
             }
-            for (NamedExpression originOutputExpr : originOutputExprs) {
-                List<AggregateFunction> aggregateFunctions
-                        = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
-                for (AggregateFunction aggregateFunction : aggregateFunctions) 
{
-                    if (inputSubstitutionMap.containsKey(aggregateFunction)) {
-                        continue;
-                    }
-                    NamedExpression localOutputExpr = new 
Alias(aggregateFunction, aggregateFunction.toSql());
-                    Expression substitutionValue = 
aggregateFunction.withChildren(
-                            Lists.newArrayList(localOutputExpr.toSlot()));
-                    inputSubstitutionMap.put(aggregateFunction, 
substitutionValue);
-                    localOutputExprs.add(localOutputExpr);
+        }
+
+        // 3. replace expression in globalOutputExprs and globalGroupByExprs
+        List<NamedExpression> globalOutputExprs = 
aggregate.getOutputExpressions().stream()
+                .map(e -> ExpressionReplacer.INSTANCE.visit(e, 
inputSubstitutionMap))
+                .map(NamedExpression.class::cast)
+                .collect(Collectors.toList());
+        List<Expression> globalGroupByExprs = localGroupByExprs.stream()
+                .map(e -> ExpressionReplacer.INSTANCE.visit(e, 
inputSubstitutionMap)).collect(Collectors.toList());
+
+        // 4. generate new plan
+        LogicalAggregate localAggregate = new LogicalAggregate<>(
+                localGroupByExprs,
+                localOutputExprs,
+                true,
+                aggregate.isNormalized(),
+                AggPhase.LOCAL,
+                aggregate.child()
+        );
+        return new LogicalAggregate<>(
+                globalGroupByExprs,
+                globalOutputExprs,
+                true,
+                aggregate.isNormalized(),
+                AggPhase.GLOBAL,
+                localAggregate
+        );
+    }
+
+    private void 
moveDistinctExprFromOutputToGroupBy(LogicalAggregate<GroupPlan> aggregate) {
+        // for example:
+        // select count(distinct a) from t1 group by b;
+        // => select a from t1 group by b, a;
+        List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
+        List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
+
+        for (NamedExpression originOutputExpr : originOutputExprs) {
+            List<AggregateFunction> aggregateFunctions =
+                    
originOutputExpr.collect(AggregateFunction.class::isInstance);
+            for (AggregateFunction aggregateFunction : aggregateFunctions) {
+                if (aggregateFunction.isDistinct()) {
+                    distinctAggFunctions.add(aggregateFunction);
+                    distinctOriginOutputExprs.add(originOutputExpr);
                 }
             }
+        }
+        if (!distinctAggFunctions.isEmpty()) {
+            for (Expression expr : distinctAggFunctions.get(0).children()) {

Review Comment:
   why only get(0)? pls add some comments



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java:
##########
@@ -63,6 +63,16 @@ public class PlanTranslatorContext {
 
     private final IdGenerator<PlanNodeId> nodeIdGenerator = 
PlanNodeId.createGenerator();
 
+    private boolean hasDistinctAgg = false;
+
+    public boolean hasDistinctAgg() {

Review Comment:
   should not use a global flag to do that, instead we should put it into 
PhysicalAggregate or just return true if AggregateFunction in Aggregate has 
distinct



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -230,6 +246,13 @@ public PlanFragment 
visitPhysicalAggregate(PhysicalAggregate<Plan> aggregate, Pl
             case GLOBAL:
                 inputPlanFragment.updateDataPartition(mergePartition);
                 return inputPlanFragment;
+            case DISTINCT_LOCAL:
+                AggregationNode globalAggNode = (AggregationNode) 
aggregationNode.getChild(0);
+                globalAggNode.unsetNeedsFinalize();
+                globalAggNode.setIntermediateTuple();
+                inputPlanFragment.updateDataPartition(mergePartition);

Review Comment:
   when we do distinct agg, the partition exprs is different from group by 
exprs, where to process this?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
-    @Override
-    public Rule build() {
-        return logicalAggregate().when(agg -> 
!agg.isDisassembled()).thenApply(ctx -> {
-            LogicalAggregate<GroupPlan> aggregate = ctx.root;
-            List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
-            List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
-
-            // 1. generate a map from local aggregate output to global 
aggregate expr substitution.
-            //    inputSubstitutionMap use for replacing expression in global 
aggregate
-            //    replace rule is:
-            //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
-            //        b. Expression is a group by key and is an expression. 
e.g. group by k1 + 1
-            //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | situation | origin expression   | local output expression 
| expression in global aggregate |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | a         | Ref(k1)#1           | Ref(k1)#1               
| Ref(k1)#1                      |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 
| Ref(key)#2                     |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     
| AF(af#3)                       |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, 
#x: ExprId x
-            // 2. collect local aggregate output expressions and local 
aggregate group by expression list
-            Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
-            List<Expression> localGroupByExprs = 
aggregate.getGroupByExpressions();
-            List<NamedExpression> localOutputExprs = Lists.newArrayList();
-            for (Expression originGroupByExpr : originGroupByExprs) {
-                if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+    private Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
+    private List<NamedExpression> distinctAggFunctionParams = new 
ArrayList<>();
+    private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+    private List<NamedExpression> distinctOriginOutputExprs = new 
ArrayList<>();
+
+    // only support distinct function with group by
+    // TODO: support distinct function without group by. (add second global 
phase)
+    private LogicalAggregate secondDisassemble(LogicalAggregate<GroupPlan> 
aggregate) {
+        // origin sql: select count(distinct a) from t1 group by b;
+        // global agg: select a from t1 group by b, a;
+        // second local agg: select count(distinct a) from t1 group by b;
+        // In order to get the second local agg from global agg:
+        // 1. the distinct expression needs to be removed from the output and 
the group by of global agg
+        // 2. add distinct agg function back to output
+        List<NamedExpression> originFirstGlobalOutputExprs = 
aggregate.getOutputExpressions();

Review Comment:
   we'd better use phase name same as AggPhase i.e.
   - LOCAL
   - GLOBAL
   - DISTINCT_LOCAL
   - DISTINCT_GLOBAL



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
-    @Override
-    public Rule build() {
-        return logicalAggregate().when(agg -> 
!agg.isDisassembled()).thenApply(ctx -> {
-            LogicalAggregate<GroupPlan> aggregate = ctx.root;
-            List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
-            List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
-
-            // 1. generate a map from local aggregate output to global 
aggregate expr substitution.
-            //    inputSubstitutionMap use for replacing expression in global 
aggregate
-            //    replace rule is:
-            //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
-            //        b. Expression is a group by key and is an expression. 
e.g. group by k1 + 1
-            //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | situation | origin expression   | local output expression 
| expression in global aggregate |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | a         | Ref(k1)#1           | Ref(k1)#1               
| Ref(k1)#1                      |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 
| Ref(key)#2                     |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     
| AF(af#3)                       |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, 
#x: ExprId x
-            // 2. collect local aggregate output expressions and local 
aggregate group by expression list
-            Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
-            List<Expression> localGroupByExprs = 
aggregate.getGroupByExpressions();
-            List<NamedExpression> localOutputExprs = Lists.newArrayList();
-            for (Expression originGroupByExpr : originGroupByExprs) {
-                if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+    private Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
+    private List<NamedExpression> distinctAggFunctionParams = new 
ArrayList<>();
+    private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+    private List<NamedExpression> distinctOriginOutputExprs = new 
ArrayList<>();
+
+    // only support distinct function with group by
+    // TODO: support distinct function without group by. (add second global 
phase)
+    private LogicalAggregate secondDisassemble(LogicalAggregate<GroupPlan> 
aggregate) {
+        // origin sql: select count(distinct a) from t1 group by b;
+        // global agg: select a from t1 group by b, a;
+        // second local agg: select count(distinct a) from t1 group by b;
+        // In order to get the second local agg from global agg:
+        // 1. the distinct expression needs to be removed from the output and 
the group by of global agg
+        // 2. add distinct agg function back to output
+        List<NamedExpression> originFirstGlobalOutputExprs = 
aggregate.getOutputExpressions();
+        List<Expression> originFirstGlobalGroupByExprs = 
aggregate.getGroupByExpressions();
+        List<NamedExpression> secondLocalOutputExprs = new 
ArrayList<>(aggregate.getOutputExpressions());
+        // add origin distinct function back
+        secondLocalOutputExprs.addAll(distinctOriginOutputExprs);
+        List<Expression> secondLocalGroupByExprs = new 
ArrayList<>(aggregate.getGroupByExpressions());
+
+        List<Expression> distinctExprs = new ArrayList<>();
+        // remove distinct param exprs from secondLocalGroupByExprs and 
secondLocalOutputExprs
+        for (NamedExpression expression : distinctAggFunctionParams) {
+            
secondLocalGroupByExprs.remove(inputSubstitutionMap.get(expression));
+            
secondLocalOutputExprs.remove(inputSubstitutionMap.get(expression));

Review Comment:
   secondLocalOutputExprs has NamedExpression, such as Alias and SlotReference, 
but i don't think key of inputSubstitutionMap has the same Alias



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
-    @Override
-    public Rule build() {
-        return logicalAggregate().when(agg -> 
!agg.isDisassembled()).thenApply(ctx -> {
-            LogicalAggregate<GroupPlan> aggregate = ctx.root;
-            List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
-            List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
-
-            // 1. generate a map from local aggregate output to global 
aggregate expr substitution.
-            //    inputSubstitutionMap use for replacing expression in global 
aggregate
-            //    replace rule is:
-            //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
-            //        b. Expression is a group by key and is an expression. 
e.g. group by k1 + 1
-            //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | situation | origin expression   | local output expression 
| expression in global aggregate |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | a         | Ref(k1)#1           | Ref(k1)#1               
| Ref(k1)#1                      |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 
| Ref(key)#2                     |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     
| AF(af#3)                       |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, 
#x: ExprId x
-            // 2. collect local aggregate output expressions and local 
aggregate group by expression list
-            Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
-            List<Expression> localGroupByExprs = 
aggregate.getGroupByExpressions();
-            List<NamedExpression> localOutputExprs = Lists.newArrayList();
-            for (Expression originGroupByExpr : originGroupByExprs) {
-                if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+    private Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
+    private List<NamedExpression> distinctAggFunctionParams = new 
ArrayList<>();
+    private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+    private List<NamedExpression> distinctOriginOutputExprs = new 
ArrayList<>();
+
+    // only support distinct function with group by
+    // TODO: support distinct function without group by. (add second global 
phase)
+    private LogicalAggregate secondDisassemble(LogicalAggregate<GroupPlan> 
aggregate) {
+        // origin sql: select count(distinct a) from t1 group by b;
+        // global agg: select a from t1 group by b, a;
+        // second local agg: select count(distinct a) from t1 group by b;
+        // In order to get the second local agg from global agg:
+        // 1. the distinct expression needs to be removed from the output and 
the group by of global agg
+        // 2. add distinct agg function back to output
+        List<NamedExpression> originFirstGlobalOutputExprs = 
aggregate.getOutputExpressions();
+        List<Expression> originFirstGlobalGroupByExprs = 
aggregate.getGroupByExpressions();
+        List<NamedExpression> secondLocalOutputExprs = new 
ArrayList<>(aggregate.getOutputExpressions());
+        // add origin distinct function back
+        secondLocalOutputExprs.addAll(distinctOriginOutputExprs);

Review Comment:
   secondLocalOutputExprs generated from aggregate output expressions, i think 
it already has origin distinct function?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {

Review Comment:
   pls update this class's java doc comments, add distinct situation



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
-    @Override
-    public Rule build() {
-        return logicalAggregate().when(agg -> 
!agg.isDisassembled()).thenApply(ctx -> {
-            LogicalAggregate<GroupPlan> aggregate = ctx.root;
-            List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
-            List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
-
-            // 1. generate a map from local aggregate output to global 
aggregate expr substitution.
-            //    inputSubstitutionMap use for replacing expression in global 
aggregate
-            //    replace rule is:
-            //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
-            //        b. Expression is a group by key and is an expression. 
e.g. group by k1 + 1
-            //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | situation | origin expression   | local output expression 
| expression in global aggregate |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | a         | Ref(k1)#1           | Ref(k1)#1               
| Ref(k1)#1                      |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 
| Ref(key)#2                     |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     
| AF(af#3)                       |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, 
#x: ExprId x
-            // 2. collect local aggregate output expressions and local 
aggregate group by expression list
-            Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
-            List<Expression> localGroupByExprs = 
aggregate.getGroupByExpressions();
-            List<NamedExpression> localOutputExprs = Lists.newArrayList();
-            for (Expression originGroupByExpr : originGroupByExprs) {
-                if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+    private Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
+    private List<NamedExpression> distinctAggFunctionParams = new 
ArrayList<>();
+    private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+    private List<NamedExpression> distinctOriginOutputExprs = new 
ArrayList<>();
+
+    // only support distinct function with group by
+    // TODO: support distinct function without group by. (add second global 
phase)
+    private LogicalAggregate secondDisassemble(LogicalAggregate<GroupPlan> 
aggregate) {
+        // origin sql: select count(distinct a) from t1 group by b;
+        // global agg: select a from t1 group by b, a;
+        // second local agg: select count(distinct a) from t1 group by b;
+        // In order to get the second local agg from global agg:
+        // 1. the distinct expression needs to be removed from the output and 
the group by of global agg
+        // 2. add distinct agg function back to output
+        List<NamedExpression> originFirstGlobalOutputExprs = 
aggregate.getOutputExpressions();
+        List<Expression> originFirstGlobalGroupByExprs = 
aggregate.getGroupByExpressions();
+        List<NamedExpression> secondLocalOutputExprs = new 
ArrayList<>(aggregate.getOutputExpressions());
+        // add origin distinct function back
+        secondLocalOutputExprs.addAll(distinctOriginOutputExprs);
+        List<Expression> secondLocalGroupByExprs = new 
ArrayList<>(aggregate.getGroupByExpressions());
+
+        List<Expression> distinctExprs = new ArrayList<>();
+        // remove distinct param exprs from secondLocalGroupByExprs and 
secondLocalOutputExprs
+        for (NamedExpression expression : distinctAggFunctionParams) {
+            
secondLocalGroupByExprs.remove(inputSubstitutionMap.get(expression));
+            
secondLocalOutputExprs.remove(inputSubstitutionMap.get(expression));
+            distinctExprs.add(inputSubstitutionMap.get(expression));
+        }
+
+        Map<Expression, Expression> secondSubstitutionMap = Maps.newHashMap();
+
+        // replace the original slot reference with the latest one
+        Expression distinctAgg = 
distinctAggFunctions.get(0).withChildren(distinctExprs);
+        secondSubstitutionMap.put(distinctAggFunctions.get(0), distinctAgg);
+        List<NamedExpression> secondLocalOutputNamedExprs = 
secondLocalOutputExprs.stream()
+                .map(e -> ExpressionReplacer.INSTANCE.visit(e, 
secondSubstitutionMap))
+                .map(NamedExpression.class::cast)
+                .collect(Collectors.toList());
+
+        // firstGlobalOutputExprs = originOutputExprs + originGroupByExprs
+        List<NamedExpression> firstGlobalOutputExprs = new 
ArrayList<>(originFirstGlobalOutputExprs);
+        for (Expression originGroupByExpr : originFirstGlobalGroupByExprs) {
+            if (firstGlobalOutputExprs.contains(originGroupByExpr)) {
+                continue;
+            }
+            if (originGroupByExpr instanceof SlotReference) {
+                firstGlobalOutputExprs.add((SlotReference) originGroupByExpr);
+            } else {
+                Preconditions.checkState(false);
+            }
+        }
+
+        // generate new plan
+        LogicalAggregate globalAggregate = new LogicalAggregate<>(
+                originFirstGlobalGroupByExprs,
+                firstGlobalOutputExprs,
+                true,
+                aggregate.isNormalized(),
+                AggPhase.GLOBAL,
+                aggregate.child()
+        );
+
+        return new LogicalAggregate<>(
+                secondLocalGroupByExprs,
+                secondLocalOutputNamedExprs,
+                true,
+                aggregate.isNormalized(),
+                AggPhase.DISTINCT_LOCAL,
+                globalAggregate
+        );
+    }
+
+    private LogicalAggregate firstDisassemble(LogicalAggregate<GroupPlan> 
aggregate) {
+        List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
+        List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
+        // 1. generate a map from local aggregate output to global aggregate 
expr substitution.
+        //    inputSubstitutionMap use for replacing expression in global 
aggregate
+        //    replace rule is:
+        //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
+        //        b. Expression is a group by key and is an expression. e.g. 
group by k1 + 1
+        //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | situation | origin expression   | local output expression | 
expression in global aggregate |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | a         | Ref(k1)#1           | Ref(k1)#1               | 
Ref(k1)#1                      |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 | 
Ref(key)#2                     |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     | 
AF(af#3)                       |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: 
ExprId x
+        // 2. collect local aggregate output expressions and local aggregate 
group by expression list
+        List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
+        List<NamedExpression> localOutputExprs = Lists.newArrayList();
+        for (Expression originGroupByExpr : originGroupByExprs) {
+            if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+                continue;
+            }
+            if (originGroupByExpr instanceof SlotReference) {
+                inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
+                localOutputExprs.add((SlotReference) originGroupByExpr);
+            } else {
+                NamedExpression localOutputExpr = new Alias(originGroupByExpr, 
originGroupByExpr.toSql());
+                inputSubstitutionMap.put(originGroupByExpr, 
localOutputExpr.toSlot());
+                localOutputExprs.add(localOutputExpr);
+            }
+        }
+        for (NamedExpression originOutputExpr : originOutputExprs) {
+            List<AggregateFunction> aggregateFunctions
+                    = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
+            for (AggregateFunction aggregateFunction : aggregateFunctions) {
+                if (inputSubstitutionMap.containsKey(aggregateFunction)) {
                     continue;
                 }
-                if (originGroupByExpr instanceof SlotReference) {
-                    inputSubstitutionMap.put(originGroupByExpr, 
originGroupByExpr);
-                    localOutputExprs.add((SlotReference) originGroupByExpr);
-                } else {
-                    NamedExpression localOutputExpr = new 
Alias(originGroupByExpr, originGroupByExpr.toSql());
-                    inputSubstitutionMap.put(originGroupByExpr, 
localOutputExpr.toSlot());
-                    localOutputExprs.add(localOutputExpr);
-                }
+                NamedExpression localOutputExpr = new Alias(aggregateFunction, 
aggregateFunction.toSql());
+                Expression substitutionValue = aggregateFunction.withChildren(
+                        Lists.newArrayList(localOutputExpr.toSlot()));
+                inputSubstitutionMap.put(aggregateFunction, substitutionValue);
+                localOutputExprs.add(localOutputExpr);
             }
-            for (NamedExpression originOutputExpr : originOutputExprs) {
-                List<AggregateFunction> aggregateFunctions
-                        = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
-                for (AggregateFunction aggregateFunction : aggregateFunctions) 
{
-                    if (inputSubstitutionMap.containsKey(aggregateFunction)) {
-                        continue;
-                    }
-                    NamedExpression localOutputExpr = new 
Alias(aggregateFunction, aggregateFunction.toSql());
-                    Expression substitutionValue = 
aggregateFunction.withChildren(
-                            Lists.newArrayList(localOutputExpr.toSlot()));
-                    inputSubstitutionMap.put(aggregateFunction, 
substitutionValue);
-                    localOutputExprs.add(localOutputExpr);
+        }
+
+        // 3. replace expression in globalOutputExprs and globalGroupByExprs
+        List<NamedExpression> globalOutputExprs = 
aggregate.getOutputExpressions().stream()
+                .map(e -> ExpressionReplacer.INSTANCE.visit(e, 
inputSubstitutionMap))
+                .map(NamedExpression.class::cast)
+                .collect(Collectors.toList());
+        List<Expression> globalGroupByExprs = localGroupByExprs.stream()
+                .map(e -> ExpressionReplacer.INSTANCE.visit(e, 
inputSubstitutionMap)).collect(Collectors.toList());
+
+        // 4. generate new plan
+        LogicalAggregate localAggregate = new LogicalAggregate<>(
+                localGroupByExprs,
+                localOutputExprs,
+                true,
+                aggregate.isNormalized(),
+                AggPhase.LOCAL,
+                aggregate.child()
+        );
+        return new LogicalAggregate<>(
+                globalGroupByExprs,
+                globalOutputExprs,
+                true,
+                aggregate.isNormalized(),
+                AggPhase.GLOBAL,
+                localAggregate
+        );
+    }
+
+    private void 
moveDistinctExprFromOutputToGroupBy(LogicalAggregate<GroupPlan> aggregate) {
+        // for example:
+        // select count(distinct a) from t1 group by b;
+        // => select a from t1 group by b, a;
+        List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
+        List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
+
+        for (NamedExpression originOutputExpr : originOutputExprs) {
+            List<AggregateFunction> aggregateFunctions =
+                    
originOutputExpr.collect(AggregateFunction.class::isInstance);
+            for (AggregateFunction aggregateFunction : aggregateFunctions) {
+                if (aggregateFunction.isDistinct()) {
+                    distinctAggFunctions.add(aggregateFunction);
+                    distinctOriginOutputExprs.add(originOutputExpr);
                 }
             }
+        }
+        if (!distinctAggFunctions.isEmpty()) {
+            for (Expression expr : distinctAggFunctions.get(0).children()) {
+                distinctAggFunctionParams.add((NamedExpression) expr);

Review Comment:
   why could assume expr is `NamedExpression`? how about `count(distinct a + 1)`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to