This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c91e80b0c [feature](Nereids): pushdown COUNT(*) through join (#22545)
9c91e80b0c is described below

commit 9c91e80b0c64e9b9266ed85c262e3bfe59dcbedf
Author: jakevin <[email protected]>
AuthorDate: Mon Aug 7 12:53:27 2023 +0800

    [feature](Nereids): pushdown COUNT(*) through join (#22545)
---
 .../rules/rewrite/PushdownCountThroughJoin.java    | 88 +++++++++++++++-------
 .../rules/rewrite/PushdownMinMaxThroughJoin.java   |  8 +-
 .../rules/rewrite/PushdownSumThroughJoin.java      |  5 +-
 .../rewrite/PushdownCountThroughJoinTest.java      | 42 +++++++++++
 4 files changed, 113 insertions(+), 30 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java
index 6d5bf8b75f..4f0f63e547 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoin.java
@@ -43,8 +43,28 @@ import java.util.Map;
 import java.util.Set;
 
 /**
- * Count(*)
- * Count(col)
+ * TODO: distinct | just push one level
+ * Support Pushdown Count(*)/Count(col).
+ * Count(col) -> Sum( cnt * cntStar )
+ * Count(*) -> Sum( leftCntStar * rightCntStar )
+ * <p>
+ * Related paper "Eager aggregation and lazy aggregation".
+ * <pre>
+ *  aggregate: count(x)
+ *  |
+ *  join
+ *  |   \
+ *  |    *
+ *  (x)
+ *  ->
+ *  aggregate: Sum( cnt * cntStar )
+ *  |
+ *  join
+ *  |   \
+ *  |    aggregate: count(*) as cntStar
+ *  aggregate: count(x) as cnt
+ *  </pre>
+ * Notice: when Count(*) exists, group by mustn't be empty.
  */
 public class PushdownCountThroughJoin implements RewriteRuleFactory {
     @Override
@@ -57,7 +77,8 @@ public class PushdownCountThroughJoin implements 
RewriteRuleFactory {
                         .when(agg -> {
                             Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
-                                    .allMatch(f -> f instanceof Count && 
f.child(0) instanceof Slot);
+                                    .allMatch(f -> f instanceof Count && 
!f.isDistinct()
+                                            && (((Count) f).isCountStar() || 
f.child(0) instanceof Slot));
                         })
                         .then(agg -> pushCount(agg, agg.child(), 
ImmutableList.of()))
                         .toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN),
@@ -69,7 +90,8 @@ public class PushdownCountThroughJoin implements 
RewriteRuleFactory {
                         .when(agg -> {
                             Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
-                                    .allMatch(f -> f instanceof Count && 
f.child(0) instanceof Slot);
+                                    .allMatch(f -> f instanceof Count && 
!f.isDistinct()
+                                            && (((Count) f).isCountStar() || 
f.child(0) instanceof Slot));
                         })
                         .then(agg -> pushCount(agg, agg.child().child(), 
agg.child().getProjects()))
                         .toRule(RuleType.PUSHDOWN_COUNT_THROUGH_JOIN)
@@ -83,23 +105,23 @@ public class PushdownCountThroughJoin implements 
RewriteRuleFactory {
 
         List<Count> leftCounts = new ArrayList<>();
         List<Count> rightCounts = new ArrayList<>();
+        List<Count> countStars = new ArrayList<>();
         for (AggregateFunction f : agg.getAggregateFunctions()) {
             Count count = (Count) f;
             if (count.isCountStar()) {
-                // TODO: handle Count(*)
-                return null;
-            }
-            Slot slot = (Slot) count.child(0);
-            if (leftOutput.contains(slot)) {
-                leftCounts.add(count);
-            } else if (rightOutput.contains(slot)) {
-                rightCounts.add(count);
+                countStars.add(count);
             } else {
-                throw new IllegalStateException("Slot " + slot + " not found 
in join output");
+                Slot slot = (Slot) count.child(0);
+                if (leftOutput.contains(slot)) {
+                    leftCounts.add(count);
+                } else if (rightOutput.contains(slot)) {
+                    rightCounts.add(count);
+                } else {
+                    throw new IllegalStateException("Slot " + slot + " not 
found in join output");
+                }
             }
         }
 
-        // TODO: empty GroupBy
         Set<Slot> leftGroupBy = new HashSet<>();
         Set<Slot> rightGroupBy = new HashSet<>();
         for (Expression e : agg.getGroupByExpressions()) {
@@ -112,6 +134,11 @@ public class PushdownCountThroughJoin implements 
RewriteRuleFactory {
                 return null;
             }
         }
+
+        if (!countStars.isEmpty() && leftGroupBy.isEmpty() && 
rightGroupBy.isEmpty()) {
+            return null;
+        }
+
         join.getHashJoinConjuncts().forEach(e -> 
e.getInputSlots().forEach(slot -> {
             if (leftOutput.contains(slot)) {
                 leftGroupBy.add(slot);
@@ -133,7 +160,7 @@ public class PushdownCountThroughJoin implements 
RewriteRuleFactory {
             leftCntSlotToOutput.put((Slot) func.child(0), alias);
             leftCntAggOutputBuilder.add(alias);
         });
-        if (!rightCounts.isEmpty()) {
+        if (!rightCounts.isEmpty() || !countStars.isEmpty()) {
             leftCnt = new Count().alias("leftCntStar");
             leftCntAggOutputBuilder.add(leftCnt);
         }
@@ -150,7 +177,7 @@ public class PushdownCountThroughJoin implements 
RewriteRuleFactory {
             rightCntAggOutputBuilder.add(alias);
         });
 
-        if (!leftCounts.isEmpty()) {
+        if (!leftCounts.isEmpty() || !countStars.isEmpty()) {
             rightCnt = new Count().alias("rightCntStar");
             rightCntAggOutputBuilder.add(rightCnt);
         }
@@ -160,22 +187,31 @@ public class PushdownCountThroughJoin implements 
RewriteRuleFactory {
         Plan newJoin = join.withChildren(leftCntAgg, rightCntAgg);
 
         // top Sum agg
-        // count(slot) -> sum( count(slot) * cnt )
+        // count(slot) -> sum( count(slot) * cntStar )
+        // count(*) -> sum( leftCntStar * leftCntStar )
         List<NamedExpression> newOutputExprs = new ArrayList<>();
         for (NamedExpression ne : agg.getOutputExpressions()) {
             if (ne instanceof Alias && ((Alias) ne).child() instanceof Count) {
                 Count oldTopCnt = (Count) ((Alias) ne).child();
-                Slot slot = (Slot) oldTopCnt.child(0);
-                if (leftCntSlotToOutput.containsKey(slot)) {
-                    Preconditions.checkState(rightCnt != null);
-                    Expression expr = new Sum(new 
Multiply(leftCntSlotToOutput.get(slot).toSlot(), rightCnt.toSlot()));
-                    newOutputExprs.add((NamedExpression) 
ne.withChildren(expr));
-                } else if (rightCntSlotToOutput.containsKey(slot)) {
-                    Preconditions.checkState(leftCnt != null);
-                    Expression expr = new Sum(new 
Multiply(rightCntSlotToOutput.get(slot).toSlot(), leftCnt.toSlot()));
+                if (oldTopCnt.isCountStar()) {
+                    Preconditions.checkState(rightCnt != null && leftCnt != 
null);
+                    Expression expr = new Sum(new Multiply(leftCnt.toSlot(), 
rightCnt.toSlot()));
                     newOutputExprs.add((NamedExpression) 
ne.withChildren(expr));
                 } else {
-                    throw new IllegalStateException("Slot " + slot + " not 
found in join output");
+                    Slot slot = (Slot) oldTopCnt.child(0);
+                    if (leftCntSlotToOutput.containsKey(slot)) {
+                        Preconditions.checkState(rightCnt != null);
+                        Expression expr = new Sum(
+                                new 
Multiply(leftCntSlotToOutput.get(slot).toSlot(), rightCnt.toSlot()));
+                        newOutputExprs.add((NamedExpression) 
ne.withChildren(expr));
+                    } else if (rightCntSlotToOutput.containsKey(slot)) {
+                        Preconditions.checkState(leftCnt != null);
+                        Expression expr = new Sum(
+                                new 
Multiply(rightCntSlotToOutput.get(slot).toSlot(), leftCnt.toSlot()));
+                        newOutputExprs.add((NamedExpression) 
ne.withChildren(expr));
+                    } else {
+                        throw new IllegalStateException("Slot " + slot + " not 
found in join output");
+                    }
                 }
             } else {
                 newOutputExprs.add(ne);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java
index bd61f4f0ac..9b728ad141 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownMinMaxThroughJoin.java
@@ -42,6 +42,7 @@ import java.util.Map;
 import java.util.Set;
 
 /**
+ * TODO: distinct
  * Related paper "Eager aggregation and lazy aggregation".
  * <pre>
  * aggregate: Min/Max(x)
@@ -69,7 +70,8 @@ public class PushdownMinMaxThroughJoin implements 
RewriteRuleFactory {
                         .when(agg -> {
                             Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
-                                .allMatch(f -> (f instanceof Min || f 
instanceof Max) && f.child(0) instanceof Slot);
+                                    .allMatch(f -> (f instanceof Min || f 
instanceof Max) && !f.isDistinct() && f.child(
+                                            0) instanceof Slot);
                         })
                         .then(agg -> pushMinMax(agg, agg.child(), 
ImmutableList.of()))
                         .toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN),
@@ -80,7 +82,9 @@ public class PushdownMinMaxThroughJoin implements 
RewriteRuleFactory {
                         .when(agg -> {
                             Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
-                                .allMatch(f -> (f instanceof Min || f 
instanceof Max) && f.child(0) instanceof Slot);
+                                    .allMatch(
+                                            f -> (f instanceof Min || f 
instanceof Max) && !f.isDistinct() && f.child(
+                                                    0) instanceof Slot);
                         })
                         .then(agg -> pushMinMax(agg, agg.child().child(), 
agg.child().getProjects()))
                         .toRule(RuleType.PUSHDOWN_MIN_MAX_THROUGH_JOIN)
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
index 1319200220..81e655ab85 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownSumThroughJoin.java
@@ -42,6 +42,7 @@ import java.util.Map;
 import java.util.Set;
 
 /**
+ * TODO: distinct
  * Related paper "Eager aggregation and lazy aggregation".
  * <pre>
  * aggregate: Sum(x)
@@ -69,7 +70,7 @@ public class PushdownSumThroughJoin implements 
RewriteRuleFactory {
                         .when(agg -> {
                             Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
-                                    .allMatch(f -> f instanceof Sum && 
f.child(0) instanceof Slot);
+                                    .allMatch(f -> f instanceof Sum && 
!f.isDistinct() && f.child(0) instanceof Slot);
                         })
                         .then(agg -> pushSum(agg, agg.child(), 
ImmutableList.of()))
                         .toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN),
@@ -80,7 +81,7 @@ public class PushdownSumThroughJoin implements 
RewriteRuleFactory {
                         .when(agg -> {
                             Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
-                                    .allMatch(f -> f instanceof Sum && 
f.child(0) instanceof Slot);
+                                    .allMatch(f -> f instanceof Sum && 
!f.isDistinct() && f.child(0) instanceof Slot);
                         })
                         .then(agg -> pushSum(agg, agg.child().child(), 
agg.child().getProjects()))
                         .toRule(RuleType.PUSHDOWN_SUM_THROUGH_JOIN)
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
index 1a39a7c5ff..10f814e9ba 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownCountThroughJoinTest.java
@@ -65,4 +65,46 @@ class PushdownCountThroughJoinTest implements 
MemoPatternMatchSupported {
                 .printlnTree();
     }
 
+    @Test
+    void testSingleCountStar() {
+        Alias count = new Count().alias("countStar");
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .aggGroupUsingIndex(ImmutableList.of(0), 
ImmutableList.of(scan1.getOutput().get(0), count))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new PushdownCountThroughJoin())
+                .printlnTree();
+    }
+
+    @Test
+    void testSingleCountStarEmptyGroupBy() {
+        Alias count = new Count().alias("countStar");
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .aggGroupUsingIndex(ImmutableList.of(), 
ImmutableList.of(count))
+                .build();
+
+        // shouldn't rewrite.
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new PushdownCountThroughJoin())
+                .printlnTree();
+    }
+
+    @Test
+    void testBothSideCountAndCountStar() {
+        Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt");
+        Alias rightCnt = new Count(scan2.getOutput().get(0)).alias("rightCnt");
+        Alias countStar = new Count().alias("countStar");
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .aggGroupUsingIndex(ImmutableList.of(0),
+                        ImmutableList.of(scan1.getOutput().get(0), leftCnt, 
rightCnt, countStar))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new PushdownCountThroughJoin())
+                .printlnTree();
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to