This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new d7e5f97b74 [feature](Nereids): eliminate AssertNumRows (#23842)
d7e5f97b74 is described below
commit d7e5f97b74a388af0aaab695b07e74901bbb1ae1
Author: jakevin <[email protected]>
AuthorDate: Wed Sep 13 22:24:02 2023 +0800
[feature](Nereids): eliminate AssertNumRows (#23842)
---
.../doris/nereids/jobs/executor/Rewriter.java | 4 +-
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../rules/rewrite/EliminateAssertNumRows.java | 93 ++++++++++++++++++
.../rules/rewrite/EliminateAssertNumRowsTest.java | 106 +++++++++++++++++++++
.../doris/nereids/util/LogicalPlanBuilder.java | 9 ++
.../data/nereids_tpch_shape_sf500_p0/shape/q11.out | 31 +++---
.../data/nereids_tpch_shape_sf500_p0/shape/q15.out | 21 ++--
.../data/nereids_tpch_shape_sf500_p0/shape/q22.out | 13 ++-
8 files changed, 243 insertions(+), 35 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index de8e92eff3..5e0fd69cbc 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -52,6 +52,7 @@ import
org.apache.doris.nereids.rules.rewrite.CountLiteralToCountStar;
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult;
import org.apache.doris.nereids.rules.rewrite.EliminateAggregate;
+import org.apache.doris.nereids.rules.rewrite.EliminateAssertNumRows;
import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition;
import org.apache.doris.nereids.rules.rewrite.EliminateEmptyRelation;
import org.apache.doris.nereids.rules.rewrite.EliminateFilter;
@@ -168,7 +169,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
bottomUp(
new EliminateLimit(),
new EliminateFilter(),
- new EliminateAggregate()
+ new EliminateAggregate(),
+ new EliminateAssertNumRows()
)
),
// please note: this rule must run before NormalizeAggregate
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 8dc865482b..27399ec088 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -201,6 +201,7 @@ public enum RuleType {
ELIMINATE_OUTER_JOIN(RuleTypeClass.REWRITE),
ELIMINATE_DEDUP_JOIN_CONDITION(RuleTypeClass.REWRITE),
ELIMINATE_NULL_AWARE_LEFT_ANTI_JOIN(RuleTypeClass.REWRITE),
+ ELIMINATE_ASSERT_NUM_ROWS(RuleTypeClass.REWRITE),
CONVERT_OUTER_JOIN_TO_ANTI(RuleTypeClass.REWRITE),
FIND_HASH_CONDITION_FOR_JOIN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_AGG_SCAN(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRows.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRows.java
new file mode 100644
index 0000000000..84d459d30d
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRows.java
@@ -0,0 +1,93 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
+import
org.apache.doris.nereids.trees.expressions.AssertNumRowsElement.Assertion;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
+
+/** EliminateAssertNumRows */
+public class EliminateAssertNumRows extends OneRewriteRuleFactory {
+
+ @Override
+ public Rule build() {
+ return logicalAssertNumRows()
+ .then(assertNumRows -> {
+ Plan checkPlan = assertNumRows.child();
+ while (skipPlan(checkPlan) != checkPlan) {
+ checkPlan = skipPlan(checkPlan);
+ }
+ return canEliminate(assertNumRows, checkPlan) ?
assertNumRows.child() : null;
+ }).toRule(RuleType.ELIMINATE_ASSERT_NUM_ROWS);
+ }
+
+ private Plan skipPlan(Plan plan) {
+ if (plan instanceof LogicalProject || plan instanceof LogicalFilter ||
plan instanceof LogicalSort) {
+ plan = plan.child(0);
+ } else if (plan instanceof LogicalJoin) {
+ if (((LogicalJoin<?, ?>)
plan).getJoinType().isLeftSemiOrAntiJoin()) {
+ plan = plan.child(0);
+ } else if (((LogicalJoin<?, ?>)
plan).getJoinType().isRightSemiOrAntiJoin()) {
+ plan = plan.child(1);
+ }
+ }
+ return plan;
+ }
+
+ private boolean canEliminate(LogicalAssertNumRows<?> assertNumRows, Plan
plan) {
+ long maxOutputRowcount;
+ // Don't need to consider TopN, because it's generated by Sort + Limit.
+ if (plan instanceof LogicalLimit) {
+ maxOutputRowcount = ((LogicalLimit<?>) plan).getLimit();
+ } else if (plan instanceof LogicalAggregate && ((LogicalAggregate<?>)
plan).getGroupByExpressions().isEmpty()) {
+ maxOutputRowcount = 1;
+ } else {
+ return false;
+ }
+
+ AssertNumRowsElement assertNumRowsElement =
assertNumRows.getAssertNumRowsElement();
+ Assertion assertion = assertNumRowsElement.getAssertion();
+ long assertNum = assertNumRowsElement.getDesiredNumOfRows();
+
+ switch (assertion) {
+ case NE:
+ case LT:
+ if (maxOutputRowcount < assertNum) {
+ return true;
+ }
+ break;
+ case LE:
+ if (maxOutputRowcount <= assertNum) {
+ return true;
+ }
+ break;
+ default:
+ return false;
+ }
+ return false;
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRowsTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRowsTest.java
new file mode 100644
index 0000000000..f4b647a54e
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRowsTest.java
@@ -0,0 +1,106 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import
org.apache.doris.nereids.trees.expressions.AssertNumRowsElement.Assertion;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Test;
+
+class EliminateAssertNumRowsTest implements MemoPatternMatchSupported {
+ @Test
+ void testFailedMatch() {
+ LogicalPlan plan = new
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
+ .limit(10, 10)
+ .assertNumRows(Assertion.LT, 10)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new EliminateAssertNumRows())
+ .matchesFromRoot(
+ logicalAssertNumRows(logicalLimit(logicalOlapScan()))
+ );
+ }
+
+ @Test
+ void testLimit() {
+ LogicalPlan plan = new
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
+ .limit(10, 10)
+ .assertNumRows(Assertion.LE, 10)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new EliminateAssertNumRows())
+ .matchesFromRoot(
+ logicalLimit(logicalOlapScan())
+ );
+ }
+
+ @Test
+ void testScalarAgg() {
+ LogicalPlan plan = new
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
+ .agg(ImmutableList.of(), ImmutableList.of((new
Count()).alias("cnt")))
+ .assertNumRows(Assertion.LE, 10)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new EliminateAssertNumRows())
+ .matchesFromRoot(
+ logicalAggregate(logicalOlapScan())
+ );
+ }
+
+ @Test
+ void skipProject() {
+ LogicalPlan plan = new
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
+ .limit(10, 10)
+ .project(ImmutableList.of(0, 1))
+ .project(ImmutableList.of(0, 1))
+ .assertNumRows(Assertion.LE, 10)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new EliminateAssertNumRows())
+ .matchesFromRoot(
+
logicalProject(logicalProject(logicalLimit(logicalOlapScan())))
+ );
+ }
+
+ @Test
+ void skipSemiJoin() {
+ LogicalPlan plan = new
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
+ .limit(10, 10)
+ .joinEmptyOn(PlanConstructor.newLogicalOlapScan(1, "t2", 0),
JoinType.LEFT_SEMI_JOIN)
+ .assertNumRows(Assertion.LE, 10)
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyTopDown(new EliminateAssertNumRows())
+ .matchesFromRoot(
+ leftSemiLogicalJoin(logicalLimit(logicalOlapScan()),
logicalOlapScan())
+ );
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
index 99f7884b2b..c5024ff931 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
@@ -20,6 +20,8 @@ package org.apache.doris.nereids.util;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
+import
org.apache.doris.nereids.trees.expressions.AssertNumRowsElement.Assertion;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -28,6 +30,7 @@ import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.LimitPhase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
@@ -192,4 +195,10 @@ public class LogicalPlanBuilder {
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys,
outputExprsList, this.plan);
return from(agg);
}
+
+ public LogicalPlanBuilder assertNumRows(Assertion assertion, long numRows)
{
+ LogicalAssertNumRows<LogicalPlan> assertNumRows = new
LogicalAssertNumRows<>(
+ new AssertNumRowsElement(numRows, "", assertion), this.plan);
+ return from(assertNumRows);
+ }
}
diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q11.out
b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q11.out
index 8fd0e3341d..1d1e947be2 100644
--- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q11.out
+++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q11.out
@@ -20,21 +20,20 @@ PhysicalResultSink
--------------------------filter((nation.n_name = 'GERMANY'))
----------------------------PhysicalOlapScan[nation]
------------PhysicalDistribute
---------------PhysicalAssertNumRows
-----------------PhysicalProject
-------------------hashAgg[GLOBAL]
---------------------PhysicalDistribute
-----------------------hashAgg[LOCAL]
-------------------------PhysicalProject
---------------------------hashJoin[INNER_JOIN](partsupp.ps_suppkey =
supplier.s_suppkey)
-----------------------------PhysicalProject
-------------------------------PhysicalOlapScan[partsupp]
-----------------------------PhysicalDistribute
-------------------------------hashJoin[INNER_JOIN](supplier.s_nationkey =
nation.n_nationkey)
+--------------PhysicalProject
+----------------hashAgg[GLOBAL]
+------------------PhysicalDistribute
+--------------------hashAgg[LOCAL]
+----------------------PhysicalProject
+------------------------hashJoin[INNER_JOIN](partsupp.ps_suppkey =
supplier.s_suppkey)
+--------------------------PhysicalProject
+----------------------------PhysicalOlapScan[partsupp]
+--------------------------PhysicalDistribute
+----------------------------hashJoin[INNER_JOIN](supplier.s_nationkey =
nation.n_nationkey)
+------------------------------PhysicalProject
+--------------------------------PhysicalOlapScan[supplier]
+------------------------------PhysicalDistribute
--------------------------------PhysicalProject
-----------------------------------PhysicalOlapScan[supplier]
---------------------------------PhysicalDistribute
-----------------------------------PhysicalProject
-------------------------------------filter((nation.n_name = 'GERMANY'))
---------------------------------------PhysicalOlapScan[nation]
+----------------------------------filter((nation.n_name = 'GERMANY'))
+------------------------------------PhysicalOlapScan[nation]
diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q15.out
b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q15.out
index 4106594748..ff4350e080 100644
--- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q15.out
+++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q15.out
@@ -17,15 +17,14 @@ PhysicalResultSink
------------------------filter((lineitem.l_shipdate >=
1996-01-01)(lineitem.l_shipdate < 1996-04-01))
--------------------------PhysicalOlapScan[lineitem]
----------------PhysicalDistribute
-------------------PhysicalAssertNumRows
---------------------hashAgg[GLOBAL]
-----------------------PhysicalDistribute
-------------------------hashAgg[LOCAL]
---------------------------PhysicalProject
-----------------------------hashAgg[GLOBAL]
-------------------------------PhysicalDistribute
---------------------------------hashAgg[LOCAL]
-----------------------------------PhysicalProject
-------------------------------------filter((lineitem.l_shipdate >=
1996-01-01)(lineitem.l_shipdate < 1996-04-01))
---------------------------------------PhysicalOlapScan[lineitem]
+------------------hashAgg[GLOBAL]
+--------------------PhysicalDistribute
+----------------------hashAgg[LOCAL]
+------------------------PhysicalProject
+--------------------------hashAgg[GLOBAL]
+----------------------------PhysicalDistribute
+------------------------------hashAgg[LOCAL]
+--------------------------------PhysicalProject
+----------------------------------filter((lineitem.l_shipdate >=
1996-01-01)(lineitem.l_shipdate < 1996-04-01))
+------------------------------------PhysicalOlapScan[lineitem]
diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q22.out
b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q22.out
index c352e2508a..7845eba2ba 100644
--- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q22.out
+++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q22.out
@@ -18,11 +18,10 @@ PhysicalResultSink
------------------------filter(substring(c_phone, 1, 2) IN ('13', '31', '23',
'29', '30', '18', '17'))
--------------------------PhysicalOlapScan[customer]
----------------------PhysicalDistribute
-------------------------PhysicalAssertNumRows
---------------------------hashAgg[GLOBAL]
-----------------------------PhysicalDistribute
-------------------------------hashAgg[LOCAL]
---------------------------------PhysicalProject
-----------------------------------filter((customer.c_acctbal >
0.00)substring(c_phone, 1, 2) IN ('13', '31', '23', '29', '30', '18', '17'))
-------------------------------------PhysicalOlapScan[customer]
+------------------------hashAgg[GLOBAL]
+--------------------------PhysicalDistribute
+----------------------------hashAgg[LOCAL]
+------------------------------PhysicalProject
+--------------------------------filter((customer.c_acctbal >
0.00)substring(c_phone, 1, 2) IN ('13', '31', '23', '29', '30', '18', '17'))
+----------------------------------PhysicalOlapScan[customer]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]