This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new be3a7e69cd [refactor](Nereids): polish code
SemiJoinLogicalJoinTranspose. (#17740)
be3a7e69cd is described below
commit be3a7e69cd9b74f1a79569348a99c93101fa04e0
Author: jakevin <[email protected]>
AuthorDate: Tue Mar 14 12:48:58 2023 +0800
[refactor](Nereids): polish code SemiJoinLogicalJoinTranspose. (#17740)
---
.../org/apache/doris/nereids/rules/RuleType.java | 8 ++-
.../rules/exploration/join/JoinCommute.java | 2 +-
.../join/SemiJoinLogicalJoinTranspose.java | 1 -
.../join/SemiJoinLogicalJoinTransposeProject.java | 41 +++---------
.../join/SemiJoinSemiJoinTranspose.java | 2 +-
.../SemiJoinLogicalJoinTransposeProjectTest.java | 73 +++++++++-------------
6 files changed, 43 insertions(+), 84 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 4f8d96a0d9..9b696f8235 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -216,7 +216,7 @@ public enum RuleType {
// exploration rules
TEST_EXPLORATION(RuleTypeClass.EXPLORATION),
- LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION),
+ LOGICAL_JOIN_COMMUTE(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_LASSCOM_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_OUTER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION),
@@ -225,7 +225,10 @@ public enum RuleType {
LOGICAL_OUTER_JOIN_ASSOC_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
- LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE(RuleTypeClass.EXPLORATION),
+ LOGICAL_JOIN_LOGICAL_SEMI_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
+
LOGICAL_JOIN_LOGICAL_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
+ LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
+ LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_EXCHANGE_LEFT_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_EXCHANGE_RIGHT_PROJECT(RuleTypeClass.EXPLORATION),
@@ -234,7 +237,6 @@ public enum RuleType {
LOGICAL_INNER_JOIN_LEFT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
- LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION),
PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java
index 367b00530e..11f9955a76 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java
@@ -67,7 +67,7 @@ public class JoinCommute extends OneExplorationRuleFactory {
}
return newJoin;
- }).toRule(RuleType.LOGICAL_JOIN_COMMUTATE);
+ }).toRule(RuleType.LOGICAL_JOIN_COMMUTE);
}
enum SwapType {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java
index 0b57ed1b70..73f202cd26 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java
@@ -58,7 +58,6 @@ public class SemiJoinLogicalJoinTranspose extends
OneExplorationRuleFactory {
&& (topJoin.left().getJoinType().isInnerJoin()
||
topJoin.left().getJoinType().isLeftOuterJoin()
||
topJoin.left().getJoinType().isRightOuterJoin())))
- .whenNot(topJoin ->
topJoin.left().getJoinType().isSemiOrAntiJoin())
.when(this::conditionChecker)
.whenNot(topJoin -> topJoin.hasJoinHint() ||
topJoin.left().hasJoinHint())
.whenNot(LogicalJoin::isMarkJoin)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
index afa88aad98..a645ca3062 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
@@ -17,14 +17,10 @@
package org.apache.doris.nereids.rules.exploration.join;
-import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@@ -34,9 +30,6 @@ import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -65,7 +58,6 @@ public class SemiJoinLogicalJoinTransposeProject extends
OneExplorationRuleFacto
&& (topJoin.left().child().getJoinType().isInnerJoin()
||
topJoin.left().child().getJoinType().isLeftOuterJoin()
||
topJoin.left().child().getJoinType().isRightOuterJoin())))
- .whenNot(topJoin ->
topJoin.left().child().getJoinType().isSemiOrAntiJoin())
.whenNot(join -> join.hasJoinHint() ||
join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() ||
join.left().child().isMarkJoin())
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
@@ -76,9 +68,8 @@ public class SemiJoinLogicalJoinTransposeProject extends
OneExplorationRuleFacto
GroupPlan b = bottomJoin.right();
GroupPlan c = topSemiJoin.right();
- // push topSemiJoin down project, so we need replace
conjuncts by project.
- Pair<List<Expression>, List<Expression>> conjuncts =
replaceConjuncts(topSemiJoin, project);
- Set<ExprId> conjunctsIds =
Stream.concat(conjuncts.first.stream(), conjuncts.second.stream())
+ Set<ExprId> conjunctsIds =
Stream.concat(topSemiJoin.getHashJoinConjuncts().stream(),
+
topSemiJoin.getOtherJoinConjuncts().stream())
.flatMap(expr ->
expr.getInputSlotExprIds().stream()).collect(Collectors.toSet());
ContainsType containsType = containsChildren(conjunctsIds,
a.getOutputExprIdSet(),
b.getOutputExprIdSet());
@@ -99,8 +90,7 @@ public class SemiJoinLogicalJoinTransposeProject extends
OneExplorationRuleFacto
// RIGHT_OUTER_JOIN should be eliminated in rewrite
phase
Preconditions.checkState(bottomJoin.getJoinType() !=
JoinType.RIGHT_OUTER_JOIN);
- Plan newBottomSemiJoin =
topSemiJoin.withConjunctsChildren(conjuncts.first, conjuncts.second,
- a, c);
+ Plan newBottomSemiJoin = topSemiJoin.withChildren(a,
c);
Plan newTopJoin =
bottomJoin.withChildren(newBottomSemiJoin, b);
return project.withChildren(newTopJoin);
} else {
@@ -112,37 +102,20 @@ public class SemiJoinLogicalJoinTransposeProject extends
OneExplorationRuleFacto
* / \ |
* project C newTopJoin
* | / \
- * bottomJoin C --> A newBottomSemiJoin
- * / \ / \
- * A B B C
+ * bottomJoin C --> A newBottomSemiJoin
+ * / \ / \
+ * A B B C
*/
// LEFT_OUTER_JOIN should be eliminated in rewrite
phase
Preconditions.checkState(bottomJoin.getJoinType() !=
JoinType.LEFT_OUTER_JOIN);
- Plan newBottomSemiJoin =
topSemiJoin.withConjunctsChildren(conjuncts.first, conjuncts.second,
- b, c);
+ Plan newBottomSemiJoin = topSemiJoin.withChildren(b,
c);
Plan newTopJoin = bottomJoin.withChildren(a,
newBottomSemiJoin);
return project.withChildren(newTopJoin);
}
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT);
}
- private Pair<List<Expression>, List<Expression>>
replaceConjuncts(LogicalJoin<? extends Plan, ? extends Plan> join,
- LogicalProject<? extends Plan> project) {
- Map<ExprId, Slot> outputToInput = new HashMap<>();
- for (NamedExpression outputExpr : project.getProjects()) {
- Set<Slot> usedSlots = outputExpr.getInputSlots();
- Preconditions.checkState(usedSlots.size() == 1);
- Slot inputSlot = usedSlots.iterator().next();
- outputToInput.put(outputExpr.getExprId(), inputSlot);
- }
- List<Expression> topHashConjuncts =
-
JoinReorderUtils.replaceJoinConjuncts(join.getHashJoinConjuncts(),
outputToInput);
- List<Expression> topOtherConjuncts =
-
JoinReorderUtils.replaceJoinConjuncts(join.getOtherJoinConjuncts(),
outputToInput);
- return Pair.of(topHashConjuncts, topOtherConjuncts);
- }
-
enum ContainsType {
LEFT, RIGHT, ALL
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java
index 6977dd9d62..67c97e5788 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java
@@ -73,7 +73,7 @@ public class SemiJoinSemiJoinTranspose extends
OneExplorationRuleFactory {
Plan newBottomJoin = topJoin.withChildren(a, c);
Plan newTopJoin = bottomJoin.withChildren(newBottomJoin,
b);
return newTopJoin;
- }).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE);
+ }).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE);
}
private boolean typeChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>,
GroupPlan> topJoin) {
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java
index be00e49dc1..7b42fe4e5d 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java
@@ -20,11 +20,10 @@ package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
@@ -33,13 +32,13 @@ import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
-public class SemiJoinLogicalJoinTransposeProjectTest {
+class SemiJoinLogicalJoinTransposeProjectTest implements
MemoPatternMatchSupported {
private static final LogicalOlapScan scan1 =
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private static final LogicalOlapScan scan2 =
PlanConstructor.newLogicalOlapScan(1, "t2", 0);
private static final LogicalOlapScan scan3 =
PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@Test
- public void testSemiJoinLogicalTransposeProjectLAsscom() {
+ void testSemiJoinLogicalTransposeProjectLAsscom() {
/*-
* topSemiJoin project
* / \ |
@@ -57,28 +56,21 @@ public class SemiJoinLogicalJoinTransposeProjectTest {
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.applyExploration(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build())
- .checkMemo(memo -> {
- Group root = memo.getRoot();
- Assertions.assertEquals(2,
root.getLogicalExpressions().size());
- Plan plan =
memo.copyOut(root.getLogicalExpressions().get(1), false);
-
- LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>)
plan.child(0);
- LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>)
newTopJoin.left();
- Assertions.assertEquals(JoinType.INNER_JOIN,
newTopJoin.getJoinType());
- Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN,
newBottomJoin.getJoinType());
-
- LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan)
newBottomJoin.left();
- LogicalOlapScan newBottomJoinRight = (LogicalOlapScan)
newBottomJoin.right();
- LogicalOlapScan newTopJoinRight = (LogicalOlapScan)
newTopJoin.right();
-
- Assertions.assertEquals("t1",
newBottomJoinLeft.getTable().getName());
- Assertions.assertEquals("t3",
newBottomJoinRight.getTable().getName());
- Assertions.assertEquals("t2",
newTopJoinRight.getTable().getName());
- });
+ .matchesExploration(
+ logicalProject(
+ innerLogicalJoin(
+ leftSemiLogicalJoin(
+ logicalOlapScan().when(scan ->
scan.getTable().getName().equals("t1")),
+ logicalOlapScan().when(scan ->
scan.getTable().getName().equals("t3"))
+ ),
+ logicalOlapScan().when(scan ->
scan.getTable().getName().equals("t2"))
+ )
+ )
+ );
}
@Test
- public void testSemiJoinLogicalTransposeProjectLAsscomFail() {
+ void testSemiJoinLogicalTransposeProjectLAsscomFail() {
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id =
t2.id
.project(ImmutableList.of(0, 2)) // t1.id, t2.id
@@ -94,15 +86,15 @@ public class SemiJoinLogicalJoinTransposeProjectTest {
}
@Test
- public void testSemiJoinLogicalTransposeProjectAll() {
+ void testSemiJoinLogicalTransposeProjectAll() {
/*-
* topSemiJoin project
* / \ |
* project C newTopJoin
* | / \
* bottomJoin C --> A newBottomSemiJoin
- * / \ / \
- * A B B C
+ * / \ / \
+ * A B B C
*/
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id =
t2.id
@@ -112,23 +104,16 @@ public class SemiJoinLogicalJoinTransposeProjectTest {
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.applyExploration(SemiJoinLogicalJoinTransposeProject.ALL.build())
- .checkMemo(memo -> {
- Group root = memo.getRoot();
- Assertions.assertEquals(2,
root.getLogicalExpressions().size());
- Plan plan =
memo.copyOut(root.getLogicalExpressions().get(1), false);
-
- LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>)
plan.child(0);
- LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>)
newTopJoin.right();
- Assertions.assertEquals(JoinType.INNER_JOIN,
newTopJoin.getJoinType());
- Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN,
newBottomJoin.getJoinType());
-
- LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan)
newBottomJoin.left();
- LogicalOlapScan newBottomJoinRight = (LogicalOlapScan)
newBottomJoin.right();
- LogicalOlapScan newTopJoinLeft = (LogicalOlapScan)
newTopJoin.left();
-
- Assertions.assertEquals("t1",
newTopJoinLeft.getTable().getName());
- Assertions.assertEquals("t2",
newBottomJoinLeft.getTable().getName());
- Assertions.assertEquals("t3",
newBottomJoinRight.getTable().getName());
- });
+ .matchesExploration(
+ logicalProject(
+ logicalJoin(
+ logicalOlapScan().when(scan ->
scan.getTable().getName().equals("t1")),
+ leftSemiLogicalJoin(
+ logicalOlapScan().when(scan ->
scan.getTable().getName().equals("t2")),
+ logicalOlapScan().when(scan ->
scan.getTable().getName().equals("t3"))
+ )
+ )
+ )
+ );
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]