This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new c6b2efccf6 Stop copying LogicalPlan and Exprs in 
`CommonSubexprEliminate` (2-3% planning speed improvement) (#10835)
c6b2efccf6 is described below

commit c6b2efccf6238cc87f2414efb28ae3b263ed27af
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Jun 19 10:13:27 2024 -0400

    Stop copying LogicalPlan and Exprs in `CommonSubexprEliminate` (2-3% 
planning speed improvement) (#10835)
    
    * Stop copying LogicalPlan and Exprs in `CommonSubexprEliminate`
    
    * thread transformed
    
    * Update unary to report transformed correctly
    
    * Preserve through window transforms
    
    * track aggregate
    
    * Avoid re-computing Aggregate schema
    
    * Update datafusion/optimizer/src/common_subexpr_eliminate.rs
    
    * Avoid unecessary setting transform flat
    
    * Cleanup unaliasing
---
 datafusion/common/src/tree_node.rs                 |   5 +
 datafusion/expr/src/logical_plan/plan.rs           |  64 +--
 .../optimizer/src/common_subexpr_eliminate.rs      | 594 ++++++++++++++-------
 3 files changed, 439 insertions(+), 224 deletions(-)

diff --git a/datafusion/common/src/tree_node.rs 
b/datafusion/common/src/tree_node.rs
index d0dd24621d..276a1cc4c5 100644
--- a/datafusion/common/src/tree_node.rs
+++ b/datafusion/common/src/tree_node.rs
@@ -615,6 +615,11 @@ impl<T> Transformed<T> {
         }
     }
 
+    /// Create a `Transformed` with `transformed and 
[`TreeNodeRecursion::Continue`].
+    pub fn new_transformed(data: T, transformed: bool) -> Self {
+        Self::new(data, transformed, TreeNodeRecursion::Continue)
+    }
+
     /// Wrapper for transformed data with [`TreeNodeRecursion::Continue`] 
statement.
     pub fn yes(data: T) -> Self {
         Self::new(data, true, TreeNodeRecursion::Continue)
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 02378ab3fc..85958223ac 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -870,37 +870,7 @@ impl LogicalPlan {
             LogicalPlan::Filter { .. } => {
                 assert_eq!(1, expr.len());
                 let predicate = expr.pop().unwrap();
-
-                // filter predicates should not contain aliased expressions so 
we remove any aliases
-                // before this logic was added we would have aliases within 
filters such as for
-                // benchmark q6:
-                //
-                // lineitem.l_shipdate >= Date32(\"8766\")
-                // AND lineitem.l_shipdate < Date32(\"9131\")
-                // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS 
lineitem.l_discount >=
-                // Decimal128(Some(49999999999999),30,15)
-                // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS 
lineitem.l_discount <=
-                // Decimal128(Some(69999999999999),30,15)
-                // AND lineitem.l_quantity < Decimal128(Some(2400),15,2)
-
-                let predicate = predicate
-                    .transform_down(|expr| {
-                        match expr {
-                            Expr::Exists { .. }
-                            | Expr::ScalarSubquery(_)
-                            | Expr::InSubquery(_) => {
-                                // subqueries could contain aliases so we 
don't recurse into those
-                                Ok(Transformed::new(expr, false, 
TreeNodeRecursion::Jump))
-                            }
-                            Expr::Alias(_) => Ok(Transformed::new(
-                                expr.unalias(),
-                                true,
-                                TreeNodeRecursion::Jump,
-                            )),
-                            _ => Ok(Transformed::no(expr)),
-                        }
-                    })
-                    .data()?;
+                let predicate = Filter::remove_aliases(predicate)?.data;
 
                 Filter::try_new(predicate, Arc::new(inputs.swap_remove(0)))
                     .map(LogicalPlan::Filter)
@@ -2230,6 +2200,38 @@ impl Filter {
         }
         false
     }
+
+    /// Remove aliases from a predicate for use in a `Filter`
+    ///
+    /// filter predicates should not contain aliased expressions so we remove
+    /// any aliases.
+    ///
+    /// before this logic was added we would have aliases within filters such 
as
+    /// for benchmark q6:
+    ///
+    /// ```sql
+    /// lineitem.l_shipdate >= Date32(\"8766\")
+    /// AND lineitem.l_shipdate < Date32(\"9131\")
+    /// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS 
lineitem.l_discount >=
+    /// Decimal128(Some(49999999999999),30,15)
+    /// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS 
lineitem.l_discount <=
+    /// Decimal128(Some(69999999999999),30,15)
+    /// AND lineitem.l_quantity < Decimal128(Some(2400),15,2)
+    /// ```
+    pub fn remove_aliases(predicate: Expr) -> Result<Transformed<Expr>> {
+        predicate.transform_down(|expr| {
+            match expr {
+                Expr::Exists { .. } | Expr::ScalarSubquery(_) | 
Expr::InSubquery(_) => {
+                    // subqueries could contain aliases so we don't recurse 
into those
+                    Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump))
+                }
+                Expr::Alias(Alias { expr, .. }) => {
+                    Ok(Transformed::new(*expr, true, TreeNodeRecursion::Jump))
+                }
+                _ => Ok(Transformed::no(expr)),
+            }
+        })
+    }
 }
 
 /// Window its input based on a set of window spec and window function (e.g. 
SUM or RANK)
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs 
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index e150a957bf..7f4093ba11 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -20,16 +20,22 @@
 use std::collections::{BTreeSet, HashMap};
 use std::sync::Arc;
 
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use crate::{OptimizerConfig, OptimizerRule};
 
+use crate::optimizer::ApplyOrder;
+use crate::utils::NamePreserver;
 use datafusion_common::alias::AliasGenerator;
 use datafusion_common::tree_node::{
-    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, 
TreeNodeRewriter,
-    TreeNodeVisitor,
+    Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, 
TreeNodeVisitor,
+};
+use datafusion_common::{
+    internal_datafusion_err, internal_err, qualified_name, Column, DFSchema, 
Result,
 };
-use datafusion_common::{qualified_name, Column, DFSchema, DataFusionError, 
Result};
 use datafusion_expr::expr::Alias;
-use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, 
Window};
+use datafusion_expr::logical_plan::tree_node::unwrap_arc;
+use datafusion_expr::logical_plan::{
+    Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
+};
 use datafusion_expr::{col, Expr, ExprSchemable};
 use indexmap::IndexMap;
 
@@ -123,32 +129,39 @@ impl CommonSubexprEliminate {
     /// Returns the rewritten expressions
     fn rewrite_exprs_list(
         &self,
-        exprs_list: &[&[Expr]],
+        exprs_list: Vec<Vec<Expr>>,
         arrays_list: &[&[IdArray]],
         expr_stats: &ExprStats,
         common_exprs: &mut CommonExprs,
         alias_generator: &AliasGenerator,
-    ) -> Result<Vec<Vec<Expr>>> {
+    ) -> Result<Transformed<Vec<Vec<Expr>>>> {
+        let mut transformed = false;
         exprs_list
-            .iter()
+            .into_iter()
             .zip(arrays_list.iter())
             .map(|(exprs, arrays)| {
                 exprs
-                    .iter()
-                    .cloned()
+                    .into_iter()
                     .zip(arrays.iter())
                     .map(|(expr, id_array)| {
-                        replace_common_expr(
+                        let replaced = replace_common_expr(
                             expr,
                             id_array,
                             expr_stats,
                             common_exprs,
                             alias_generator,
-                        )
+                        )?;
+                        // remember if this expression was actually replaced
+                        transformed |= replaced.transformed;
+                        Ok(replaced.data)
                     })
                     .collect::<Result<Vec<_>>>()
             })
             .collect::<Result<Vec<_>>>()
+            .map(|rewritten_exprs_list| {
+                // propagate back transformed information
+                Transformed::new_transformed(rewritten_exprs_list, transformed)
+            })
     }
 
     /// Rewrites the expression in `exprs_list` with common sub-expressions
@@ -161,13 +174,15 @@ impl CommonSubexprEliminate {
     ///    common sub-expressions that were used
     fn rewrite_expr(
         &self,
-        exprs_list: &[&[Expr]],
+        exprs_list: Vec<Vec<Expr>>,
         arrays_list: &[&[IdArray]],
-        input: &LogicalPlan,
+        input: LogicalPlan,
         expr_stats: &ExprStats,
         config: &dyn OptimizerConfig,
-    ) -> Result<(Vec<Vec<Expr>>, LogicalPlan)> {
+    ) -> Result<Transformed<(Vec<Vec<Expr>>, LogicalPlan)>> {
+        let mut transformed = false;
         let mut common_exprs = CommonExprs::new();
+
         let rewrite_exprs = self.rewrite_exprs_list(
             exprs_list,
             arrays_list,
@@ -175,115 +190,193 @@ impl CommonSubexprEliminate {
             &mut common_exprs,
             &config.alias_generator(),
         )?;
+        transformed |= rewrite_exprs.transformed;
 
-        let mut new_input = self
-            .try_optimize(input, config)?
-            .unwrap_or_else(|| input.clone());
+        let new_input = self.rewrite(input, config)?;
+        transformed |= new_input.transformed;
+        let mut new_input = new_input.data;
 
         if !common_exprs.is_empty() {
+            assert!(transformed);
             new_input = build_common_expr_project_plan(new_input, 
common_exprs)?;
         }
 
-        Ok((rewrite_exprs, new_input))
+        // return the transformed information
+
+        Ok(Transformed::new_transformed(
+            (rewrite_exprs.data, new_input),
+            transformed,
+        ))
     }
 
-    fn try_optimize_window(
+    fn try_optimize_proj(
         &self,
-        window: &Window,
+        projection: Projection,
         config: &dyn OptimizerConfig,
-    ) -> Result<LogicalPlan> {
-        let mut window_exprs = vec![];
-        let mut arrays_per_window = vec![];
-        let mut expr_stats = ExprStats::new();
-
-        // Get all window expressions inside the consecutive window operators.
-        // Consecutive window expressions may refer to same complex expression.
-        // If same complex expression is referred more than once by subsequent 
`WindowAggr`s,
-        // we can cache complex expression by evaluating it with a projection 
before the
-        // first WindowAggr.
-        // This enables us to cache complex expression "c3+c4" for following 
plan:
-        // WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN 
UNBOUNDED PRECEDING AND CURRENT ROW]]
-        // --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN 
UNBOUNDED PRECEDING AND CURRENT ROW]]
-        // where, it is referred once by each `WindowAggr` (total of 2) in the 
plan.
-        let mut plan = LogicalPlan::Window(window.clone());
-        while let LogicalPlan::Window(window) = plan {
-            let Window {
-                input, window_expr, ..
-            } = window;
-            plan = input.as_ref().clone();
+    ) -> Result<Transformed<LogicalPlan>> {
+        let Projection {
+            expr,
+            input,
+            schema,
+            ..
+        } = projection;
+        let input = unwrap_arc(input);
+        self.try_unary_plan(expr, input, config)?
+            .map_data(|(new_expr, new_input)| {
+                Projection::try_new_with_schema(new_expr, Arc::new(new_input), 
schema)
+                    .map(LogicalPlan::Projection)
+            })
+    }
+    fn try_optimize_sort(
+        &self,
+        sort: Sort,
+        config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        let Sort { expr, input, fetch } = sort;
+        let input = unwrap_arc(input);
+        let new_sort = self.try_unary_plan(expr, input, config)?.update_data(
+            |(new_expr, new_input)| {
+                LogicalPlan::Sort(Sort {
+                    expr: new_expr,
+                    input: Arc::new(new_input),
+                    fetch,
+                })
+            },
+        );
+        Ok(new_sort)
+    }
 
-            let arrays = to_arrays(&window_expr, &mut expr_stats, 
ExprMask::Normal)?;
+    fn try_optimize_filter(
+        &self,
+        filter: Filter,
+        config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        let Filter {
+            predicate, input, ..
+        } = filter;
+        let input = unwrap_arc(input);
+        let expr = vec![predicate];
+        self.try_unary_plan(expr, input, config)?
+            .transform_data(|(mut new_expr, new_input)| {
+                assert_eq!(new_expr.len(), 1); // passed in vec![predicate]
+                let new_predicate = new_expr.pop().unwrap();
+                Ok(Filter::remove_aliases(new_predicate)?
+                    .update_data(|new_predicate| (new_predicate, new_input)))
+            })?
+            .map_data(|(new_predicate, new_input)| {
+                Filter::try_new(new_predicate, Arc::new(new_input))
+                    .map(LogicalPlan::Filter)
+            })
+    }
 
-            window_exprs.push(window_expr);
-            arrays_per_window.push(arrays);
-        }
+    fn try_optimize_window(
+        &self,
+        window: Window,
+        config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        // collect all window expressions from any number of LogicalPlanWindow
+        let ConsecutiveWindowExprs {
+            window_exprs,
+            arrays_per_window,
+            expr_stats,
+            plan,
+        } = ConsecutiveWindowExprs::try_new(window)?;
 
-        let mut window_exprs = window_exprs
-            .iter()
-            .map(|expr| expr.as_slice())
-            .collect::<Vec<_>>();
         let arrays_per_window = arrays_per_window
             .iter()
             .map(|arrays| arrays.as_slice())
             .collect::<Vec<_>>();
 
+        // save the original names
+        let name_preserver = NamePreserver::new(&plan);
+        let mut saved_names = window_exprs
+            .iter()
+            .map(|exprs| {
+                exprs
+                    .iter()
+                    .map(|expr| name_preserver.save(expr))
+                    .collect::<Result<Vec<_>>>()
+            })
+            .collect::<Result<Vec<_>>>()?;
+
         assert_eq!(window_exprs.len(), arrays_per_window.len());
-        let (mut new_expr, new_input) = self.rewrite_expr(
-            &window_exprs,
+        let num_window_exprs = window_exprs.len();
+        let rewritten_window_exprs = self.rewrite_expr(
+            window_exprs,
             &arrays_per_window,
-            &plan,
+            plan,
             &expr_stats,
             config,
         )?;
-        assert_eq!(window_exprs.len(), new_expr.len());
+        let transformed = rewritten_window_exprs.transformed;
+
+        let (mut new_expr, new_input) = rewritten_window_exprs.data;
 
-        // Construct consecutive window operator, with their corresponding new 
window expressions.
-        plan = new_input;
-        while let Some(new_window_expr) = new_expr.pop() {
-            // Since `new_expr` and `window_exprs` length are same. We can 
safely `.unwrap` here.
-            let orig_window_expr = window_exprs.pop().unwrap();
-            assert_eq!(new_window_expr.len(), orig_window_expr.len());
+        let mut plan = new_input;
 
-            // Rename new re-written window expressions with original name (by 
giving alias)
-            // Otherwise we may receive schema error, in subsequent operators.
+        // Construct consecutive window operator, with their corresponding new
+        // window expressions.
+        //
+        // Note this iterates over, `new_expr` and `saved_names` which are the
+        // same length, in reverse order
+        assert_eq!(num_window_exprs, new_expr.len());
+        assert_eq!(num_window_exprs, saved_names.len());
+        while let (Some(new_window_expr), Some(saved_names)) =
+            (new_expr.pop(), saved_names.pop())
+        {
+            assert_eq!(new_window_expr.len(), saved_names.len());
+
+            // Rename re-written window expressions with original name, to
+            // preserve the output schema
             let new_window_expr = new_window_expr
                 .into_iter()
-                .zip(orig_window_expr.iter())
-                .map(|(new_window_expr, window_expr)| {
-                    let original_name = window_expr.name_for_alias()?;
-                    new_window_expr.alias_if_changed(original_name)
-                })
+                .zip(saved_names.into_iter())
+                .map(|(new_window_expr, saved_name)| 
saved_name.restore(new_window_expr))
                 .collect::<Result<Vec<_>>>()?;
             plan = LogicalPlan::Window(Window::try_new(new_window_expr, 
Arc::new(plan))?);
         }
 
-        Ok(plan)
+        Ok(Transformed::new_transformed(plan, transformed))
     }
 
     fn try_optimize_aggregate(
         &self,
-        aggregate: &Aggregate,
+        aggregate: Aggregate,
         config: &dyn OptimizerConfig,
-    ) -> Result<LogicalPlan> {
+    ) -> Result<Transformed<LogicalPlan>> {
         let Aggregate {
             group_expr,
             aggr_expr,
             input,
+            schema: orig_schema,
             ..
         } = aggregate;
         let mut expr_stats = ExprStats::new();
 
+        // track transformed information
+        let mut transformed = false;
+
         // rewrite inputs
-        let group_arrays = to_arrays(group_expr, &mut expr_stats, 
ExprMask::Normal)?;
-        let aggr_arrays = to_arrays(aggr_expr, &mut expr_stats, 
ExprMask::Normal)?;
+        let group_arrays = to_arrays(&group_expr, &mut expr_stats, 
ExprMask::Normal)?;
+        let aggr_arrays = to_arrays(&aggr_expr, &mut expr_stats, 
ExprMask::Normal)?;
+
+        let name_perserver = NamePreserver::new_for_projection();
+        let saved_names = aggr_expr
+            .iter()
+            .map(|expr| name_perserver.save(expr))
+            .collect::<Result<Vec<_>>>()?;
 
-        let (mut new_expr, new_input) = self.rewrite_expr(
-            &[group_expr, aggr_expr],
+        // rewrite both group exprs and aggr_expr
+        let rewritten = self.rewrite_expr(
+            vec![group_expr, aggr_expr],
             &[&group_arrays, &aggr_arrays],
-            input,
+            unwrap_arc(input),
             &expr_stats,
             config,
         )?;
+        transformed |= rewritten.transformed;
+        let (mut new_expr, new_input) = rewritten.data;
+
         // note the reversed pop order.
         let new_aggr_expr = pop_expr(&mut new_expr)?;
         let new_group_expr = pop_expr(&mut new_expr)?;
@@ -296,108 +389,208 @@ impl CommonSubexprEliminate {
             &mut expr_stats,
             ExprMask::NormalAndAggregates,
         )?;
-        let mut common_exprs = CommonExprs::new();
-        let mut rewritten = self.rewrite_exprs_list(
-            &[&new_aggr_expr],
+        let mut common_exprs = IndexMap::new();
+        let mut rewritten_exprs = self.rewrite_exprs_list(
+            vec![new_aggr_expr.clone()],
             &[&aggr_arrays],
             &expr_stats,
             &mut common_exprs,
             &config.alias_generator(),
         )?;
-        let rewritten = pop_expr(&mut rewritten)?;
+        transformed |= rewritten_exprs.transformed;
+        let rewritten = pop_expr(&mut rewritten_exprs.data)?;
 
         if common_exprs.is_empty() {
             // Alias aggregation expressions if they have changed
             let new_aggr_expr = new_aggr_expr
-                .iter()
-                .zip(aggr_expr.iter())
-                .map(|(new_expr, old_expr)| {
-                    new_expr.clone().alias_if_changed(old_expr.display_name()?)
-                })
+                .into_iter()
+                .zip(saved_names.into_iter())
+                .map(|(new_expr, saved_name)| saved_name.restore(new_expr))
                 .collect::<Result<Vec<Expr>>>()?;
-            // Since group_epxr changes, schema changes also. Use try_new 
method.
-            Aggregate::try_new(Arc::new(new_input), new_group_expr, 
new_aggr_expr)
-                .map(LogicalPlan::Aggregate)
-        } else {
-            let mut agg_exprs = common_exprs
-                .into_values()
-                .map(|(expr, expr_alias)| expr.alias(expr_alias))
-                .collect::<Vec<_>>();
-
-            let mut proj_exprs = vec![];
-            for expr in &new_group_expr {
-                extract_expressions(expr, &new_input_schema, &mut proj_exprs)?
-            }
-            for (expr_rewritten, expr_orig) in 
rewritten.into_iter().zip(new_aggr_expr) {
-                if expr_rewritten == expr_orig {
-                    if let Expr::Alias(Alias { expr, name, .. }) = 
expr_rewritten {
-                        agg_exprs.push(expr.alias(&name));
-                        proj_exprs.push(Expr::Column(Column::from_name(name)));
-                    } else {
-                        let expr_alias = 
config.alias_generator().next(CSE_PREFIX);
-                        let (qualifier, field) =
-                            expr_rewritten.to_field(&new_input_schema)?;
-                        let out_name = qualified_name(qualifier.as_ref(), 
field.name());
-
-                        agg_exprs.push(expr_rewritten.alias(&expr_alias));
-                        proj_exprs.push(
-                            
Expr::Column(Column::from_name(expr_alias)).alias(out_name),
-                        );
-                    }
+            // Since group_expr may have changed, schema may also. Use try_new 
method.
+            let new_agg = if transformed {
+                Aggregate::try_new(Arc::new(new_input), new_group_expr, 
new_aggr_expr)?
+            } else {
+                Aggregate::try_new_with_schema(
+                    Arc::new(new_input),
+                    new_group_expr,
+                    new_aggr_expr,
+                    orig_schema,
+                )?
+            };
+            let new_agg = LogicalPlan::Aggregate(new_agg);
+            return Ok(Transformed::new_transformed(new_agg, transformed));
+        }
+        let mut agg_exprs = common_exprs
+            .into_values()
+            .map(|(expr, expr_alias)| expr.alias(expr_alias))
+            .collect::<Vec<_>>();
+
+        let mut proj_exprs = vec![];
+        for expr in &new_group_expr {
+            extract_expressions(expr, &new_input_schema, &mut proj_exprs)?
+        }
+        for (expr_rewritten, expr_orig) in 
rewritten.into_iter().zip(new_aggr_expr) {
+            if expr_rewritten == expr_orig {
+                if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten {
+                    agg_exprs.push(expr.alias(&name));
+                    proj_exprs.push(Expr::Column(Column::from_name(name)));
                 } else {
-                    proj_exprs.push(expr_rewritten);
+                    let expr_alias = config.alias_generator().next(CSE_PREFIX);
+                    let (qualifier, field) =
+                        expr_rewritten.to_field(&new_input_schema)?;
+                    let out_name = qualified_name(qualifier.as_ref(), 
field.name());
+
+                    agg_exprs.push(expr_rewritten.alias(&expr_alias));
+                    proj_exprs.push(
+                        
Expr::Column(Column::from_name(expr_alias)).alias(out_name),
+                    );
                 }
+            } else {
+                proj_exprs.push(expr_rewritten);
             }
+        }
 
-            let agg = LogicalPlan::Aggregate(Aggregate::try_new(
-                Arc::new(new_input),
-                new_group_expr,
-                agg_exprs,
-            )?);
+        let agg = LogicalPlan::Aggregate(Aggregate::try_new(
+            Arc::new(new_input),
+            new_group_expr,
+            agg_exprs,
+        )?);
 
-            Ok(LogicalPlan::Projection(Projection::try_new(
-                proj_exprs,
-                Arc::new(agg),
-            )?))
-        }
+        Projection::try_new(proj_exprs, Arc::new(agg))
+            .map(LogicalPlan::Projection)
+            .map(Transformed::yes)
     }
 
+    /// Rewrites the expr list and input to remove common subexpressions
+    ///
+    /// # Parameters
+    ///
+    /// * `exprs`: List of expressions in the node
+    /// * `input`: input plan (that produces the columns referred to in 
`exprs`)
+    ///
+    /// # Return value
+    ///
+    ///  Returns `(rewritten_exprs, new_input)`. `new_input` is either:
+    ///
+    /// 1. The original `input` of no common subexpressions were extracted
+    /// 2. A newly added projection on top of the original input
+    /// that computes the common subexpressions
     fn try_unary_plan(
         &self,
-        plan: &LogicalPlan,
+        expr: Vec<Expr>,
+        input: LogicalPlan,
         config: &dyn OptimizerConfig,
-    ) -> Result<LogicalPlan> {
-        let expr = plan.expressions();
-        let inputs = plan.inputs();
-        let input = inputs[0];
+    ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
         let mut expr_stats = ExprStats::new();
-
-        // Visit expr list and build expr identifier to occuring count map 
(`expr_stats`).
         let arrays = to_arrays(&expr, &mut expr_stats, ExprMask::Normal)?;
 
-        let (mut new_expr, new_input) =
-            self.rewrite_expr(&[&expr], &[&arrays], input, &expr_stats, 
config)?;
+        self.rewrite_expr(vec![expr], &[&arrays], input, &expr_stats, config)?
+            .map_data(|(mut new_expr, new_input)| {
+                assert_eq!(new_expr.len(), 1);
+                Ok((new_expr.pop().unwrap(), new_input))
+            })
+    }
+}
 
-        plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input])
+/// Get all window expressions inside the consecutive window operators.
+///
+/// Returns the window expressions, and the input to the deepest child
+/// LogicalPlan.
+///
+/// For example, if the input widnow looks like
+///
+/// ```text
+///   LogicalPlan::Window(exprs=[a, b, c])
+///     LogicalPlan::Window(exprs=[d])
+///       InputPlan
+/// ```
+///
+/// Returns:
+/// *  `window_exprs`: `[a, b, c, d]`
+/// * InputPlan
+///
+/// Consecutive window expressions may refer to same complex expression.
+///
+/// If same complex expression is referred more than once by subsequent
+/// `WindowAggr`s, we can cache complex expression by evaluating it with a
+/// projection before the first WindowAggr.
+///
+/// This enables us to cache complex expression "c3+c4" for following plan:
+///
+/// ```text
+/// WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN 
UNBOUNDED PRECEDING AND CURRENT ROW]]
+/// --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN 
UNBOUNDED PRECEDING AND CURRENT ROW]]
+/// ```
+///
+/// where, it is referred once by each `WindowAggr` (total of 2) in the plan.
+struct ConsecutiveWindowExprs {
+    window_exprs: Vec<Vec<Expr>>,
+    /// result of calling `to_arrays` on each set of window exprs
+    arrays_per_window: Vec<Vec<Vec<(usize, String)>>>,
+    expr_stats: ExprStats,
+    /// input plan to the window
+    plan: LogicalPlan,
+}
+
+impl ConsecutiveWindowExprs {
+    fn try_new(window: Window) -> Result<Self> {
+        let mut window_exprs = vec![];
+        let mut arrays_per_window = vec![];
+        let mut expr_stats = ExprStats::new();
+
+        let mut plan = LogicalPlan::Window(window);
+        while let LogicalPlan::Window(Window {
+            input, window_expr, ..
+        }) = plan
+        {
+            plan = unwrap_arc(input);
+
+            let arrays = to_arrays(&window_expr, &mut expr_stats, 
ExprMask::Normal)?;
+
+            window_exprs.push(window_expr);
+            arrays_per_window.push(arrays);
+        }
+
+        Ok(Self {
+            window_exprs,
+            arrays_per_window,
+            expr_stats,
+            plan,
+        })
     }
 }
 
 impl OptimizerRule for CommonSubexprEliminate {
     fn try_optimize(
         &self,
-        plan: &LogicalPlan,
-        config: &dyn OptimizerConfig,
+        _plan: &LogicalPlan,
+        _config: &dyn OptimizerConfig,
     ) -> Result<Option<LogicalPlan>> {
+        internal_err!("Should have called CommonSubexprEliminate::rewrite")
+    }
+
+    fn supports_rewrite(&self) -> bool {
+        true
+    }
+
+    fn apply_order(&self) -> Option<ApplyOrder> {
+        Some(ApplyOrder::TopDown)
+    }
+
+    fn rewrite(
+        &self,
+        plan: LogicalPlan,
+        config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        let original_schema = Arc::clone(plan.schema());
+
         let optimized_plan = match plan {
-            LogicalPlan::Projection(_)
-            | LogicalPlan::Sort(_)
-            | LogicalPlan::Filter(_) => Some(self.try_unary_plan(plan, 
config)?),
-            LogicalPlan::Window(window) => {
-                Some(self.try_optimize_window(window, config)?)
-            }
-            LogicalPlan::Aggregate(aggregate) => {
-                Some(self.try_optimize_aggregate(aggregate, config)?)
-            }
+            LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, 
config)?,
+            LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
+            LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, 
config)?,
+            LogicalPlan::Window(window) => self.try_optimize_window(window, 
config)?,
+            LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, 
config)?,
             LogicalPlan::Join(_)
             | LogicalPlan::CrossJoin(_)
             | LogicalPlan::Repartition(_)
@@ -420,21 +613,19 @@ impl OptimizerRule for CommonSubexprEliminate {
             | LogicalPlan::Unnest(_)
             | LogicalPlan::RecursiveQuery(_)
             | LogicalPlan::Prepare(_) => {
-                // apply the optimization to all inputs of the plan
-                utils::optimize_children(self, plan, config)?
+                // ApplyOrder::TopDown handles recursion
+                Transformed::no(plan)
             }
         };
 
-        let original_schema = plan.schema();
-        match optimized_plan {
-            Some(optimized_plan) if optimized_plan.schema() != original_schema 
=> {
-                // add an additional projection if the output schema changed.
-                Ok(Some(build_recover_project_plan(
-                    original_schema,
-                    optimized_plan,
-                )?))
-            }
-            plan => Ok(plan),
+        // If we rewrote the plan, ensure the schema stays the same
+        if optimized_plan.transformed && optimized_plan.data.schema() != 
&original_schema
+        {
+            optimized_plan.map_data(|optimized_plan| {
+                build_recover_project_plan(&original_schema, optimized_plan)
+            })
+        } else {
+            Ok(optimized_plan)
         }
     }
 
@@ -459,22 +650,29 @@ impl CommonSubexprEliminate {
 fn pop_expr(new_expr: &mut Vec<Vec<Expr>>) -> Result<Vec<Expr>> {
     new_expr
         .pop()
-        .ok_or_else(|| DataFusionError::Internal("Failed to pop 
expression".to_string()))
+        .ok_or_else(|| internal_datafusion_err!("Failed to pop expression"))
 }
 
+/// Returns the identifier list for each element in  `exprs`
+///
+/// Returns and array with 1 element for each input expr in `exprs`
+///
+/// Each element is itself the result of [`expr_to_identifier`] for that expr
+/// (e.g. the identifiers for each node in the tree)
 fn to_arrays(
-    expr: &[Expr],
+    exprs: &[Expr],
     expr_stats: &mut ExprStats,
     expr_mask: ExprMask,
 ) -> Result<Vec<IdArray>> {
-    expr.iter()
+    exprs
+        .iter()
         .map(|e| {
             let mut id_array = vec![];
             expr_to_identifier(e, expr_stats, &mut id_array, expr_mask)?;
 
             Ok(id_array)
         })
-        .collect::<Result<Vec<_>>>()
+        .collect()
 }
 
 /// Build the "intermediate" projection plan that evaluates the extracted 
common
@@ -506,10 +704,7 @@ fn build_common_expr_project_plan(
         }
     }
 
-    Ok(LogicalPlan::Projection(Projection::try_new(
-        project_exprs,
-        Arc::new(input),
-    )?))
+    Projection::try_new(project_exprs, 
Arc::new(input)).map(LogicalPlan::Projection)
 }
 
 /// Build the projection plan to eliminate unnecessary columns produced by
@@ -522,10 +717,7 @@ fn build_recover_project_plan(
     input: LogicalPlan,
 ) -> Result<LogicalPlan> {
     let col_exprs = schema.iter().map(Expr::from).collect();
-    Ok(LogicalPlan::Projection(Projection::try_new(
-        col_exprs,
-        Arc::new(input),
-    )?))
+    Projection::try_new(col_exprs, 
Arc::new(input)).map(LogicalPlan::Projection)
 }
 
 fn extract_expressions(
@@ -807,7 +999,7 @@ fn replace_common_expr(
     expr_stats: &ExprStats,
     common_exprs: &mut CommonExprs,
     alias_generator: &AliasGenerator,
-) -> Result<Expr> {
+) -> Result<Transformed<Expr>> {
     expr.rewrite(&mut CommonSubexprRewriter {
         expr_stats,
         id_array,
@@ -816,7 +1008,6 @@ fn replace_common_expr(
         alias_counter: 0,
         alias_generator,
     })
-    .data()
 }
 
 #[cfg(test)]
@@ -839,18 +1030,36 @@ mod test {
 
     use super::*;
 
+    fn assert_non_optimized_plan_eq(
+        expected: &str,
+        plan: LogicalPlan,
+        config: Option<&dyn OptimizerConfig>,
+    ) {
+        assert_eq!(expected, format!("{plan:?}"), "Unexpected starting plan");
+        let optimizer = CommonSubexprEliminate {};
+        let default_config = OptimizerContext::new();
+        let config = config.unwrap_or(&default_config);
+        let optimized_plan = optimizer.rewrite(plan, config).unwrap();
+        assert!(!optimized_plan.transformed, "unexpectedly optimize plan");
+        let optimized_plan = optimized_plan.data;
+        assert_eq!(
+            expected,
+            format!("{optimized_plan:?}"),
+            "Unexpected optimized plan"
+        );
+    }
+
     fn assert_optimized_plan_eq(
         expected: &str,
-        plan: &LogicalPlan,
+        plan: LogicalPlan,
         config: Option<&dyn OptimizerConfig>,
     ) {
         let optimizer = CommonSubexprEliminate {};
         let default_config = OptimizerContext::new();
         let config = config.unwrap_or(&default_config);
-        let optimized_plan = optimizer
-            .try_optimize(plan, config)
-            .unwrap()
-            .expect("failed to optimize plan");
+        let optimized_plan = optimizer.rewrite(plan, config).unwrap();
+        assert!(optimized_plan.transformed, "failed to optimize plan");
+        let optimized_plan = optimized_plan.data;
         let formatted_plan = format!("{optimized_plan:?}");
         assert_eq!(expected, formatted_plan);
     }
@@ -933,7 +1142,7 @@ mod test {
         \n  Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, 
test.a, test.b, test.c\
         \n    TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, None);
+        assert_optimized_plan_eq(expected, plan, None);
 
         Ok(())
     }
@@ -953,7 +1162,7 @@ mod test {
         \n  Projection: test.a + test.b AS __common_expr_1, test.a, test.b, 
test.c\
         \n    TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, None);
+        assert_optimized_plan_eq(expected, plan, None);
 
         Ok(())
     }
@@ -1006,7 +1215,7 @@ mod test {
         \n  Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS __common_expr_1, 
my_agg(test.a) AS __common_expr_2, AVG(test.b) AS col3, AVG(test.c) AS 
__common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\
         \n    TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, None);
+        assert_optimized_plan_eq(expected, plan, None);
 
         // test: trafo after aggregate
         let plan = LogicalPlanBuilder::from(table_scan.clone())
@@ -1025,7 +1234,7 @@ mod test {
         \n  Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS __common_expr_1, 
my_agg(test.a) AS __common_expr_2]]\
         \n    TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, None);
+        assert_optimized_plan_eq(expected, plan, None);
 
         // test: transformation before aggregate
         let plan = LogicalPlanBuilder::from(table_scan.clone())
@@ -1042,7 +1251,7 @@ mod test {
         \n  Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, 
test.c\
         \n    TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, None);
+        assert_optimized_plan_eq(expected, plan, None);
 
         // test: common between agg and group
         let plan = LogicalPlanBuilder::from(table_scan.clone())
@@ -1059,7 +1268,7 @@ mod test {
         \n  Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, 
test.c\
         \n    TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, None);
+        assert_optimized_plan_eq(expected, plan, None);
 
         // test: all mixed
         let plan = LogicalPlanBuilder::from(table_scan)
@@ -1081,7 +1290,7 @@ mod test {
         \n    Projection: UInt32(1) + test.a AS __common_expr_1, test.a, 
test.b, test.c\
         \n      TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, None);
+        assert_optimized_plan_eq(expected, plan, None);
 
         Ok(())
     }
@@ -1108,7 +1317,7 @@ mod test {
         \n    Projection: UInt32(1) + table.test.col.a AS __common_expr_1, 
table.test.col.a\
         \n      TableScan: table.test";
 
-        assert_optimized_plan_eq(expected, &plan, None);
+        assert_optimized_plan_eq(expected, plan, None);
 
         Ok(())
     }
@@ -1128,7 +1337,7 @@ mod test {
         \n  Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, 
test.c\
         \n    TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, None);
+        assert_optimized_plan_eq(expected, plan, None);
 
         Ok(())
     }
@@ -1144,7 +1353,7 @@ mod test {
         let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\
         \n  TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, None);
+        assert_non_optimized_plan_eq(expected, plan, None);
 
         Ok(())
     }
@@ -1162,7 +1371,7 @@ mod test {
         \n  Projection: Int32(1) + test.a, test.a\
         \n    TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, None);
+        assert_non_optimized_plan_eq(expected, plan, None);
         Ok(())
     }
 
@@ -1257,10 +1466,9 @@ mod test {
             .build()
             .unwrap();
         let rule = CommonSubexprEliminate {};
-        let optimized_plan = rule
-            .try_optimize(&plan, &OptimizerContext::new())
-            .unwrap()
-            .unwrap();
+        let optimized_plan = rule.rewrite(plan, 
&OptimizerContext::new()).unwrap();
+        assert!(!optimized_plan.transformed);
+        let optimized_plan = optimized_plan.data;
 
         let schema = optimized_plan.schema();
         let fields_with_datatypes: Vec<_> = schema
@@ -1299,7 +1507,7 @@ mod test {
         \n    Projection: Int32(1) + test.a AS __common_expr_1, test.a, 
test.b, test.c\
         \n      TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, None);
+        assert_optimized_plan_eq(expected, plan, None);
 
         Ok(())
     }
@@ -1365,7 +1573,7 @@ mod test {
         \n    Projection: test.a + test.b AS __common_expr_1, test.c\
         \n      TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, Some(config));
+        assert_optimized_plan_eq(expected, plan, Some(config));
 
         let config = &OptimizerContext::new();
         let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
@@ -1388,7 +1596,7 @@ mod test {
         \n    Projection: test.a + test.b AS __common_expr_2, test.c\
         \n      TableScan: test";
 
-        assert_optimized_plan_eq(expected, &plan, Some(config));
+        assert_optimized_plan_eq(expected, plan, Some(config));
 
         Ok(())
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to