This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new e4500e86a59 [feature](Nereids): push down topN through join #24720 
(#26634)
e4500e86a59 is described below

commit e4500e86a59189424b41c266c6802a3ba5a42c65
Author: jakevin <[email protected]>
AuthorDate: Thu Nov 9 18:49:22 2023 +0800

    [feature](Nereids): push down topN through join #24720 (#26634)
    
    Push TopN through Join.
    
    JoinType just can be left/right outer join or cross join, because data of 
their one child can't be filtered.
    
    new TopN is (original limit + original offset, 0) as limit and offset.
    
    (cherry picked from commit 3c9ff7af399eacf2ac77e7fd314ded740a385aae)
---
 .../doris/nereids/jobs/executor/Rewriter.java      |  35 +++---
 .../org/apache/doris/nereids/rules/RuleType.java   |  12 ++-
 .../nereids/rules/rewrite/EliminateLimit.java      |  31 ++++--
 ...{ReplaceLimitNode.java => LimitSortToTopN.java} |  19 +---
 .../rewrite/PushdownFilterThroughProject.java      |  39 ++++---
 .../rules/rewrite/PushdownProjectThroughLimit.java |   6 +-
 .../rules/rewrite/PushdownTopNThroughJoin.java     | 118 +++++++++++++++++++++
 .../rules/rewrite/PushdownTopNThroughWindow.java   |   4 +-
 .../apache/doris/nereids/util/ExpressionUtils.java |   4 -
 .../nereids/rules/rewrite/PushdownLimitTest.java   |  26 ++---
 .../doris/nereids/util/LogicalPlanBuilder.java     |   9 ++
 11 files changed, 217 insertions(+), 86 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 016947ee553..d3c17e5e6c7 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -71,6 +71,7 @@ import org.apache.doris.nereids.rules.rewrite.InferAggNotNull;
 import org.apache.doris.nereids.rules.rewrite.InferFilterNotNull;
 import org.apache.doris.nereids.rules.rewrite.InferJoinNotNull;
 import org.apache.doris.nereids.rules.rewrite.InferPredicates;
+import org.apache.doris.nereids.rules.rewrite.LimitSortToTopN;
 import org.apache.doris.nereids.rules.rewrite.MergeFilters;
 import org.apache.doris.nereids.rules.rewrite.MergeOneRowRelationIntoUnion;
 import org.apache.doris.nereids.rules.rewrite.MergeProjects;
@@ -89,9 +90,9 @@ import 
org.apache.doris.nereids.rules.rewrite.PushProjectIntoOneRowRelation;
 import org.apache.doris.nereids.rules.rewrite.PushProjectThroughUnion;
 import org.apache.doris.nereids.rules.rewrite.PushdownFilterThroughProject;
 import org.apache.doris.nereids.rules.rewrite.PushdownLimit;
+import org.apache.doris.nereids.rules.rewrite.PushdownTopNThroughJoin;
 import org.apache.doris.nereids.rules.rewrite.PushdownTopNThroughWindow;
 import org.apache.doris.nereids.rules.rewrite.ReorderJoin;
-import org.apache.doris.nereids.rules.rewrite.ReplaceLimitNode;
 import org.apache.doris.nereids.rules.rewrite.RewriteCteChildren;
 import org.apache.doris.nereids.rules.rewrite.SemiJoinCommute;
 import org.apache.doris.nereids.rules.rewrite.SimplifyAggGroupBy;
@@ -263,14 +264,14 @@ public class Rewriter extends AbstractBatchJobExecutor {
             // ),
 
             topic("Limit optimization",
-                    topDown(
-                            // TODO: the logical plan should not contains any 
phase information,
-                            //       we should refactor like 
AggregateStrategies, e.g. LimitStrategies,
-                            //       generate one PhysicalLimit if current 
distribution is gather or two
-                            //       PhysicalLimits with gather exchange
-                            new ReplaceLimitNode(),
-                            new SplitLimit(),
-                            new PushdownLimit(),
+                    // TODO: the logical plan should not contains any phase 
information,
+                    //       we should refactor like AggregateStrategies, e.g. 
LimitStrategies,
+                    //       generate one PhysicalLimit if current 
distribution is gather or two
+                    //       PhysicalLimits with gather exchange
+                    topDown(new LimitSortToTopN()),
+                    topDown(new SplitLimit()),
+                    topDown(new PushdownLimit(),
+                            new PushdownTopNThroughJoin(),
                             new PushdownTopNThroughWindow(),
                             new CreatePartitionTopNFromWindow()
                     )
@@ -301,15 +302,15 @@ public class Rewriter extends AbstractBatchJobExecutor {
             topic("topn optimize",
                     topDown(new DeferMaterializeTopNResult())
             ),
+            topic("eliminate",
+                    // SORT_PRUNING should be applied after mergeLimit
+                    custom(RuleType.ELIMINATE_SORT, EliminateSort::new),
+                    bottomUp(new EliminateEmptyRelation())
+            ),
             // this rule batch must keep at the end of rewrite to do some plan 
check
             topic("Final rewrite and check",
                     custom(RuleType.ENSURE_PROJECT_ON_TOP_JOIN, 
EnsureProjectOnTopJoin::new),
-                    topDown(
-                            new PushdownFilterThroughProject(),
-                            new MergeProjects()
-                    ),
-                    // SORT_PRUNING should be applied after mergeLimit
-                    custom(RuleType.ELIMINATE_SORT, EliminateSort::new),
+                    topDown(new PushdownFilterThroughProject(), new 
MergeProjects()),
                     custom(RuleType.ADJUST_CONJUNCTS_RETURN_TYPE, 
AdjustConjunctsReturnType::new),
                     bottomUp(
                             new 
ExpressionRewrite(CheckLegalityAfterRewrite.INSTANCE),
@@ -323,10 +324,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
                             new CollectFilterAboveConsumer(),
                             new CollectProjectAboveConsumer()
                     )
-            ),
-
-            topic("eliminate empty relation",
-                bottomUp(new EliminateEmptyRelation())
             )
     );
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 362a2312fd9..dcf13637c96 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -162,8 +162,9 @@ public enum RuleType {
     COLUMN_PRUNING(RuleTypeClass.REWRITE),
     ELIMINATE_SORT(RuleTypeClass.REWRITE),
 
-    PUSHDOWN_TOP_N_THROUGH_PROJECTION_WINDOW(RuleTypeClass.REWRITE),
-    PUSHDOWN_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE),
+    PUSHDOWN_MIN_MAX_THROUGH_JOIN(RuleTypeClass.REWRITE),
+    PUSHDOWN_SUM_THROUGH_JOIN(RuleTypeClass.REWRITE),
+    PUSHDOWN_COUNT_THROUGH_JOIN(RuleTypeClass.REWRITE),
 
     TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN(RuleTypeClass.REWRITE),
     TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN_PROJECT(RuleTypeClass.REWRITE),
@@ -243,7 +244,12 @@ public enum RuleType {
     PUSH_LIMIT_THROUGH_PROJECT_WINDOW(RuleTypeClass.REWRITE),
     PUSH_LIMIT_THROUGH_UNION(RuleTypeClass.REWRITE),
     PUSH_LIMIT_THROUGH_WINDOW(RuleTypeClass.REWRITE),
-    PUSH_LIMIT_INTO_SORT(RuleTypeClass.REWRITE),
+    LIMIT_SORT_TO_TOP_N(RuleTypeClass.REWRITE),
+    // topN push down
+    PUSH_TOP_N_THROUGH_JOIN(RuleTypeClass.REWRITE),
+    PUSH_TOP_N_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
+    PUSH_TOP_N_THROUGH_PROJECT_WINDOW(RuleTypeClass.REWRITE),
+    PUSH_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE),
     // adjust nullable
     ADJUST_NULLABLE(RuleTypeClass.REWRITE),
     ADJUST_CONJUNCTS_RETURN_TYPE(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateLimit.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateLimit.java
index 9cc19e47d8b..8fbfc13934d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateLimit.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateLimit.java
@@ -19,18 +19,35 @@ package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.UnaryNode;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
 import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
 
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
 /**
  * Eliminate limit = 0.
  */
-public class EliminateLimit extends OneRewriteRuleFactory {
+public class EliminateLimit implements RewriteRuleFactory {
+
     @Override
-    public Rule build() {
-        return logicalLimit()
-                .when(limit -> limit.getLimit() == 0)
-                .thenApply(ctx -> new 
LogicalEmptyRelation(ctx.statementContext.getNextRelationId(),
-                        ctx.root.getOutput()))
-                .toRule(RuleType.ELIMINATE_LIMIT);
+    public List<Rule> buildRules() {
+        return ImmutableList.of(
+                logicalLimit()
+                        .when(limit -> limit.getLimit() == 0)
+                        .thenApply(ctx -> new 
LogicalEmptyRelation(ctx.statementContext.getNextRelationId(),
+                                ctx.root.getOutput()))
+                        .toRule(RuleType.ELIMINATE_LIMIT),
+                logicalLimit(logicalOneRowRelation())
+                        .then(limit -> limit.getLimit() > 0 && 
limit.getOffset() == 0
+                                ? limit.child() : new 
LogicalEmptyRelation(StatementScopeIdGenerator.newRelationId(),
+                                limit.child().getOutput()))
+                        .toRule(RuleType.ELIMINATE_LIMIT_ON_ONE_ROW_RELATION),
+                logicalLimit(logicalEmptyRelation())
+                        .then(UnaryNode::child)
+                        .toRule(RuleType.ELIMINATE_LIMIT_ON_EMPTY_RELATION)
+        );
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ReplaceLimitNode.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitSortToTopN.java
similarity index 70%
rename from 
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ReplaceLimitNode.java
rename to 
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitSortToTopN.java
index 9ec2e69b26e..37d08d887a5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ReplaceLimitNode.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitSortToTopN.java
@@ -19,10 +19,7 @@ package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.UnaryNode;
-import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
 import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
 import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
@@ -35,7 +32,7 @@ import java.util.List;
 /**
  * rule to eliminate limit node by replace to other nodes.
  */
-public class ReplaceLimitNode implements RewriteRuleFactory {
+public class LimitSortToTopN implements RewriteRuleFactory {
     @Override
     public List<Rule> buildRules() {
         return ImmutableList.of(
@@ -47,8 +44,8 @@ public class ReplaceLimitNode implements RewriteRuleFactory {
                                     limit.getLimit(),
                                     limit.getOffset(),
                                     sort.child(0));
-                        }).toRule(RuleType.PUSH_LIMIT_INTO_SORT),
-                //limit->proj->sort ==> proj->topN
+                        }).toRule(RuleType.LIMIT_SORT_TO_TOP_N),
+                // limit -> proj -> sort ==> proj -> topN
                 logicalLimit(logicalProject(logicalSort()))
                         .then(limit -> {
                             LogicalProject project = limit.child();
@@ -58,15 +55,7 @@ public class ReplaceLimitNode implements RewriteRuleFactory {
                                     limit.getOffset(),
                                     sort.child(0));
                             return 
project.withChildren(Lists.newArrayList(topN));
-                        }).toRule(RuleType.PUSH_LIMIT_INTO_SORT),
-                logicalLimit(logicalOneRowRelation())
-                        .then(limit -> limit.getLimit() > 0 && 
limit.getOffset() == 0
-                                ? limit.child() : new 
LogicalEmptyRelation(StatementScopeIdGenerator.newRelationId(),
-                                limit.child().getOutput()))
-                        .toRule(RuleType.ELIMINATE_LIMIT_ON_ONE_ROW_RELATION),
-                logicalLimit(logicalEmptyRelation())
-                        .then(UnaryNode::child)
-                        .toRule(RuleType.ELIMINATE_LIMIT_ON_EMPTY_RELATION)
+                        }).toRule(RuleType.LIMIT_SORT_TO_TOP_N)
         );
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughProject.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughProject.java
index 386f4a01198..a0d64b1a609 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughProject.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughProject.java
@@ -41,27 +41,26 @@ public class PushdownFilterThroughProject implements 
RewriteRuleFactory {
     @Override
     public List<Rule> buildRules() {
         return ImmutableList.of(
-            
RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT.build(logicalFilter(logicalProject())
-                    .whenNot(filter -> 
filter.child().getProjects().stream().anyMatch(
-                            expr -> 
expr.anyMatch(WindowExpression.class::isInstance)))
-                    
.then(PushdownFilterThroughProject::pushdownFilterThroughProject)),
-            // filter(project(limit)) will change to filter(limit(project)) by 
PushdownProjectThroughLimit,
-            // then we should change filter(limit(project)) to 
project(filter(limit))
-            RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT
-                    .build(logicalFilter(logicalLimit(logicalProject()))
-                            .whenNot(filter -> 
filter.child().child().getProjects().stream()
-                                    .anyMatch(expr -> expr
-                                            
.anyMatch(WindowExpression.class::isInstance)))
-                            .then(filter -> {
-                                LogicalLimit<LogicalProject<Plan>> limit = 
filter.child();
-                                LogicalProject<Plan> project = limit.child();
+                logicalFilter(logicalProject())
+                        .whenNot(filter -> 
filter.child().getProjects().stream().anyMatch(
+                                expr -> 
expr.anyMatch(WindowExpression.class::isInstance)))
+                        
.then(PushdownFilterThroughProject::pushdownFilterThroughProject)
+                        .toRule(RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT),
+                // filter(project(limit)) will change to 
filter(limit(project)) by PushdownProjectThroughLimit,
+                // then we should change filter(limit(project)) to 
project(filter(limit))
+                logicalFilter(logicalLimit(logicalProject()))
+                        .whenNot(filter -> 
filter.child().child().getProjects().stream()
+                                .anyMatch(expr -> 
expr.anyMatch(WindowExpression.class::isInstance)))
+                        .then(filter -> {
+                            LogicalLimit<LogicalProject<Plan>> limit = 
filter.child();
+                            LogicalProject<Plan> project = limit.child();
 
-                                return 
project.withProjectsAndChild(project.getProjects(),
-                                        new LogicalFilter<>(
-                                                
ExpressionUtils.replace(filter.getConjuncts(),
-                                                        
project.getAliasToProducer()),
-                                                
limit.withChildren(project.child())));
-                            }))
+                            return 
project.withProjectsAndChild(project.getProjects(),
+                                    new LogicalFilter<>(
+                                            
ExpressionUtils.replace(filter.getConjuncts(),
+                                                    
project.getAliasToProducer()),
+                                            
limit.withChildren(project.child())));
+                        
}).toRule(RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT)
         );
     }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownProjectThroughLimit.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownProjectThroughLimit.java
index 652d0309106..8c4d2a93c56 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownProjectThroughLimit.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownProjectThroughLimit.java
@@ -24,6 +24,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 
 /**
+ * <pre>
  * Before:
  *          project
  *             │
@@ -42,6 +43,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalProject;
  *             │
  *             ▼
  *          plan node
+ * </pre>
  */
 public class PushdownProjectThroughLimit extends OneRewriteRuleFactory {
 
@@ -50,9 +52,7 @@ public class PushdownProjectThroughLimit extends 
OneRewriteRuleFactory {
         return logicalProject(logicalLimit()).thenApply(ctx -> {
             LogicalProject<LogicalLimit<Plan>> logicalProject = ctx.root;
             LogicalLimit<Plan> logicalLimit = logicalProject.child();
-            return new LogicalLimit<>(logicalLimit.getLimit(), 
logicalLimit.getOffset(),
-                    logicalLimit.getPhase(), 
logicalProject.withProjectsAndChild(logicalProject.getProjects(),
-                    logicalLimit.child()));
+            return 
logicalLimit.withChildren(logicalProject.withChildren(logicalLimit.child()));
         }).toRule(RuleType.PUSHDOWN_PROJECT_THROUGH_LIMIT);
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoin.java
new file mode 100644
index 00000000000..ac179864393
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoin.java
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.properties.OrderKey;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Push down TopN through Outer Join into left child .....
+ */
+public class PushdownTopNThroughJoin implements RewriteRuleFactory {
+
+    @Override
+    public List<Rule> buildRules() {
+        return ImmutableList.of(
+                // topN -> join
+                logicalTopN(logicalJoin())
+                        // TODO: complex orderby
+                        .when(topN -> 
topN.getOrderKeys().stream().map(OrderKey::getExpr)
+                                .allMatch(Slot.class::isInstance))
+                        .then(topN -> {
+                            LogicalJoin<Plan, Plan> join = topN.child();
+                            Plan newJoin = pushLimitThroughJoin(topN, join);
+                            if (newJoin == null || 
topN.child().children().equals(newJoin.children())) {
+                                return null;
+                            }
+                            return topN.withChildren(newJoin);
+                        })
+                        .toRule(RuleType.PUSH_TOP_N_THROUGH_JOIN),
+
+                // topN -> project -> join
+                logicalTopN(logicalProject(logicalJoin()))
+                        .when(topN -> 
topN.getOrderKeys().stream().map(OrderKey::getExpr)
+                                .allMatch(Slot.class::isInstance))
+                        .then(topN -> {
+                            LogicalProject<LogicalJoin<Plan, Plan>> project = 
topN.child();
+                            LogicalJoin<Plan, Plan> join = project.child();
+
+                            // If orderby exprs aren't all in the output of 
the project, we can't push down.
+                            // topN(order by: slot(a+1))
+                            // - project(a+1, b)
+                            // TODO: in the future, we also can push down it.
+                            Set<Slot> outputSet = 
project.child().getOutputSet();
+                            if 
(!topN.getOrderKeys().stream().map(OrderKey::getExpr)
+                                    .flatMap(e -> e.getInputSlots().stream())
+                                    .allMatch(outputSet::contains)) {
+                                return null;
+                            }
+
+                            Plan newJoin = pushLimitThroughJoin(topN, join);
+                            if (newJoin == null || 
join.children().equals(newJoin.children())) {
+                                return null;
+                            }
+                            return 
topN.withChildren(project.withChildren(newJoin));
+                        }).toRule(RuleType.PUSH_TOP_N_THROUGH_PROJECT_JOIN)
+        );
+    }
+
+    private Plan pushLimitThroughJoin(LogicalTopN<? extends Plan> topN, 
LogicalJoin<Plan, Plan> join) {
+        switch (join.getJoinType()) {
+            case LEFT_OUTER_JOIN:
+                Set<Slot> rightOutputSet = join.right().getOutputSet();
+                if (topN.getOrderKeys().stream().map(OrderKey::getExpr)
+                        .anyMatch(e -> Utils.isIntersecting(rightOutputSet, 
e.getInputSlots()))) {
+                    return null;
+                }
+                return join.withChildren(topN.withChildren(join.left()), 
join.right());
+            case RIGHT_OUTER_JOIN:
+                Set<Slot> leftOutputSet = join.left().getOutputSet();
+                if (topN.getOrderKeys().stream().map(OrderKey::getExpr)
+                        .anyMatch(e -> Utils.isIntersecting(leftOutputSet, 
e.getInputSlots()))) {
+                    return null;
+                }
+                return join.withChildren(join.left(), 
topN.withChildren(join.right()));
+            case CROSS_JOIN:
+                List<Slot> orderbySlots = 
topN.getOrderKeys().stream().map(OrderKey::getExpr)
+                        .flatMap(e -> 
e.getInputSlots().stream()).collect(Collectors.toList());
+                if (join.left().getOutputSet().containsAll(orderbySlots)) {
+                    return join.withChildren(topN.withChildren(join.left()), 
join.right());
+                } else if 
(join.right().getOutputSet().containsAll(orderbySlots)) {
+                    return join.withChildren(join.left(), 
topN.withChildren(join.right()));
+                } else {
+                    return null;
+                }
+            default:
+                // don't push limit.
+                return null;
+        }
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughWindow.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughWindow.java
index 755b71199cc..f1547d91089 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughWindow.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughWindow.java
@@ -59,7 +59,7 @@ public class PushdownTopNThroughWindow implements 
RewriteRuleFactory {
                     return topn;
                 }
                 return topn.withChildren(newWindow.get());
-            }).toRule(RuleType.PUSHDOWN_TOP_N_THROUGH_WINDOW),
+            }).toRule(RuleType.PUSH_TOP_N_THROUGH_WINDOW),
 
             // topn -> projection -> window
             logicalTopN(logicalProject(logicalWindow())).then(topn -> {
@@ -79,7 +79,7 @@ public class PushdownTopNThroughWindow implements 
RewriteRuleFactory {
                     return topn;
                 }
                 return 
topn.withChildren(project.withChildren(newWindow.get()));
-            }).toRule(RuleType.PUSHDOWN_TOP_N_THROUGH_PROJECTION_WINDOW)
+            }).toRule(RuleType.PUSH_TOP_N_THROUGH_PROJECT_WINDOW)
         );
     }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 8ddbd97d62b..1e67808c614 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -322,10 +322,6 @@ public class ExpressionUtils {
         return builder.build();
     }
 
-    public static boolean isAllLiteral(Expression... children) {
-        return Arrays.stream(children).allMatch(c -> c instanceof Literal);
-    }
-
     public static boolean isAllLiteral(List<Expression> children) {
         return children.stream().allMatch(c -> c instanceof Literal);
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownLimitTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownLimitTest.java
index f85882791e0..28e1a7fa468 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownLimitTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownLimitTest.java
@@ -65,7 +65,7 @@ import java.util.stream.Collectors;
 
 class PushdownLimitTest extends TestWithFeService implements 
MemoPatternMatchSupported {
     private final LogicalOlapScan scanScore = new 
LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), 
PlanConstructor.score);
-    private Plan scanStudent = new 
LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), 
PlanConstructor.student);
+    private final LogicalOlapScan scanStudent = new 
LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), 
PlanConstructor.student);
 
     @Override
     protected void runBeforeAll() throws Exception {
@@ -114,7 +114,7 @@ class PushdownLimitTest extends TestWithFeService 
implements MemoPatternMatchSup
     }
 
     @Test
-    public void testPushLimitThroughLeftJoin() {
+    void testPushLimitThroughLeftJoin() {
         test(JoinType.LEFT_OUTER_JOIN, true,
                 logicalLimit(
                         logicalProject(
@@ -136,7 +136,7 @@ class PushdownLimitTest extends TestWithFeService 
implements MemoPatternMatchSup
     }
 
     @Test
-    public void testPushLimitThroughRightJoin() {
+    void testPushLimitThroughRightJoin() {
         // after use RelationUtil to allocate relation id, the id will 
increase when getNextId() called.
         test(JoinType.RIGHT_OUTER_JOIN, true,
                 logicalLimit(
@@ -159,7 +159,7 @@ class PushdownLimitTest extends TestWithFeService 
implements MemoPatternMatchSup
     }
 
     @Test
-    public void testPushLimitThroughCrossJoin() {
+    void testPushLimitThroughCrossJoin() {
         test(JoinType.CROSS_JOIN, true,
                 logicalLimit(
                         logicalProject(
@@ -181,7 +181,7 @@ class PushdownLimitTest extends TestWithFeService 
implements MemoPatternMatchSup
     }
 
     @Test
-    public void testPushLimitThroughInnerJoin() {
+    void testPushLimitThroughInnerJoin() {
         test(JoinType.INNER_JOIN, true,
                 logicalLimit(
                         logicalProject(
@@ -203,7 +203,7 @@ class PushdownLimitTest extends TestWithFeService 
implements MemoPatternMatchSup
     }
 
     @Test
-    public void testTranslate() {
+    void testTranslate() {
         PlanChecker.from(connectContext).checkPlannerResult("select * from t1 
left join t2 on t1.k1=t2.k1 limit 5",
                 planner -> {
                     List<PlanFragment> fragments = planner.getFragments();
@@ -227,7 +227,7 @@ class PushdownLimitTest extends TestWithFeService 
implements MemoPatternMatchSup
     }
 
     @Test
-    public void testLimitPushSort() {
+    void testLimitPushSort() {
         PlanChecker.from(connectContext)
                 .analyze("select k1 from t1 order by k1 limit 1")
                 .rewrite()
@@ -235,7 +235,7 @@ class PushdownLimitTest extends TestWithFeService 
implements MemoPatternMatchSup
     }
 
     @Test
-    public void testLimitPushUnion() {
+    void testLimitPushUnion() {
         PlanChecker.from(connectContext)
                 .analyze("select k1 from t1 "
                         + "union all select k2 from t2 "
@@ -262,7 +262,7 @@ class PushdownLimitTest extends TestWithFeService 
implements MemoPatternMatchSup
     }
 
     @Test
-    public void testLimitPushWindow() {
+    void testLimitPushWindow() {
         ConnectContext context = MemoTestUtils.createConnectContext();
         context.getSessionVariable().setEnablePartitionTopN(true);
         NamedExpression grade = scanScore.getOutput().get(2).toSlot();
@@ -304,7 +304,7 @@ class PushdownLimitTest extends TestWithFeService 
implements MemoPatternMatchSup
     }
 
     @Test
-    public void testTopNPushWindow() {
+    void testTopNPushWindow() {
         ConnectContext context = MemoTestUtils.createConnectContext();
         context.getSessionVariable().setEnablePartitionTopN(true);
         NamedExpression grade = scanScore.getOutput().get(2).toSlot();
@@ -322,7 +322,7 @@ class PushdownLimitTest extends TestWithFeService 
implements MemoPatternMatchSup
         List<OrderKey> orderKey = ImmutableList.of(
                 new OrderKey(windowAlias1.toSlot(), true, true)
         );
-        LogicalSort<LogicalWindow> sort = new LogicalSort<>(orderKey, window);
+        LogicalSort<Plan> sort = new LogicalSort<>(orderKey, window);
 
         LogicalPlan plan = new LogicalPlanBuilder(sort)
                 .limit(100)
@@ -364,8 +364,8 @@ class PushdownLimitTest extends TestWithFeService 
implements MemoPatternMatchSup
         LogicalJoin<? extends Plan, ? extends Plan> join = new LogicalJoin<>(
                 joinType,
                 joinConditions,
-                new LogicalOlapScan(((LogicalOlapScan) 
scanScore).getRelationId(), PlanConstructor.score),
-                new LogicalOlapScan(((LogicalOlapScan) 
scanStudent).getRelationId(), PlanConstructor.student)
+                new LogicalOlapScan(scanScore.getRelationId(), 
PlanConstructor.score),
+                new LogicalOlapScan(scanStudent.getRelationId(), 
PlanConstructor.student)
         );
 
         if (hasProject) {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
index cbe1d7eb980..25d842418e5 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
@@ -39,6 +39,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
 import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
+import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
@@ -143,6 +144,14 @@ public class LogicalPlanBuilder {
         return limit(limit, 0);
     }
 
+    public LogicalPlanBuilder topN(long limit, long offset, List<Integer> 
orderKeySlotsIndex) {
+        List<OrderKey> orderKeys = orderKeySlotsIndex.stream()
+                .map(i -> new OrderKey(this.plan.getOutput().get(i), false, 
false))
+                .collect(Collectors.toList());
+        LogicalTopN<Plan> topNPlan = new LogicalTopN<>(orderKeys, limit, 
offset, this.plan);
+        return from(topNPlan);
+    }
+
     public LogicalPlanBuilder filter(Expression conjunct) {
         return 
filter(ImmutableSet.copyOf(ExpressionUtils.extractConjunction(conjunct)));
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to