This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.1 by this push:
new 0deb629d07a [fix](Nereids): clone the producer plan and put
logicalAnchor generated by `Or_Expansion` above `logicalSink` (#34771)
0deb629d07a is described below
commit 0deb629d07ab7148e45fc4ae997907c41d092bca
Author: 谢健 <[email protected]>
AuthorDate: Tue May 14 15:04:21 2024 +0800
[fix](Nereids): clone the producer plan and put logicalAnchor generated by
`Or_Expansion` above `logicalSink` (#34771)
* put cte anchor on the root
put logicalAnchor on root
clone plan of cte consumer
* fix unit test
---
.../doris/nereids/jobs/executor/Rewriter.java | 2 +-
.../doris/nereids/rules/rewrite/OrExpansion.java | 197 ++++++++++++++-------
.../nereids/rules/rewrite/OrExpansionTest.java | 86 +++++++++
3 files changed, 219 insertions(+), 66 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 422667af6cf..b18f8a67a3a 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -471,7 +471,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
custom(RuleType.REWRITE_CTE_CHILDREN, () -> new
RewriteCteChildren(jobs))
),
topic("or expansion",
- topDown(new OrExpansion())),
+ custom(RuleType.OR_EXPANSION, () ->
OrExpansion.INSTANCE)),
topic("whole plan check",
custom(RuleType.ADJUST_NULLABLE, AdjustNullable::new)
)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java
index ff531ffce38..9f9257f5f60 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java
@@ -19,10 +19,12 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
+import org.apache.doris.nereids.StatementContext;
+import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
+import org.apache.doris.nereids.rules.rewrite.OrExpansion.OrExpandsionContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
@@ -38,8 +40,11 @@ import
org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.qe.ConnectContext;
@@ -53,6 +58,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Map.Entry;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
@@ -62,7 +68,7 @@ import javax.annotation.Nullable;
* => / \
* HJ(cond1) HJ(cond2 and !cond1)
*/
-public class OrExpansion extends OneExplorationRuleFactory {
+public class OrExpansion extends DefaultPlanRewriter<OrExpandsionContext>
implements CustomRewriter {
public static final OrExpansion INSTANCE = new OrExpansion();
public static final ImmutableSet<JoinType> supportJoinType = new
ImmutableSet
.Builder<JoinType>()
@@ -73,63 +79,101 @@ public class OrExpansion extends OneExplorationRuleFactory
{
.build();
@Override
- public Rule build() {
- return logicalJoin(any(), any()).when(JoinUtils::shouldNestedLoopJoin)
- .whenNot(LogicalJoin::isMarkJoin)
- .when(join -> supportJoinType.contains(join.getJoinType())
- &&
ConnectContext.get().getSessionVariable().getEnablePipelineEngine())
- .thenApply(ctx -> {
- LogicalJoin<? extends Plan, ? extends Plan> join =
ctx.root;
-
Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(),
- "Only Expansion nest loop join without hashCond");
+ public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+ OrExpandsionContext ctx = new OrExpandsionContext(
+ jobContext.getCascadesContext().getStatementContext(),
jobContext.getCascadesContext());
+ plan = plan.accept(this, ctx);
+ for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+ LogicalCTEProducer<? extends Plan> producer =
ctx.cteProducerList.get(i);
+ plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+ }
+ return plan;
+ }
+
+ @Override
+ public Plan visit(Plan plan, OrExpandsionContext ctx) {
+ List<Plan> newChildren = new ArrayList<>();
+ boolean hasNewChildren = false;
+ for (Plan child : plan.children()) {
+ Plan newChild = child.accept(this, ctx);
+ if (newChild != child) {
+ hasNewChildren = true;
+ }
+ newChildren.add(newChild);
+ }
+ return hasNewChildren ? plan.withChildren(newChildren) : plan;
+ }
- //1. Try to split or conditions
- Pair<List<Expression>, List<Expression>>
hashOtherConditions = splitOrCondition(join);
- if (hashOtherConditions == null ||
hashOtherConditions.first.size() <= 1) {
- return join;
- }
+ @Override
+ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan>
join, OrExpandsionContext ctx) {
+ join = (LogicalJoin<? extends Plan, ? extends Plan>) this.visit(join,
ctx);
+ if (join.isMarkJoin() || !JoinUtils.shouldNestedLoopJoin(join)) {
+ return join;
+ }
+ if (!(supportJoinType.contains(join.getJoinType())
+ &&
ConnectContext.get().getSessionVariable().getEnablePipelineEngine())) {
+ return join;
+ }
+ Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(),
+ "Only Expansion nest loop join without hashCond");
- //2. Construct CTE with the children
- LogicalCTEProducer<? extends Plan> leftProducer = new
LogicalCTEProducer<>(
- ctx.statementContext.getNextCTEId(), join.left());
- LogicalCTEProducer<? extends Plan> rightProducer = new
LogicalCTEProducer<>(
- ctx.statementContext.getNextCTEId(), join.right());
- List<Plan> joins = new ArrayList<>();
+ //1. Try to split or conditions
+ Pair<List<Expression>, List<Expression>> hashOtherConditions =
splitOrCondition(join);
+ if (hashOtherConditions == null || hashOtherConditions.first.size() <=
1) {
+ return join;
+ }
- // 3. Expand join to hash join with CTE
- if (join.getJoinType().isInnerJoin()) {
- joins.addAll(expandInnerJoin(ctx.cascadesContext,
hashOtherConditions,
- join, leftProducer, rightProducer));
- } else if (join.getJoinType().isOuterJoin()) {
- // left outer join = inner join union left anti join
- joins.addAll(expandInnerJoin(ctx.cascadesContext,
hashOtherConditions,
- join, leftProducer, rightProducer));
- joins.add(expandLeftAntiJoin(ctx.cascadesContext,
- hashOtherConditions, join, leftProducer,
rightProducer));
- if
(join.getJoinType().equals(JoinType.FULL_OUTER_JOIN)) {
- // full outer join = inner join union left anti
join union right anti join
- joins.add(expandLeftAntiJoin(ctx.cascadesContext,
- hashOtherConditions, join, rightProducer,
leftProducer));
- }
- } else if
(join.getJoinType().equals(JoinType.LEFT_ANTI_JOIN)) {
- joins.add(expandLeftAntiJoin(ctx.cascadesContext,
- hashOtherConditions, join, leftProducer,
rightProducer));
- } else {
- throw new RuntimeException("or-expansion is not
supported for " + join);
- }
+ //2. Construct CTE with the children
+ LogicalPlan leftClone = LogicalPlanDeepCopier.INSTANCE
+ .deepCopy((LogicalPlan) join.left(), new DeepCopierContext());
+ LogicalCTEProducer<? extends Plan> leftProducer = new
LogicalCTEProducer<>(
+ ctx.statementContext.getNextCTEId(), leftClone);
+ LogicalPlan rightClone = LogicalPlanDeepCopier.INSTANCE
+ .deepCopy((LogicalPlan) join.right(), new DeepCopierContext());
+ LogicalCTEProducer<? extends Plan> rightProducer = new
LogicalCTEProducer<>(
+ ctx.statementContext.getNextCTEId(), rightClone);
+ Map<Slot, Slot> leftCloneToLeft = new HashMap<>();
+ for (int i = 0; i < leftClone.getOutput().size(); i++) {
+ leftCloneToLeft.put(leftClone.getOutput().get(i),
(join.left()).getOutput().get(i));
+ }
+ Map<Slot, Slot> rightCloneToRight = new HashMap<>();
+ for (int i = 0; i < rightClone.getOutput().size(); i++) {
+ rightCloneToRight.put(rightClone.getOutput().get(i),
(join.right()).getOutput().get(i));
+ }
- //4. union all joins and construct LogicalCTEAnchor with
CTEs
- List<List<SlotReference>> childrenOutputs = joins.stream()
- .map(j -> j.getOutput().stream()
- .map(SlotReference.class::cast)
- .collect(ImmutableList.toImmutableList()))
- .collect(ImmutableList.toImmutableList());
- LogicalUnion union = new LogicalUnion(Qualifier.ALL, new
ArrayList<>(join.getOutput()),
- childrenOutputs, ImmutableList.of(), false, joins);
- LogicalCTEAnchor<? extends Plan, ? extends Plan>
intermediateAnchor = new LogicalCTEAnchor<>(
- rightProducer.getCteId(), rightProducer, union);
- return new LogicalCTEAnchor<Plan,
Plan>(leftProducer.getCteId(), leftProducer, intermediateAnchor);
- }).toRule(RuleType.OR_EXPANSION);
+ // 3. Expand join to hash join with CTE
+ List<Plan> joins = new ArrayList<>();
+ if (join.getJoinType().isInnerJoin()) {
+ joins.addAll(expandInnerJoin(ctx.cascadesContext,
hashOtherConditions,
+ join, leftProducer, rightProducer, leftCloneToLeft,
rightCloneToRight));
+ } else if (join.getJoinType().isOuterJoin()) {
+ // left outer join = inner join union left anti join
+ joins.addAll(expandInnerJoin(ctx.cascadesContext,
hashOtherConditions,
+ join, leftProducer, rightProducer, leftCloneToLeft,
rightCloneToRight));
+ joins.add(expandLeftAntiJoin(ctx.cascadesContext,
+ hashOtherConditions, join, leftProducer, rightProducer,
leftCloneToLeft, rightCloneToRight));
+ if (join.getJoinType().equals(JoinType.FULL_OUTER_JOIN)) {
+ // full outer join = inner join union left anti join union
right anti join
+ joins.add(expandLeftAntiJoin(ctx.cascadesContext,
hashOtherConditions,
+ join, rightProducer, leftProducer, rightCloneToRight,
leftCloneToLeft));
+ }
+ } else if (join.getJoinType().equals(JoinType.LEFT_ANTI_JOIN)) {
+ joins.add(expandLeftAntiJoin(ctx.cascadesContext,
hashOtherConditions,
+ join, leftProducer, rightProducer, leftCloneToLeft,
rightCloneToRight));
+ } else {
+ throw new RuntimeException("or-expansion is not supported for " +
join);
+ }
+ //4. union all joins and put producers to context
+ List<List<SlotReference>> childrenOutputs = joins.stream()
+ .map(j -> j.getOutput().stream()
+ .map(SlotReference.class::cast)
+ .collect(ImmutableList.toImmutableList()))
+ .collect(ImmutableList.toImmutableList());
+ LogicalUnion union = new LogicalUnion(Qualifier.ALL, new
ArrayList<>(join.getOutput()),
+ childrenOutputs, ImmutableList.of(), false, joins);
+ ctx.cteProducerList.add(leftProducer);
+ ctx.cteProducerList.add(rightProducer);
+ return union;
}
// try to find a condition that can be split into hash conditions
@@ -150,6 +194,18 @@ public class OrExpansion extends OneExplorationRuleFactory
{
return null;
}
+ private Map<Slot, Slot> constructReplaceMap(LogicalCTEConsumer
leftConsumer, Map<Slot, Slot> leftCloneToLeft,
+ LogicalCTEConsumer rightConsumer, Map<Slot, Slot>
rightCloneToRight) {
+ Map<Slot, Slot> replaced = new HashMap<>();
+ for (Entry<Slot, Slot> entry :
leftConsumer.getProducerToConsumerOutputMap().entrySet()) {
+ replaced.put(leftCloneToLeft.get(entry.getKey()),
entry.getValue());
+ }
+ for (Entry<Slot, Slot> entry :
rightConsumer.getProducerToConsumerOutputMap().entrySet()) {
+ replaced.put(rightCloneToRight.get(entry.getKey()),
entry.getValue());
+ }
+ return replaced;
+ }
+
// expand Anti Join:
// Left Anti join cond1 or cond2, other Left Anti join cond1 and
other
// / \ /
\
@@ -160,7 +216,8 @@ public class OrExpansion extends OneExplorationRuleFactory {
Pair<List<Expression>, List<Expression>> hashOtherConditions,
LogicalJoin<? extends Plan, ? extends Plan> originJoin,
LogicalCTEProducer<? extends Plan> leftProducer,
- LogicalCTEProducer<? extends
org.apache.doris.nereids.trees.plans.Plan> rightProducer) {
+ LogicalCTEProducer<? extends
org.apache.doris.nereids.trees.plans.Plan> rightProducer,
+ Map<Slot, Slot> leftCloneToLeft, Map<Slot, Slot>
rightCloneToRight) {
LogicalCTEConsumer left = new
LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
leftProducer.getCteId(), "", leftProducer);
LogicalCTEConsumer right = new
LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
@@ -168,8 +225,7 @@ public class OrExpansion extends OneExplorationRuleFactory {
ctx.putCTEIdToConsumer(left);
ctx.putCTEIdToConsumer(right);
- Map<Slot, Slot> replaced = new
HashMap<>(left.getProducerToConsumerOutputMap());
- replaced.putAll(right.getProducerToConsumerOutputMap());
+ Map<Slot, Slot> replaced = constructReplaceMap(left, leftCloneToLeft,
right, rightCloneToRight);
List<Expression> disjunctions = hashOtherConditions.first;
List<Expression> otherConditions = hashOtherConditions.second;
List<Expression> newOtherConditions = otherConditions.stream()
@@ -191,8 +247,7 @@ public class OrExpansion extends OneExplorationRuleFactory {
LogicalCTEConsumer newRight = new LogicalCTEConsumer(
ctx.getStatementContext().getNextRelationId(),
rightProducer.getCteId(), "", rightProducer);
ctx.putCTEIdToConsumer(newRight);
- Map<Slot, Slot> newReplaced = new
HashMap<>(left.getProducerToConsumerOutputMap());
- newReplaced.putAll(newRight.getProducerToConsumerOutputMap());
+ Map<Slot, Slot> newReplaced = constructReplaceMap(left,
leftCloneToLeft, newRight, rightCloneToRight);
newOtherConditions = otherConditions.stream()
.map(e -> e.rewriteUp(s -> newReplaced.containsKey(s) ?
newReplaced.get(s) : s))
.collect(Collectors.toList());
@@ -224,7 +279,8 @@ public class OrExpansion extends OneExplorationRuleFactory {
private List<Plan> expandInnerJoin(CascadesContext ctx,
Pair<List<Expression>,
List<Expression>> hashOtherConditions,
LogicalJoin<? extends Plan, ? extends Plan> join,
LogicalCTEProducer<? extends Plan> leftProducer,
- LogicalCTEProducer<? extends Plan> rightProducer) {
+ LogicalCTEProducer<? extends Plan> rightProducer,
+ Map<Slot, Slot> leftCloneToLeft, Map<Slot, Slot>
rightCloneToRight) {
List<Expression> disjunctions = hashOtherConditions.first;
List<Expression> otherConditions = hashOtherConditions.second;
// For null values, equalTo and not equalTo both return false
@@ -248,8 +304,7 @@ public class OrExpansion extends OneExplorationRuleFactory {
ctx.putCTEIdToConsumer(right);
//rewrite conjuncts to replace the old slots with CTE slots
- Map<Slot, Slot> replaced = new
HashMap<>(left.getProducerToConsumerOutputMap());
- replaced.putAll(right.getProducerToConsumerOutputMap());
+ Map<Slot, Slot> replaced = constructReplaceMap(left,
leftCloneToLeft, right, rightCloneToRight);
List<Expression> hashCond = pair.first.stream()
.map(e -> e.rewriteUp(s -> replaced.containsKey(s) ?
replaced.get(s) : s))
.collect(Collectors.toList());
@@ -283,4 +338,16 @@ public class OrExpansion extends OneExplorationRuleFactory
{
}
return Pair.of(Lists.newArrayList(equal.get(hashCondIdx)), others);
}
+
+ class OrExpandsionContext {
+ List<LogicalCTEProducer<? extends Plan>> cteProducerList;
+ StatementContext statementContext;
+ CascadesContext cascadesContext;
+
+ public OrExpandsionContext(StatementContext statementContext,
CascadesContext cascadesContext) {
+ this.statementContext = statementContext;
+ this.cteProducerList = new ArrayList<>();
+ this.cascadesContext = cascadesContext;
+ }
+ }
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrExpansionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrExpansionTest.java
new file mode 100644
index 00000000000..9f8bd8bcc55
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrExpansionTest.java
@@ -0,0 +1,86 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.utframe.TestWithFeService;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+class OrExpansionTest extends TestWithFeService implements
MemoPatternMatchSupported {
+ @Override
+ protected void runBeforeAll() throws Exception {
+ createDatabase("test");
+ connectContext.setDatabase("default_cluster:test");
+ createTables(
+ "CREATE TABLE IF NOT EXISTS t1 (\n"
+ + " id1 int not null,\n"
+ + " id2 int not null\n"
+ + ")\n"
+ + "DUPLICATE KEY(id1)\n"
+ + "DISTRIBUTED BY HASH(id1) BUCKETS 10\n"
+ + "PROPERTIES (\"replication_num\" = \"1\")\n",
+ "CREATE TABLE IF NOT EXISTS t2 (\n"
+ + " id1 int not null,\n"
+ + " id2 int not null\n"
+ + ")\n"
+ + "DUPLICATE KEY(id1)\n"
+ + "DISTRIBUTED BY HASH(id2) BUCKETS 10\n"
+ + "PROPERTIES (\"replication_num\" = \"1\")\n"
+ );
+ }
+
+ @Test
+ void testOrExpand() {
+
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
+ String sql = "select t1.id1 + 1 as id from t1 join t2 on t1.id1 =
t2.id1 or t1.id2 = t2.id2";
+ Plan plan = PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .printlnTree()
+ .getPlan();
+ Assertions.assertTrue(plan instanceof LogicalCTEAnchor);
+ Assertions.assertTrue(plan.child(1) instanceof LogicalCTEAnchor);
+ }
+
+ @Test
+ void testOrExpandCTE() {
+
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
+ connectContext.getSessionVariable().inlineCTEReferencedThreshold = 0;
+ String sql = "with t3 as (select t1.id1 + 1 as id1, t1.id2 + 2 as id2
from t1), "
+ + "t4 as (select t2.id1 + 1 as id1, t2.id2 + 2 as id2 from
t2) "
+ + "select t3.id1 from "
+ + "(select id1, id2 from t3 group by id1, id2) t3 "
+ + " join "
+ + "(select id1, id2 from t4 group by id1, id2) t4 "
+ + "on t3.id1 = t4.id1 or t3.id2 = t4.id2";
+ Plan plan = PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .printlnTree()
+ .getPlan();
+ Assertions.assertTrue(plan instanceof LogicalCTEAnchor);
+ Assertions.assertTrue(plan.child(1) instanceof LogicalCTEAnchor);
+ Assertions.assertTrue(plan.child(1).child(1) instanceof
LogicalCTEAnchor);
+ Assertions.assertTrue(plan.child(1).child(1).child(1) instanceof
LogicalCTEAnchor);
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]