This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a165b7f579 Avoid copies in `CountWildcardRule` via TreeNode API 
(#10066)
a165b7f579 is described below

commit a165b7f57946c7c4e40259e982a2a0aad3ee456c
Author: Andrew Lamb <[email protected]>
AuthorDate: Mon Apr 15 10:57:19 2024 -0400

    Avoid copies in `CountWildcardRule` via TreeNode API (#10066)
    
    * Avoid copies in `CountWildcardRule` via TreeNode API
---
 .../optimizer/src/analyzer/count_wildcard_rule.rs  | 241 ++++++---------------
 .../optimizer/src/analyzer/function_rewrite.rs     |   4 +-
 2 files changed, 66 insertions(+), 179 deletions(-)

diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs 
b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index 273766edac..080ec074d3 100644
--- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -15,23 +15,17 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::sync::Arc;
-
 use crate::analyzer::AnalyzerRule;
 
+use crate::utils::NamePreserver;
 use datafusion_common::config::ConfigOptions;
-use datafusion_common::tree_node::{
-    Transformed, TransformedResult, TreeNode, TreeNodeRewriter,
-};
+use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
 use datafusion_common::Result;
-use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, 
InSubquery};
-use datafusion_expr::expr_rewriter::rewrite_preserving_name;
-use datafusion_expr::utils::COUNT_STAR_EXPANSION;
-use datafusion_expr::Expr::ScalarSubquery;
-use datafusion_expr::{
-    aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan,
-    LogicalPlanBuilder, Projection, Sort, Subquery,
+use datafusion_expr::expr::{
+    AggregateFunction, AggregateFunctionDefinition, WindowFunction,
 };
+use datafusion_expr::utils::COUNT_STAR_EXPANSION;
+use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};
 
 /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
 ///
@@ -47,7 +41,8 @@ impl CountWildcardRule {
 
 impl AnalyzerRule for CountWildcardRule {
     fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> 
Result<LogicalPlan> {
-        plan.transform_down(&analyze_internal).data()
+        plan.transform_down_with_subqueries(&analyze_internal)
+            .data()
     }
 
     fn name(&self) -> &str {
@@ -55,173 +50,53 @@ impl AnalyzerRule for CountWildcardRule {
     }
 }
 
-fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
-    let mut rewriter = CountWildcardRewriter {};
-    match plan {
-        LogicalPlan::Window(window) => {
-            let window_expr = window
-                .window_expr
-                .iter()
-                .map(|expr| rewrite_preserving_name(expr.clone(), &mut 
rewriter))
-                .collect::<Result<Vec<_>>>()?;
-
-            Ok(Transformed::yes(
-                LogicalPlanBuilder::from((*window.input).clone())
-                    .window(window_expr)?
-                    .build()?,
-            ))
-        }
-        LogicalPlan::Aggregate(agg) => {
-            let aggr_expr = agg
-                .aggr_expr
-                .iter()
-                .map(|expr| rewrite_preserving_name(expr.clone(), &mut 
rewriter))
-                .collect::<Result<Vec<_>>>()?;
-
-            Ok(Transformed::yes(LogicalPlan::Aggregate(
-                Aggregate::try_new(agg.input.clone(), agg.group_expr, 
aggr_expr)?,
-            )))
-        }
-        LogicalPlan::Sort(Sort { expr, input, fetch }) => {
-            let sort_expr = expr
-                .iter()
-                .map(|expr| rewrite_preserving_name(expr.clone(), &mut 
rewriter))
-                .collect::<Result<Vec<_>>>()?;
-            Ok(Transformed::yes(LogicalPlan::Sort(Sort {
-                expr: sort_expr,
-                input,
-                fetch,
-            })))
-        }
-        LogicalPlan::Projection(projection) => {
-            let projection_expr = projection
-                .expr
-                .iter()
-                .map(|expr| rewrite_preserving_name(expr.clone(), &mut 
rewriter))
-                .collect::<Result<Vec<_>>>()?;
-            Ok(Transformed::yes(LogicalPlan::Projection(
-                Projection::try_new(projection_expr, projection.input)?,
-            )))
-        }
-        LogicalPlan::Filter(Filter {
-            predicate, input, ..
-        }) => {
-            let predicate = rewrite_preserving_name(predicate, &mut rewriter)?;
-            Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new(
-                predicate, input,
-            )?)))
-        }
-
-        _ => Ok(Transformed::no(plan)),
-    }
+fn is_wildcard(expr: &Expr) -> bool {
+    matches!(expr, Expr::Wildcard { qualifier: None })
 }
 
-struct CountWildcardRewriter {}
-
-impl TreeNodeRewriter for CountWildcardRewriter {
-    type Node = Expr;
-
-    fn f_up(&mut self, old_expr: Expr) -> Result<Transformed<Expr>> {
-        Ok(match old_expr.clone() {
-            Expr::WindowFunction(expr::WindowFunction {
-                fun:
-                    expr::WindowFunctionDefinition::AggregateFunction(
-                        aggregate_function::AggregateFunction::Count,
-                    ),
-                args,
-                partition_by,
-                order_by,
-                window_frame,
-                null_treatment,
-            }) if args.len() == 1 => match args[0] {
-                Expr::Wildcard { qualifier: None } => {
-                    Transformed::yes(Expr::WindowFunction(expr::WindowFunction 
{
-                        fun: expr::WindowFunctionDefinition::AggregateFunction(
-                            aggregate_function::AggregateFunction::Count,
-                        ),
-                        args: vec![lit(COUNT_STAR_EXPANSION)],
-                        partition_by,
-                        order_by,
-                        window_frame,
-                        null_treatment,
-                    }))
-                }
+fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
+    matches!(
+        &aggregate_function.func_def,
+        AggregateFunctionDefinition::BuiltIn(
+            datafusion_expr::aggregate_function::AggregateFunction::Count,
+        )
+    ) && aggregate_function.args.len() == 1
+        && is_wildcard(&aggregate_function.args[0])
+}
 
-                _ => Transformed::no(old_expr),
-            },
-            Expr::AggregateFunction(AggregateFunction {
-                func_def:
-                    AggregateFunctionDefinition::BuiltIn(
-                        aggregate_function::AggregateFunction::Count,
-                    ),
-                args,
-                distinct,
-                filter,
-                order_by,
-                null_treatment,
-            }) if args.len() == 1 => match args[0] {
-                Expr::Wildcard { qualifier: None } => {
-                    
Transformed::yes(Expr::AggregateFunction(AggregateFunction::new(
-                        aggregate_function::AggregateFunction::Count,
-                        vec![lit(COUNT_STAR_EXPANSION)],
-                        distinct,
-                        filter,
-                        order_by,
-                        null_treatment,
-                    )))
-                }
-                _ => Transformed::no(old_expr),
-            },
+fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
+    matches!(
+        &window_function.fun,
+        WindowFunctionDefinition::AggregateFunction(
+            datafusion_expr::aggregate_function::AggregateFunction::Count,
+        )
+    ) && window_function.args.len() == 1
+        && is_wildcard(&window_function.args[0])
+}
 
-            ScalarSubquery(Subquery {
-                subquery,
-                outer_ref_columns,
-            }) => subquery
-                .as_ref()
-                .clone()
-                .transform_down(&analyze_internal)?
-                .update_data(|new_plan| {
-                    ScalarSubquery(Subquery {
-                        subquery: Arc::new(new_plan),
-                        outer_ref_columns,
-                    })
-                }),
-            Expr::InSubquery(InSubquery {
-                expr,
-                subquery,
-                negated,
-            }) => subquery
-                .subquery
-                .as_ref()
-                .clone()
-                .transform_down(&analyze_internal)?
-                .update_data(|new_plan| {
-                    Expr::InSubquery(InSubquery::new(
-                        expr,
-                        Subquery {
-                            subquery: Arc::new(new_plan),
-                            outer_ref_columns: subquery.outer_ref_columns,
-                        },
-                        negated,
-                    ))
-                }),
-            Expr::Exists(expr::Exists { subquery, negated }) => subquery
-                .subquery
-                .as_ref()
-                .clone()
-                .transform_down(&analyze_internal)?
-                .update_data(|new_plan| {
-                    Expr::Exists(expr::Exists {
-                        subquery: Subquery {
-                            subquery: Arc::new(new_plan),
-                            outer_ref_columns: subquery.outer_ref_columns,
-                        },
-                        negated,
-                    })
-                }),
-            _ => Transformed::no(old_expr),
-        })
-    }
+fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
+    let name_preserver = NamePreserver::new(&plan);
+    plan.map_expressions(|expr| {
+        let original_name = name_preserver.save(&expr)?;
+        let transformed_expr = expr.transform_up(&|expr| match expr {
+            Expr::WindowFunction(mut window_function)
+                if is_count_star_window_aggregate(&window_function) =>
+            {
+                window_function.args = vec![lit(COUNT_STAR_EXPANSION)];
+                Ok(Transformed::yes(Expr::WindowFunction(window_function)))
+            }
+            Expr::AggregateFunction(mut aggregate_function)
+                if is_count_star_aggregate(&aggregate_function) =>
+            {
+                aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
+                Ok(Transformed::yes(Expr::AggregateFunction(
+                    aggregate_function,
+                )))
+            }
+            _ => Ok(Transformed::no(expr)),
+        })?;
+        transformed_expr.map_data(|data| original_name.restore(data))
+    })
 }
 
 #[cfg(test)]
@@ -233,9 +108,10 @@ mod tests {
     use datafusion_expr::expr::Sort;
     use datafusion_expr::{
         col, count, exists, expr, in_subquery, lit, 
logical_plan::LogicalPlanBuilder,
-        max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr,
+        max, out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, 
Expr,
         WindowFrame, WindowFrameBound, WindowFrameUnits, 
WindowFunctionDefinition,
     };
+    use std::sync::Arc;
 
     fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
         assert_analyzed_plan_eq_display_indent(
@@ -381,6 +257,17 @@ mod tests {
         assert_plan_eq(&plan, expected)
     }
 
+    #[test]
+    fn test_count_wildcard_on_non_count_aggregate() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let err = LogicalPlanBuilder::from(table_scan)
+            .aggregate(Vec::<Expr>::new(), vec![sum(wildcard())])
+            .unwrap_err()
+            .to_string();
+        assert!(err.contains("Error during planning: No function matches the 
given name and argument types 'SUM(Null)'."), "{err}");
+        Ok(())
+    }
+
     #[test]
     fn test_count_wildcard_on_nesting() -> Result<()> {
         let table_scan = test_table_scan()?;
diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs 
b/datafusion/optimizer/src/analyzer/function_rewrite.rs
index deb493e099..4dd3222a32 100644
--- a/datafusion/optimizer/src/analyzer/function_rewrite.rs
+++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs
@@ -64,7 +64,7 @@ impl ApplyFunctionRewrites {
             let original_name = name_preserver.save(&expr)?;
 
             // recursively transform the expression, applying the rewrites at 
each step
-            let result = expr.transform_up(&|expr| {
+            let transformed_expr = expr.transform_up(&|expr| {
                 let mut result = Transformed::no(expr);
                 for rewriter in self.function_rewrites.iter() {
                     result = result.transform_data(|expr| {
@@ -74,7 +74,7 @@ impl ApplyFunctionRewrites {
                 Ok(result)
             })?;
 
-            result.map_data(|expr| original_name.restore(expr))
+            transformed_expr.map_data(|expr| original_name.restore(expr))
         })
     }
 }

Reply via email to