This is an automated email from the ASF dual-hosted git repository. dataroaring pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 9472f9db15e branch-3.0: [fix](Nereids) we should also push down expr in join's mark conjuncts #50886 (#50954) 9472f9db15e is described below commit 9472f9db15ea634e9b0a1590314692d4d5633d3e Author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> AuthorDate: Wed Jun 11 10:46:47 2025 +0800 branch-3.0: [fix](Nereids) we should also push down expr in join's mark conjuncts #50886 (#50954) Cherry-picked from #50886 Co-authored-by: morrySnow <zhangwen...@selectdb.com> --- .../PushDownExpressionsInHashCondition.java | 59 +++++++++++---- .../nereids/trees/plans/logical/LogicalJoin.java | 17 ++++- .../PushDownExpressionsInHashConditionTest.java | 83 ++++++++++++++++++++++ 3 files changed, 144 insertions(+), 15 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashCondition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashCondition.java index 731fea0497d..10fdfc3ff23 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashCondition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashCondition.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; @@ -62,8 +63,19 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory { @Override public Rule build() { return logicalJoin() - .when(join -> join.getHashJoinConjuncts().stream().anyMatch(equalTo -> - equalTo.children().stream().anyMatch(e -> !(e instanceof Slot)))) + .when(join -> { + boolean needProcessHashConjuncts = join.getHashJoinConjuncts().stream() + .anyMatch(equalTo -> equalTo.children().stream() + .anyMatch(e -> !(e instanceof Slot))); + List<Slot> leftSlots = join.left().getOutput(); + List<Slot> rightSlots = join.right().getOutput(); + Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable( + leftSlots, rightSlots, join.getMarkJoinConjuncts()); + boolean needProcessMarkConjuncts = pair.first.stream() + .anyMatch(equalTo -> equalTo.children().stream() + .anyMatch(e -> !(e instanceof Slot))); + return needProcessHashConjuncts || needProcessMarkConjuncts; + }) .then(PushDownExpressionsInHashCondition::pushDownHashExpression) .toRule(RuleType.PUSH_DOWN_EXPRESSIONS_IN_HASH_CONDITIONS); } @@ -75,15 +87,20 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory { LogicalJoin<? extends Plan, ? extends Plan> join) { Set<NamedExpression> leftProjectExprs = Sets.newHashSet(); Set<NamedExpression> rightProjectExprs = Sets.newHashSet(); - Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap(); + Map<Expression, NamedExpression> replaceMap = Maps.newHashMap(); join.getHashJoinConjuncts().forEach(conjunct -> { Preconditions.checkArgument(conjunct instanceof EqualPredicate); // sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it // doesn't swap the two sides. conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) conjunct, join.left().getOutputSet()); - generateReplaceMapAndProjectExprs(conjunct.child(0), exprReplaceMap, leftProjectExprs); - generateReplaceMapAndProjectExprs(conjunct.child(1), exprReplaceMap, rightProjectExprs); + generateReplaceMapAndProjectExprs(conjunct.child(0), replaceMap, leftProjectExprs); + generateReplaceMapAndProjectExprs(conjunct.child(1), replaceMap, rightProjectExprs); }); + List<Expression> newHashConjuncts = join.getHashJoinConjuncts().stream() + .map(equalTo -> equalTo.withChildren(equalTo.children() + .stream().map(expr -> replaceMap.get(expr).toSlot()) + .collect(ImmutableList.toImmutableList()))) + .collect(ImmutableList.toImmutableList()); // add other conjuncts used slots to project exprs Set<ExprId> leftExprIdSet = join.left().getOutputExprIdSet(); @@ -100,7 +117,28 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory { }); // add mark conjuncts used slots to project exprs - join.getMarkJoinConjuncts().stream().flatMap(conjunct -> + // if mark conjuncts could be hash condition, normalize it + List<Slot> leftSlots = join.left().getOutput(); + List<Slot> rightSlots = join.right().getOutput(); + Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(leftSlots, + rightSlots, join.getMarkJoinConjuncts()); + pair.first.forEach(conjunct -> { + Preconditions.checkArgument(conjunct instanceof EqualPredicate); + // sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it + // doesn't swap the two sides. + conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) conjunct, join.left().getOutputSet()); + generateReplaceMapAndProjectExprs(conjunct.child(0), replaceMap, leftProjectExprs); + generateReplaceMapAndProjectExprs(conjunct.child(1), replaceMap, rightProjectExprs); + }); + ImmutableList.Builder<Expression> newMarkConjunctsBuilder = ImmutableList.builder(); + pair.first.stream() + .map(equalTo -> equalTo.withChildren(equalTo.children() + .stream().map(expr -> replaceMap.get(expr).toSlot()) + .collect(ImmutableList.toImmutableList()))) + .forEach(newMarkConjunctsBuilder::add); + newMarkConjunctsBuilder.addAll(pair.second); + + pair.second.stream().flatMap(conjunct -> conjunct.getInputSlots().stream() ).forEach(slot -> { if (leftExprIdSet.contains(slot.getExprId())) { @@ -111,14 +149,9 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory { rightProjectExprs.add(slot); } }); - - List<Expression> newHashConjuncts = join.getHashJoinConjuncts().stream() - .map(equalTo -> equalTo.withChildren(equalTo.children() - .stream().map(expr -> exprReplaceMap.get(expr).toSlot()) - .collect(ImmutableList.toImmutableList()))) - .collect(ImmutableList.toImmutableList()); - return join.withHashJoinConjunctsAndChildren( + return join.withHashAndMarkJoinConjunctsAndChildren( newHashConjuncts, + newMarkConjunctsBuilder.build(), createChildProjectPlan(join.left(), join, leftProjectExprs), createChildProjectPlan(join.right(), join, rightProjectExprs), join.getJoinReorderContext()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index c583360c3d8..3cae93cd512 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -380,8 +380,21 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends children, otherJoinReorderContext); } - public LogicalJoin<Plan, Plan> withHashJoinConjunctsAndChildren( - List<Expression> hashJoinConjuncts, Plan left, Plan right, JoinReorderContext otherJoinReorderContext) { + /** + * Creates a new LogicalJoin with updated hash join conjuncts, mark join conjuncts, and child plans. + * + * @param hashJoinConjuncts the list of hash join conjuncts used for hash-based join conditions. + * @param markJoinConjuncts the list of mark join conjuncts used for marking specific join conditions. + * These are typically used in semi-join or anti-join scenarios to track + * whether a condition is satisfied. + * @param left the left child plan. + * @param right the right child plan. + * @param otherJoinReorderContext the context for join reordering. + * @return a new LogicalJoin instance with the specified parameters. + */ + public LogicalJoin<Plan, Plan> withHashAndMarkJoinConjunctsAndChildren( + List<Expression> hashJoinConjuncts, List<Expression> markJoinConjuncts, + Plan left, Plan right, JoinReorderContext otherJoinReorderContext) { Preconditions.checkArgument(children.size() == 2); return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts, hint, markJoinSlotReference, Optional.empty(), Optional.empty(), diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashConditionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashConditionTest.java index f9907e2088a..7f860062b0e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashConditionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashConditionTest.java @@ -20,7 +20,26 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext; +import org.apache.doris.nereids.trees.UnaryNode; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Positive; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.StringType; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; @@ -184,4 +203,68 @@ public class PushDownExpressionsInHashConditionTest extends TestWithFeService im ) ); } + + @Test + public void testPushDownMarkConjuncts() { + Plan left = new LogicalOneRowRelation(new RelationId(1), + ImmutableList.of(new Alias(new ExprId(1), new IntegerLiteral(1), "a"))); + Plan right = new LogicalOneRowRelation(new RelationId(2), + ImmutableList.of(new Alias(new ExprId(2), new IntegerLiteral(2), "b"))); + Expression sameLeft = new Abs(left.getOutput().get(0)); + Expression sameRight = new Positive(right.getOutput().get(0)); + Expression hashLeft = new Cast(sameLeft, StringType.INSTANCE); + Expression hashRight = new Cast(sameRight, StringType.INSTANCE); + Expression markLeft = new Cast(sameLeft, BigIntType.INSTANCE); + Expression markRight = new Cast(sameRight, BigIntType.INSTANCE); + + LogicalJoin<?, ?> plan = new LogicalJoin<>( + JoinType.INNER_JOIN, + left, + right, + new JoinReorderContext() + ); + + Expression sameConjuncts = new EqualTo(sameLeft, sameRight); + Expression hashConjuncts = new EqualTo(hashLeft, hashRight); + Expression markConjuncts = new EqualTo(markLeft, markRight); + Expression otherConjuncts = new Add(left.getOutput().get(0), new IntegerLiteral(1)); + + plan = plan.withJoinConjuncts(ImmutableList.of(sameConjuncts, hashConjuncts), ImmutableList.of(otherConjuncts), + ImmutableList.of(sameConjuncts, markConjuncts, otherConjuncts), + new JoinReorderContext()); + + PlanChecker.from(connectContext, plan).applyTopDown(new PushDownExpressionsInHashCondition()) + .matches(logicalJoin(logicalProject(logicalOneRowRelation()) + .when(p -> p.getProjects().size() == 4 + && p.getProjects().stream().filter(Alias.class::isInstance) + .map(Alias.class::cast).map(UnaryNode::child) + .filter(sameLeft::equals).count() == 1 + && p.getProjects().stream().filter(Alias.class::isInstance) + .map(Alias.class::cast).map(UnaryNode::child) + .filter(markLeft::equals).count() == 1 + && p.getProjects().stream().filter(Alias.class::isInstance) + .map(Alias.class::cast).map(UnaryNode::child) + .filter(hashLeft::equals).count() == 1), + logicalProject(logicalOneRowRelation()) + .when(p -> p.getProjects().size() == 4 + && p.getProjects().stream().filter(Alias.class::isInstance) + .map(Alias.class::cast).map(UnaryNode::child) + .filter(sameRight::equals).count() == 1 + && p.getProjects().stream().filter(Alias.class::isInstance) + .map(Alias.class::cast).map(UnaryNode::child) + .filter(markRight::equals).count() == 1 + && p.getProjects().stream().filter(Alias.class::isInstance) + .map(Alias.class::cast).map(UnaryNode::child) + .filter(hashRight::equals).count() == 1) + ).when(j -> j.getMarkJoinConjuncts().size() == 3 + && j.getMarkJoinConjuncts().stream().filter(EqualTo.class::isInstance) + .allMatch(e -> ((EqualTo) e).left() instanceof SlotReference + && ((EqualTo) e).right() instanceof SlotReference) + && j.getMarkJoinConjuncts().stream().filter(EqualTo.class::isInstance).count() == 2) + .when(j -> j.getHashJoinConjuncts().size() == 2 + && j.getHashJoinConjuncts().stream().filter(EqualTo.class::isInstance) + .allMatch(e -> ((EqualTo) e).left() instanceof SlotReference + && ((EqualTo) e).right() instanceof SlotReference) + && j.getHashJoinConjuncts().stream().filter(EqualTo.class::isInstance).count() == 2)); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org