This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 9c91e80b0c [feature](Nereids): pushdown COUNT(*) through join (#22545)
9c91e80b0c is described below
commit 9c91e80b0c64e9b9266ed85c262e3bfe59dcbedf
Author: jakevin <[email protected]>
AuthorDate: Mon Aug 7 12:53:27 2023 +0800
[feature](Nereids): pushdown COUNT(*) through join (#22545)
---
.../rules/rewrite/PushdownCountThroughJoin.java | 88 +++++++++++++++-------
.../rules/rewrite/PushdownMinMaxThroughJoin.java | 8 +-
.../rules/rewrite/PushdownSumThroughJoin.java | 5 +-
.../rewrite/PushdownCountThroughJoinTest.java | 42 +++++++++++
4 files changed, 113 insertions(+), 30 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java
index 6d5bf8b75f..4f0f63e547 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java
@@ -43,8 +43,28 @@ import java.util.Map;
import java.util.Set;
/**
- * Count(*)
- * Count(col)
+ * TODO: distinct | just push one level
+ * Support Pushdown Count(*)/Count(col).
+ * Count(col) -> Sum( cnt * cntStar )
+ * Count(*) -> Sum( leftCntStar * rightCntStar )
+ * <p>
+ * Related paper "Eager aggregation and lazy aggregation".
+ * <pre>
+ * aggregate: count(x)
+ * |
+ * join
+ * | \
+ * | *
+ * (x)
+ * ->
+ * aggregate: Sum( cnt * cntStar )
+ * |
+ * join
+ * | \
+ * | aggregate: count(*) as cntStar
+ * aggregate: count(x) as cnt
+ * </pre>
+ * Notice: when Count(*) exists, group by mustn't be empty.
*/
public class PushdownCountThroughJoin implements RewriteRuleFactory {
@Override
@@ -57,7 +77,8 @@ public class PushdownCountThroughJoin implements
RewriteRuleFactory {
.when(agg -> {
Set<AggregateFunction> funcs =
agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Count &&
f.child(0) instanceof Slot);
+ .allMatch(f -> f instanceof Count &&
!f.isDistinct()
+ && (((Count) f).isCountStar() ||
f.child(0) instanceof Slot));
})
.then(agg -> pushCount(agg, agg.child(),
ImmutableList.of()))
.toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN),
@@ -69,7 +90,8 @@ public class PushdownCountThroughJoin implements
RewriteRuleFactory {
.when(agg -> {
Set<AggregateFunction> funcs =
agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Count &&
f.child(0) instanceof Slot);
+ .allMatch(f -> f instanceof Count &&
!f.isDistinct()
+ && (((Count) f).isCountStar() ||
f.child(0) instanceof Slot));
})
.then(agg -> pushCount(agg, agg.child().child(),
agg.child().getProjects()))
.toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN)
@@ -83,23 +105,23 @@ public class PushdownCountThroughJoin implements
RewriteRuleFactory {
List<Count> leftCounts = new ArrayList<>();
List<Count> rightCounts = new ArrayList<>();
+ List<Count> countStars = new ArrayList<>();
for (AggregateFunction f : agg.getAggregateFunctions()) {
Count count = (Count) f;
if (count.isCountStar()) {
- // TODO: handle Count(*)
- return null;
- }
- Slot slot = (Slot) count.child(0);
- if (leftOutput.contains(slot)) {
- leftCounts.add(count);
- } else if (rightOutput.contains(slot)) {
- rightCounts.add(count);
+ countStars.add(count);
} else {
- throw new IllegalStateException("Slot " + slot + " not found
in join output");
+ Slot slot = (Slot) count.child(0);
+ if (leftOutput.contains(slot)) {
+ leftCounts.add(count);
+ } else if (rightOutput.contains(slot)) {
+ rightCounts.add(count);
+ } else {
+ throw new IllegalStateException("Slot " + slot + " not
found in join output");
+ }
}
}
- // TODO: empty GroupBy
Set<Slot> leftGroupBy = new HashSet<>();
Set<Slot> rightGroupBy = new HashSet<>();
for (Expression e : agg.getGroupByExpressions()) {
@@ -112,6 +134,11 @@ public class PushdownCountThroughJoin implements
RewriteRuleFactory {
return null;
}
}
+
+ if (!countStars.isEmpty() && leftGroupBy.isEmpty() &&
rightGroupBy.isEmpty()) {
+ return null;
+ }
+
join.getHashJoinConjuncts().forEach(e ->
e.getInputSlots().forEach(slot -> {
if (leftOutput.contains(slot)) {
leftGroupBy.add(slot);
@@ -133,7 +160,7 @@ public class PushdownCountThroughJoin implements
RewriteRuleFactory {
leftCntSlotToOutput.put((Slot) func.child(0), alias);
leftCntAggOutputBuilder.add(alias);
});
- if (!rightCounts.isEmpty()) {
+ if (!rightCounts.isEmpty() || !countStars.isEmpty()) {
leftCnt = new Count().alias("leftCntStar");
leftCntAggOutputBuilder.add(leftCnt);
}
@@ -150,7 +177,7 @@ public class PushdownCountThroughJoin implements
RewriteRuleFactory {
rightCntAggOutputBuilder.add(alias);
});
- if (!leftCounts.isEmpty()) {
+ if (!leftCounts.isEmpty() || !countStars.isEmpty()) {
rightCnt = new Count().alias("rightCntStar");
rightCntAggOutputBuilder.add(rightCnt);
}
@@ -160,22 +187,31 @@ public class PushdownCountThroughJoin implements
RewriteRuleFactory {
Plan newJoin = join.withChildren(leftCntAgg, rightCntAgg);
// top Sum agg
- // count(slot) -> sum( count(slot) * cnt )
+ // count(slot) -> sum( count(slot) * cntStar )
+ // count(*) -> sum( leftCntStar * leftCntStar )
List<NamedExpression> newOutputExprs = new ArrayList<>();
for (NamedExpression ne : agg.getOutputExpressions()) {
if (ne instanceof Alias && ((Alias) ne).child() instanceof Count) {
Count oldTopCnt = (Count) ((Alias) ne).child();
- Slot slot = (Slot) oldTopCnt.child(0);
- if (leftCntSlotToOutput.containsKey(slot)) {
- Preconditions.checkState(rightCnt != null);
- Expression expr = new Sum(new
Multiply(leftCntSlotToOutput.get(slot).toSlot(), rightCnt.toSlot()));
- newOutputExprs.add((NamedExpression)
ne.withChildren(expr));
- } else if (rightCntSlotToOutput.containsKey(slot)) {
- Preconditions.checkState(leftCnt != null);
- Expression expr = new Sum(new
Multiply(rightCntSlotToOutput.get(slot).toSlot(), leftCnt.toSlot()));
+ if (oldTopCnt.isCountStar()) {
+ Preconditions.checkState(rightCnt != null && leftCnt !=
null);
+ Expression expr = new Sum(new Multiply(leftCnt.toSlot(),
rightCnt.toSlot()));
newOutputExprs.add((NamedExpression)
ne.withChildren(expr));
} else {
- throw new IllegalStateException("Slot " + slot + " not
found in join output");
+ Slot slot = (Slot) oldTopCnt.child(0);
+ if (leftCntSlotToOutput.containsKey(slot)) {
+ Preconditions.checkState(rightCnt != null);
+ Expression expr = new Sum(
+ new
Multiply(leftCntSlotToOutput.get(slot).toSlot(), rightCnt.toSlot()));
+ newOutputExprs.add((NamedExpression)
ne.withChildren(expr));
+ } else if (rightCntSlotToOutput.containsKey(slot)) {
+ Preconditions.checkState(leftCnt != null);
+ Expression expr = new Sum(
+ new
Multiply(rightCntSlotToOutput.get(slot).toSlot(), leftCnt.toSlot()));
+ newOutputExprs.add((NamedExpression)
ne.withChildren(expr));
+ } else {
+ throw new IllegalStateException("Slot " + slot + " not
found in join output");
+ }
}
} else {
newOutputExprs.add(ne);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java
index bd61f4f0ac..9b728ad141 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java
@@ -42,6 +42,7 @@ import java.util.Map;
import java.util.Set;
/**
+ * TODO: distinct
* Related paper "Eager aggregation and lazy aggregation".
* <pre>
* aggregate: Min/Max(x)
@@ -69,7 +70,8 @@ public class PushdownMinMaxThroughJoin implements
RewriteRuleFactory {
.when(agg -> {
Set<AggregateFunction> funcs =
agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> (f instanceof Min || f
instanceof Max) && f.child(0) instanceof Slot);
+ .allMatch(f -> (f instanceof Min || f
instanceof Max) && !f.isDistinct() && f.child(
+ 0) instanceof Slot);
})
.then(agg -> pushMinMax(agg, agg.child(),
ImmutableList.of()))
.toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN),
@@ -80,7 +82,9 @@ public class PushdownMinMaxThroughJoin implements
RewriteRuleFactory {
.when(agg -> {
Set<AggregateFunction> funcs =
agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> (f instanceof Min || f
instanceof Max) && f.child(0) instanceof Slot);
+ .allMatch(
+ f -> (f instanceof Min || f
instanceof Max) && !f.isDistinct() && f.child(
+ 0) instanceof Slot);
})
.then(agg -> pushMinMax(agg, agg.child().child(),
agg.child().getProjects()))
.toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
index 1319200220..81e655ab85 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
@@ -42,6 +42,7 @@ import java.util.Map;
import java.util.Set;
/**
+ * TODO: distinct
* Related paper "Eager aggregation and lazy aggregation".
* <pre>
* aggregate: Sum(x)
@@ -69,7 +70,7 @@ public class PushdownSumThroughJoin implements
RewriteRuleFactory {
.when(agg -> {
Set<AggregateFunction> funcs =
agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Sum &&
f.child(0) instanceof Slot);
+ .allMatch(f -> f instanceof Sum &&
!f.isDistinct() && f.child(0) instanceof Slot);
})
.then(agg -> pushSum(agg, agg.child(),
ImmutableList.of()))
.toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN),
@@ -80,7 +81,7 @@ public class PushdownSumThroughJoin implements
RewriteRuleFactory {
.when(agg -> {
Set<AggregateFunction> funcs =
agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
- .allMatch(f -> f instanceof Sum &&
f.child(0) instanceof Slot);
+ .allMatch(f -> f instanceof Sum &&
!f.isDistinct() && f.child(0) instanceof Slot);
})
.then(agg -> pushSum(agg, agg.child().child(),
agg.child().getProjects()))
.toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN)
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
index 1a39a7c5ff..10f814e9ba 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
@@ -65,4 +65,46 @@ class PushdownCountThroughJoinTest implements
MemoPatternMatchSupported {
.printlnTree();
}
+ @Test
+ void testSingleCountStar() {
+ Alias count = new Count().alias("countStar");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0),
ImmutableList.of(scan1.getOutput().get(0), count))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushdownCountThroughJoin())
+ .printlnTree();
+ }
+
+ @Test
+ void testSingleCountStarEmptyGroupBy() {
+ Alias count = new Count().alias("countStar");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(),
ImmutableList.of(count))
+ .build();
+
+ // shouldn't rewrite.
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushdownCountThroughJoin())
+ .printlnTree();
+ }
+
+ @Test
+ void testBothSideCountAndCountStar() {
+ Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt");
+ Alias rightCnt = new Count(scan2.getOutput().get(0)).alias("rightCnt");
+ Alias countStar = new Count().alias("countStar");
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0),
+ ImmutableList.of(scan1.getOutput().get(0), leftCnt,
rightCnt, countStar))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new PushdownCountThroughJoin())
+ .printlnTree();
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]