This is an automated email from the ASF dual-hosted git repository.
dataroaring pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 9472f9db15e branch-3.0: [fix](Nereids) we should also push down expr
in join's mark conjuncts #50886 (#50954)
9472f9db15e is described below
commit 9472f9db15ea634e9b0a1590314692d4d5633d3e
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Wed Jun 11 10:46:47 2025 +0800
branch-3.0: [fix](Nereids) we should also push down expr in join's mark
conjuncts #50886 (#50954)
Cherry-picked from #50886
Co-authored-by: morrySnow <[email protected]>
---
.../PushDownExpressionsInHashCondition.java | 59 +++++++++++----
.../nereids/trees/plans/logical/LogicalJoin.java | 17 ++++-
.../PushDownExpressionsInHashConditionTest.java | 83 ++++++++++++++++++++++
3 files changed, 144 insertions(+), 15 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashCondition.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashCondition.java
index 731fea0497d..10fdfc3ff23 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashCondition.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashCondition.java
@@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.rewrite;
+import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
@@ -62,8 +63,19 @@ public class PushDownExpressionsInHashCondition extends
OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalJoin()
- .when(join ->
join.getHashJoinConjuncts().stream().anyMatch(equalTo ->
- equalTo.children().stream().anyMatch(e -> !(e
instanceof Slot))))
+ .when(join -> {
+ boolean needProcessHashConjuncts =
join.getHashJoinConjuncts().stream()
+ .anyMatch(equalTo -> equalTo.children().stream()
+ .anyMatch(e -> !(e instanceof Slot)));
+ List<Slot> leftSlots = join.left().getOutput();
+ List<Slot> rightSlots = join.right().getOutput();
+ Pair<List<Expression>, List<Expression>> pair =
JoinUtils.extractExpressionForHashTable(
+ leftSlots, rightSlots,
join.getMarkJoinConjuncts());
+ boolean needProcessMarkConjuncts = pair.first.stream()
+ .anyMatch(equalTo -> equalTo.children().stream()
+ .anyMatch(e -> !(e instanceof Slot)));
+ return needProcessHashConjuncts ||
needProcessMarkConjuncts;
+ })
.then(PushDownExpressionsInHashCondition::pushDownHashExpression)
.toRule(RuleType.PUSH_DOWN_EXPRESSIONS_IN_HASH_CONDITIONS);
}
@@ -75,15 +87,20 @@ public class PushDownExpressionsInHashCondition extends
OneRewriteRuleFactory {
LogicalJoin<? extends Plan, ? extends Plan> join) {
Set<NamedExpression> leftProjectExprs = Sets.newHashSet();
Set<NamedExpression> rightProjectExprs = Sets.newHashSet();
- Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap();
+ Map<Expression, NamedExpression> replaceMap = Maps.newHashMap();
join.getHashJoinConjuncts().forEach(conjunct -> {
Preconditions.checkArgument(conjunct instanceof EqualPredicate);
// sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the
situation, but actually it
// doesn't swap the two sides.
conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate)
conjunct, join.left().getOutputSet());
- generateReplaceMapAndProjectExprs(conjunct.child(0),
exprReplaceMap, leftProjectExprs);
- generateReplaceMapAndProjectExprs(conjunct.child(1),
exprReplaceMap, rightProjectExprs);
+ generateReplaceMapAndProjectExprs(conjunct.child(0), replaceMap,
leftProjectExprs);
+ generateReplaceMapAndProjectExprs(conjunct.child(1), replaceMap,
rightProjectExprs);
});
+ List<Expression> newHashConjuncts =
join.getHashJoinConjuncts().stream()
+ .map(equalTo -> equalTo.withChildren(equalTo.children()
+ .stream().map(expr -> replaceMap.get(expr).toSlot())
+ .collect(ImmutableList.toImmutableList())))
+ .collect(ImmutableList.toImmutableList());
// add other conjuncts used slots to project exprs
Set<ExprId> leftExprIdSet = join.left().getOutputExprIdSet();
@@ -100,7 +117,28 @@ public class PushDownExpressionsInHashCondition extends
OneRewriteRuleFactory {
});
// add mark conjuncts used slots to project exprs
- join.getMarkJoinConjuncts().stream().flatMap(conjunct ->
+ // if mark conjuncts could be hash condition, normalize it
+ List<Slot> leftSlots = join.left().getOutput();
+ List<Slot> rightSlots = join.right().getOutput();
+ Pair<List<Expression>, List<Expression>> pair =
JoinUtils.extractExpressionForHashTable(leftSlots,
+ rightSlots, join.getMarkJoinConjuncts());
+ pair.first.forEach(conjunct -> {
+ Preconditions.checkArgument(conjunct instanceof EqualPredicate);
+ // sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the
situation, but actually it
+ // doesn't swap the two sides.
+ conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate)
conjunct, join.left().getOutputSet());
+ generateReplaceMapAndProjectExprs(conjunct.child(0), replaceMap,
leftProjectExprs);
+ generateReplaceMapAndProjectExprs(conjunct.child(1), replaceMap,
rightProjectExprs);
+ });
+ ImmutableList.Builder<Expression> newMarkConjunctsBuilder =
ImmutableList.builder();
+ pair.first.stream()
+ .map(equalTo -> equalTo.withChildren(equalTo.children()
+ .stream().map(expr -> replaceMap.get(expr).toSlot())
+ .collect(ImmutableList.toImmutableList())))
+ .forEach(newMarkConjunctsBuilder::add);
+ newMarkConjunctsBuilder.addAll(pair.second);
+
+ pair.second.stream().flatMap(conjunct ->
conjunct.getInputSlots().stream()
).forEach(slot -> {
if (leftExprIdSet.contains(slot.getExprId())) {
@@ -111,14 +149,9 @@ public class PushDownExpressionsInHashCondition extends
OneRewriteRuleFactory {
rightProjectExprs.add(slot);
}
});
-
- List<Expression> newHashConjuncts =
join.getHashJoinConjuncts().stream()
- .map(equalTo -> equalTo.withChildren(equalTo.children()
- .stream().map(expr ->
exprReplaceMap.get(expr).toSlot())
- .collect(ImmutableList.toImmutableList())))
- .collect(ImmutableList.toImmutableList());
- return join.withHashJoinConjunctsAndChildren(
+ return join.withHashAndMarkJoinConjunctsAndChildren(
newHashConjuncts,
+ newMarkConjunctsBuilder.build(),
createChildProjectPlan(join.left(), join, leftProjectExprs),
createChildProjectPlan(join.right(), join, rightProjectExprs),
join.getJoinReorderContext());
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
index c583360c3d8..3cae93cd512 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
@@ -380,8 +380,21 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan,
RIGHT_CHILD_TYPE extends
children, otherJoinReorderContext);
}
- public LogicalJoin<Plan, Plan> withHashJoinConjunctsAndChildren(
- List<Expression> hashJoinConjuncts, Plan left, Plan right,
JoinReorderContext otherJoinReorderContext) {
+ /**
+ * Creates a new LogicalJoin with updated hash join conjuncts, mark join
conjuncts, and child plans.
+ *
+ * @param hashJoinConjuncts the list of hash join conjuncts used for
hash-based join conditions.
+ * @param markJoinConjuncts the list of mark join conjuncts used for
marking specific join conditions.
+ * These are typically used in semi-join or
anti-join scenarios to track
+ * whether a condition is satisfied.
+ * @param left the left child plan.
+ * @param right the right child plan.
+ * @param otherJoinReorderContext the context for join reordering.
+ * @return a new LogicalJoin instance with the specified parameters.
+ */
+ public LogicalJoin<Plan, Plan> withHashAndMarkJoinConjunctsAndChildren(
+ List<Expression> hashJoinConjuncts, List<Expression>
markJoinConjuncts,
+ Plan left, Plan right, JoinReorderContext otherJoinReorderContext)
{
Preconditions.checkArgument(children.size() == 2);
return new LogicalJoin<>(joinType, hashJoinConjuncts,
otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, Optional.empty(),
Optional.empty(),
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashConditionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashConditionTest.java
index f9907e2088a..7f860062b0e 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashConditionTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownExpressionsInHashConditionTest.java
@@ -20,7 +20,26 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.PhysicalProperties;
+import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
+import org.apache.doris.nereids.trees.UnaryNode;
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Positive;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.RelationId;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
@@ -184,4 +203,68 @@ public class PushDownExpressionsInHashConditionTest
extends TestWithFeService im
)
);
}
+
+ @Test
+ public void testPushDownMarkConjuncts() {
+ Plan left = new LogicalOneRowRelation(new RelationId(1),
+ ImmutableList.of(new Alias(new ExprId(1), new
IntegerLiteral(1), "a")));
+ Plan right = new LogicalOneRowRelation(new RelationId(2),
+ ImmutableList.of(new Alias(new ExprId(2), new
IntegerLiteral(2), "b")));
+ Expression sameLeft = new Abs(left.getOutput().get(0));
+ Expression sameRight = new Positive(right.getOutput().get(0));
+ Expression hashLeft = new Cast(sameLeft, StringType.INSTANCE);
+ Expression hashRight = new Cast(sameRight, StringType.INSTANCE);
+ Expression markLeft = new Cast(sameLeft, BigIntType.INSTANCE);
+ Expression markRight = new Cast(sameRight, BigIntType.INSTANCE);
+
+ LogicalJoin<?, ?> plan = new LogicalJoin<>(
+ JoinType.INNER_JOIN,
+ left,
+ right,
+ new JoinReorderContext()
+ );
+
+ Expression sameConjuncts = new EqualTo(sameLeft, sameRight);
+ Expression hashConjuncts = new EqualTo(hashLeft, hashRight);
+ Expression markConjuncts = new EqualTo(markLeft, markRight);
+ Expression otherConjuncts = new Add(left.getOutput().get(0), new
IntegerLiteral(1));
+
+ plan = plan.withJoinConjuncts(ImmutableList.of(sameConjuncts,
hashConjuncts), ImmutableList.of(otherConjuncts),
+ ImmutableList.of(sameConjuncts, markConjuncts, otherConjuncts),
+ new JoinReorderContext());
+
+ PlanChecker.from(connectContext, plan).applyTopDown(new
PushDownExpressionsInHashCondition())
+ .matches(logicalJoin(logicalProject(logicalOneRowRelation())
+ .when(p -> p.getProjects().size() == 4
+ &&
p.getProjects().stream().filter(Alias.class::isInstance)
+
.map(Alias.class::cast).map(UnaryNode::child)
+ .filter(sameLeft::equals).count() == 1
+ &&
p.getProjects().stream().filter(Alias.class::isInstance)
+
.map(Alias.class::cast).map(UnaryNode::child)
+ .filter(markLeft::equals).count() == 1
+ &&
p.getProjects().stream().filter(Alias.class::isInstance)
+
.map(Alias.class::cast).map(UnaryNode::child)
+ .filter(hashLeft::equals).count() ==
1),
+ logicalProject(logicalOneRowRelation())
+ .when(p -> p.getProjects().size() == 4
+ &&
p.getProjects().stream().filter(Alias.class::isInstance)
+
.map(Alias.class::cast).map(UnaryNode::child)
+ .filter(sameRight::equals).count() == 1
+ &&
p.getProjects().stream().filter(Alias.class::isInstance)
+
.map(Alias.class::cast).map(UnaryNode::child)
+ .filter(markRight::equals).count() == 1
+ &&
p.getProjects().stream().filter(Alias.class::isInstance)
+
.map(Alias.class::cast).map(UnaryNode::child)
+ .filter(hashRight::equals).count() ==
1)
+ ).when(j -> j.getMarkJoinConjuncts().size() == 3
+ &&
j.getMarkJoinConjuncts().stream().filter(EqualTo.class::isInstance)
+ .allMatch(e -> ((EqualTo) e).left() instanceof
SlotReference
+ && ((EqualTo) e).right() instanceof
SlotReference)
+ &&
j.getMarkJoinConjuncts().stream().filter(EqualTo.class::isInstance).count() ==
2)
+ .when(j -> j.getHashJoinConjuncts().size() == 2
+ &&
j.getHashJoinConjuncts().stream().filter(EqualTo.class::isInstance)
+ .allMatch(e -> ((EqualTo) e).left() instanceof
SlotReference
+ && ((EqualTo) e).right() instanceof
SlotReference)
+ &&
j.getHashJoinConjuncts().stream().filter(EqualTo.class::isInstance).count() ==
2));
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]