neilconway commented on code in PR #20180: URL: https://github.com/apache/datafusion/pull/20180#discussion_r2799823469
########## datafusion/optimizer/src/rewrite_aggregate_with_constant.rs: ########## @@ -0,0 +1,609 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`RewriteAggregateWithConstant`] rewrites `SUM(column ± constant)` to `SUM(column) ± constant * COUNT(column)` + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use std::collections::HashMap; + +use datafusion_common::tree_node::Transformed; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::expr::AggregateFunctionParams; +use datafusion_expr::{ + Aggregate, BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, binary_expr, + col, lit, +}; +use datafusion_functions_aggregate::expr_fn::{count, sum}; +use indexmap::IndexMap; + +/// Optimizer rule that rewrites `SUM(column ± constant)` expressions +/// into `SUM(column) ± constant * COUNT(column)` when multiple such expressions +/// exist for the same base column. +/// +/// This reduces computation by calculating SUM once and deriving other values. +/// +/// # Example +/// ```sql +/// SELECT SUM(a), SUM(a + 1), SUM(a + 2) FROM t; +/// ``` +/// is rewritten into a Projection on top of an Aggregate: +/// ```sql +/// -- New Projection Node +/// SELECT sum_a, sum_a + 1 * count_a, sum_a + 2 * count_a +/// -- New Aggregate Node +/// FROM (SELECT SUM(a) as sum_a, COUNT(a) as count_a FROM t); +/// ``` +#[derive(Default, Debug)] +pub struct RewriteAggregateWithConstant {} + +impl RewriteAggregateWithConstant { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for RewriteAggregateWithConstant { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result<Transformed<LogicalPlan>> { + match plan { + // This rule specifically targets Aggregate nodes + LogicalPlan::Aggregate(aggregate) => { + // Step 1: Identify which expressions can be rewritten and group them by base column + let rewrite_info = analyze_aggregate(&aggregate)?; + + if rewrite_info.is_empty() { + // No groups found with 2+ matching SUM expressions, return original plan + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } + + // Step 2: Perform the actual transformation into Aggregate + Projection + transform_aggregate(aggregate, &rewrite_info) + } + // Non-aggregate plans are passed through unchanged + _ => Ok(Transformed::no(plan)), + } + } + + fn name(&self) -> &str { + "rewrite_aggregate_with_constant" + } + + fn apply_order(&self) -> Option<ApplyOrder> { + // Bottom-up ensures we optimize subqueries before the outer query + Some(ApplyOrder::BottomUp) + } +} + +/// Internal structure to track metadata for a SUM expression that qualifies for rewrite. +#[derive(Debug, Clone)] +struct SumWithConstant { + /// The inner expression being summed (e.g., the `a` in `SUM(a + 1)`) + base_expr: Expr, + /// The constant value being added/subtracted (e.g., `1` in `SUM(a + 1)`) + constant: ScalarValue, + /// The operator (`+` or `-`) + operator: Operator, + /// The index in the original Aggregate's `aggr_expr` list, used to maintain output order + original_index: usize, + // Note: ORDER BY inside SUM is irrelevant because SUM is commutative — + // the order of addition doesn't change the result. If this rule is ever + // extended to non-commutative aggregates, ORDER BY handling would need + // to be added back. +} + +/// Maps a base expression's schema name to all its SUM(base ± const) variants. +/// We use IndexMap to preserve insertion order, ensuring deterministic output +/// in the rewritten plan (important for stable EXPLAIN output in tests). +type RewriteGroups = IndexMap<String, Vec<SumWithConstant>>; + +/// Scans the aggregate expressions to find candidates for the rewrite. +fn analyze_aggregate(aggregate: &Aggregate) -> Result<RewriteGroups> { + let mut groups: RewriteGroups = IndexMap::new(); + + for (idx, expr) in aggregate.aggr_expr.iter().enumerate() { + // Try to match the pattern SUM(col ± lit) + if let Some(sum_info) = extract_sum_with_constant(expr, idx)? { + let key = sum_info.base_expr.schema_name().to_string(); + groups.entry(key).or_default().push(sum_info); + } + } + + // Optimization: Only rewrite if we have at least 2 expressions for the same column. + // If there's only one SUM(a + 1), rewriting it to SUM(a) + 1*COUNT(a) + // actually increases the work (1 agg -> 2 aggs). + groups.retain(|_, v| v.len() >= 2); + + Ok(groups) +} + +/// Extract SUM(base_expr ± constant) pattern from an expression. +/// Handles both `Expr::AggregateFunction(...)` and `Expr::Alias(Expr::AggregateFunction(...))` +/// so the rule works regardless of whether aggregate expressions carry aliases +/// (e.g., when plans are built via the LogicalPlanBuilder API). +fn extract_sum_with_constant(expr: &Expr, idx: usize) -> Result<Option<SumWithConstant>> { + // Unwrap Expr::Alias if present — the SQL planner puts aliases in a + // Projection above the Aggregate, but the builder API allows aliases + // directly inside aggr_expr. + let inner = match expr { + Expr::Alias(alias) => alias.expr.as_ref(), + other => other, + }; + + match inner { + Expr::AggregateFunction(agg_fn) => { + // Rule only applies to SUM + if agg_fn.func.name().to_lowercase() != "sum" { + return Ok(None); + } + + let AggregateFunctionParams { + args, + distinct, + filter, + order_by: _, + null_treatment: _, + } = &agg_fn.params; + + // We cannot easily rewrite SUM(DISTINCT a + 1) or SUM(a + 1) FILTER (...) + // as the math SUM(a) + k*COUNT(a) wouldn't hold correctly with these modifiers. + if *distinct || filter.is_some() { + return Ok(None); + } + + // SUM must have exactly one argument (e.g. SUM(a + 1)). + // This rejects invalid calls like SUM() or non-standard multi-argument variations. + if args.len() != 1 { + return Ok(None); + } + + let arg = &args[0]; + + // Try to match: base_expr +/- constant + // Note: If the base_expr is complex (e.g., SUM(a + b + 1)), base_expr will be "a + b". + // The rule will still work if multiple SUMs have the exact same complex base_expr, + // as they will be grouped by the string representation of that expression. + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = arg + && matches!(op, Operator::Plus | Operator::Minus) + { + // Check if right side is a literal constant Review Comment: Remove duplicate comment ########## datafusion/optimizer/src/rewrite_aggregate_with_constant.rs: ########## @@ -0,0 +1,609 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`RewriteAggregateWithConstant`] rewrites `SUM(column ± constant)` to `SUM(column) ± constant * COUNT(column)` + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use std::collections::HashMap; + +use datafusion_common::tree_node::Transformed; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::expr::AggregateFunctionParams; +use datafusion_expr::{ + Aggregate, BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, binary_expr, + col, lit, +}; +use datafusion_functions_aggregate::expr_fn::{count, sum}; +use indexmap::IndexMap; + +/// Optimizer rule that rewrites `SUM(column ± constant)` expressions +/// into `SUM(column) ± constant * COUNT(column)` when multiple such expressions +/// exist for the same base column. +/// +/// This reduces computation by calculating SUM once and deriving other values. +/// +/// # Example +/// ```sql +/// SELECT SUM(a), SUM(a + 1), SUM(a + 2) FROM t; +/// ``` +/// is rewritten into a Projection on top of an Aggregate: +/// ```sql +/// -- New Projection Node +/// SELECT sum_a, sum_a + 1 * count_a, sum_a + 2 * count_a +/// -- New Aggregate Node +/// FROM (SELECT SUM(a) as sum_a, COUNT(a) as count_a FROM t); +/// ``` +#[derive(Default, Debug)] +pub struct RewriteAggregateWithConstant {} + +impl RewriteAggregateWithConstant { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for RewriteAggregateWithConstant { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result<Transformed<LogicalPlan>> { + match plan { + // This rule specifically targets Aggregate nodes + LogicalPlan::Aggregate(aggregate) => { + // Step 1: Identify which expressions can be rewritten and group them by base column + let rewrite_info = analyze_aggregate(&aggregate)?; + + if rewrite_info.is_empty() { + // No groups found with 2+ matching SUM expressions, return original plan + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } + + // Step 2: Perform the actual transformation into Aggregate + Projection + transform_aggregate(aggregate, &rewrite_info) + } + // Non-aggregate plans are passed through unchanged + _ => Ok(Transformed::no(plan)), + } + } + + fn name(&self) -> &str { + "rewrite_aggregate_with_constant" + } + + fn apply_order(&self) -> Option<ApplyOrder> { + // Bottom-up ensures we optimize subqueries before the outer query + Some(ApplyOrder::BottomUp) + } +} + +/// Internal structure to track metadata for a SUM expression that qualifies for rewrite. +#[derive(Debug, Clone)] +struct SumWithConstant { + /// The inner expression being summed (e.g., the `a` in `SUM(a + 1)`) + base_expr: Expr, + /// The constant value being added/subtracted (e.g., `1` in `SUM(a + 1)`) + constant: ScalarValue, + /// The operator (`+` or `-`) + operator: Operator, + /// The index in the original Aggregate's `aggr_expr` list, used to maintain output order + original_index: usize, + // Note: ORDER BY inside SUM is irrelevant because SUM is commutative — + // the order of addition doesn't change the result. If this rule is ever + // extended to non-commutative aggregates, ORDER BY handling would need + // to be added back. +} + +/// Maps a base expression's schema name to all its SUM(base ± const) variants. +/// We use IndexMap to preserve insertion order, ensuring deterministic output +/// in the rewritten plan (important for stable EXPLAIN output in tests). +type RewriteGroups = IndexMap<String, Vec<SumWithConstant>>; + +/// Scans the aggregate expressions to find candidates for the rewrite. +fn analyze_aggregate(aggregate: &Aggregate) -> Result<RewriteGroups> { + let mut groups: RewriteGroups = IndexMap::new(); + + for (idx, expr) in aggregate.aggr_expr.iter().enumerate() { + // Try to match the pattern SUM(col ± lit) + if let Some(sum_info) = extract_sum_with_constant(expr, idx)? { + let key = sum_info.base_expr.schema_name().to_string(); + groups.entry(key).or_default().push(sum_info); + } + } + + // Optimization: Only rewrite if we have at least 2 expressions for the same column. + // If there's only one SUM(a + 1), rewriting it to SUM(a) + 1*COUNT(a) + // actually increases the work (1 agg -> 2 aggs). + groups.retain(|_, v| v.len() >= 2); + + Ok(groups) +} + +/// Extract SUM(base_expr ± constant) pattern from an expression. +/// Handles both `Expr::AggregateFunction(...)` and `Expr::Alias(Expr::AggregateFunction(...))` +/// so the rule works regardless of whether aggregate expressions carry aliases +/// (e.g., when plans are built via the LogicalPlanBuilder API). +fn extract_sum_with_constant(expr: &Expr, idx: usize) -> Result<Option<SumWithConstant>> { + // Unwrap Expr::Alias if present — the SQL planner puts aliases in a + // Projection above the Aggregate, but the builder API allows aliases + // directly inside aggr_expr. + let inner = match expr { + Expr::Alias(alias) => alias.expr.as_ref(), + other => other, + }; + + match inner { + Expr::AggregateFunction(agg_fn) => { + // Rule only applies to SUM + if agg_fn.func.name().to_lowercase() != "sum" { + return Ok(None); + } + + let AggregateFunctionParams { + args, + distinct, + filter, + order_by: _, + null_treatment: _, + } = &agg_fn.params; + + // We cannot easily rewrite SUM(DISTINCT a + 1) or SUM(a + 1) FILTER (...) + // as the math SUM(a) + k*COUNT(a) wouldn't hold correctly with these modifiers. + if *distinct || filter.is_some() { + return Ok(None); + } + + // SUM must have exactly one argument (e.g. SUM(a + 1)). + // This rejects invalid calls like SUM() or non-standard multi-argument variations. + if args.len() != 1 { + return Ok(None); + } + + let arg = &args[0]; + + // Try to match: base_expr +/- constant + // Note: If the base_expr is complex (e.g., SUM(a + b + 1)), base_expr will be "a + b". + // The rule will still work if multiple SUMs have the exact same complex base_expr, + // as they will be grouped by the string representation of that expression. + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = arg + && matches!(op, Operator::Plus | Operator::Minus) + { + // Check if right side is a literal constant + // Check if right side is a literal constant (e.g., SUM(a + 1)) + if let Expr::Literal(constant, _) = right.as_ref() + && is_numeric_constant(constant) + { + return Ok(Some(SumWithConstant { + base_expr: (**left).clone(), + constant: constant.clone(), + operator: *op, + original_index: idx, + })); + } + + // Also check left side for commutative addition (e.g., SUM(1 + a)) + // Does NOT apply to subtraction: SUM(5 - a) ≠ SUM(a - 5) + if let Expr::Literal(constant, _) = left.as_ref() + && is_numeric_constant(constant) + && *op == Operator::Plus + { + return Ok(Some(SumWithConstant { + base_expr: (**right).clone(), + constant: constant.clone(), + operator: Operator::Plus, + original_index: idx, + })); + } + } + + Ok(None) + } + _ => Ok(None), + } +} + +/// Check if a scalar value is a numeric constant +/// (guards against non-arithmetic types like strings, booleans, dates, etc.) +fn is_numeric_constant(value: &ScalarValue) -> bool { + matches!( + value, + ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + | ScalarValue::Float32(_) Review Comment: float16? ########## datafusion/optimizer/src/rewrite_aggregate_with_constant.rs: ########## @@ -0,0 +1,609 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`RewriteAggregateWithConstant`] rewrites `SUM(column ± constant)` to `SUM(column) ± constant * COUNT(column)` + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use std::collections::HashMap; + +use datafusion_common::tree_node::Transformed; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::expr::AggregateFunctionParams; +use datafusion_expr::{ + Aggregate, BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, binary_expr, + col, lit, +}; +use datafusion_functions_aggregate::expr_fn::{count, sum}; +use indexmap::IndexMap; + +/// Optimizer rule that rewrites `SUM(column ± constant)` expressions +/// into `SUM(column) ± constant * COUNT(column)` when multiple such expressions +/// exist for the same base column. +/// +/// This reduces computation by calculating SUM once and deriving other values. +/// +/// # Example +/// ```sql +/// SELECT SUM(a), SUM(a + 1), SUM(a + 2) FROM t; +/// ``` +/// is rewritten into a Projection on top of an Aggregate: +/// ```sql +/// -- New Projection Node +/// SELECT sum_a, sum_a + 1 * count_a, sum_a + 2 * count_a +/// -- New Aggregate Node +/// FROM (SELECT SUM(a) as sum_a, COUNT(a) as count_a FROM t); +/// ``` +#[derive(Default, Debug)] +pub struct RewriteAggregateWithConstant {} + +impl RewriteAggregateWithConstant { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for RewriteAggregateWithConstant { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result<Transformed<LogicalPlan>> { + match plan { + // This rule specifically targets Aggregate nodes + LogicalPlan::Aggregate(aggregate) => { + // Step 1: Identify which expressions can be rewritten and group them by base column + let rewrite_info = analyze_aggregate(&aggregate)?; + + if rewrite_info.is_empty() { + // No groups found with 2+ matching SUM expressions, return original plan + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } + + // Step 2: Perform the actual transformation into Aggregate + Projection + transform_aggregate(aggregate, &rewrite_info) + } + // Non-aggregate plans are passed through unchanged + _ => Ok(Transformed::no(plan)), + } + } + + fn name(&self) -> &str { + "rewrite_aggregate_with_constant" + } + + fn apply_order(&self) -> Option<ApplyOrder> { + // Bottom-up ensures we optimize subqueries before the outer query + Some(ApplyOrder::BottomUp) + } +} + +/// Internal structure to track metadata for a SUM expression that qualifies for rewrite. +#[derive(Debug, Clone)] +struct SumWithConstant { + /// The inner expression being summed (e.g., the `a` in `SUM(a + 1)`) + base_expr: Expr, + /// The constant value being added/subtracted (e.g., `1` in `SUM(a + 1)`) + constant: ScalarValue, + /// The operator (`+` or `-`) + operator: Operator, + /// The index in the original Aggregate's `aggr_expr` list, used to maintain output order + original_index: usize, + // Note: ORDER BY inside SUM is irrelevant because SUM is commutative — + // the order of addition doesn't change the result. If this rule is ever + // extended to non-commutative aggregates, ORDER BY handling would need + // to be added back. +} + +/// Maps a base expression's schema name to all its SUM(base ± const) variants. +/// We use IndexMap to preserve insertion order, ensuring deterministic output +/// in the rewritten plan (important for stable EXPLAIN output in tests). +type RewriteGroups = IndexMap<String, Vec<SumWithConstant>>; + +/// Scans the aggregate expressions to find candidates for the rewrite. +fn analyze_aggregate(aggregate: &Aggregate) -> Result<RewriteGroups> { + let mut groups: RewriteGroups = IndexMap::new(); + + for (idx, expr) in aggregate.aggr_expr.iter().enumerate() { + // Try to match the pattern SUM(col ± lit) + if let Some(sum_info) = extract_sum_with_constant(expr, idx)? { + let key = sum_info.base_expr.schema_name().to_string(); + groups.entry(key).or_default().push(sum_info); + } + } + + // Optimization: Only rewrite if we have at least 2 expressions for the same column. + // If there's only one SUM(a + 1), rewriting it to SUM(a) + 1*COUNT(a) + // actually increases the work (1 agg -> 2 aggs). + groups.retain(|_, v| v.len() >= 2); + + Ok(groups) +} + +/// Extract SUM(base_expr ± constant) pattern from an expression. +/// Handles both `Expr::AggregateFunction(...)` and `Expr::Alias(Expr::AggregateFunction(...))` +/// so the rule works regardless of whether aggregate expressions carry aliases +/// (e.g., when plans are built via the LogicalPlanBuilder API). +fn extract_sum_with_constant(expr: &Expr, idx: usize) -> Result<Option<SumWithConstant>> { + // Unwrap Expr::Alias if present — the SQL planner puts aliases in a + // Projection above the Aggregate, but the builder API allows aliases + // directly inside aggr_expr. + let inner = match expr { + Expr::Alias(alias) => alias.expr.as_ref(), + other => other, + }; + + match inner { + Expr::AggregateFunction(agg_fn) => { + // Rule only applies to SUM + if agg_fn.func.name().to_lowercase() != "sum" { + return Ok(None); + } + + let AggregateFunctionParams { + args, + distinct, + filter, + order_by: _, + null_treatment: _, + } = &agg_fn.params; + + // We cannot easily rewrite SUM(DISTINCT a + 1) or SUM(a + 1) FILTER (...) + // as the math SUM(a) + k*COUNT(a) wouldn't hold correctly with these modifiers. + if *distinct || filter.is_some() { + return Ok(None); + } + + // SUM must have exactly one argument (e.g. SUM(a + 1)). + // This rejects invalid calls like SUM() or non-standard multi-argument variations. + if args.len() != 1 { + return Ok(None); + } + + let arg = &args[0]; + + // Try to match: base_expr +/- constant + // Note: If the base_expr is complex (e.g., SUM(a + b + 1)), base_expr will be "a + b". + // The rule will still work if multiple SUMs have the exact same complex base_expr, + // as they will be grouped by the string representation of that expression. + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = arg + && matches!(op, Operator::Plus | Operator::Minus) + { + // Check if right side is a literal constant + // Check if right side is a literal constant (e.g., SUM(a + 1)) + if let Expr::Literal(constant, _) = right.as_ref() + && is_numeric_constant(constant) + { + return Ok(Some(SumWithConstant { + base_expr: (**left).clone(), + constant: constant.clone(), + operator: *op, + original_index: idx, + })); + } + + // Also check left side for commutative addition (e.g., SUM(1 + a)) + // Does NOT apply to subtraction: SUM(5 - a) ≠ SUM(a - 5) + if let Expr::Literal(constant, _) = left.as_ref() + && is_numeric_constant(constant) + && *op == Operator::Plus + { + return Ok(Some(SumWithConstant { + base_expr: (**right).clone(), + constant: constant.clone(), + operator: Operator::Plus, + original_index: idx, + })); + } + } + + Ok(None) + } + _ => Ok(None), + } +} + +/// Check if a scalar value is a numeric constant +/// (guards against non-arithmetic types like strings, booleans, dates, etc.) +fn is_numeric_constant(value: &ScalarValue) -> bool { Review Comment: How should `NULL` values be handled by this function? I suppose we want to return false for them? ########## datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt: ########## @@ -0,0 +1,692 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Aggregate Rewrite With Constant Optimizer Tests +## Tests for the optimizer rule that rewrites SUM(col ± constant) to SUM(col) ± constant * COUNT(*) +## Rule only applies when there are 2+ SUM expressions on the SAME base column +########## + +# ==== Test 1: Basic addition with multiple sum expressions ==== + +statement ok +CREATE TABLE test_table ( + a INT, + b INT, + c INT +) AS VALUES + (1, 10, 100), + (2, 20, 200), + (3, 30, 300), + (4, 40, 400), + (5, 50, 500); + +# Test: Multiple SUM expressions with constants should be rewritten +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + 1) as sum_a_plus_1, + SUM(a + 2) as sum_a_plus_2, + SUM(a + 3) as sum_a_plus_3 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS sum_a, sum(test_table.a) + count(test_table.a) AS sum_a_plus_1, sum(test_table.a) + Int64(2) * count(test_table.a) AS sum_a_plus_2, sum(test_table.a) + Int64(3) * count(test_table.a) AS sum_a_plus_3 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as sum_a, sum(test_table.a)@0 + count(test_table.a)@1 as sum_a_plus_1, sum(test_table.a)@0 + 2 * count(test_table.a)@1 as sum_a_plus_2, sum(test_table.a)@0 + 3 * count(test_table.a)@1 as sum_a_plus_3] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) as sum_a, + SUM(a + 1) as sum_a_plus_1, + SUM(a + 2) as sum_a_plus_2, + SUM(a + 3) as sum_a_plus_3 +FROM test_table; +---- +15 20 25 30 + +# ==== Test 2: Subtraction operations ==== + +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a - 1) as sum_a_minus_1, + SUM(a - 2) as sum_a_minus_2, + SUM(a - 3) as sum_a_minus_3 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS sum_a, sum(test_table.a) - count(test_table.a) AS sum_a_minus_1, sum(test_table.a) - Int64(2) * count(test_table.a) AS sum_a_minus_2, sum(test_table.a) - Int64(3) * count(test_table.a) AS sum_a_minus_3 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as sum_a, sum(test_table.a)@0 - count(test_table.a)@1 as sum_a_minus_1, sum(test_table.a)@0 - 2 * count(test_table.a)@1 as sum_a_minus_2, sum(test_table.a)@0 - 3 * count(test_table.a)@1 as sum_a_minus_3] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) as sum_a, + SUM(a - 1) as sum_a_minus_1, + SUM(a - 2) as sum_a_minus_2, + SUM(a - 3) as sum_a_minus_3 +FROM test_table; +---- +15 10 5 0 + +# ==== Test 3: With GROUP BY ==== + +statement ok +CREATE TABLE group_test ( + category VARCHAR, + value INT +) AS VALUES + ('A', 1), + ('A', 2), + ('A', 3), + ('B', 4), + ('B', 5), + ('B', 6); + +query TT +EXPLAIN SELECT + category, + SUM(value) as sum_val, + SUM(value + 1) as sum_val_plus_1, + SUM(value - 2) as sum_val_minus_2 +FROM group_test +GROUP BY category; +---- +logical_plan +01)Projection: group_test.category, sum(group_test.value) AS sum_val, sum(group_test.value) + count(group_test.value) AS sum_val_plus_1, sum(group_test.value) - Int64(2) * count(group_test.value) AS sum_val_minus_2 +02)--Aggregate: groupBy=[[group_test.category]], aggr=[[sum(__common_expr_1 AS group_test.value), count(__common_expr_1 AS group_test.value)]] +03)----Projection: CAST(group_test.value AS Int64) AS __common_expr_1, group_test.category +04)------TableScan: group_test projection=[category, value] +physical_plan +01)ProjectionExec: expr=[category@0 as category, sum(group_test.value)@1 as sum_val, sum(group_test.value)@1 + count(group_test.value)@2 as sum_val_plus_1, sum(group_test.value)@1 - 2 * count(group_test.value)@2 as sum_val_minus_2] +02)--AggregateExec: mode=FinalPartitioned, gby=[category@0 as category], aggr=[sum(group_test.value), count(group_test.value)] +03)----RepartitionExec: partitioning=Hash([category@0], 4), input_partitions=1 +04)------AggregateExec: mode=Partial, gby=[category@1 as category], aggr=[sum(group_test.value), count(group_test.value)] +05)--------ProjectionExec: expr=[CAST(value@1 AS Int64) as __common_expr_1, category@0 as category] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query TIII rowsort +SELECT + category, + SUM(value) as sum_val, + SUM(value + 1) as sum_val_plus_1, + SUM(value - 2) as sum_val_minus_2 +FROM group_test +GROUP BY category; +---- +A 6 9 0 +B 15 18 9 + +# ==== Test 4: With nullable columns (SHOULD NOT rewrite - only 1 SUM per column) ==== + +statement ok +CREATE TABLE nullable_test ( + id INT, + a INT, + b INT +) AS VALUES + (1, 10, NULL), + (2, 20, 200), + (3, NULL, 300), + (4, 40, 400), + (5, 50, NULL); + +# This should NOT be rewritten because each column has only 1 SUM with a constant +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + 5) as sum_a_plus_5, + SUM(b) as sum_b, + SUM(b - 10) as sum_b_minus_10 +FROM nullable_test; +---- +logical_plan +01)Projection: sum(nullable_test.a) AS sum_a, sum(nullable_test.a + Int64(5)) AS sum_a_plus_5, sum(nullable_test.b) AS sum_b, sum(nullable_test.b - Int64(10)) AS sum_b_minus_10 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS nullable_test.a), sum(__common_expr_1 AS nullable_test.a + Int64(5)), sum(__common_expr_2 AS nullable_test.b), sum(__common_expr_2 AS nullable_test.b - Int64(10))]] +03)----Projection: CAST(nullable_test.a AS Int64) AS __common_expr_1, CAST(nullable_test.b AS Int64) AS __common_expr_2 +04)------TableScan: nullable_test projection=[a, b] +physical_plan +01)ProjectionExec: expr=[sum(nullable_test.a)@0 as sum_a, sum(nullable_test.a + Int64(5))@1 as sum_a_plus_5, sum(nullable_test.b)@2 as sum_b, sum(nullable_test.b - Int64(10))@3 as sum_b_minus_10] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(nullable_test.a), sum(nullable_test.a + Int64(5)), sum(nullable_test.b), sum(nullable_test.b - Int64(10))] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, CAST(b@1 AS Int64) as __common_expr_2] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) as sum_a, + SUM(a + 5) as sum_a_plus_5, + SUM(b) as sum_b, + SUM(b - 10) as sum_b_minus_10 +FROM nullable_test; +---- +120 140 900 870 + +# Test with multiple SUMs on nullable column +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + 5) as sum_a_plus_5, + SUM(a + 10) as sum_a_plus_10 +FROM nullable_test; +---- +logical_plan +01)Projection: sum(nullable_test.a) AS sum_a, sum(nullable_test.a) + Int64(5) * count(nullable_test.a) AS sum_a_plus_5, sum(nullable_test.a) + Int64(10) * count(nullable_test.a) AS sum_a_plus_10 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS nullable_test.a), count(__common_expr_1 AS nullable_test.a)]] +03)----Projection: CAST(nullable_test.a AS Int64) AS __common_expr_1 +04)------TableScan: nullable_test projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(nullable_test.a)@0 as sum_a, sum(nullable_test.a)@0 + 5 * count(nullable_test.a)@1 as sum_a_plus_5, sum(nullable_test.a)@0 + 10 * count(nullable_test.a)@1 as sum_a_plus_10] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(nullable_test.a), count(nullable_test.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query III +SELECT + SUM(a) as sum_a, + SUM(a + 5) as sum_a_plus_5, + SUM(a + 10) as sum_a_plus_10 +FROM nullable_test; +---- +120 140 160 + +# ==== Test 5: Negative constants ==== + +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + (-1)) as sum_a_minus_1, + SUM(a - (-2)) as sum_a_plus_2, + SUM(a + (-3)) as sum_a_minus_3 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS sum_a, sum(test_table.a) + Int64(-1) * count(test_table.a) AS sum_a_minus_1, sum(test_table.a) - Int64(-2) * count(test_table.a) AS sum_a_plus_2, sum(test_table.a) + Int64(-3) * count(test_table.a) AS sum_a_minus_3 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as sum_a, sum(test_table.a)@0 + -1 * count(test_table.a)@1 as sum_a_minus_1, sum(test_table.a)@0 - -2 * count(test_table.a)@1 as sum_a_plus_2, sum(test_table.a)@0 + -3 * count(test_table.a)@1 as sum_a_minus_3] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) as sum_a, + SUM(a + (-1)) as sum_a_minus_1, + SUM(a - (-2)) as sum_a_plus_2, + SUM(a + (-3)) as sum_a_minus_3 +FROM test_table; +---- +15 10 25 0 + +# ==== Test 6: No matching rewrite patterns ==== + +# Should not rewrite - only one sum with constant +query TT +EXPLAIN SELECT SUM(a + 1) FROM test_table; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(CAST(test_table.a AS Int64) + Int64(1))]] +02)--TableScan: test_table projection=[a] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a + Int64(1))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Should not rewrite - different base columns +query TT +EXPLAIN SELECT SUM(a + 1), SUM(b + 2) FROM test_table; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(CAST(test_table.a AS Int64) + Int64(1)), sum(CAST(test_table.b AS Int64) + Int64(2))]] +02)--TableScan: test_table projection=[a, b] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a + Int64(1)), sum(test_table.b + Int64(2))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# ==== Test 7: Mixed sum types (rewrites a and b, not c) ==== + +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + 1) as sum_a_plus_1, + SUM(b) as sum_b, + SUM(b + 2) as sum_b_plus_2, + SUM(c + 3) as sum_c_plus_3 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS sum_a, sum(test_table.a + Int64(1)) AS sum_a_plus_1, sum(test_table.b) AS sum_b, sum(test_table.b + Int64(2)) AS sum_b_plus_2, sum(test_table.c + Int64(3)) AS sum_c_plus_3 Review Comment: It looks to me like the rewrite is not actually being applied here? (Contra the comment above: "Test 7: Mixed sum types (rewrites a and b, not c)"). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
