This is an automated email from the ASF dual-hosted git repository.

dataroaring pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 9472f9db15e branch-3.0: [fix](Nereids) we should also push down expr 
in join's mark conjuncts #50886 (#50954)
9472f9db15e is described below

commit 9472f9db15ea634e9b0a1590314692d4d5633d3e
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Wed Jun 11 10:46:47 2025 +0800

    branch-3.0: [fix](Nereids) we should also push down expr in join's mark 
conjuncts #50886 (#50954)
    
    Cherry-picked from #50886
    
    Co-authored-by: morrySnow <zhangwen...@selectdb.com>
---
 .../PushDownExpressionsInHashCondition.java        | 59 +++++++++++----
 .../nereids/trees/plans/logical/LogicalJoin.java   | 17 ++++-
 .../PushDownExpressionsInHashConditionTest.java    | 83 ++++++++++++++++++++++
 3 files changed, 144 insertions(+), 15 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashCondition.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashCondition.java
index 731fea0497d..10fdfc3ff23 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashCondition.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashCondition.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Alias;
@@ -62,8 +63,19 @@ public class PushDownExpressionsInHashCondition extends 
OneRewriteRuleFactory {
     @Override
     public Rule build() {
         return logicalJoin()
-                .when(join -> 
join.getHashJoinConjuncts().stream().anyMatch(equalTo ->
-                        equalTo.children().stream().anyMatch(e -> !(e 
instanceof Slot))))
+                .when(join -> {
+                    boolean needProcessHashConjuncts = 
join.getHashJoinConjuncts().stream()
+                            .anyMatch(equalTo -> equalTo.children().stream()
+                                    .anyMatch(e -> !(e instanceof Slot)));
+                    List<Slot> leftSlots = join.left().getOutput();
+                    List<Slot> rightSlots = join.right().getOutput();
+                    Pair<List<Expression>, List<Expression>> pair = 
JoinUtils.extractExpressionForHashTable(
+                            leftSlots, rightSlots, 
join.getMarkJoinConjuncts());
+                    boolean needProcessMarkConjuncts = pair.first.stream()
+                            .anyMatch(equalTo -> equalTo.children().stream()
+                                    .anyMatch(e -> !(e instanceof Slot)));
+                    return needProcessHashConjuncts || 
needProcessMarkConjuncts;
+                })
                 
.then(PushDownExpressionsInHashCondition::pushDownHashExpression)
                 .toRule(RuleType.PUSH_DOWN_EXPRESSIONS_IN_HASH_CONDITIONS);
     }
@@ -75,15 +87,20 @@ public class PushDownExpressionsInHashCondition extends 
OneRewriteRuleFactory {
             LogicalJoin<? extends Plan, ? extends Plan> join) {
         Set<NamedExpression> leftProjectExprs = Sets.newHashSet();
         Set<NamedExpression> rightProjectExprs = Sets.newHashSet();
-        Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap();
+        Map<Expression, NamedExpression> replaceMap = Maps.newHashMap();
         join.getHashJoinConjuncts().forEach(conjunct -> {
             Preconditions.checkArgument(conjunct instanceof EqualPredicate);
             // sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the 
situation, but actually it
             // doesn't swap the two sides.
             conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) 
conjunct, join.left().getOutputSet());
-            generateReplaceMapAndProjectExprs(conjunct.child(0), 
exprReplaceMap, leftProjectExprs);
-            generateReplaceMapAndProjectExprs(conjunct.child(1), 
exprReplaceMap, rightProjectExprs);
+            generateReplaceMapAndProjectExprs(conjunct.child(0), replaceMap, 
leftProjectExprs);
+            generateReplaceMapAndProjectExprs(conjunct.child(1), replaceMap, 
rightProjectExprs);
         });
+        List<Expression> newHashConjuncts = 
join.getHashJoinConjuncts().stream()
+                .map(equalTo -> equalTo.withChildren(equalTo.children()
+                        .stream().map(expr -> replaceMap.get(expr).toSlot())
+                        .collect(ImmutableList.toImmutableList())))
+                .collect(ImmutableList.toImmutableList());
 
         // add other conjuncts used slots to project exprs
         Set<ExprId> leftExprIdSet = join.left().getOutputExprIdSet();
@@ -100,7 +117,28 @@ public class PushDownExpressionsInHashCondition extends 
OneRewriteRuleFactory {
         });
 
         // add mark conjuncts used slots to project exprs
-        join.getMarkJoinConjuncts().stream().flatMap(conjunct ->
+        // if mark conjuncts could be hash condition, normalize it
+        List<Slot> leftSlots = join.left().getOutput();
+        List<Slot> rightSlots = join.right().getOutput();
+        Pair<List<Expression>, List<Expression>> pair = 
JoinUtils.extractExpressionForHashTable(leftSlots,
+                rightSlots, join.getMarkJoinConjuncts());
+        pair.first.forEach(conjunct -> {
+            Preconditions.checkArgument(conjunct instanceof EqualPredicate);
+            // sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the 
situation, but actually it
+            // doesn't swap the two sides.
+            conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) 
conjunct, join.left().getOutputSet());
+            generateReplaceMapAndProjectExprs(conjunct.child(0), replaceMap, 
leftProjectExprs);
+            generateReplaceMapAndProjectExprs(conjunct.child(1), replaceMap, 
rightProjectExprs);
+        });
+        ImmutableList.Builder<Expression> newMarkConjunctsBuilder = 
ImmutableList.builder();
+        pair.first.stream()
+                .map(equalTo -> equalTo.withChildren(equalTo.children()
+                        .stream().map(expr -> replaceMap.get(expr).toSlot())
+                        .collect(ImmutableList.toImmutableList())))
+                .forEach(newMarkConjunctsBuilder::add);
+        newMarkConjunctsBuilder.addAll(pair.second);
+
+        pair.second.stream().flatMap(conjunct ->
                 conjunct.getInputSlots().stream()
         ).forEach(slot -> {
             if (leftExprIdSet.contains(slot.getExprId())) {
@@ -111,14 +149,9 @@ public class PushDownExpressionsInHashCondition extends 
OneRewriteRuleFactory {
                 rightProjectExprs.add(slot);
             }
         });
-
-        List<Expression> newHashConjuncts = 
join.getHashJoinConjuncts().stream()
-                .map(equalTo -> equalTo.withChildren(equalTo.children()
-                        .stream().map(expr -> 
exprReplaceMap.get(expr).toSlot())
-                        .collect(ImmutableList.toImmutableList())))
-                .collect(ImmutableList.toImmutableList());
-        return join.withHashJoinConjunctsAndChildren(
+        return join.withHashAndMarkJoinConjunctsAndChildren(
                 newHashConjuncts,
+                newMarkConjunctsBuilder.build(),
                 createChildProjectPlan(join.left(), join, leftProjectExprs),
                 createChildProjectPlan(join.right(), join, rightProjectExprs), 
join.getJoinReorderContext());
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
index c583360c3d8..3cae93cd512 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
@@ -380,8 +380,21 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
                 children, otherJoinReorderContext);
     }
 
-    public LogicalJoin<Plan, Plan> withHashJoinConjunctsAndChildren(
-            List<Expression> hashJoinConjuncts, Plan left, Plan right, 
JoinReorderContext otherJoinReorderContext) {
+    /**
+     * Creates a new LogicalJoin with updated hash join conjuncts, mark join 
conjuncts, and child plans.
+     *
+     * @param hashJoinConjuncts the list of hash join conjuncts used for 
hash-based join conditions.
+     * @param markJoinConjuncts the list of mark join conjuncts used for 
marking specific join conditions.
+     *                          These are typically used in semi-join or 
anti-join scenarios to track
+     *                          whether a condition is satisfied.
+     * @param left the left child plan.
+     * @param right the right child plan.
+     * @param otherJoinReorderContext the context for join reordering.
+     * @return a new LogicalJoin instance with the specified parameters.
+     */
+    public LogicalJoin<Plan, Plan> withHashAndMarkJoinConjunctsAndChildren(
+            List<Expression> hashJoinConjuncts, List<Expression> 
markJoinConjuncts,
+            Plan left, Plan right, JoinReorderContext otherJoinReorderContext) 
{
         Preconditions.checkArgument(children.size() == 2);
         return new LogicalJoin<>(joinType, hashJoinConjuncts, 
otherJoinConjuncts, markJoinConjuncts,
                 hint, markJoinSlotReference, Optional.empty(), 
Optional.empty(),
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashConditionTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashConditionTest.java
index f9907e2088a..7f860062b0e 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashConditionTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashConditionTest.java
@@ -20,7 +20,26 @@ package org.apache.doris.nereids.rules.rewrite;
 import org.apache.doris.nereids.NereidsPlanner;
 import org.apache.doris.nereids.parser.NereidsParser;
 import org.apache.doris.nereids.properties.PhysicalProperties;
+import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
+import org.apache.doris.nereids.trees.UnaryNode;
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Positive;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.RelationId;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.StringType;
 import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.utframe.TestWithFeService;
@@ -184,4 +203,68 @@ public class PushDownExpressionsInHashConditionTest 
extends TestWithFeService im
                     )
                 );
     }
+
+    @Test
+    public void testPushDownMarkConjuncts() {
+        Plan left = new LogicalOneRowRelation(new RelationId(1),
+                ImmutableList.of(new Alias(new ExprId(1), new 
IntegerLiteral(1), "a")));
+        Plan right = new LogicalOneRowRelation(new RelationId(2),
+                ImmutableList.of(new Alias(new ExprId(2), new 
IntegerLiteral(2), "b")));
+        Expression sameLeft = new Abs(left.getOutput().get(0));
+        Expression sameRight = new Positive(right.getOutput().get(0));
+        Expression hashLeft = new Cast(sameLeft, StringType.INSTANCE);
+        Expression hashRight = new Cast(sameRight, StringType.INSTANCE);
+        Expression markLeft = new Cast(sameLeft, BigIntType.INSTANCE);
+        Expression markRight = new Cast(sameRight, BigIntType.INSTANCE);
+
+        LogicalJoin<?, ?> plan = new LogicalJoin<>(
+                JoinType.INNER_JOIN,
+                left,
+                right,
+                new JoinReorderContext()
+        );
+
+        Expression sameConjuncts = new EqualTo(sameLeft, sameRight);
+        Expression hashConjuncts = new EqualTo(hashLeft, hashRight);
+        Expression markConjuncts = new EqualTo(markLeft, markRight);
+        Expression otherConjuncts = new Add(left.getOutput().get(0), new 
IntegerLiteral(1));
+
+        plan = plan.withJoinConjuncts(ImmutableList.of(sameConjuncts, 
hashConjuncts), ImmutableList.of(otherConjuncts),
+                ImmutableList.of(sameConjuncts, markConjuncts, otherConjuncts),
+                new JoinReorderContext());
+
+        PlanChecker.from(connectContext, plan).applyTopDown(new 
PushDownExpressionsInHashCondition())
+                .matches(logicalJoin(logicalProject(logicalOneRowRelation())
+                                .when(p -> p.getProjects().size() == 4
+                                        && 
p.getProjects().stream().filter(Alias.class::isInstance)
+                                        
.map(Alias.class::cast).map(UnaryNode::child)
+                                        .filter(sameLeft::equals).count() == 1
+                                        && 
p.getProjects().stream().filter(Alias.class::isInstance)
+                                        
.map(Alias.class::cast).map(UnaryNode::child)
+                                        .filter(markLeft::equals).count() == 1
+                                        && 
p.getProjects().stream().filter(Alias.class::isInstance)
+                                        
.map(Alias.class::cast).map(UnaryNode::child)
+                                        .filter(hashLeft::equals).count() == 
1),
+                        logicalProject(logicalOneRowRelation())
+                                .when(p -> p.getProjects().size() == 4
+                                        && 
p.getProjects().stream().filter(Alias.class::isInstance)
+                                        
.map(Alias.class::cast).map(UnaryNode::child)
+                                        .filter(sameRight::equals).count() == 1
+                                        && 
p.getProjects().stream().filter(Alias.class::isInstance)
+                                        
.map(Alias.class::cast).map(UnaryNode::child)
+                                        .filter(markRight::equals).count() == 1
+                                        && 
p.getProjects().stream().filter(Alias.class::isInstance)
+                                        
.map(Alias.class::cast).map(UnaryNode::child)
+                                        .filter(hashRight::equals).count() == 
1)
+                        ).when(j -> j.getMarkJoinConjuncts().size() == 3
+                        && 
j.getMarkJoinConjuncts().stream().filter(EqualTo.class::isInstance)
+                                .allMatch(e -> ((EqualTo) e).left() instanceof 
SlotReference
+                                        && ((EqualTo) e).right() instanceof 
SlotReference)
+                        && 
j.getMarkJoinConjuncts().stream().filter(EqualTo.class::isInstance).count() == 
2)
+                        .when(j -> j.getHashJoinConjuncts().size() == 2
+                        && 
j.getHashJoinConjuncts().stream().filter(EqualTo.class::isInstance)
+                                .allMatch(e -> ((EqualTo) e).left() instanceof 
SlotReference
+                                        && ((EqualTo) e).right() instanceof 
SlotReference)
+                        && 
j.getHashJoinConjuncts().stream().filter(EqualTo.class::isInstance).count() == 
2));
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to