This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new d4b99f1bb19 [Fix](nereids) Only rewrite the slots that appear both in 
trival-agg func and grouping sets (#31600)
d4b99f1bb19 is described below

commit d4b99f1bb197518b4609130f80abd26a25513714
Author: feiniaofeiafei <[email protected]>
AuthorDate: Thu Feb 29 22:17:43 2024 +0800

    [Fix](nereids) Only rewrite the slots that appear both in trival-agg func 
and grouping sets (#31600)
    
    * [Fix](nereids) Only rewrite the slots that appear both in trival-agg func 
and grouping sets
    
    * [Fix](nereids) Only rewrite the slots that appear both in trival-agg func 
and grouping sets
    
    ---------
    
    Co-authored-by: feiniaofeiafei <[email protected]>
---
 .../nereids/rules/analysis/NormalizeRepeat.java    | 84 ++++++++++++++++++----
 ...ot_both_appear_in_agg_fun_and_grouping_sets.out | 61 ++++++++++++++++
 ...both_appear_in_agg_fun_and_grouping_sets.groovy | 30 ++++++++
 3 files changed, 160 insertions(+), 15 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
index 3c893ce4bec..8437dc40b04 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
@@ -29,8 +29,10 @@ import 
org.apache.doris.nereids.trees.expressions.OrderExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.algebra.Repeat;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@@ -47,6 +49,7 @@ import com.google.common.collect.Sets;
 import com.google.common.collect.Sets.SetView;
 import org.jetbrains.annotations.NotNull;
 
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashSet;
 import java.util.List;
@@ -267,8 +270,11 @@ public class NormalizeRepeat extends 
OneAnalysisRuleFactory {
     private LogicalAggregate<Plan> dealSlotAppearBothInAggFuncAndGroupingSets(
             @NotNull LogicalAggregate<Plan> aggregate) {
         LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
-        Set<Slot> aggUsedSlots = aggregate.getOutputExpressions().stream()
-                .flatMap(e -> 
e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
+
+        List<AggregateFunction> aggregateFunctions = Lists.newArrayList();
+        aggregate.getOutputExpressions().forEach(
+                o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, 
aggregateFunctions));
+        Set<Slot> aggUsedSlots = aggregateFunctions.stream()
                 .flatMap(e -> 
e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
                 .collect(ImmutableSet.toImmutableSet());
         Set<Slot> groupingSetsUsedSlot = repeat.getGroupingSets().stream()
@@ -308,20 +314,68 @@ public class NormalizeRepeat extends 
OneAnalysisRuleFactory {
                 .build());
         aggregate = aggregate.withChildren(ImmutableList.of(repeat));
 
-        // modify aggregate functions' parameter slot reference to new copied 
slots
         List<NamedExpression> newOutputExpressions = 
aggregate.getOutputExpressions().stream()
-                .map(output -> (NamedExpression) 
output.rewriteDownShortCircuit(expr -> {
-                    if (expr instanceof AggregateFunction) {
-                        return expr.rewriteDownShortCircuit(e -> {
-                            if (e instanceof Slot && 
slotMapping.containsKey(e)) {
-                                return slotMapping.get(e).toSlot();
-                            }
-                            return e;
-                        });
-                    }
-                    return expr;
-                })
-                ).collect(Collectors.toList());
+                .map(e -> (NamedExpression) 
e.accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE,
+                        slotMapping))
+                .collect(Collectors.toList());
         return aggregate.withAggOutput(newOutputExpressions);
     }
+
+    /**
+     * This class use the map(slotMapping) to rewrite all slots in trival-agg.
+     * The purpose of this class is to only rewrite the slots in trival-agg 
and not to rewrite the slots in window-agg.
+     */
+    private static class RewriteAggFuncWithoutWindowAggFunc
+            extends DefaultExpressionRewriter<Map<Slot, Alias>> {
+
+        private static final RewriteAggFuncWithoutWindowAggFunc
+                INSTANCE = new RewriteAggFuncWithoutWindowAggFunc();
+
+        private RewriteAggFuncWithoutWindowAggFunc() {}
+
+        @Override
+        public Expression visitAggregateFunction(AggregateFunction 
aggregateFunction, Map<Slot, Alias> slotMapping) {
+            return aggregateFunction.rewriteDownShortCircuit(e -> {
+                if (e instanceof Slot && slotMapping.containsKey(e)) {
+                    return slotMapping.get(e).toSlot();
+                }
+                return e;
+            });
+        }
+
+        @Override
+        public Expression visitWindow(WindowExpression windowExpression, 
Map<Slot, Alias> slotMapping) {
+            List<Expression> newChildren = new ArrayList<>();
+            Expression function = windowExpression.getFunction();
+            Expression oldFuncChild = function.child(0);
+            boolean hasNewChildren = false;
+            if (oldFuncChild != null) {
+                Expression newFuncChild;
+                newFuncChild = function.child(0).accept(this, slotMapping);
+                hasNewChildren = (newFuncChild != oldFuncChild);
+                newChildren.add(hasNewChildren
+                        ? 
function.withChildren(ImmutableList.of(newFuncChild)) : function);
+            } else {
+                newChildren.add(function);
+            }
+            for (Expression partitionKey : 
windowExpression.getPartitionKeys()) {
+                Expression newChild = partitionKey.accept(this, slotMapping);
+                if (newChild != partitionKey) {
+                    hasNewChildren = true;
+                }
+                newChildren.add(newChild);
+            }
+            for (Expression orderKey : windowExpression.getOrderKeys()) {
+                Expression newChild = orderKey.accept(this, slotMapping);
+                if (newChild != orderKey) {
+                    hasNewChildren = true;
+                }
+                newChildren.add(newChild);
+            }
+            if (windowExpression.getWindowFrame().isPresent()) {
+                newChildren.add(windowExpression.getWindowFrame().get());
+            }
+            return hasNewChildren ? windowExpression.withChildren(newChildren) 
: windowExpression;
+        }
+    }
 }
diff --git 
a/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
 
b/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
index 901226f8548..2c96648dac4 100644
--- 
a/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
+++ 
b/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
@@ -64,3 +64,64 @@ yeah
 4
 5
 
+-- !select6 --
+\N
+-86
+-48
+-12
+82
+89
+16054
+19196
+
+-- !select7 --
+\N
+-48
+-43
+82
+89
+35195
+
+-- !select8 --
+\N
+\N
+-86
+-86
+-48
+-12
+82
+89
+16054
+19196
+
+-- !select9 --
+\N
+\N
+\N
+-129
+-129
+-129
+-96
+-96
+-12
+164
+164
+178
+178
+16054
+19196
+35195
+35275
+
+-- !select10 --
+\N
+\N
+-172
+-172
+-96
+-24
+164
+178
+32108
+38392
+
diff --git 
a/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
 
b/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
index 865ce3b5f50..bee63217a9f 100644
--- 
a/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
+++ 
b/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
@@ -60,4 +60,34 @@ suite("slot_both_appear_in_agg_fun_and_grouping_sets") {
           select sum(rank() over (partition by col_text_undef_signed order by 
col_int_undef_signed)) 
           as col1 from table_10_undef_undef4 group by grouping 
sets((col_int_undef_signed)) order by 1;
     """
+
+    qt_select6 """
+        select sum(sum(col_int_undef_signed)) over (partition by 
sum(col_int_undef_signed) 
+        order by sum(col_int_undef_signed)) from table_10_undef_undef4 group 
by 
+        grouping sets ((col_int_undef_signed)) order by 1;
+    """
+
+    qt_select7 """
+        select sum(sum(col_int_undef_signed)) over (partition by 
sum(col_int_undef_signed) 
+        order by sum(col_int_undef_signed)) from table_10_undef_undef4 group 
by 
+        grouping sets ((col_text_undef_signed)) order by 1;
+    """
+
+    qt_select8 """
+        select sum(sum(col_int_undef_signed)) over (partition by 
sum(col_int_undef_signed)
+        order by sum(col_int_undef_signed)) from table_10_undef_undef4 group by
+        grouping sets ((col_text_undef_signed,col_int_undef_signed)) order by 
1;
+    """
+
+    qt_select9 """
+        select sum(sum(col_int_undef_signed)) over (partition by 
sum(col_int_undef_signed)
+        order by sum(col_int_undef_signed)) from table_10_undef_undef4 group by
+        grouping sets ((col_text_undef_signed,col_int_undef_signed), 
(col_text_undef_signed), ()) order by 1;
+    """
+    
+    qt_select10 """
+        select sum(col_int_undef_signed + sum(col_int_undef_signed)) over 
(partition by sum(col_int_undef_signed)
+        order by sum(col_int_undef_signed)) from table_10_undef_undef4 group by
+        grouping sets ((col_text_undef_signed,col_int_undef_signed)) order by 
1;
+    """
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to