This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new d4b99f1bb19 [Fix](nereids) Only rewrite the slots that appear both in
trival-agg func and grouping sets (#31600)
d4b99f1bb19 is described below
commit d4b99f1bb197518b4609130f80abd26a25513714
Author: feiniaofeiafei <[email protected]>
AuthorDate: Thu Feb 29 22:17:43 2024 +0800
[Fix](nereids) Only rewrite the slots that appear both in trival-agg func
and grouping sets (#31600)
* [Fix](nereids) Only rewrite the slots that appear both in trival-agg func
and grouping sets
* [Fix](nereids) Only rewrite the slots that appear both in trival-agg func
and grouping sets
---------
Co-authored-by: feiniaofeiafei <[email protected]>
---
.../nereids/rules/analysis/NormalizeRepeat.java | 84 ++++++++++++++++++----
...ot_both_appear_in_agg_fun_and_grouping_sets.out | 61 ++++++++++++++++
...both_appear_in_agg_fun_and_grouping_sets.groovy | 30 ++++++++
3 files changed, 160 insertions(+), 15 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
index 3c893ce4bec..8437dc40b04 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
@@ -29,8 +29,10 @@ import
org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Repeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
@@ -47,6 +49,7 @@ import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;
import org.jetbrains.annotations.NotNull;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
@@ -267,8 +270,11 @@ public class NormalizeRepeat extends
OneAnalysisRuleFactory {
private LogicalAggregate<Plan> dealSlotAppearBothInAggFuncAndGroupingSets(
@NotNull LogicalAggregate<Plan> aggregate) {
LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
- Set<Slot> aggUsedSlots = aggregate.getOutputExpressions().stream()
- .flatMap(e ->
e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
+
+ List<AggregateFunction> aggregateFunctions = Lists.newArrayList();
+ aggregate.getOutputExpressions().forEach(
+ o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE,
aggregateFunctions));
+ Set<Slot> aggUsedSlots = aggregateFunctions.stream()
.flatMap(e ->
e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
.collect(ImmutableSet.toImmutableSet());
Set<Slot> groupingSetsUsedSlot = repeat.getGroupingSets().stream()
@@ -308,20 +314,68 @@ public class NormalizeRepeat extends
OneAnalysisRuleFactory {
.build());
aggregate = aggregate.withChildren(ImmutableList.of(repeat));
- // modify aggregate functions' parameter slot reference to new copied
slots
List<NamedExpression> newOutputExpressions =
aggregate.getOutputExpressions().stream()
- .map(output -> (NamedExpression)
output.rewriteDownShortCircuit(expr -> {
- if (expr instanceof AggregateFunction) {
- return expr.rewriteDownShortCircuit(e -> {
- if (e instanceof Slot &&
slotMapping.containsKey(e)) {
- return slotMapping.get(e).toSlot();
- }
- return e;
- });
- }
- return expr;
- })
- ).collect(Collectors.toList());
+ .map(e -> (NamedExpression)
e.accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE,
+ slotMapping))
+ .collect(Collectors.toList());
return aggregate.withAggOutput(newOutputExpressions);
}
+
+ /**
+ * This class use the map(slotMapping) to rewrite all slots in trival-agg.
+ * The purpose of this class is to only rewrite the slots in trival-agg
and not to rewrite the slots in window-agg.
+ */
+ private static class RewriteAggFuncWithoutWindowAggFunc
+ extends DefaultExpressionRewriter<Map<Slot, Alias>> {
+
+ private static final RewriteAggFuncWithoutWindowAggFunc
+ INSTANCE = new RewriteAggFuncWithoutWindowAggFunc();
+
+ private RewriteAggFuncWithoutWindowAggFunc() {}
+
+ @Override
+ public Expression visitAggregateFunction(AggregateFunction
aggregateFunction, Map<Slot, Alias> slotMapping) {
+ return aggregateFunction.rewriteDownShortCircuit(e -> {
+ if (e instanceof Slot && slotMapping.containsKey(e)) {
+ return slotMapping.get(e).toSlot();
+ }
+ return e;
+ });
+ }
+
+ @Override
+ public Expression visitWindow(WindowExpression windowExpression,
Map<Slot, Alias> slotMapping) {
+ List<Expression> newChildren = new ArrayList<>();
+ Expression function = windowExpression.getFunction();
+ Expression oldFuncChild = function.child(0);
+ boolean hasNewChildren = false;
+ if (oldFuncChild != null) {
+ Expression newFuncChild;
+ newFuncChild = function.child(0).accept(this, slotMapping);
+ hasNewChildren = (newFuncChild != oldFuncChild);
+ newChildren.add(hasNewChildren
+ ?
function.withChildren(ImmutableList.of(newFuncChild)) : function);
+ } else {
+ newChildren.add(function);
+ }
+ for (Expression partitionKey :
windowExpression.getPartitionKeys()) {
+ Expression newChild = partitionKey.accept(this, slotMapping);
+ if (newChild != partitionKey) {
+ hasNewChildren = true;
+ }
+ newChildren.add(newChild);
+ }
+ for (Expression orderKey : windowExpression.getOrderKeys()) {
+ Expression newChild = orderKey.accept(this, slotMapping);
+ if (newChild != orderKey) {
+ hasNewChildren = true;
+ }
+ newChildren.add(newChild);
+ }
+ if (windowExpression.getWindowFrame().isPresent()) {
+ newChildren.add(windowExpression.getWindowFrame().get());
+ }
+ return hasNewChildren ? windowExpression.withChildren(newChildren)
: windowExpression;
+ }
+ }
}
diff --git
a/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
b/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
index 901226f8548..2c96648dac4 100644
---
a/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
+++
b/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
@@ -64,3 +64,64 @@ yeah
4
5
+-- !select6 --
+\N
+-86
+-48
+-12
+82
+89
+16054
+19196
+
+-- !select7 --
+\N
+-48
+-43
+82
+89
+35195
+
+-- !select8 --
+\N
+\N
+-86
+-86
+-48
+-12
+82
+89
+16054
+19196
+
+-- !select9 --
+\N
+\N
+\N
+-129
+-129
+-129
+-96
+-96
+-12
+164
+164
+178
+178
+16054
+19196
+35195
+35275
+
+-- !select10 --
+\N
+\N
+-172
+-172
+-96
+-24
+164
+178
+32108
+38392
+
diff --git
a/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
b/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
index 865ce3b5f50..bee63217a9f 100644
---
a/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
+++
b/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
@@ -60,4 +60,34 @@ suite("slot_both_appear_in_agg_fun_and_grouping_sets") {
select sum(rank() over (partition by col_text_undef_signed order by
col_int_undef_signed))
as col1 from table_10_undef_undef4 group by grouping
sets((col_int_undef_signed)) order by 1;
"""
+
+ qt_select6 """
+ select sum(sum(col_int_undef_signed)) over (partition by
sum(col_int_undef_signed)
+ order by sum(col_int_undef_signed)) from table_10_undef_undef4 group
by
+ grouping sets ((col_int_undef_signed)) order by 1;
+ """
+
+ qt_select7 """
+ select sum(sum(col_int_undef_signed)) over (partition by
sum(col_int_undef_signed)
+ order by sum(col_int_undef_signed)) from table_10_undef_undef4 group
by
+ grouping sets ((col_text_undef_signed)) order by 1;
+ """
+
+ qt_select8 """
+ select sum(sum(col_int_undef_signed)) over (partition by
sum(col_int_undef_signed)
+ order by sum(col_int_undef_signed)) from table_10_undef_undef4 group by
+ grouping sets ((col_text_undef_signed,col_int_undef_signed)) order by
1;
+ """
+
+ qt_select9 """
+ select sum(sum(col_int_undef_signed)) over (partition by
sum(col_int_undef_signed)
+ order by sum(col_int_undef_signed)) from table_10_undef_undef4 group by
+ grouping sets ((col_text_undef_signed,col_int_undef_signed),
(col_text_undef_signed), ()) order by 1;
+ """
+
+ qt_select10 """
+ select sum(col_int_undef_signed + sum(col_int_undef_signed)) over
(partition by sum(col_int_undef_signed)
+ order by sum(col_int_undef_signed)) from table_10_undef_undef4 group by
+ grouping sets ((col_text_undef_signed,col_int_undef_signed)) order by
1;
+ """
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]