This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 35c2a529fac [fix](Nereids): when predicate contains right output,
don't convert outer to anti join (#30276)
35c2a529fac is described below
commit 35c2a529fac1b3aa805cb7fd365565c86f327844
Author: 谢健 <[email protected]>
AuthorDate: Thu Jan 25 14:01:27 2024 +0800
[fix](Nereids): when predicate contains right output, don't convert outer
to anti join (#30276)
---
.../rules/rewrite/ConvertOuterJoinToAntiJoin.java | 6 +-
.../rewrite/ConvertOuterJoinToAntiJoinTest.java | 82 +++++++++++++++++++++-
2 files changed, 84 insertions(+), 4 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java
index ebc4630d78e..74bd7e29142 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java
@@ -81,7 +81,8 @@ public class ConvertOuterJoinToAntiJoin extends
OneRewriteRuleFactory {
&&
rightAlwaysNullSlots.containsAll(p.getInputSlots())))
.collect(ImmutableSet.toImmutableSet());
boolean containRightSlot = predicates.stream()
- .anyMatch(s ->
join.right().getOutputSet().containsAll(s.getInputSlots()));
+ .flatMap(p -> p.getInputSlots().stream())
+ .anyMatch(join.right().getOutputSet()::contains);
if (!containRightSlot) {
res = join.withJoinType(JoinType.LEFT_ANTI_JOIN);
res = predicates.isEmpty() ? res :
filter.withConjuncts(predicates).withChildren(res);
@@ -94,7 +95,8 @@ public class ConvertOuterJoinToAntiJoin extends
OneRewriteRuleFactory {
&&
leftAlwaysNullSlots.containsAll(p.getInputSlots())))
.collect(ImmutableSet.toImmutableSet());
boolean containLeftSlot = predicates.stream()
- .anyMatch(s ->
join.left().getOutputSet().containsAll(s.getInputSlots()));
+ .flatMap(p -> p.getInputSlots().stream())
+ .anyMatch(join.left().getOutputSet()::contains);
if (!containLeftSlot) {
res = join.withJoinType(JoinType.RIGHT_ANTI_JOIN);
res = predicates.isEmpty() ? res :
filter.withConjuncts(predicates).withChildren(res);
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java
index 49960c77ee4..20b36d3272e 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java
@@ -18,8 +18,12 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.And;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@@ -30,6 +34,7 @@ import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
import org.junit.jupiter.api.Test;
class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
@@ -48,7 +53,7 @@ class ConvertOuterJoinToAntiJoinTest implements
MemoPatternMatchSupported {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) //
t1.id = t2.id
.filter(new IsNull(scan2.getOutput().get(0)))
- .project(ImmutableList.of(0, 1))
+ .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
@@ -63,7 +68,7 @@ class ConvertOuterJoinToAntiJoinTest implements
MemoPatternMatchSupported {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0)) //
t1.id = t2.id
.filter(new IsNull(scan1.getOutput().get(0)))
- .project(ImmutableList.of(2, 3))
+ .projectExprs(ImmutableList.copyOf(scan2.getOutput()))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
@@ -72,4 +77,77 @@ class ConvertOuterJoinToAntiJoinTest implements
MemoPatternMatchSupported {
.printlnTree()
.matches(logicalJoin().when(join ->
join.getJoinType().isRightAntiJoin()));
}
+
+ @Test
+ void testEliminateLeftWithLeftPredicate() {
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) //
t1.id = t2.id
+ .filter(Sets.newHashSet(
+ new IsNull(scan2.getOutput().get(0)),
+ new EqualTo(scan1.getOutput().get(0), new
IntegerLiteral(1)))
+ )
+ .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new InferFilterNotNull())
+ .applyTopDown(new ConvertOuterJoinToAntiJoin())
+ .printlnTree()
+ .matches(logicalJoin().when(join ->
join.getJoinType().isLeftAntiJoin()));
+ }
+
+ @Test
+ void testEliminateLeftWithRightPredicate() {
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) //
t1.id = t2.id
+ .filter(Sets.newHashSet(
+ new IsNull(scan2.getOutput().get(0)),
+ new EqualTo(scan2.getOutput().get(0), new
IntegerLiteral(1)))
+ )
+ .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new InferFilterNotNull())
+ .applyTopDown(new ConvertOuterJoinToAntiJoin())
+ .printlnTree()
+ .matches(logicalJoin().when(join ->
join.getJoinType().isLeftOuterJoin()));
+ }
+
+ @Test
+ void testEliminateLeftWithOrPredicate() {
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) //
t1.id = t2.id
+ .filter(Sets.newHashSet(
+ new IsNull(scan1.getOutput().get(0)),
+ new Or(new IsNull(scan1.getOutput().get(0)), new
IsNull(scan2.getOutput().get(0))))
+ )
+ .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new InferFilterNotNull())
+ .applyTopDown(new ConvertOuterJoinToAntiJoin())
+ .printlnTree()
+ .matches(logicalJoin().when(join ->
join.getJoinType().isLeftOuterJoin()));
+ }
+
+ @Test
+ void testEliminateLeftWithAndPredicate() {
+ LogicalPlan plan = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) //
t1.id = t2.id
+ .filter(Sets.newHashSet(
+ new IsNull(scan1.getOutput().get(0)),
+ new And(new EqualTo(scan1.getOutput().get(0), new
IntegerLiteral(1)),
+ new EqualTo(scan1.getOutput().get(0), new
IntegerLiteral(1))))
+ )
+ .project(ImmutableList.of(2, 3))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new InferFilterNotNull())
+ .applyTopDown(new ConvertOuterJoinToAntiJoin())
+ .printlnTree()
+ .matches(logicalJoin().when(join ->
join.getJoinType().isLeftOuterJoin()));
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]