This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a165b7f579 Avoid copies in `CountWildcardRule` via TreeNode API
(#10066)
a165b7f579 is described below
commit a165b7f57946c7c4e40259e982a2a0aad3ee456c
Author: Andrew Lamb <[email protected]>
AuthorDate: Mon Apr 15 10:57:19 2024 -0400
Avoid copies in `CountWildcardRule` via TreeNode API (#10066)
* Avoid copies in `CountWildcardRule` via TreeNode API
---
.../optimizer/src/analyzer/count_wildcard_rule.rs | 241 ++++++---------------
.../optimizer/src/analyzer/function_rewrite.rs | 4 +-
2 files changed, 66 insertions(+), 179 deletions(-)
diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index 273766edac..080ec074d3 100644
--- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -15,23 +15,17 @@
// specific language governing permissions and limitations
// under the License.
-use std::sync::Arc;
-
use crate::analyzer::AnalyzerRule;
+use crate::utils::NamePreserver;
use datafusion_common::config::ConfigOptions;
-use datafusion_common::tree_node::{
- Transformed, TransformedResult, TreeNode, TreeNodeRewriter,
-};
+use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
-use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition,
InSubquery};
-use datafusion_expr::expr_rewriter::rewrite_preserving_name;
-use datafusion_expr::utils::COUNT_STAR_EXPANSION;
-use datafusion_expr::Expr::ScalarSubquery;
-use datafusion_expr::{
- aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan,
- LogicalPlanBuilder, Projection, Sort, Subquery,
+use datafusion_expr::expr::{
+ AggregateFunction, AggregateFunctionDefinition, WindowFunction,
};
+use datafusion_expr::utils::COUNT_STAR_EXPANSION;
+use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};
/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
///
@@ -47,7 +41,8 @@ impl CountWildcardRule {
impl AnalyzerRule for CountWildcardRule {
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) ->
Result<LogicalPlan> {
- plan.transform_down(&analyze_internal).data()
+ plan.transform_down_with_subqueries(&analyze_internal)
+ .data()
}
fn name(&self) -> &str {
@@ -55,173 +50,53 @@ impl AnalyzerRule for CountWildcardRule {
}
}
-fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
- let mut rewriter = CountWildcardRewriter {};
- match plan {
- LogicalPlan::Window(window) => {
- let window_expr = window
- .window_expr
- .iter()
- .map(|expr| rewrite_preserving_name(expr.clone(), &mut
rewriter))
- .collect::<Result<Vec<_>>>()?;
-
- Ok(Transformed::yes(
- LogicalPlanBuilder::from((*window.input).clone())
- .window(window_expr)?
- .build()?,
- ))
- }
- LogicalPlan::Aggregate(agg) => {
- let aggr_expr = agg
- .aggr_expr
- .iter()
- .map(|expr| rewrite_preserving_name(expr.clone(), &mut
rewriter))
- .collect::<Result<Vec<_>>>()?;
-
- Ok(Transformed::yes(LogicalPlan::Aggregate(
- Aggregate::try_new(agg.input.clone(), agg.group_expr,
aggr_expr)?,
- )))
- }
- LogicalPlan::Sort(Sort { expr, input, fetch }) => {
- let sort_expr = expr
- .iter()
- .map(|expr| rewrite_preserving_name(expr.clone(), &mut
rewriter))
- .collect::<Result<Vec<_>>>()?;
- Ok(Transformed::yes(LogicalPlan::Sort(Sort {
- expr: sort_expr,
- input,
- fetch,
- })))
- }
- LogicalPlan::Projection(projection) => {
- let projection_expr = projection
- .expr
- .iter()
- .map(|expr| rewrite_preserving_name(expr.clone(), &mut
rewriter))
- .collect::<Result<Vec<_>>>()?;
- Ok(Transformed::yes(LogicalPlan::Projection(
- Projection::try_new(projection_expr, projection.input)?,
- )))
- }
- LogicalPlan::Filter(Filter {
- predicate, input, ..
- }) => {
- let predicate = rewrite_preserving_name(predicate, &mut rewriter)?;
- Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new(
- predicate, input,
- )?)))
- }
-
- _ => Ok(Transformed::no(plan)),
- }
+fn is_wildcard(expr: &Expr) -> bool {
+ matches!(expr, Expr::Wildcard { qualifier: None })
}
-struct CountWildcardRewriter {}
-
-impl TreeNodeRewriter for CountWildcardRewriter {
- type Node = Expr;
-
- fn f_up(&mut self, old_expr: Expr) -> Result<Transformed<Expr>> {
- Ok(match old_expr.clone() {
- Expr::WindowFunction(expr::WindowFunction {
- fun:
- expr::WindowFunctionDefinition::AggregateFunction(
- aggregate_function::AggregateFunction::Count,
- ),
- args,
- partition_by,
- order_by,
- window_frame,
- null_treatment,
- }) if args.len() == 1 => match args[0] {
- Expr::Wildcard { qualifier: None } => {
- Transformed::yes(Expr::WindowFunction(expr::WindowFunction
{
- fun: expr::WindowFunctionDefinition::AggregateFunction(
- aggregate_function::AggregateFunction::Count,
- ),
- args: vec![lit(COUNT_STAR_EXPANSION)],
- partition_by,
- order_by,
- window_frame,
- null_treatment,
- }))
- }
+fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
+ matches!(
+ &aggregate_function.func_def,
+ AggregateFunctionDefinition::BuiltIn(
+ datafusion_expr::aggregate_function::AggregateFunction::Count,
+ )
+ ) && aggregate_function.args.len() == 1
+ && is_wildcard(&aggregate_function.args[0])
+}
- _ => Transformed::no(old_expr),
- },
- Expr::AggregateFunction(AggregateFunction {
- func_def:
- AggregateFunctionDefinition::BuiltIn(
- aggregate_function::AggregateFunction::Count,
- ),
- args,
- distinct,
- filter,
- order_by,
- null_treatment,
- }) if args.len() == 1 => match args[0] {
- Expr::Wildcard { qualifier: None } => {
-
Transformed::yes(Expr::AggregateFunction(AggregateFunction::new(
- aggregate_function::AggregateFunction::Count,
- vec![lit(COUNT_STAR_EXPANSION)],
- distinct,
- filter,
- order_by,
- null_treatment,
- )))
- }
- _ => Transformed::no(old_expr),
- },
+fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
+ matches!(
+ &window_function.fun,
+ WindowFunctionDefinition::AggregateFunction(
+ datafusion_expr::aggregate_function::AggregateFunction::Count,
+ )
+ ) && window_function.args.len() == 1
+ && is_wildcard(&window_function.args[0])
+}
- ScalarSubquery(Subquery {
- subquery,
- outer_ref_columns,
- }) => subquery
- .as_ref()
- .clone()
- .transform_down(&analyze_internal)?
- .update_data(|new_plan| {
- ScalarSubquery(Subquery {
- subquery: Arc::new(new_plan),
- outer_ref_columns,
- })
- }),
- Expr::InSubquery(InSubquery {
- expr,
- subquery,
- negated,
- }) => subquery
- .subquery
- .as_ref()
- .clone()
- .transform_down(&analyze_internal)?
- .update_data(|new_plan| {
- Expr::InSubquery(InSubquery::new(
- expr,
- Subquery {
- subquery: Arc::new(new_plan),
- outer_ref_columns: subquery.outer_ref_columns,
- },
- negated,
- ))
- }),
- Expr::Exists(expr::Exists { subquery, negated }) => subquery
- .subquery
- .as_ref()
- .clone()
- .transform_down(&analyze_internal)?
- .update_data(|new_plan| {
- Expr::Exists(expr::Exists {
- subquery: Subquery {
- subquery: Arc::new(new_plan),
- outer_ref_columns: subquery.outer_ref_columns,
- },
- negated,
- })
- }),
- _ => Transformed::no(old_expr),
- })
- }
+fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
+ let name_preserver = NamePreserver::new(&plan);
+ plan.map_expressions(|expr| {
+ let original_name = name_preserver.save(&expr)?;
+ let transformed_expr = expr.transform_up(&|expr| match expr {
+ Expr::WindowFunction(mut window_function)
+ if is_count_star_window_aggregate(&window_function) =>
+ {
+ window_function.args = vec![lit(COUNT_STAR_EXPANSION)];
+ Ok(Transformed::yes(Expr::WindowFunction(window_function)))
+ }
+ Expr::AggregateFunction(mut aggregate_function)
+ if is_count_star_aggregate(&aggregate_function) =>
+ {
+ aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
+ Ok(Transformed::yes(Expr::AggregateFunction(
+ aggregate_function,
+ )))
+ }
+ _ => Ok(Transformed::no(expr)),
+ })?;
+ transformed_expr.map_data(|data| original_name.restore(data))
+ })
}
#[cfg(test)]
@@ -233,9 +108,10 @@ mod tests {
use datafusion_expr::expr::Sort;
use datafusion_expr::{
col, count, exists, expr, in_subquery, lit,
logical_plan::LogicalPlanBuilder,
- max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr,
+ max, out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction,
Expr,
WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
+ use std::sync::Arc;
fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
assert_analyzed_plan_eq_display_indent(
@@ -381,6 +257,17 @@ mod tests {
assert_plan_eq(&plan, expected)
}
+ #[test]
+ fn test_count_wildcard_on_non_count_aggregate() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let err = LogicalPlanBuilder::from(table_scan)
+ .aggregate(Vec::<Expr>::new(), vec![sum(wildcard())])
+ .unwrap_err()
+ .to_string();
+ assert!(err.contains("Error during planning: No function matches the
given name and argument types 'SUM(Null)'."), "{err}");
+ Ok(())
+ }
+
#[test]
fn test_count_wildcard_on_nesting() -> Result<()> {
let table_scan = test_table_scan()?;
diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs
b/datafusion/optimizer/src/analyzer/function_rewrite.rs
index deb493e099..4dd3222a32 100644
--- a/datafusion/optimizer/src/analyzer/function_rewrite.rs
+++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs
@@ -64,7 +64,7 @@ impl ApplyFunctionRewrites {
let original_name = name_preserver.save(&expr)?;
// recursively transform the expression, applying the rewrites at
each step
- let result = expr.transform_up(&|expr| {
+ let transformed_expr = expr.transform_up(&|expr| {
let mut result = Transformed::no(expr);
for rewriter in self.function_rewrites.iter() {
result = result.transform_data(|expr| {
@@ -74,7 +74,7 @@ impl ApplyFunctionRewrites {
Ok(result)
})?;
- result.map_data(|expr| original_name.restore(expr))
+ transformed_expr.map_data(|expr| original_name.restore(expr))
})
}
}