This is an automated email from the ASF dual-hosted git repository.
englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 0d46c10a6fe [opt](Nereids) strip redundant widening integer cast in
SumLiteralRewrite (#61224)
0d46c10a6fe is described below
commit 0d46c10a6feed5f4192d64bf3c5781909ac5661b
Author: minghong <[email protected]>
AuthorDate: Fri Mar 13 15:26:25 2026 +0800
[opt](Nereids) strip redundant widening integer cast in SumLiteralRewrite
(#61224)
### What problem does this PR solve?
SumLiteralRewrite transforms SUM(expr +/- literal) into SUM(expr) +/-
literal * COUNT(expr). When type coercion has introduced an implicit
widening cast (e.g. CAST(smallint_col AS INT)), the rewritten SUM/COUNT
still operates on the wider type, forcing unnecessary wider data reads.
This is redundant because SUM always returns BIGINT for any integer
input (TINYINT/SMALLINT/INT/BIGINT). Strip implicit widening integer
casts in extractSumLiteral() so the aggregate operates on the original
narrow column directly.
This benefits ClickBench Q29-style queries where SUM(col), SUM(col+1),
SUM(col+2) share a narrow integer column — after stripping the cast,
SUM(col+1) and SUM(col+2) reuse the existing SUM(col).
---
.../nereids/rules/rewrite/SumLiteralRewrite.java | 33 ++++++++++++
.../rules/rewrite/SumLiteralRewriteTest.java | 60 ++++++++++++++++++++++
.../data/nereids_rules_p0/sumRewrite.out | 8 +--
3 files changed, 97 insertions(+), 4 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java
index 09be00a5819..cf983c03133 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java
@@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
+import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -204,10 +205,42 @@ public class SumLiteralRewrite extends
OneRewriteRuleFactory {
// only support integer or float types
return null;
}
+ // Strip redundant widening integer cast introduced by type coercion.
+ // e.g. SUM(CAST(smallint_col AS INT) + 1) → after rewrite becomes
SUM(CAST(smallint_col AS INT)).
+ // Since SUM always returns BIGINT for any integer input,
CAST(smallint→int) is unnecessary
+ // and forces wider data reads. Strip it so we get SUM(smallint_col)
directly.
+ left = stripWideningIntegerCast(left);
SumInfo info = new SumInfo(left, ((Sum) func).isDistinct(), ((Sum)
func).isAlwaysNullable());
return Pair.of(namedExpression, Pair.of(info, (Literal) right));
}
+ /**
+ * Strip a widening integer cast that is redundant for SUM/COUNT.
+ * For example, CAST(smallint_col AS INT) → smallint_col.
+ *
+ * This is safe because:
+ * - SUM returns BIGINT for all integer inputs
(TINYINT/SMALLINT/INT/BIGINT),
+ * so widening the input before aggregation does not change the result.
+ * - COUNT just counts non-null values, unaffected by widening.
+ *
+ * Only implicit (type-coercion) casts between integer-like types are
stripped.
+ */
+ private static Expression stripWideningIntegerCast(Expression expr) {
+ if (!(expr instanceof Cast)) {
+ return expr;
+ }
+ Cast cast = (Cast) expr;
+ if (cast.isExplicitType()) {
+ return expr;
+ }
+ Expression inner = cast.child();
+ if (inner.getDataType().isIntegerLikeType() &&
cast.getDataType().isIntegerLikeType()
+ && inner.getDataType().width() <= cast.getDataType().width()) {
+ return inner;
+ }
+ return expr;
+ }
+
static class SumInfo {
Expression expr;
boolean isDistinct;
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java
index 19ea7b864fb..5b918c62a59 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java
@@ -19,12 +19,14 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
@@ -143,4 +145,62 @@ class SumLiteralRewriteTest implements
MemoPatternMatchSupported {
.matches(logicalAggregate().when(p -> p.getOutputs().size() ==
3));
}
+
+ @Test
+ void testStripWideningIntegerCast() {
+ Slot slot1 = scan1.getOutput().get(0);
+ // Simulate type coercion's implicit widening cast: CAST(int_col AS
BIGINT)
+ Cast castSlot = new Cast(slot1, BigIntType.INSTANCE);
+ Alias add1 = new Alias(new Sum(new Add(castSlot, Literal.of(1))));
+ Alias add2 = new Alias(new Sum(new Add(castSlot, Literal.of(2))));
+ LogicalAggregate<?> agg = new LogicalAggregate<>(
+ ImmutableList.of(), ImmutableList.of(add1, add2), scan1);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyTopDown(ImmutableList.of(new
SumLiteralRewrite().build()))
+ .printlnTree()
+ // After stripping the implicit widening cast, Sum and Count
should use
+ // slot1 directly (not Cast(slot1 AS BIGINT)), so no Cast in
aggregate outputs
+ .matches(logicalAggregate().when(a ->
+ a.getOutputExpressions().stream().noneMatch(
+ e -> e.anyMatch(expr -> expr instanceof
Cast))));
+
+ // Verify explicit cast is NOT stripped
+ Cast explicitCast = new Cast(slot1, BigIntType.INSTANCE, true);
+ Alias addExplicit1 = new Alias(new Sum(new Add(explicitCast,
Literal.of(1))));
+ Alias addExplicit2 = new Alias(new Sum(new Add(explicitCast,
Literal.of(2))));
+ agg = new LogicalAggregate<>(
+ ImmutableList.of(), ImmutableList.of(addExplicit1,
addExplicit2), scan1);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyTopDown(ImmutableList.of(new
SumLiteralRewrite().build()))
+ .printlnTree()
+ // Explicit cast should be preserved — aggregate outputs
should still contain Cast
+ .matches(logicalAggregate().when(a ->
+ a.getOutputExpressions().stream().anyMatch(
+ e -> e.anyMatch(expr -> expr instanceof
Cast))));
+ }
+
+ @Test
+ void testStripWideningCastWithExistingSum() {
+ // Simulates ClickBench Q29: SELECT SUM(col), SUM(col+1), SUM(col+2)
+ // where col is a narrow integer type and type coercion introduces
implicit widening cast.
+ Slot slot1 = scan1.getOutput().get(0);
+ // Pre-existing plain SUM(slot) — no cast, no literal
+ Alias sum = new Alias(new Sum(slot1));
+ // Simulate type coercion widening: SUM(CAST(int_col AS BIGINT) + 1)
etc.
+ Cast castSlot = new Cast(slot1, BigIntType.INSTANCE);
+ Alias add1 = new Alias(new Sum(new Add(castSlot, Literal.of(1))));
+ Alias add2 = new Alias(new Sum(new Add(castSlot, Literal.of(2))));
+ LogicalAggregate<?> agg = new LogicalAggregate<>(
+ ImmutableList.of(), ImmutableList.of(sum, add1, add2), scan1);
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyTopDown(ImmutableList.of(new
SumLiteralRewrite().build()))
+ .printlnTree()
+ // After stripping widening cast, the base expr of
SUM(CAST(slot AS BIGINT) + n)
+ // becomes slot — matching the pre-existing SUM(slot). Rewrite
reuses it and only
+ // adds COUNT(slot). Aggregate outputs: sum(slot) +
count(slot) = 2.
+ .matches(logicalAggregate().when(a ->
+ a.getOutputExpressions().size() == 2
+ && a.getOutputExpressions().stream().noneMatch(
+ e -> e.anyMatch(expr -> expr instanceof
Cast))));
+ }
}
diff --git a/regression-test/data/nereids_rules_p0/sumRewrite.out
b/regression-test/data/nereids_rules_p0/sumRewrite.out
index 356fad37763..265ccae7dea 100644
--- a/regression-test/data/nereids_rules_p0/sumRewrite.out
+++ b/regression-test/data/nereids_rules_p0/sumRewrite.out
@@ -268,10 +268,10 @@ PhysicalResultSink
-- !sum_null_and_not_null_shape --
PhysicalResultSink
---PhysicalProject[(sum(cast(id as BIGINT)) + (count(cast(id as BIGINT)) * 1))
AS `sum(id + 1)`, (sum(cast(id as BIGINT)) - (count(cast(id as BIGINT)) * 1))
AS `sum(id - 1)`, (sum(cast(not_null_id as BIGINT)) + (count(cast(not_null_id
as BIGINT)) * 1)) AS `sum(not_null_id + 1)`, (sum(cast(not_null_id as BIGINT))
- (count(cast(not_null_id as BIGINT)) * 1)) AS `sum(not_null_id - 1)`, sum(id),
sum(not_null_id)]
-----hashAgg[GLOBAL, groupByExpr=(), outputExpr=(count(cast(id as BIGINT)) AS
`count(cast(id as BIGINT))`, count(cast(not_null_id as BIGINT)) AS
`count(cast(not_null_id as BIGINT))`, sum(cast(id as BIGINT)) AS `sum(cast(id
as BIGINT))`, sum(cast(not_null_id as BIGINT)) AS `sum(cast(not_null_id as
BIGINT))`, sum(id) AS `sum(id)`, sum(not_null_id) AS `sum(not_null_id)`)]
-------hashAgg[LOCAL, groupByExpr=(), outputExpr=(partial_count(cast(id as
BIGINT)) AS `partial_count(cast(id as BIGINT))`, partial_count(cast(not_null_id
as BIGINT)) AS `partial_count(cast(not_null_id as BIGINT))`,
partial_sum(cast(id as BIGINT)) AS `partial_sum(cast(id as BIGINT))`,
partial_sum(cast(not_null_id as BIGINT)) AS `partial_sum(cast(not_null_id as
BIGINT))`, partial_sum(id) AS `partial_sum(id)`, partial_sum(not_null_id) AS
`partial_sum(not_null_id)`)]
---------PhysicalProject[cast(id as BIGINT) AS `cast(id as BIGINT)`,
cast(not_null_id as BIGINT) AS `cast(not_null_id as BIGINT)`, sr.id,
sr.not_null_id]
+--PhysicalProject[(sum(id) + (count(id) * 1)) AS `sum(id + 1)`, (sum(id) -
(count(id) * 1)) AS `sum(id - 1)`, (sum(not_null_id) + (count(not_null_id) *
1)) AS `sum(not_null_id + 1)`, (sum(not_null_id) - (count(not_null_id) * 1)) AS
`sum(not_null_id - 1)`, sum(id), sum(not_null_id)]
+----hashAgg[GLOBAL, groupByExpr=(), outputExpr=(count(id) AS `count(id)`,
count(not_null_id) AS `count(not_null_id)`, sum(id) AS `sum(id)`,
sum(not_null_id) AS `sum(not_null_id)`)]
+------hashAgg[LOCAL, groupByExpr=(), outputExpr=(partial_count(id) AS
`partial_count(id)`, partial_count(not_null_id) AS
`partial_count(not_null_id)`, partial_sum(id) AS `partial_sum(id)`,
partial_sum(not_null_id) AS `partial_sum(not_null_id)`)]
+--------PhysicalProject[sr.id, sr.not_null_id]
----------PhysicalOlapScan[sr]
-- !sum_null_and_not_null_result --
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]