This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new be3a7e69cd [refactor](Nereids): polish code 
SemiJoinLogicalJoinTranspose. (#17740)
be3a7e69cd is described below

commit be3a7e69cd9b74f1a79569348a99c93101fa04e0
Author: jakevin <[email protected]>
AuthorDate: Tue Mar 14 12:48:58 2023 +0800

    [refactor](Nereids): polish code SemiJoinLogicalJoinTranspose. (#17740)
---
 .../org/apache/doris/nereids/rules/RuleType.java   |  8 ++-
 .../rules/exploration/join/JoinCommute.java        |  2 +-
 .../join/SemiJoinLogicalJoinTranspose.java         |  1 -
 .../join/SemiJoinLogicalJoinTransposeProject.java  | 41 +++---------
 .../join/SemiJoinSemiJoinTranspose.java            |  2 +-
 .../SemiJoinLogicalJoinTransposeProjectTest.java   | 73 +++++++++-------------
 6 files changed, 43 insertions(+), 84 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 4f8d96a0d9..9b696f8235 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -216,7 +216,7 @@ public enum RuleType {
 
     // exploration rules
     TEST_EXPLORATION(RuleTypeClass.EXPLORATION),
-    LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION),
+    LOGICAL_JOIN_COMMUTE(RuleTypeClass.EXPLORATION),
     LOGICAL_INNER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION),
     LOGICAL_INNER_JOIN_LASSCOM_PROJECT(RuleTypeClass.EXPLORATION),
     LOGICAL_OUTER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION),
@@ -225,7 +225,10 @@ public enum RuleType {
     LOGICAL_OUTER_JOIN_ASSOC_PROJECT(RuleTypeClass.EXPLORATION),
     LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
     
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
-    LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE(RuleTypeClass.EXPLORATION),
+    LOGICAL_JOIN_LOGICAL_SEMI_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
+    
LOGICAL_JOIN_LOGICAL_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
+    LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
+    LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
     LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION),
     LOGICAL_JOIN_EXCHANGE_LEFT_PROJECT(RuleTypeClass.EXPLORATION),
     LOGICAL_JOIN_EXCHANGE_RIGHT_PROJECT(RuleTypeClass.EXPLORATION),
@@ -234,7 +237,6 @@ public enum RuleType {
     LOGICAL_INNER_JOIN_LEFT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
     LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE(RuleTypeClass.EXPLORATION),
     LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
-    LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
     PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION),
     PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION),
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java
index 367b00530e..11f9955a76 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java
@@ -67,7 +67,7 @@ public class JoinCommute extends OneExplorationRuleFactory {
                     }
 
                     return newJoin;
-                }).toRule(RuleType.LOGICAL_JOIN_COMMUTATE);
+                }).toRule(RuleType.LOGICAL_JOIN_COMMUTE);
     }
 
     enum SwapType {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java
index 0b57ed1b70..73f202cd26 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java
@@ -58,7 +58,6 @@ public class SemiJoinLogicalJoinTranspose extends 
OneExplorationRuleFactory {
                         && (topJoin.left().getJoinType().isInnerJoin()
                                 || 
topJoin.left().getJoinType().isLeftOuterJoin()
                                 || 
topJoin.left().getJoinType().isRightOuterJoin())))
-                .whenNot(topJoin -> 
topJoin.left().getJoinType().isSemiOrAntiJoin())
                 .when(this::conditionChecker)
                 .whenNot(topJoin -> topJoin.hasJoinHint() || 
topJoin.left().hasJoinHint())
                 .whenNot(LogicalJoin::isMarkJoin)
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
index afa88aad98..a645ca3062 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
@@ -17,14 +17,10 @@
 
 package org.apache.doris.nereids.rules.exploration.join;
 
-import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
 import org.apache.doris.nereids.trees.expressions.ExprId;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.Plan;
@@ -34,9 +30,6 @@ import org.apache.doris.nereids.util.Utils;
 
 import com.google.common.base.Preconditions;
 
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -65,7 +58,6 @@ public class SemiJoinLogicalJoinTransposeProject extends 
OneExplorationRuleFacto
                         && (topJoin.left().child().getJoinType().isInnerJoin()
                         || 
topJoin.left().child().getJoinType().isLeftOuterJoin()
                         || 
topJoin.left().child().getJoinType().isRightOuterJoin())))
-                .whenNot(topJoin -> 
topJoin.left().child().getJoinType().isSemiOrAntiJoin())
                 .whenNot(join -> join.hasJoinHint() || 
join.left().child().hasJoinHint())
                 .whenNot(join -> join.isMarkJoin() || 
join.left().child().isMarkJoin())
                 .when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
@@ -76,9 +68,8 @@ public class SemiJoinLogicalJoinTransposeProject extends 
OneExplorationRuleFacto
                     GroupPlan b = bottomJoin.right();
                     GroupPlan c = topSemiJoin.right();
 
-                    // push topSemiJoin down project, so we need replace 
conjuncts by project.
-                    Pair<List<Expression>, List<Expression>> conjuncts = 
replaceConjuncts(topSemiJoin, project);
-                    Set<ExprId> conjunctsIds = 
Stream.concat(conjuncts.first.stream(), conjuncts.second.stream())
+                    Set<ExprId> conjunctsIds = 
Stream.concat(topSemiJoin.getHashJoinConjuncts().stream(),
+                                    
topSemiJoin.getOtherJoinConjuncts().stream())
                             .flatMap(expr -> 
expr.getInputSlotExprIds().stream()).collect(Collectors.toSet());
                     ContainsType containsType = containsChildren(conjunctsIds, 
a.getOutputExprIdSet(),
                             b.getOutputExprIdSet());
@@ -99,8 +90,7 @@ public class SemiJoinLogicalJoinTransposeProject extends 
OneExplorationRuleFacto
                         // RIGHT_OUTER_JOIN should be eliminated in rewrite 
phase
                         Preconditions.checkState(bottomJoin.getJoinType() != 
JoinType.RIGHT_OUTER_JOIN);
 
-                        Plan newBottomSemiJoin = 
topSemiJoin.withConjunctsChildren(conjuncts.first, conjuncts.second,
-                                a, c);
+                        Plan newBottomSemiJoin = topSemiJoin.withChildren(a, 
c);
                         Plan newTopJoin = 
bottomJoin.withChildren(newBottomSemiJoin, b);
                         return project.withChildren(newTopJoin);
                     } else {
@@ -112,37 +102,20 @@ public class SemiJoinLogicalJoinTransposeProject extends 
OneExplorationRuleFacto
                          *       /     \                       |
                          *    project   C                  newTopJoin
                          *       |                        /         \
-                         *  bottomJoin  C     -->       A     newBottomSemiJoin
-                         *   /    \                               /      \
-                         *  A      B                             B       C
+                         *  bottomJoin  C     -->        A    newBottomSemiJoin
+                         *    /    \                              /      \
+                         *   A      B                             B       C
                          */
                         // LEFT_OUTER_JOIN should be eliminated in rewrite 
phase
                         Preconditions.checkState(bottomJoin.getJoinType() != 
JoinType.LEFT_OUTER_JOIN);
 
-                        Plan newBottomSemiJoin = 
topSemiJoin.withConjunctsChildren(conjuncts.first, conjuncts.second,
-                                b, c);
+                        Plan newBottomSemiJoin = topSemiJoin.withChildren(b, 
c);
                         Plan newTopJoin = bottomJoin.withChildren(a, 
newBottomSemiJoin);
                         return project.withChildren(newTopJoin);
                     }
                 
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT);
     }
 
-    private Pair<List<Expression>, List<Expression>> 
replaceConjuncts(LogicalJoin<? extends Plan, ? extends Plan> join,
-            LogicalProject<? extends Plan> project) {
-        Map<ExprId, Slot> outputToInput = new HashMap<>();
-        for (NamedExpression outputExpr : project.getProjects()) {
-            Set<Slot> usedSlots = outputExpr.getInputSlots();
-            Preconditions.checkState(usedSlots.size() == 1);
-            Slot inputSlot = usedSlots.iterator().next();
-            outputToInput.put(outputExpr.getExprId(), inputSlot);
-        }
-        List<Expression> topHashConjuncts =
-                
JoinReorderUtils.replaceJoinConjuncts(join.getHashJoinConjuncts(), 
outputToInput);
-        List<Expression> topOtherConjuncts =
-                
JoinReorderUtils.replaceJoinConjuncts(join.getOtherJoinConjuncts(), 
outputToInput);
-        return Pair.of(topHashConjuncts, topOtherConjuncts);
-    }
-
     enum ContainsType {
         LEFT, RIGHT, ALL
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java
index 6977dd9d62..67c97e5788 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java
@@ -73,7 +73,7 @@ public class SemiJoinSemiJoinTranspose extends 
OneExplorationRuleFactory {
                     Plan newBottomJoin = topJoin.withChildren(a, c);
                     Plan newTopJoin = bottomJoin.withChildren(newBottomJoin, 
b);
                     return newTopJoin;
-                }).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE);
+                }).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE);
     }
 
     private boolean typeChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, 
GroupPlan> topJoin) {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java
index be00e49dc1..7b42fe4e5d 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java
@@ -20,11 +20,10 @@ package org.apache.doris.nereids.rules.exploration.join;
 import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.memo.Group;
 import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
@@ -33,13 +32,13 @@ import com.google.common.collect.ImmutableList;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
-public class SemiJoinLogicalJoinTransposeProjectTest {
+class SemiJoinLogicalJoinTransposeProjectTest implements 
MemoPatternMatchSupported {
     private static final LogicalOlapScan scan1 = 
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
     private static final LogicalOlapScan scan2 = 
PlanConstructor.newLogicalOlapScan(1, "t2", 0);
     private static final LogicalOlapScan scan3 = 
PlanConstructor.newLogicalOlapScan(2, "t3", 0);
 
     @Test
-    public void testSemiJoinLogicalTransposeProjectLAsscom() {
+    void testSemiJoinLogicalTransposeProjectLAsscom() {
         /*-
          *     topSemiJoin                    project
          *      /     \                         |
@@ -57,28 +56,21 @@ public class SemiJoinLogicalJoinTransposeProjectTest {
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
                 
.applyExploration(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build())
-                .checkMemo(memo -> {
-                    Group root = memo.getRoot();
-                    Assertions.assertEquals(2, 
root.getLogicalExpressions().size());
-                    Plan plan = 
memo.copyOut(root.getLogicalExpressions().get(1), false);
-
-                    LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) 
plan.child(0);
-                    LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) 
newTopJoin.left();
-                    Assertions.assertEquals(JoinType.INNER_JOIN, 
newTopJoin.getJoinType());
-                    Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, 
newBottomJoin.getJoinType());
-
-                    LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) 
newBottomJoin.left();
-                    LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) 
newBottomJoin.right();
-                    LogicalOlapScan newTopJoinRight = (LogicalOlapScan) 
newTopJoin.right();
-
-                    Assertions.assertEquals("t1", 
newBottomJoinLeft.getTable().getName());
-                    Assertions.assertEquals("t3", 
newBottomJoinRight.getTable().getName());
-                    Assertions.assertEquals("t2", 
newTopJoinRight.getTable().getName());
-                });
+                .matchesExploration(
+                        logicalProject(
+                                innerLogicalJoin(
+                                        leftSemiLogicalJoin(
+                                                logicalOlapScan().when(scan -> 
scan.getTable().getName().equals("t1")),
+                                                logicalOlapScan().when(scan -> 
scan.getTable().getName().equals("t3"))
+                                        ),
+                                        logicalOlapScan().when(scan -> 
scan.getTable().getName().equals("t2"))
+                                )
+                        )
+                );
     }
 
     @Test
-    public void testSemiJoinLogicalTransposeProjectLAsscomFail() {
+    void testSemiJoinLogicalTransposeProjectLAsscomFail() {
         LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
                 .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = 
t2.id
                 .project(ImmutableList.of(0, 2)) // t1.id, t2.id
@@ -94,15 +86,15 @@ public class SemiJoinLogicalJoinTransposeProjectTest {
     }
 
     @Test
-    public void testSemiJoinLogicalTransposeProjectAll() {
+    void testSemiJoinLogicalTransposeProjectAll() {
         /*-
          *     topSemiJoin                  project
          *       /     \                       |
          *    project   C                  newTopJoin
          *       |                        /         \
          *  bottomJoin  C     -->       A     newBottomSemiJoin
-         *   /    \                               /      \
-         *  A      B                             B       C
+         *    /    \                              /      \
+         *   A      B                             B       C
          */
         LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
                 .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = 
t2.id
@@ -112,23 +104,16 @@ public class SemiJoinLogicalJoinTransposeProjectTest {
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
                 
.applyExploration(SemiJoinLogicalJoinTransposeProject.ALL.build())
-                .checkMemo(memo -> {
-                    Group root = memo.getRoot();
-                    Assertions.assertEquals(2, 
root.getLogicalExpressions().size());
-                    Plan plan = 
memo.copyOut(root.getLogicalExpressions().get(1), false);
-
-                    LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) 
plan.child(0);
-                    LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) 
newTopJoin.right();
-                    Assertions.assertEquals(JoinType.INNER_JOIN, 
newTopJoin.getJoinType());
-                    Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, 
newBottomJoin.getJoinType());
-
-                    LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) 
newBottomJoin.left();
-                    LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) 
newBottomJoin.right();
-                    LogicalOlapScan newTopJoinLeft = (LogicalOlapScan) 
newTopJoin.left();
-
-                    Assertions.assertEquals("t1", 
newTopJoinLeft.getTable().getName());
-                    Assertions.assertEquals("t2", 
newBottomJoinLeft.getTable().getName());
-                    Assertions.assertEquals("t3", 
newBottomJoinRight.getTable().getName());
-                });
+                .matchesExploration(
+                        logicalProject(
+                                logicalJoin(
+                                        logicalOlapScan().when(scan -> 
scan.getTable().getName().equals("t1")),
+                                        leftSemiLogicalJoin(
+                                                logicalOlapScan().when(scan -> 
scan.getTable().getName().equals("t2")),
+                                                logicalOlapScan().when(scan -> 
scan.getTable().getName().equals("t3"))
+                                        )
+                                )
+                        )
+                );
     }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to