This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit b2ac326f4caf6368bd58e00a608c794320c4e35b
Author: 谢健 <[email protected]>
AuthorDate: Tue Aug 29 15:01:26 2023 +0800

    [fix](Nereids) make agg output unchanged after normalized (#23499)
    
    The normalizedAgg rule can change the output of agg.
    
    For example:
    ```
    select c1 as c, c1 from t having c1 > 0
    ```
    The normalizedAgg rule will make a plan with output c, which can cause the 
having filter error
    
    Therefore, the output exprId should be unchanged after normalized
---
 .../nereids/rules/rewrite/NormalizeAggregate.java  | 49 ++++++++++++++--------
 .../doris/nereids/trees/expressions/CaseWhen.java  | 11 ++++-
 .../nereids/trees/expressions/WhenClause.java      |  2 +-
 .../suites/nereids_p0/aggregate/aggregate.groovy   |  3 ++
 4 files changed, 45 insertions(+), 20 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java
index 90e997941a..eb683e8b58 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeAggregate.java
@@ -33,6 +33,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableList.Builder;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
@@ -185,24 +186,6 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
             bottomProjects.addAll(aggInputSlots);
             // build group by exprs
             List<Expression> normalizedGroupExprs = 
groupByToSlotContext.normalizeToUseSlotRef(groupingByExprs);
-            // build upper project, use two context to do pop up, because agg 
output maybe contain two part:
-            //   group by keys and agg expressions
-            List<NamedExpression> upperProjects = groupByToSlotContext
-                    
.normalizeToUseSlotRefWithoutWindowFunction(aggregateOutput);
-            upperProjects = 
normalizedAggFuncsToSlotContext.normalizeToUseSlotRefWithoutWindowFunction(upperProjects);
-            // process Expression like Alias(SlotReference#0)#0
-            upperProjects = upperProjects.stream().map(e -> {
-                if (e instanceof Alias) {
-                    Alias alias = (Alias) e;
-                    if (alias.child() instanceof SlotReference) {
-                        SlotReference slotReference = (SlotReference) 
alias.child();
-                        if 
(slotReference.getExprId().equals(alias.getExprId())) {
-                            return slotReference;
-                        }
-                    }
-                }
-                return e;
-            }).collect(Collectors.toList());
 
             Plan bottomPlan;
             if (!bottomProjects.isEmpty()) {
@@ -211,11 +194,41 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
                 bottomPlan = aggregate.child();
             }
 
+            List<NamedExpression> upperProjects = 
normalizeOutput(aggregateOutput,
+                    groupByToSlotContext, normalizedAggFuncsToSlotContext);
+
             return new LogicalProject<>(upperProjects,
                     aggregate.withNormalized(normalizedGroupExprs, 
normalizedAggOutput, bottomPlan));
         }).toRule(RuleType.NORMALIZE_AGGREGATE);
     }
 
+    private List<NamedExpression> normalizeOutput(List<NamedExpression> 
aggregateOutput,
+            NormalizeToSlotContext groupByToSlotContext, 
NormalizeToSlotContext normalizedAggFuncsToSlotContext) {
+        // build upper project, use two context to do pop up, because agg 
output maybe contain two part:
+        //   group by keys and agg expressions
+        List<NamedExpression> upperProjects = groupByToSlotContext
+                .normalizeToUseSlotRefWithoutWindowFunction(aggregateOutput);
+        upperProjects = 
normalizedAggFuncsToSlotContext.normalizeToUseSlotRefWithoutWindowFunction(upperProjects);
+
+        Builder<NamedExpression> builder = new ImmutableList.Builder<>();
+        for (int i = 0; i < aggregateOutput.size(); i++) {
+            NamedExpression e = upperProjects.get(i);
+            // process Expression like Alias(SlotReference#0)#0
+            if (e instanceof Alias && e.child(0) instanceof SlotReference) {
+                SlotReference slotReference = (SlotReference) e.child(0);
+                if (slotReference.getExprId().equals(e.getExprId())) {
+                    e = slotReference;
+                }
+            }
+            // Make the output ExprId unchanged
+            if (!e.getExprId().equals(aggregateOutput.get(i).getExprId())) {
+                e = new Alias(aggregateOutput.get(i).getExprId(), e, 
aggregateOutput.get(i).getName());
+            }
+            builder.add(e);
+        }
+        return builder.build();
+    }
+
     private static class CollectNonWindowedAggFuncs extends 
DefaultExpressionVisitor<Void, List<AggregateFunction>> {
 
         private static final CollectNonWindowedAggFuncs INSTANCE = new 
CollectNonWindowedAggFuncs();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java
index 1b0e9eb80f..03e9d17d41 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java
@@ -97,7 +97,16 @@ public class CaseWhen extends Expression {
 
     @Override
     public String toString() {
-        return toSql();
+        StringBuilder output = new StringBuilder("CASE");
+        for (Expression child : children()) {
+            if (child instanceof WhenClause) {
+                output.append(child);
+            } else {
+                output.append(" ELSE ").append(child.toString());
+            }
+        }
+        output.append(" END");
+        return output.toString();
     }
 
     @Override
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java
index 147dc5f601..33d3f2d2d2 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java
@@ -112,6 +112,6 @@ public class WhenClause extends Expression implements 
BinaryExpression, ExpectsI
 
     @Override
     public String toString() {
-        return toSql();
+        return " WHEN " + left().toString() + " THEN " + right().toString();
     }
 }
diff --git a/regression-test/suites/nereids_p0/aggregate/aggregate.groovy 
b/regression-test/suites/nereids_p0/aggregate/aggregate.groovy
index 7ac3fbe9c5..e1ae3131b2 100644
--- a/regression-test/suites/nereids_p0/aggregate/aggregate.groovy
+++ b/regression-test/suites/nereids_p0/aggregate/aggregate.groovy
@@ -314,4 +314,7 @@ suite("aggregate") {
     qt_aggregate """ select avg(distinct c_bigint), avg(distinct c_double) 
from regression_test_nereids_p0_aggregate.${tableName} """
     qt_aggregate """ select count(distinct c_bigint),count(distinct 
c_double),count(distinct c_string),count(distinct c_date_1),count(distinct 
c_timestamp_1),count(distinct c_timestamp_2),count(distinct 
c_timestamp_3),count(distinct c_boolean) from 
regression_test_nereids_p0_aggregate.${tableName} """
     qt_select_quantile_percent """ select 
QUANTILE_PERCENT(QUANTILE_UNION(TO_QUANTILE_STATE(c_bigint,2048)),0.5) from 
regression_test_nereids_p0_aggregate.${tableName};  """
+
+    sql "select k1 as k, k1 from tempbaseall group by k1 having k1 > 0"
+    sql "select k1 as k, k1 from tempbaseall group by k1 having k > 0"
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to