This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new c583563087b [feature](Nereids): double eager support mix function
(#30468)
c583563087b is described below
commit c583563087bc5a0db9920aa88aafb63a5bd61e19
Author: jakevin <[email protected]>
AuthorDate: Mon Jan 29 13:08:09 2024 +0800
[feature](Nereids): double eager support mix function (#30468)
---
.../doris/nereids/jobs/executor/Rewriter.java | 6 +-
.../org/apache/doris/nereids/rules/RuleType.java | 3 +-
...hroughJoin.java => PushDownAggThroughJoin.java} | 107 +++++------
.../rules/rewrite/PushDownSumThroughJoin.java | 212 ---------------------
.../rewrite/PushDownCountThroughJoinTest.java | 13 +-
.../rules/rewrite/PushDownSumThroughJoinTest.java | 29 ++-
.../eager_aggregate/push_down_sum_through_join.out | 12 +-
.../nereids_rules_p0/eager_aggregate/basic.groovy | 3 +-
.../push_down_count_through_join.groovy | 2 +-
.../push_down_sum_through_join.groovy | 4 +-
10 files changed, 101 insertions(+), 290 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 2c0e57b715e..34f7afe4995 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -98,14 +98,13 @@ import
org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderTopN;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoOdbcScan;
+import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
-import org.apache.doris.nereids.rules.rewrite.PushDownCountThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.PushDownLimit;
import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin;
import
org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughUnion;
-import org.apache.doris.nereids.rules.rewrite.PushDownSumThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownTopNDistinctThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownTopNDistinctThroughUnion;
import org.apache.doris.nereids.rules.rewrite.PushDownTopNThroughJoin;
@@ -288,9 +287,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
topic("Eager aggregation",
topDown(
- new PushDownSumThroughJoin(),
new PushDownAggThroughJoinOneSide(),
- new PushDownCountThroughJoin()
+ new PushDownAggThroughJoin()
),
custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN,
PushDownDistinctThroughJoin::new)
),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index b35c7e03b72..594f49a3b70 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -167,8 +167,7 @@ public enum RuleType {
ELIMINATE_SORT(RuleTypeClass.REWRITE),
PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE(RuleTypeClass.REWRITE),
- PUSH_DOWN_SUM_THROUGH_JOIN(RuleTypeClass.REWRITE),
- PUSH_DOWN_COUNT_THROUGH_JOIN(RuleTypeClass.REWRITE),
+ PUSH_DOWN_AGG_THROUGH_JOIN(RuleTypeClass.REWRITE),
TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN(RuleTypeClass.REWRITE),
TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN_PROJECT(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoin.java
similarity index 69%
rename from
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoin.java
rename to
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoin.java
index 462180ab7a6..f003d2ac2cc 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoin.java
@@ -67,7 +67,7 @@ import java.util.Set;
* </pre>
* Notice: rule can't optimize condition that groupby is empty when Count(*)
exists.
*/
-public class PushDownCountThroughJoin implements RewriteRuleFactory {
+public class PushDownAggThroughJoin implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
@@ -78,19 +78,22 @@ public class PushDownCountThroughJoin implements
RewriteRuleFactory {
.when(agg -> {
Set<AggregateFunction> funcs =
agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Count &&
!f.isDistinct()
- && (((Count) f).isCountStar() ||
f.child(0) instanceof Slot));
+ .allMatch(f -> !f.isDistinct()
+ && (f instanceof Count &&
(((Count) f).isCountStar() || f.child(
+ 0) instanceof Slot)
+ || (f instanceof Sum && f.child(0)
instanceof Slot))
+ );
})
.thenApply(ctx -> {
Set<Integer> enableNereidsRules =
ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
- if
(!enableNereidsRules.contains(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN.type())) {
+ if
(!enableNereidsRules.contains(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalJoin<Plan, Plan>> agg =
ctx.root;
- return pushCount(agg, agg.child(),
ImmutableList.of());
+ return pushAgg(agg, agg.child(),
ImmutableList.of());
})
- .toRule(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN),
+ .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> agg.child().isAllSlots())
.when(agg ->
agg.child().child().getOtherJoinConjuncts().isEmpty())
@@ -99,40 +102,42 @@ public class PushDownCountThroughJoin implements
RewriteRuleFactory {
.when(agg -> {
Set<AggregateFunction> funcs =
agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Count &&
!f.isDistinct()
- && (((Count) f).isCountStar() ||
f.child(0) instanceof Slot));
+ .allMatch(f -> !f.isDistinct()
+ && (f instanceof Count &&
(((Count) f).isCountStar() || f.child(
+ 0) instanceof Slot)
+ || (f instanceof Sum && f.child(0)
instanceof Slot))
+ );
})
.thenApply(ctx -> {
Set<Integer> enableNereidsRules =
ctx.cascadesContext.getConnectContext()
.getSessionVariable().getEnableNereidsRules();
- if
(!enableNereidsRules.contains(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN.type())) {
+ if
(!enableNereidsRules.contains(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN.type())) {
return null;
}
LogicalAggregate<LogicalProject<LogicalJoin<Plan,
Plan>>> agg = ctx.root;
- return pushCount(agg, agg.child().child(),
agg.child().getProjects());
+ return pushAgg(agg, agg.child().child(),
agg.child().getProjects());
})
- .toRule(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN)
+ .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN)
);
}
- private LogicalAggregate<Plan> pushCount(LogicalAggregate<? extends Plan>
agg,
+ private static LogicalAggregate<Plan> pushAgg(LogicalAggregate<? extends
Plan> agg,
LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) {
List<Slot> leftOutput = join.left().getOutput();
List<Slot> rightOutput = join.right().getOutput();
- List<Count> leftCounts = new ArrayList<>();
- List<Count> rightCounts = new ArrayList<>();
+ List<AggregateFunction> leftAggs = new ArrayList<>();
+ List<AggregateFunction> rightAggs = new ArrayList<>();
List<Count> countStars = new ArrayList<>();
for (AggregateFunction f : agg.getAggregateFunctions()) {
- Count count = (Count) f;
- if (count.isCountStar()) {
- countStars.add(count);
+ if (f instanceof Count && ((Count) f).isCountStar()) {
+ countStars.add((Count) f);
} else {
- Slot slot = (Slot) count.child(0);
+ Slot slot = (Slot) f.child(0);
if (leftOutput.contains(slot)) {
- leftCounts.add(count);
+ leftAggs.add(f);
} else if (rightOutput.contains(slot)) {
- rightCounts.add(count);
+ rightAggs.add(f);
} else {
throw new IllegalStateException("Slot " + slot + " not
found in join output");
}
@@ -168,63 +173,59 @@ public class PushDownCountThroughJoin implements
RewriteRuleFactory {
Alias leftCnt = null;
Alias rightCnt = null;
- // left Count agg
- Map<Slot, NamedExpression> leftCntSlotToOutput = new HashMap<>();
- Builder<NamedExpression> leftCntAggOutputBuilder =
ImmutableList.<NamedExpression>builder()
- .addAll(leftGroupBy);
- leftCounts.forEach(func -> {
+ // left agg
+ Map<Slot, NamedExpression> leftSlotToOutput = new HashMap<>();
+ Builder<NamedExpression> leftAggOutputBuilder =
ImmutableList.<NamedExpression>builder().addAll(leftGroupBy);
+ leftAggs.forEach(func -> {
Alias alias = func.alias(func.getName());
- leftCntSlotToOutput.put((Slot) func.child(0), alias);
- leftCntAggOutputBuilder.add(alias);
+ leftSlotToOutput.put((Slot) func.child(0), alias);
+ leftAggOutputBuilder.add(alias);
});
- if (!rightCounts.isEmpty() || !countStars.isEmpty()) {
+ if (!rightAggs.isEmpty() || !countStars.isEmpty()) {
leftCnt = new Count().alias("leftCntStar");
- leftCntAggOutputBuilder.add(leftCnt);
+ leftAggOutputBuilder.add(leftCnt);
}
- LogicalAggregate<Plan> leftCntAgg = new LogicalAggregate<>(
- ImmutableList.copyOf(leftGroupBy),
leftCntAggOutputBuilder.build(), join.left());
-
- // right Count agg
- Map<Slot, NamedExpression> rightCntSlotToOutput = new HashMap<>();
- Builder<NamedExpression> rightCntAggOutputBuilder =
ImmutableList.<NamedExpression>builder()
- .addAll(rightGroupBy);
- rightCounts.forEach(func -> {
+ LogicalAggregate<Plan> leftAgg = new LogicalAggregate<>(
+ ImmutableList.copyOf(leftGroupBy),
leftAggOutputBuilder.build(), join.left());
+ // right agg
+ Map<Slot, NamedExpression> rightSlotToOutput = new HashMap<>();
+ Builder<NamedExpression> rightAggOutputBuilder =
ImmutableList.<NamedExpression>builder().addAll(rightGroupBy);
+ rightAggs.forEach(func -> {
Alias alias = func.alias(func.getName());
- rightCntSlotToOutput.put((Slot) func.child(0), alias);
- rightCntAggOutputBuilder.add(alias);
+ rightSlotToOutput.put((Slot) func.child(0), alias);
+ rightAggOutputBuilder.add(alias);
});
-
- if (!leftCounts.isEmpty() || !countStars.isEmpty()) {
+ if (!leftAggs.isEmpty() || !countStars.isEmpty()) {
rightCnt = new Count().alias("rightCntStar");
- rightCntAggOutputBuilder.add(rightCnt);
+ rightAggOutputBuilder.add(rightCnt);
}
- LogicalAggregate<Plan> rightCntAgg = new LogicalAggregate<>(
- ImmutableList.copyOf(rightGroupBy),
rightCntAggOutputBuilder.build(), join.right());
+ LogicalAggregate<Plan> rightAgg = new LogicalAggregate<>(
+ ImmutableList.copyOf(rightGroupBy),
rightAggOutputBuilder.build(), join.right());
- Plan newJoin = join.withChildren(leftCntAgg, rightCntAgg);
+ Plan newJoin = join.withChildren(leftAgg, rightAgg);
// top Sum agg
// count(slot) -> sum( count(slot) * cntStar )
// count(*) -> sum( leftCntStar * leftCntStar )
List<NamedExpression> newOutputExprs = new ArrayList<>();
for (NamedExpression ne : agg.getOutputExpressions()) {
- if (ne instanceof Alias && ((Alias) ne).child() instanceof Count) {
- Count oldTopCnt = (Count) ((Alias) ne).child();
- if (oldTopCnt.isCountStar()) {
+ if (ne instanceof Alias && ((Alias) ne).child() instanceof
AggregateFunction) {
+ AggregateFunction func = (AggregateFunction) ((Alias)
ne).child();
+ if (func instanceof Count && ((Count) func).isCountStar()) {
Preconditions.checkState(rightCnt != null && leftCnt !=
null);
Expression expr = new Sum(new Multiply(leftCnt.toSlot(),
rightCnt.toSlot()));
newOutputExprs.add((NamedExpression)
ne.withChildren(expr));
} else {
- Slot slot = (Slot) oldTopCnt.child(0);
- if (leftCntSlotToOutput.containsKey(slot)) {
+ Slot slot = (Slot) func.child(0);
+ if (leftSlotToOutput.containsKey(slot)) {
Preconditions.checkState(rightCnt != null);
Expression expr = new Sum(
- new
Multiply(leftCntSlotToOutput.get(slot).toSlot(), rightCnt.toSlot()));
+ new
Multiply(leftSlotToOutput.get(slot).toSlot(), rightCnt.toSlot()));
newOutputExprs.add((NamedExpression)
ne.withChildren(expr));
- } else if (rightCntSlotToOutput.containsKey(slot)) {
+ } else if (rightSlotToOutput.containsKey(slot)) {
Preconditions.checkState(leftCnt != null);
Expression expr = new Sum(
- new
Multiply(rightCntSlotToOutput.get(slot).toSlot(), leftCnt.toSlot()));
+ new
Multiply(rightSlotToOutput.get(slot).toSlot(), leftCnt.toSlot()));
newOutputExprs.add((NamedExpression)
ne.withChildren(expr));
} else {
throw new IllegalStateException("Slot " + slot + " not
found in join output");
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java
deleted file mode 100644
index e8987e670a5..00000000000
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java
+++ /dev/null
@@ -1,212 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.rewrite;
-
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.Multiply;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableList.Builder;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * TODO: distinct
- * Related paper "Eager aggregation and lazy aggregation".
- * <pre>
- * aggregate: Sum(x)
- * |
- * join
- * | \
- * | *
- * (x)
- * ->
- * aggregate: Sum(sum1)
- * |
- * join
- * | \
- * | *
- * aggregate: Sum(x) as sum1
- * </pre>
- */
-public class PushDownSumThroughJoin implements RewriteRuleFactory {
- @Override
- public List<Rule> buildRules() {
- return ImmutableList.of(
- logicalAggregate(innerLogicalJoin())
- .when(agg ->
agg.child().getOtherJoinConjuncts().isEmpty())
- .whenNot(agg ->
agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
- .when(agg -> {
- Set<AggregateFunction> funcs =
agg.getAggregateFunctions();
- return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Sum &&
!f.isDistinct() && f.child(0) instanceof Slot);
- })
- .thenApply(ctx -> {
- Set<Integer> enableNereidsRules =
ctx.cascadesContext.getConnectContext()
-
.getSessionVariable().getEnableNereidsRules();
- if
(!enableNereidsRules.contains(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN.type())) {
- return null;
- }
- LogicalAggregate<LogicalJoin<Plan, Plan>> agg =
ctx.root;
- return pushSum(agg, agg.child(),
ImmutableList.of());
- })
- .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN),
- logicalAggregate(logicalProject(innerLogicalJoin()))
- .when(agg -> agg.child().isAllSlots())
- .when(agg ->
agg.child().child().getOtherJoinConjuncts().isEmpty())
- .whenNot(agg ->
agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
- .when(agg -> {
- Set<AggregateFunction> funcs =
agg.getAggregateFunctions();
- return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Sum &&
!f.isDistinct() && f.child(0) instanceof Slot);
- })
- .thenApply(ctx -> {
- Set<Integer> enableNereidsRules =
ctx.cascadesContext.getConnectContext()
-
.getSessionVariable().getEnableNereidsRules();
- if
(!enableNereidsRules.contains(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN.type())) {
- return null;
- }
- LogicalAggregate<LogicalProject<LogicalJoin<Plan,
Plan>>> agg = ctx.root;
- return pushSum(agg, agg.child().child(),
agg.child().getProjects());
- })
- .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN)
- );
- }
-
- private LogicalAggregate<Plan> pushSum(LogicalAggregate<? extends Plan>
agg,
- LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) {
- List<Slot> leftOutput = join.left().getOutput();
- List<Slot> rightOutput = join.right().getOutput();
-
- List<Sum> leftSums = new ArrayList<>();
- List<Sum> rightSums = new ArrayList<>();
- for (AggregateFunction f : agg.getAggregateFunctions()) {
- Sum sum = (Sum) f;
- Slot slot = (Slot) sum.child();
- if (leftOutput.contains(slot)) {
- leftSums.add(sum);
- } else if (rightOutput.contains(slot)) {
- rightSums.add(sum);
- } else {
- throw new IllegalStateException("Slot " + slot + " not found
in join output");
- }
- }
- if (leftSums.isEmpty() && rightSums.isEmpty()
- || (!leftSums.isEmpty() && !rightSums.isEmpty())) {
- return null;
- }
-
- Set<Slot> leftGroupBy = new HashSet<>();
- Set<Slot> rightGroupBy = new HashSet<>();
- for (Expression e : agg.getGroupByExpressions()) {
- Slot slot = (Slot) e;
- if (leftOutput.contains(slot)) {
- leftGroupBy.add(slot);
- } else if (rightOutput.contains(slot)) {
- rightGroupBy.add(slot);
- } else {
- return null;
- }
- }
- join.getHashJoinConjuncts().forEach(e ->
e.getInputSlots().forEach(slot -> {
- if (leftOutput.contains(slot)) {
- leftGroupBy.add(slot);
- } else if (rightOutput.contains(slot)) {
- rightGroupBy.add(slot);
- } else {
- throw new IllegalStateException("Slot " + slot + " not found
in join output");
- }
- }));
-
- List<Sum> sums;
- Set<Slot> sumGroupBy;
- Set<Slot> cntGroupBy;
- Plan sumChild;
- Plan cntChild;
- if (!leftSums.isEmpty()) {
- sums = leftSums;
- sumGroupBy = leftGroupBy;
- cntGroupBy = rightGroupBy;
- sumChild = join.left();
- cntChild = join.right();
- } else {
- sums = rightSums;
- sumGroupBy = rightGroupBy;
- cntGroupBy = leftGroupBy;
- sumChild = join.right();
- cntChild = join.left();
- }
-
- // Sum agg
- Map<Slot, NamedExpression> sumSlotToOutput = new HashMap<>();
- Builder<NamedExpression> sumAggOutputBuilder =
ImmutableList.<NamedExpression>builder().addAll(sumGroupBy);
- sums.forEach(func -> {
- Alias alias = func.alias(func.getName());
- sumSlotToOutput.put((Slot) func.child(0), alias);
- sumAggOutputBuilder.add(alias);
- });
- LogicalAggregate<Plan> sumAgg = new LogicalAggregate<>(
- ImmutableList.copyOf(sumGroupBy), sumAggOutputBuilder.build(),
sumChild);
-
- // Count agg
- Alias cnt = new Count().alias("cnt");
- List<NamedExpression> cntAggOutput =
ImmutableList.<NamedExpression>builder()
- .addAll(cntGroupBy).add(cnt)
- .build();
- LogicalAggregate<Plan> cntAgg = new LogicalAggregate<>(
- ImmutableList.copyOf(cntGroupBy), cntAggOutput, cntChild);
-
- Plan newJoin = !leftSums.isEmpty() ? join.withChildren(sumAgg, cntAgg)
: join.withChildren(cntAgg, sumAgg);
-
- // top Sum agg
- // replace sum(x) -> sum(sum# * cnt)
- List<NamedExpression> newOutputExprs = new ArrayList<>();
- for (NamedExpression ne : agg.getOutputExpressions()) {
- if (ne instanceof Alias && ((Alias) ne).child() instanceof
AggregateFunction) {
- AggregateFunction func = (AggregateFunction) ((Alias)
ne).child();
- Slot slot = (Slot) func.child(0);
- if (sumSlotToOutput.containsKey(slot)) {
- Expression expr = func.withChildren(new
Multiply(sumSlotToOutput.get(slot).toSlot(), cnt.toSlot()));
- newOutputExprs.add((NamedExpression)
ne.withChildren(expr));
- } else {
- throw new IllegalStateException("Slot " + slot + " not
found in join output");
- }
- } else {
- newOutputExprs.add(ne);
- }
- }
- return agg.withAggOutputChild(newOutputExprs, newJoin);
- }
-}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java
index 34ccfe70f70..8e0e0e15df3 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java
@@ -45,7 +45,7 @@ class PushDownCountThroughJoinTest implements
MemoPatternMatchSupported {
private MockUp<SessionVariable> mockUp = new MockUp<SessionVariable>() {
@Mock
public Set<Integer> getEnableNereidsRules() {
- return
ImmutableSet.of(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN.type());
+ return ImmutableSet.of(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN.type());
}
};
@@ -58,7 +58,8 @@ class PushDownCountThroughJoinTest implements
MemoPatternMatchSupported {
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownCountThroughJoin())
+ .applyTopDown(new PushDownAggThroughJoin())
+ .printlnTree()
.matches(
logicalAggregate(
logicalJoin(
@@ -81,7 +82,7 @@ class PushDownCountThroughJoinTest implements
MemoPatternMatchSupported {
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownCountThroughJoin())
+ .applyTopDown(new PushDownAggThroughJoin())
.matches(
logicalAggregate(
logicalJoin(
@@ -101,7 +102,7 @@ class PushDownCountThroughJoinTest implements
MemoPatternMatchSupported {
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownCountThroughJoin())
+ .applyTopDown(new PushDownAggThroughJoin())
.matches(
logicalAggregate(
logicalJoin(
@@ -122,7 +123,7 @@ class PushDownCountThroughJoinTest implements
MemoPatternMatchSupported {
// shouldn't rewrite.
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownCountThroughJoin())
+ .applyTopDown(new PushDownAggThroughJoin())
.matches(
logicalAggregate(
logicalJoin(
@@ -145,7 +146,7 @@ class PushDownCountThroughJoinTest implements
MemoPatternMatchSupported {
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownCountThroughJoin())
+ .applyTopDown(new PushDownAggThroughJoin())
.matches(
logicalAggregate(
logicalJoin(
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinTest.java
index 088372b0d76..29a745b379f 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinTest.java
@@ -45,7 +45,7 @@ class PushDownSumThroughJoinTest implements
MemoPatternMatchSupported {
private MockUp<SessionVariable> mockUp = new MockUp<SessionVariable>() {
@Mock
public Set<Integer> getEnableNereidsRules() {
- return ImmutableSet.of(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN.type());
+ return ImmutableSet.of(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN.type());
}
};
@@ -58,7 +58,7 @@ class PushDownSumThroughJoinTest implements
MemoPatternMatchSupported {
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownSumThroughJoin())
+ .applyTopDown(new PushDownAggThroughJoin())
.matches(
logicalAggregate(
logicalJoin(
@@ -78,7 +78,28 @@ class PushDownSumThroughJoinTest implements
MemoPatternMatchSupported {
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownSumThroughJoin())
+ .applyTopDown(new PushDownAggThroughJoin())
+ .matches(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate(),
+ logicalAggregate()
+ )
+ )
+ );
+ }
+
+ @Test
+ void testSingleJoinBothSum() {
+ Alias leftSum = new Sum(scan1.getOutput().get(1)).alias("leftSum");
+ Alias rightSum = new Sum(scan2.getOutput().get(1)).alias("rightSum");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0),
ImmutableList.of(scan1.getOutput().get(0), leftSum, rightSum))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushDownAggThroughJoin())
.matches(
logicalAggregate(
logicalJoin(
@@ -99,7 +120,7 @@ class PushDownSumThroughJoinTest implements
MemoPatternMatchSupported {
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
- .applyTopDown(new PushDownSumThroughJoin())
+ .applyTopDown(new PushDownAggThroughJoin())
.matches(
logicalAggregate(
logicalJoin(
diff --git
a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.out
b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.out
index da05df5419d..106d8882079 100644
---
a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.out
+++
b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.out
@@ -176,8 +176,10 @@ PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id) and (t1.name =
t2.name)) otherCondition=()
---------PhysicalOlapScan[sum_t]
---------PhysicalOlapScan[sum_t]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[sum_t]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[sum_t]
-- !groupby_pushdown_with_where_clause --
PhysicalResultSink
@@ -195,8 +197,10 @@ PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=()
---------PhysicalOlapScan[sum_t]
---------PhysicalOlapScan[sum_t]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[sum_t]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[sum_t]
-- !groupby_pushdown_with_order_by_limit --
PhysicalResultSink
diff --git
a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy
b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy
index 58d50b3add4..249e7af4bb4 100644
--- a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy
+++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy
@@ -22,8 +22,7 @@ suite("eager_aggregate_basic") {
sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"
sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side"
- sql "SET ENABLE_NEREIDS_RULES=push_down_sum_through_join"
- sql "SET ENABLE_NEREIDS_RULES=push_down_count_through_join"
+ sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join"
sql """
DROP TABLE IF EXISTS shunt_log_com_dd_library;
diff --git
a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join.groovy
b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join.groovy
index f5f4bf53b45..37cd6000941 100644
---
a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join.groovy
+++
b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join.groovy
@@ -48,7 +48,7 @@ suite("push_down_count_through_join") {
sql "insert into count_t values (9, 3, null)"
sql "insert into count_t values (10, null, null)"
- sql "SET ENABLE_NEREIDS_RULES=push_down_count_through_join"
+ sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join"
qt_groupby_pushdown_basic """
explain shape plan select count(t1.score) from count_t t1, count_t t2
where t1.id = t2.id group by t1.name;
diff --git
a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.groovy
b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.groovy
index e51899dcc3d..95736d26475 100644
---
a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.groovy
+++
b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.groovy
@@ -48,7 +48,7 @@ suite("push_down_sum_through_join") {
sql "insert into sum_t values (9, 3, null)"
sql "insert into sum_t values (10, null, null)"
- sql "SET ENABLE_NEREIDS_RULES=push_down_sum_through_join"
+ sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join"
qt_groupby_pushdown_basic """
explain shape plan select sum(t1.score) from sum_t t1, sum_t t2 where
t1.id = t2.id group by t1.name;
@@ -131,7 +131,7 @@ suite("push_down_sum_through_join") {
"""
qt_groupby_pushdown_varied_aggregates """
- explain shape plan select sum(t1.score), avg(t1.id), count(t2.name)
from sum_t t1 join sum_t t2 on t1.id = t2.id group by t1.name;
+ explain shape plan select sum(t1.score), count(t2.name) from sum_t t1
join sum_t t2 on t1.id = t2.id group by t1.name;
"""
qt_groupby_pushdown_with_order_by_limit """
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]