This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 7db879b4f27 [fix](project push down) union and its children's output
are mis order (#58765)
7db879b4f27 is described below
commit 7db879b4f2793b05ca7e020110d58d5fdeee35a4
Author: morrySnow <[email protected]>
AuthorDate: Mon Dec 8 10:27:54 2025 +0800
[fix](project push down) union and its children's output are mis order
(#58765)
### What problem does this PR solve?
Related PR: #57204
Problem Summary:
This pull request refactors and improves the `PushDownProject` rule in
the Nereids optimizer, mainly focusing on the logic for pushing down
projections through `UNION` operations. It also introduces a
comprehensive unit test to verify the new logic, making the relevant
methods more testable and robust.
**Refactoring and Logic Improvements:**
* Refactored the `pushThroughUnion` logic by extracting it into a new
static method, making it easier to test and use independently. The main
logic now takes explicit arguments instead of relying on the context
object.
* Improved the handling of projections and child outputs when pushing
down through `UNION`, ensuring correct mapping and replacement of slots.
This includes using regulator outputs for children and constant
expressions, and making the slot replacement logic static for better
testability.
**Testing Enhancements:**
* Added a new unit test class `PushDownProjectTest` to rigorously test
the pushdown logic in various scenarios, including unions with and
without children. The tests verify both the structure and the
correctness of the rewritten plans.
**Code Quality Improvements:**
* Added the `@VisibleForTesting` annotation and imported necessary
dependencies to clarify method visibility and intent for testing.
* Replaced some usages of `Collection` with `List` for better type
safety and clarity in projection handling.
These changes make the projection pushdown logic more modular, testable,
and robust, and provide strong test coverage for future maintenance.
---
.../nereids/rules/rewrite/PushDownProject.java | 168 +++++++++---------
.../nereids/rules/rewrite/PushDownProjectTest.java | 192 +++++++++++++++++++++
2 files changed, 279 insertions(+), 81 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java
index 9fbc9413b29..832f9c25e77 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java
@@ -36,10 +36,12 @@ import
org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableCollection;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import java.util.ArrayList;
@@ -198,101 +200,105 @@ public class PushDownProject implements
RewriteRuleFactory, NormalizeToSlot {
if (!ctx.connectContext.getSessionVariable().enablePruneNestedColumns)
{
return ctx.root;
}
- LogicalProject<LogicalUnion> project = ctx.root;
+ return pushThroughUnion(ctx.root, ctx.statementContext);
+ }
+
+ @VisibleForTesting
+ static Plan pushThroughUnion(LogicalProject<LogicalUnion> project,
StatementContext statementContext) {
LogicalUnion union = project.child();
PushdownProjectHelper pushdownProjectHelper
- = new PushdownProjectHelper(ctx.statementContext, project);
-
+ = new PushdownProjectHelper(statementContext, project);
Pair<Boolean, List<NamedExpression>> pushProjects
=
pushdownProjectHelper.pushDownExpressions(project.getProjects());
- if (pushProjects.first) {
- List<NamedExpression> unionOutputs = union.getOutputs();
- Map<Slot, Integer> slotToColumnIndex = new LinkedHashMap<>();
- for (int i = 0; i < unionOutputs.size(); i++) {
- NamedExpression output = unionOutputs.get(i);
- slotToColumnIndex.put(output.toSlot(), i);
- }
+ if (!pushProjects.first) {
+ return project;
+ }
+ List<NamedExpression> unionOutputs = union.getOutputs();
+ Map<Slot, Integer> slotToColumnIndex = new LinkedHashMap<>();
+ for (int i = 0; i < unionOutputs.size(); i++) {
+ NamedExpression output = unionOutputs.get(i);
+ slotToColumnIndex.put(output.toSlot(), i);
+ }
- Collection<NamedExpression> pushDownProjections
- = pushdownProjectHelper.childToPushDownProjects.values();
- List<Plan> newChildren = new ArrayList<>();
- List<List<SlotReference>> newChildrenOutputs = new ArrayList<>();
- for (Plan child : union.children()) {
- List<NamedExpression> pushedOutput = replaceSlot(
- ctx.statementContext,
- pushDownProjections,
- slot -> {
- Integer sourceColumnIndex =
slotToColumnIndex.get(slot);
- if (sourceColumnIndex != null) {
- return
child.getOutput().get(sourceColumnIndex).toSlot();
- }
- return slot;
+ List<NamedExpression> pushDownProjections
+ =
Lists.newArrayList(pushdownProjectHelper.childToPushDownProjects.values());
+ List<Plan> newChildren = new ArrayList<>();
+ List<List<SlotReference>> newChildrenOutputs = new ArrayList<>();
+ for (int i = 0; i < union.arity(); i++) {
+ List<SlotReference> regulatorOutput =
union.getRegularChildOutput(i);
+ List<NamedExpression> pushedOutput = replaceSlot(
+ statementContext,
+ pushDownProjections,
+ slot -> {
+ Integer sourceColumnIndex =
slotToColumnIndex.get(slot);
+ if (sourceColumnIndex != null) {
+ return
regulatorOutput.get(sourceColumnIndex).toSlot();
}
- );
-
- LogicalProject<Plan> newChild = new LogicalProject<>(
- ImmutableList.<NamedExpression>builder()
- .addAll(child.getOutput())
- .addAll(pushedOutput)
- .build(),
- child
- );
-
- newChildrenOutputs.add((List) newChild.getOutput());
- newChildren.add(newChild);
- }
+ return slot;
+ }
+ );
- for (List<NamedExpression> originConstantExprs :
union.getConstantExprsList()) {
- List<NamedExpression> pushedOutput = replaceSlot(
- ctx.statementContext,
- pushDownProjections,
- slot -> {
- Integer sourceColumnIndex =
slotToColumnIndex.get(slot);
- if (sourceColumnIndex != null) {
- return
originConstantExprs.get(sourceColumnIndex).toSlot();
- }
- return slot;
+ LogicalProject<Plan> newChild = new LogicalProject<>(
+ ImmutableList.<NamedExpression>builder()
+ .addAll(regulatorOutput)
+ .addAll(pushedOutput)
+ .build(),
+ union.child(i)
+ );
+
+ newChildrenOutputs.add((List) newChild.getOutput());
+ newChildren.add(newChild);
+ }
+
+ for (List<NamedExpression> originConstantExprs :
union.getConstantExprsList()) {
+ List<NamedExpression> pushedOutput = replaceSlot(
+ statementContext,
+ pushDownProjections,
+ slot -> {
+ Integer sourceColumnIndex =
slotToColumnIndex.get(slot);
+ if (sourceColumnIndex != null) {
+ return
originConstantExprs.get(sourceColumnIndex).toSlot();
}
- );
-
- LogicalOneRowRelation originOneRowRelation = new
LogicalOneRowRelation(
- ctx.statementContext.getNextRelationId(),
- originConstantExprs
- );
-
- LogicalProject<Plan> newChild = new LogicalProject<>(
- ImmutableList.<NamedExpression>builder()
- .addAll(originOneRowRelation.getOutput())
- .addAll(pushedOutput)
- .build(),
- originOneRowRelation
- );
-
- newChildrenOutputs.add((List) newChild.getOutput());
- newChildren.add(newChild);
- }
+ return slot;
+ }
+ );
- List<NamedExpression> newUnionOutputs = new
ArrayList<>(union.getOutputs());
- for (NamedExpression projection : pushDownProjections) {
- newUnionOutputs.add(projection.toSlot());
- }
+ LogicalOneRowRelation originOneRowRelation = new
LogicalOneRowRelation(
+ statementContext.getNextRelationId(),
+ originConstantExprs
+ );
- return new LogicalProject<>(
- pushProjects.second,
- new LogicalUnion(
- union.getQualifier(),
- newUnionOutputs,
- newChildrenOutputs,
- ImmutableList.of(),
- union.hasPushedFilter(),
- newChildren
- )
+ LogicalProject<Plan> newChild = new LogicalProject<>(
+ ImmutableList.<NamedExpression>builder()
+ .addAll(originOneRowRelation.getOutput())
+ .addAll(pushedOutput)
+ .build(),
+ originOneRowRelation
);
+
+ newChildrenOutputs.add((List) newChild.getOutput());
+ newChildren.add(newChild);
}
- return project;
+
+ List<NamedExpression> newUnionOutputs = new
ArrayList<>(union.getOutputs());
+ for (NamedExpression projection : pushDownProjections) {
+ newUnionOutputs.add(projection.toSlot());
+ }
+
+ return new LogicalProject<>(
+ pushProjects.second,
+ new LogicalUnion(
+ union.getQualifier(),
+ newUnionOutputs,
+ newChildrenOutputs,
+ ImmutableList.of(),
+ union.hasPushedFilter(),
+ newChildren
+ )
+ );
}
- private List<NamedExpression> replaceSlot(
+ private static List<NamedExpression> replaceSlot(
StatementContext statementContext,
Collection<NamedExpression> pushDownProjections,
Function<Slot, Slot> slotReplace) {
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownProjectTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownProjectTest.java
new file mode 100644
index 00000000000..47398e3ef9a
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownProjectTest.java
@@ -0,0 +1,192 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.StatementContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
+import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.RelationId;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.types.TinyIntType;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+
+import com.google.common.collect.Lists;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+
+public class PushDownProjectTest implements MemoPatternMatchSupported {
+
+ private final List rel1Output = Lists.newArrayList(
+ new SlotReference(new ExprId(1), "c1", TinyIntType.INSTANCE, true,
Lists.newArrayList()),
+ new SlotReference(new ExprId(2), "c2", TinyIntType.INSTANCE, true,
Lists.newArrayList())
+ );
+ private final List regulatorRel1Output = Lists.newArrayList(
+ new SlotReference(new ExprId(2), "c2", TinyIntType.INSTANCE, true,
Lists.newArrayList()),
+ new SlotReference(new ExprId(1), "c1", TinyIntType.INSTANCE, true,
Lists.newArrayList()),
+ new SlotReference(new ExprId(2), "c2", TinyIntType.INSTANCE, true,
Lists.newArrayList())
+ );
+ private final List rel2Output = Lists.newArrayList(
+ new SlotReference(new ExprId(3), "c3", TinyIntType.INSTANCE, true,
Lists.newArrayList()),
+ new SlotReference(new ExprId(4), "c4", TinyIntType.INSTANCE, true,
Lists.newArrayList()),
+ new SlotReference(new ExprId(5), "c5", TinyIntType.INSTANCE, true,
Lists.newArrayList())
+ );
+ private final List regulatorRel2Output = Lists.newArrayList(
+ new SlotReference(new ExprId(3), "c3", TinyIntType.INSTANCE, true,
Lists.newArrayList()),
+ new SlotReference(new ExprId(5), "c5", TinyIntType.INSTANCE, true,
Lists.newArrayList()),
+ new SlotReference(new ExprId(4), "c4", TinyIntType.INSTANCE, true,
Lists.newArrayList())
+ );
+ private final List<NamedExpression> unionOutput = Lists.newArrayList(
+ new SlotReference(new ExprId(10), "c10", TinyIntType.INSTANCE,
true, Lists.newArrayList()),
+ new SlotReference(new ExprId(11), "c11", TinyIntType.INSTANCE,
true, Lists.newArrayList()),
+ new SlotReference(new ExprId(12), "c12", TinyIntType.INSTANCE,
true, Lists.newArrayList())
+ );
+ private final List<NamedExpression> pushDownProjections =
Lists.newArrayList(
+ new Alias(new ExprId(100), new ElementAt(
+ new SlotReference(new ExprId(10), "c10",
TinyIntType.INSTANCE, true, Lists.newArrayList()),
+ new StringLiteral("a"))),
+ new Alias(new ExprId(101), new ElementAt(
+ new SlotReference(new ExprId(10), "c10",
TinyIntType.INSTANCE, true, Lists.newArrayList()),
+ new StringLiteral("b"))),
+ new Alias(new ExprId(102), new ElementAt(
+ new SlotReference(new ExprId(12), "c10",
TinyIntType.INSTANCE, true, Lists.newArrayList()),
+ new StringLiteral("a"))),
+ new SlotReference(new ExprId(11), "c11", TinyIntType.INSTANCE,
true, Lists.newArrayList())
+ );
+
+ private final LogicalOneRowRelation rel1 = new LogicalOneRowRelation(new
RelationId(1), rel1Output);
+ private final LogicalOneRowRelation rel2 = new LogicalOneRowRelation(new
RelationId(2), rel2Output);
+ private final List<Plan> children = Lists.newArrayList(rel1, rel2);
+
+ @Test
+ public void testPushDownProjectThroughUnionOnlyHasChildren() {
+ List<List<SlotReference>> regulatorOutputs =
Lists.newArrayList(regulatorRel1Output, regulatorRel2Output);
+ LogicalUnion union = new LogicalUnion(Qualifier.ALL, unionOutput,
+ regulatorOutputs, Lists.newArrayList(), true, children);
+ LogicalProject<LogicalUnion> project = new
LogicalProject<>(pushDownProjections, union);
+ StatementContext context = new StatementContext();
+ LogicalProject<LogicalUnion> resProject
+ = (LogicalProject<LogicalUnion>)
PushDownProject.pushThroughUnion(project, context);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), resProject)
+ .matchesFromRoot(
+ logicalProject(
+ logicalUnion(
+ logicalProject(
+ logicalOneRowRelation()
+ .when(r ->
r.getRelationId().asInt() == 1)
+ .when(r ->
r.getOutputs().size() == 2)
+ ).when(p -> p.getOutputs().size() == 6)
+ .when(p ->
p.getProjects().get(0).getExprId().asInt() == 2)
+ .when(p ->
p.getProjects().get(1).getExprId().asInt() == 1)
+ .when(p ->
p.getProjects().get(2).getExprId().asInt() == 2)
+ .when(p ->
p.getProjects().get(3).child(0) instanceof ElementAt)
+ .when(p ->
p.getProjects().get(4).child(0) instanceof ElementAt)
+ .when(p ->
p.getProjects().get(5).child(0) instanceof ElementAt)
+ .when(p -> ((SlotReference)
(p.getProjects().get(3).child(0).child(0))).getExprId().asInt() == 2)
+ .when(p -> ((StringLiteral)
(p.getProjects().get(3).child(0).child(1))).getValue().equals("a"))
+ .when(p -> ((SlotReference)
(p.getProjects().get(4).child(0).child(0))).getExprId().asInt() == 2)
+ .when(p -> ((StringLiteral)
(p.getProjects().get(4).child(0).child(1))).getValue().equals("b"))
+ .when(p -> ((SlotReference)
(p.getProjects().get(5).child(0).child(0))).getExprId().asInt() == 2)
+ .when(p -> ((StringLiteral)
(p.getProjects().get(5).child(0).child(1))).getValue().equals("a")),
+ logicalProject(
+ logicalOneRowRelation()
+ .when(r ->
r.getRelationId().asInt() == 2)
+ .when(r ->
r.getOutputs().size() == 3)
+ ).when(p -> p.getOutputs().size() == 6)
+ .when(p ->
p.getProjects().get(0).getExprId().asInt() == 3)
+ .when(p ->
p.getProjects().get(1).getExprId().asInt() == 5)
+ .when(p ->
p.getProjects().get(2).getExprId().asInt() == 4)
+ .when(p ->
p.getProjects().get(3).child(0) instanceof ElementAt)
+ .when(p ->
p.getProjects().get(4).child(0) instanceof ElementAt)
+ .when(p ->
p.getProjects().get(5).child(0) instanceof ElementAt)
+ .when(p -> ((SlotReference)
(p.getProjects().get(3).child(0).child(0))).getExprId().asInt() == 3)
+ .when(p -> ((StringLiteral)
(p.getProjects().get(3).child(0).child(1))).getValue().equals("a"))
+ .when(p -> ((SlotReference)
(p.getProjects().get(4).child(0).child(0))).getExprId().asInt() == 3)
+ .when(p -> ((StringLiteral)
(p.getProjects().get(4).child(0).child(1))).getValue().equals("b"))
+ .when(p -> ((SlotReference)
(p.getProjects().get(5).child(0).child(0))).getExprId().asInt() == 4)
+ .when(p -> ((StringLiteral)
(p.getProjects().get(5).child(0).child(1))).getValue().equals("a"))
+ ).when(u -> u.getOutput().size() == 6)
+ .when(u ->
u.getOutput().get(0).getExprId().asInt() == 10)
+ .when(u ->
u.getOutput().get(1).getExprId().asInt() == 11)
+ .when(u ->
u.getOutput().get(2).getExprId().asInt() == 12)
+ ).when(p -> p.getProjects().stream().noneMatch(ne ->
ne.containsType(ElementAt.class)))
+ );
+ }
+
+ @Test
+ public void testPushDownProjectThroughUnionHasNoChildren() {
+ LogicalUnion union = new LogicalUnion(Qualifier.ALL, unionOutput,
Lists.newArrayList(),
+ Lists.newArrayList(regulatorRel1Output, regulatorRel2Output),
true, Lists.newArrayList());
+ LogicalProject<LogicalUnion> project = new
LogicalProject<>(pushDownProjections, union);
+ StatementContext context = new StatementContext();
+ LogicalProject<LogicalUnion> resProject
+ = (LogicalProject<LogicalUnion>)
PushDownProject.pushThroughUnion(project, context);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), resProject)
+ .matchesFromRoot(
+ logicalProject(
+ logicalUnion(
+ logicalProject(
+ logicalOneRowRelation()
+ .when(r ->
r.getOutputs().size() == 3)
+ ).when(p -> p.getOutputs().size() == 6)
+ .when(p ->
p.getProjects().get(0).getExprId().asInt() == 2)
+ .when(p ->
p.getProjects().get(1).getExprId().asInt() == 1)
+ .when(p ->
p.getProjects().get(2).getExprId().asInt() == 2)
+ .when(p ->
p.getProjects().get(3).child(0) instanceof ElementAt)
+ .when(p ->
p.getProjects().get(4).child(0) instanceof ElementAt)
+ .when(p ->
p.getProjects().get(5).child(0) instanceof ElementAt)
+ .when(p -> ((SlotReference)
(p.getProjects().get(3).child(0).child(0))).getExprId().asInt() == 2)
+ .when(p -> ((StringLiteral)
(p.getProjects().get(3).child(0).child(1))).getValue().equals("a"))
+ .when(p -> ((SlotReference)
(p.getProjects().get(4).child(0).child(0))).getExprId().asInt() == 2)
+ .when(p -> ((StringLiteral)
(p.getProjects().get(4).child(0).child(1))).getValue().equals("b"))
+ .when(p -> ((SlotReference)
(p.getProjects().get(5).child(0).child(0))).getExprId().asInt() == 2)
+ .when(p -> ((StringLiteral)
(p.getProjects().get(5).child(0).child(1))).getValue().equals("a")),
+ logicalProject(
+ logicalOneRowRelation()
+ .when(r ->
r.getOutputs().size() == 3)
+ ).when(p -> p.getOutputs().size() == 6)
+ .when(p ->
p.getProjects().get(0).getExprId().asInt() == 3)
+ .when(p ->
p.getProjects().get(1).getExprId().asInt() == 5)
+ .when(p ->
p.getProjects().get(2).getExprId().asInt() == 4)
+ .when(p ->
p.getProjects().get(3).child(0) instanceof ElementAt)
+ .when(p ->
p.getProjects().get(4).child(0) instanceof ElementAt)
+ .when(p ->
p.getProjects().get(5).child(0) instanceof ElementAt)
+ .when(p -> ((SlotReference)
(p.getProjects().get(3).child(0).child(0))).getExprId().asInt() == 3)
+ .when(p -> ((StringLiteral)
(p.getProjects().get(3).child(0).child(1))).getValue().equals("a"))
+ .when(p -> ((SlotReference)
(p.getProjects().get(4).child(0).child(0))).getExprId().asInt() == 3)
+ .when(p -> ((StringLiteral)
(p.getProjects().get(4).child(0).child(1))).getValue().equals("b"))
+ .when(p -> ((SlotReference)
(p.getProjects().get(5).child(0).child(0))).getExprId().asInt() == 4)
+ .when(p -> ((StringLiteral)
(p.getProjects().get(5).child(0).child(1))).getValue().equals("a"))
+ ).when(u -> u.getOutput().size() == 6)
+ .when(u ->
u.getOutput().get(0).getExprId().asInt() == 10)
+ .when(u ->
u.getOutput().get(1).getExprId().asInt() == 11)
+ .when(u ->
u.getOutput().get(2).getExprId().asInt() == 12)
+ ).when(p -> p.getProjects().stream().noneMatch(ne ->
ne.containsType(ElementAt.class)))
+ );
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]