This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new c6b2efccf6 Stop copying LogicalPlan and Exprs in
`CommonSubexprEliminate` (2-3% planning speed improvement) (#10835)
c6b2efccf6 is described below
commit c6b2efccf6238cc87f2414efb28ae3b263ed27af
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Jun 19 10:13:27 2024 -0400
Stop copying LogicalPlan and Exprs in `CommonSubexprEliminate` (2-3%
planning speed improvement) (#10835)
* Stop copying LogicalPlan and Exprs in `CommonSubexprEliminate`
* thread transformed
* Update unary to report transformed correctly
* Preserve through window transforms
* track aggregate
* Avoid re-computing Aggregate schema
* Update datafusion/optimizer/src/common_subexpr_eliminate.rs
* Avoid unecessary setting transform flat
* Cleanup unaliasing
---
datafusion/common/src/tree_node.rs | 5 +
datafusion/expr/src/logical_plan/plan.rs | 64 +--
.../optimizer/src/common_subexpr_eliminate.rs | 594 ++++++++++++++-------
3 files changed, 439 insertions(+), 224 deletions(-)
diff --git a/datafusion/common/src/tree_node.rs
b/datafusion/common/src/tree_node.rs
index d0dd24621d..276a1cc4c5 100644
--- a/datafusion/common/src/tree_node.rs
+++ b/datafusion/common/src/tree_node.rs
@@ -615,6 +615,11 @@ impl<T> Transformed<T> {
}
}
+ /// Create a `Transformed` with `transformed and
[`TreeNodeRecursion::Continue`].
+ pub fn new_transformed(data: T, transformed: bool) -> Self {
+ Self::new(data, transformed, TreeNodeRecursion::Continue)
+ }
+
/// Wrapper for transformed data with [`TreeNodeRecursion::Continue`]
statement.
pub fn yes(data: T) -> Self {
Self::new(data, true, TreeNodeRecursion::Continue)
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index 02378ab3fc..85958223ac 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -870,37 +870,7 @@ impl LogicalPlan {
LogicalPlan::Filter { .. } => {
assert_eq!(1, expr.len());
let predicate = expr.pop().unwrap();
-
- // filter predicates should not contain aliased expressions so
we remove any aliases
- // before this logic was added we would have aliases within
filters such as for
- // benchmark q6:
- //
- // lineitem.l_shipdate >= Date32(\"8766\")
- // AND lineitem.l_shipdate < Date32(\"9131\")
- // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS
lineitem.l_discount >=
- // Decimal128(Some(49999999999999),30,15)
- // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS
lineitem.l_discount <=
- // Decimal128(Some(69999999999999),30,15)
- // AND lineitem.l_quantity < Decimal128(Some(2400),15,2)
-
- let predicate = predicate
- .transform_down(|expr| {
- match expr {
- Expr::Exists { .. }
- | Expr::ScalarSubquery(_)
- | Expr::InSubquery(_) => {
- // subqueries could contain aliases so we
don't recurse into those
- Ok(Transformed::new(expr, false,
TreeNodeRecursion::Jump))
- }
- Expr::Alias(_) => Ok(Transformed::new(
- expr.unalias(),
- true,
- TreeNodeRecursion::Jump,
- )),
- _ => Ok(Transformed::no(expr)),
- }
- })
- .data()?;
+ let predicate = Filter::remove_aliases(predicate)?.data;
Filter::try_new(predicate, Arc::new(inputs.swap_remove(0)))
.map(LogicalPlan::Filter)
@@ -2230,6 +2200,38 @@ impl Filter {
}
false
}
+
+ /// Remove aliases from a predicate for use in a `Filter`
+ ///
+ /// filter predicates should not contain aliased expressions so we remove
+ /// any aliases.
+ ///
+ /// before this logic was added we would have aliases within filters such
as
+ /// for benchmark q6:
+ ///
+ /// ```sql
+ /// lineitem.l_shipdate >= Date32(\"8766\")
+ /// AND lineitem.l_shipdate < Date32(\"9131\")
+ /// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS
lineitem.l_discount >=
+ /// Decimal128(Some(49999999999999),30,15)
+ /// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS
lineitem.l_discount <=
+ /// Decimal128(Some(69999999999999),30,15)
+ /// AND lineitem.l_quantity < Decimal128(Some(2400),15,2)
+ /// ```
+ pub fn remove_aliases(predicate: Expr) -> Result<Transformed<Expr>> {
+ predicate.transform_down(|expr| {
+ match expr {
+ Expr::Exists { .. } | Expr::ScalarSubquery(_) |
Expr::InSubquery(_) => {
+ // subqueries could contain aliases so we don't recurse
into those
+ Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump))
+ }
+ Expr::Alias(Alias { expr, .. }) => {
+ Ok(Transformed::new(*expr, true, TreeNodeRecursion::Jump))
+ }
+ _ => Ok(Transformed::no(expr)),
+ }
+ })
+ }
}
/// Window its input based on a set of window spec and window function (e.g.
SUM or RANK)
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index e150a957bf..7f4093ba11 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -20,16 +20,22 @@
use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use crate::{OptimizerConfig, OptimizerRule};
+use crate::optimizer::ApplyOrder;
+use crate::utils::NamePreserver;
use datafusion_common::alias::AliasGenerator;
use datafusion_common::tree_node::{
- Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
TreeNodeRewriter,
- TreeNodeVisitor,
+ Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
TreeNodeVisitor,
+};
+use datafusion_common::{
+ internal_datafusion_err, internal_err, qualified_name, Column, DFSchema,
Result,
};
-use datafusion_common::{qualified_name, Column, DFSchema, DataFusionError,
Result};
use datafusion_expr::expr::Alias;
-use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection,
Window};
+use datafusion_expr::logical_plan::tree_node::unwrap_arc;
+use datafusion_expr::logical_plan::{
+ Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
+};
use datafusion_expr::{col, Expr, ExprSchemable};
use indexmap::IndexMap;
@@ -123,32 +129,39 @@ impl CommonSubexprEliminate {
/// Returns the rewritten expressions
fn rewrite_exprs_list(
&self,
- exprs_list: &[&[Expr]],
+ exprs_list: Vec<Vec<Expr>>,
arrays_list: &[&[IdArray]],
expr_stats: &ExprStats,
common_exprs: &mut CommonExprs,
alias_generator: &AliasGenerator,
- ) -> Result<Vec<Vec<Expr>>> {
+ ) -> Result<Transformed<Vec<Vec<Expr>>>> {
+ let mut transformed = false;
exprs_list
- .iter()
+ .into_iter()
.zip(arrays_list.iter())
.map(|(exprs, arrays)| {
exprs
- .iter()
- .cloned()
+ .into_iter()
.zip(arrays.iter())
.map(|(expr, id_array)| {
- replace_common_expr(
+ let replaced = replace_common_expr(
expr,
id_array,
expr_stats,
common_exprs,
alias_generator,
- )
+ )?;
+ // remember if this expression was actually replaced
+ transformed |= replaced.transformed;
+ Ok(replaced.data)
})
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()
+ .map(|rewritten_exprs_list| {
+ // propagate back transformed information
+ Transformed::new_transformed(rewritten_exprs_list, transformed)
+ })
}
/// Rewrites the expression in `exprs_list` with common sub-expressions
@@ -161,13 +174,15 @@ impl CommonSubexprEliminate {
/// common sub-expressions that were used
fn rewrite_expr(
&self,
- exprs_list: &[&[Expr]],
+ exprs_list: Vec<Vec<Expr>>,
arrays_list: &[&[IdArray]],
- input: &LogicalPlan,
+ input: LogicalPlan,
expr_stats: &ExprStats,
config: &dyn OptimizerConfig,
- ) -> Result<(Vec<Vec<Expr>>, LogicalPlan)> {
+ ) -> Result<Transformed<(Vec<Vec<Expr>>, LogicalPlan)>> {
+ let mut transformed = false;
let mut common_exprs = CommonExprs::new();
+
let rewrite_exprs = self.rewrite_exprs_list(
exprs_list,
arrays_list,
@@ -175,115 +190,193 @@ impl CommonSubexprEliminate {
&mut common_exprs,
&config.alias_generator(),
)?;
+ transformed |= rewrite_exprs.transformed;
- let mut new_input = self
- .try_optimize(input, config)?
- .unwrap_or_else(|| input.clone());
+ let new_input = self.rewrite(input, config)?;
+ transformed |= new_input.transformed;
+ let mut new_input = new_input.data;
if !common_exprs.is_empty() {
+ assert!(transformed);
new_input = build_common_expr_project_plan(new_input,
common_exprs)?;
}
- Ok((rewrite_exprs, new_input))
+ // return the transformed information
+
+ Ok(Transformed::new_transformed(
+ (rewrite_exprs.data, new_input),
+ transformed,
+ ))
}
- fn try_optimize_window(
+ fn try_optimize_proj(
&self,
- window: &Window,
+ projection: Projection,
config: &dyn OptimizerConfig,
- ) -> Result<LogicalPlan> {
- let mut window_exprs = vec![];
- let mut arrays_per_window = vec![];
- let mut expr_stats = ExprStats::new();
-
- // Get all window expressions inside the consecutive window operators.
- // Consecutive window expressions may refer to same complex expression.
- // If same complex expression is referred more than once by subsequent
`WindowAggr`s,
- // we can cache complex expression by evaluating it with a projection
before the
- // first WindowAggr.
- // This enables us to cache complex expression "c3+c4" for following
plan:
- // WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN
UNBOUNDED PRECEDING AND CURRENT ROW]]
- // --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN
UNBOUNDED PRECEDING AND CURRENT ROW]]
- // where, it is referred once by each `WindowAggr` (total of 2) in the
plan.
- let mut plan = LogicalPlan::Window(window.clone());
- while let LogicalPlan::Window(window) = plan {
- let Window {
- input, window_expr, ..
- } = window;
- plan = input.as_ref().clone();
+ ) -> Result<Transformed<LogicalPlan>> {
+ let Projection {
+ expr,
+ input,
+ schema,
+ ..
+ } = projection;
+ let input = unwrap_arc(input);
+ self.try_unary_plan(expr, input, config)?
+ .map_data(|(new_expr, new_input)| {
+ Projection::try_new_with_schema(new_expr, Arc::new(new_input),
schema)
+ .map(LogicalPlan::Projection)
+ })
+ }
+ fn try_optimize_sort(
+ &self,
+ sort: Sort,
+ config: &dyn OptimizerConfig,
+ ) -> Result<Transformed<LogicalPlan>> {
+ let Sort { expr, input, fetch } = sort;
+ let input = unwrap_arc(input);
+ let new_sort = self.try_unary_plan(expr, input, config)?.update_data(
+ |(new_expr, new_input)| {
+ LogicalPlan::Sort(Sort {
+ expr: new_expr,
+ input: Arc::new(new_input),
+ fetch,
+ })
+ },
+ );
+ Ok(new_sort)
+ }
- let arrays = to_arrays(&window_expr, &mut expr_stats,
ExprMask::Normal)?;
+ fn try_optimize_filter(
+ &self,
+ filter: Filter,
+ config: &dyn OptimizerConfig,
+ ) -> Result<Transformed<LogicalPlan>> {
+ let Filter {
+ predicate, input, ..
+ } = filter;
+ let input = unwrap_arc(input);
+ let expr = vec![predicate];
+ self.try_unary_plan(expr, input, config)?
+ .transform_data(|(mut new_expr, new_input)| {
+ assert_eq!(new_expr.len(), 1); // passed in vec![predicate]
+ let new_predicate = new_expr.pop().unwrap();
+ Ok(Filter::remove_aliases(new_predicate)?
+ .update_data(|new_predicate| (new_predicate, new_input)))
+ })?
+ .map_data(|(new_predicate, new_input)| {
+ Filter::try_new(new_predicate, Arc::new(new_input))
+ .map(LogicalPlan::Filter)
+ })
+ }
- window_exprs.push(window_expr);
- arrays_per_window.push(arrays);
- }
+ fn try_optimize_window(
+ &self,
+ window: Window,
+ config: &dyn OptimizerConfig,
+ ) -> Result<Transformed<LogicalPlan>> {
+ // collect all window expressions from any number of LogicalPlanWindow
+ let ConsecutiveWindowExprs {
+ window_exprs,
+ arrays_per_window,
+ expr_stats,
+ plan,
+ } = ConsecutiveWindowExprs::try_new(window)?;
- let mut window_exprs = window_exprs
- .iter()
- .map(|expr| expr.as_slice())
- .collect::<Vec<_>>();
let arrays_per_window = arrays_per_window
.iter()
.map(|arrays| arrays.as_slice())
.collect::<Vec<_>>();
+ // save the original names
+ let name_preserver = NamePreserver::new(&plan);
+ let mut saved_names = window_exprs
+ .iter()
+ .map(|exprs| {
+ exprs
+ .iter()
+ .map(|expr| name_preserver.save(expr))
+ .collect::<Result<Vec<_>>>()
+ })
+ .collect::<Result<Vec<_>>>()?;
+
assert_eq!(window_exprs.len(), arrays_per_window.len());
- let (mut new_expr, new_input) = self.rewrite_expr(
- &window_exprs,
+ let num_window_exprs = window_exprs.len();
+ let rewritten_window_exprs = self.rewrite_expr(
+ window_exprs,
&arrays_per_window,
- &plan,
+ plan,
&expr_stats,
config,
)?;
- assert_eq!(window_exprs.len(), new_expr.len());
+ let transformed = rewritten_window_exprs.transformed;
+
+ let (mut new_expr, new_input) = rewritten_window_exprs.data;
- // Construct consecutive window operator, with their corresponding new
window expressions.
- plan = new_input;
- while let Some(new_window_expr) = new_expr.pop() {
- // Since `new_expr` and `window_exprs` length are same. We can
safely `.unwrap` here.
- let orig_window_expr = window_exprs.pop().unwrap();
- assert_eq!(new_window_expr.len(), orig_window_expr.len());
+ let mut plan = new_input;
- // Rename new re-written window expressions with original name (by
giving alias)
- // Otherwise we may receive schema error, in subsequent operators.
+ // Construct consecutive window operator, with their corresponding new
+ // window expressions.
+ //
+ // Note this iterates over, `new_expr` and `saved_names` which are the
+ // same length, in reverse order
+ assert_eq!(num_window_exprs, new_expr.len());
+ assert_eq!(num_window_exprs, saved_names.len());
+ while let (Some(new_window_expr), Some(saved_names)) =
+ (new_expr.pop(), saved_names.pop())
+ {
+ assert_eq!(new_window_expr.len(), saved_names.len());
+
+ // Rename re-written window expressions with original name, to
+ // preserve the output schema
let new_window_expr = new_window_expr
.into_iter()
- .zip(orig_window_expr.iter())
- .map(|(new_window_expr, window_expr)| {
- let original_name = window_expr.name_for_alias()?;
- new_window_expr.alias_if_changed(original_name)
- })
+ .zip(saved_names.into_iter())
+ .map(|(new_window_expr, saved_name)|
saved_name.restore(new_window_expr))
.collect::<Result<Vec<_>>>()?;
plan = LogicalPlan::Window(Window::try_new(new_window_expr,
Arc::new(plan))?);
}
- Ok(plan)
+ Ok(Transformed::new_transformed(plan, transformed))
}
fn try_optimize_aggregate(
&self,
- aggregate: &Aggregate,
+ aggregate: Aggregate,
config: &dyn OptimizerConfig,
- ) -> Result<LogicalPlan> {
+ ) -> Result<Transformed<LogicalPlan>> {
let Aggregate {
group_expr,
aggr_expr,
input,
+ schema: orig_schema,
..
} = aggregate;
let mut expr_stats = ExprStats::new();
+ // track transformed information
+ let mut transformed = false;
+
// rewrite inputs
- let group_arrays = to_arrays(group_expr, &mut expr_stats,
ExprMask::Normal)?;
- let aggr_arrays = to_arrays(aggr_expr, &mut expr_stats,
ExprMask::Normal)?;
+ let group_arrays = to_arrays(&group_expr, &mut expr_stats,
ExprMask::Normal)?;
+ let aggr_arrays = to_arrays(&aggr_expr, &mut expr_stats,
ExprMask::Normal)?;
+
+ let name_perserver = NamePreserver::new_for_projection();
+ let saved_names = aggr_expr
+ .iter()
+ .map(|expr| name_perserver.save(expr))
+ .collect::<Result<Vec<_>>>()?;
- let (mut new_expr, new_input) = self.rewrite_expr(
- &[group_expr, aggr_expr],
+ // rewrite both group exprs and aggr_expr
+ let rewritten = self.rewrite_expr(
+ vec![group_expr, aggr_expr],
&[&group_arrays, &aggr_arrays],
- input,
+ unwrap_arc(input),
&expr_stats,
config,
)?;
+ transformed |= rewritten.transformed;
+ let (mut new_expr, new_input) = rewritten.data;
+
// note the reversed pop order.
let new_aggr_expr = pop_expr(&mut new_expr)?;
let new_group_expr = pop_expr(&mut new_expr)?;
@@ -296,108 +389,208 @@ impl CommonSubexprEliminate {
&mut expr_stats,
ExprMask::NormalAndAggregates,
)?;
- let mut common_exprs = CommonExprs::new();
- let mut rewritten = self.rewrite_exprs_list(
- &[&new_aggr_expr],
+ let mut common_exprs = IndexMap::new();
+ let mut rewritten_exprs = self.rewrite_exprs_list(
+ vec![new_aggr_expr.clone()],
&[&aggr_arrays],
&expr_stats,
&mut common_exprs,
&config.alias_generator(),
)?;
- let rewritten = pop_expr(&mut rewritten)?;
+ transformed |= rewritten_exprs.transformed;
+ let rewritten = pop_expr(&mut rewritten_exprs.data)?;
if common_exprs.is_empty() {
// Alias aggregation expressions if they have changed
let new_aggr_expr = new_aggr_expr
- .iter()
- .zip(aggr_expr.iter())
- .map(|(new_expr, old_expr)| {
- new_expr.clone().alias_if_changed(old_expr.display_name()?)
- })
+ .into_iter()
+ .zip(saved_names.into_iter())
+ .map(|(new_expr, saved_name)| saved_name.restore(new_expr))
.collect::<Result<Vec<Expr>>>()?;
- // Since group_epxr changes, schema changes also. Use try_new
method.
- Aggregate::try_new(Arc::new(new_input), new_group_expr,
new_aggr_expr)
- .map(LogicalPlan::Aggregate)
- } else {
- let mut agg_exprs = common_exprs
- .into_values()
- .map(|(expr, expr_alias)| expr.alias(expr_alias))
- .collect::<Vec<_>>();
-
- let mut proj_exprs = vec![];
- for expr in &new_group_expr {
- extract_expressions(expr, &new_input_schema, &mut proj_exprs)?
- }
- for (expr_rewritten, expr_orig) in
rewritten.into_iter().zip(new_aggr_expr) {
- if expr_rewritten == expr_orig {
- if let Expr::Alias(Alias { expr, name, .. }) =
expr_rewritten {
- agg_exprs.push(expr.alias(&name));
- proj_exprs.push(Expr::Column(Column::from_name(name)));
- } else {
- let expr_alias =
config.alias_generator().next(CSE_PREFIX);
- let (qualifier, field) =
- expr_rewritten.to_field(&new_input_schema)?;
- let out_name = qualified_name(qualifier.as_ref(),
field.name());
-
- agg_exprs.push(expr_rewritten.alias(&expr_alias));
- proj_exprs.push(
-
Expr::Column(Column::from_name(expr_alias)).alias(out_name),
- );
- }
+ // Since group_expr may have changed, schema may also. Use try_new
method.
+ let new_agg = if transformed {
+ Aggregate::try_new(Arc::new(new_input), new_group_expr,
new_aggr_expr)?
+ } else {
+ Aggregate::try_new_with_schema(
+ Arc::new(new_input),
+ new_group_expr,
+ new_aggr_expr,
+ orig_schema,
+ )?
+ };
+ let new_agg = LogicalPlan::Aggregate(new_agg);
+ return Ok(Transformed::new_transformed(new_agg, transformed));
+ }
+ let mut agg_exprs = common_exprs
+ .into_values()
+ .map(|(expr, expr_alias)| expr.alias(expr_alias))
+ .collect::<Vec<_>>();
+
+ let mut proj_exprs = vec![];
+ for expr in &new_group_expr {
+ extract_expressions(expr, &new_input_schema, &mut proj_exprs)?
+ }
+ for (expr_rewritten, expr_orig) in
rewritten.into_iter().zip(new_aggr_expr) {
+ if expr_rewritten == expr_orig {
+ if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten {
+ agg_exprs.push(expr.alias(&name));
+ proj_exprs.push(Expr::Column(Column::from_name(name)));
} else {
- proj_exprs.push(expr_rewritten);
+ let expr_alias = config.alias_generator().next(CSE_PREFIX);
+ let (qualifier, field) =
+ expr_rewritten.to_field(&new_input_schema)?;
+ let out_name = qualified_name(qualifier.as_ref(),
field.name());
+
+ agg_exprs.push(expr_rewritten.alias(&expr_alias));
+ proj_exprs.push(
+
Expr::Column(Column::from_name(expr_alias)).alias(out_name),
+ );
}
+ } else {
+ proj_exprs.push(expr_rewritten);
}
+ }
- let agg = LogicalPlan::Aggregate(Aggregate::try_new(
- Arc::new(new_input),
- new_group_expr,
- agg_exprs,
- )?);
+ let agg = LogicalPlan::Aggregate(Aggregate::try_new(
+ Arc::new(new_input),
+ new_group_expr,
+ agg_exprs,
+ )?);
- Ok(LogicalPlan::Projection(Projection::try_new(
- proj_exprs,
- Arc::new(agg),
- )?))
- }
+ Projection::try_new(proj_exprs, Arc::new(agg))
+ .map(LogicalPlan::Projection)
+ .map(Transformed::yes)
}
+ /// Rewrites the expr list and input to remove common subexpressions
+ ///
+ /// # Parameters
+ ///
+ /// * `exprs`: List of expressions in the node
+ /// * `input`: input plan (that produces the columns referred to in
`exprs`)
+ ///
+ /// # Return value
+ ///
+ /// Returns `(rewritten_exprs, new_input)`. `new_input` is either:
+ ///
+ /// 1. The original `input` of no common subexpressions were extracted
+ /// 2. A newly added projection on top of the original input
+ /// that computes the common subexpressions
fn try_unary_plan(
&self,
- plan: &LogicalPlan,
+ expr: Vec<Expr>,
+ input: LogicalPlan,
config: &dyn OptimizerConfig,
- ) -> Result<LogicalPlan> {
- let expr = plan.expressions();
- let inputs = plan.inputs();
- let input = inputs[0];
+ ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
let mut expr_stats = ExprStats::new();
-
- // Visit expr list and build expr identifier to occuring count map
(`expr_stats`).
let arrays = to_arrays(&expr, &mut expr_stats, ExprMask::Normal)?;
- let (mut new_expr, new_input) =
- self.rewrite_expr(&[&expr], &[&arrays], input, &expr_stats,
config)?;
+ self.rewrite_expr(vec![expr], &[&arrays], input, &expr_stats, config)?
+ .map_data(|(mut new_expr, new_input)| {
+ assert_eq!(new_expr.len(), 1);
+ Ok((new_expr.pop().unwrap(), new_input))
+ })
+ }
+}
- plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input])
+/// Get all window expressions inside the consecutive window operators.
+///
+/// Returns the window expressions, and the input to the deepest child
+/// LogicalPlan.
+///
+/// For example, if the input widnow looks like
+///
+/// ```text
+/// LogicalPlan::Window(exprs=[a, b, c])
+/// LogicalPlan::Window(exprs=[d])
+/// InputPlan
+/// ```
+///
+/// Returns:
+/// * `window_exprs`: `[a, b, c, d]`
+/// * InputPlan
+///
+/// Consecutive window expressions may refer to same complex expression.
+///
+/// If same complex expression is referred more than once by subsequent
+/// `WindowAggr`s, we can cache complex expression by evaluating it with a
+/// projection before the first WindowAggr.
+///
+/// This enables us to cache complex expression "c3+c4" for following plan:
+///
+/// ```text
+/// WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN
UNBOUNDED PRECEDING AND CURRENT ROW]]
+/// --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN
UNBOUNDED PRECEDING AND CURRENT ROW]]
+/// ```
+///
+/// where, it is referred once by each `WindowAggr` (total of 2) in the plan.
+struct ConsecutiveWindowExprs {
+ window_exprs: Vec<Vec<Expr>>,
+ /// result of calling `to_arrays` on each set of window exprs
+ arrays_per_window: Vec<Vec<Vec<(usize, String)>>>,
+ expr_stats: ExprStats,
+ /// input plan to the window
+ plan: LogicalPlan,
+}
+
+impl ConsecutiveWindowExprs {
+ fn try_new(window: Window) -> Result<Self> {
+ let mut window_exprs = vec![];
+ let mut arrays_per_window = vec![];
+ let mut expr_stats = ExprStats::new();
+
+ let mut plan = LogicalPlan::Window(window);
+ while let LogicalPlan::Window(Window {
+ input, window_expr, ..
+ }) = plan
+ {
+ plan = unwrap_arc(input);
+
+ let arrays = to_arrays(&window_expr, &mut expr_stats,
ExprMask::Normal)?;
+
+ window_exprs.push(window_expr);
+ arrays_per_window.push(arrays);
+ }
+
+ Ok(Self {
+ window_exprs,
+ arrays_per_window,
+ expr_stats,
+ plan,
+ })
}
}
impl OptimizerRule for CommonSubexprEliminate {
fn try_optimize(
&self,
- plan: &LogicalPlan,
- config: &dyn OptimizerConfig,
+ _plan: &LogicalPlan,
+ _config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
+ internal_err!("Should have called CommonSubexprEliminate::rewrite")
+ }
+
+ fn supports_rewrite(&self) -> bool {
+ true
+ }
+
+ fn apply_order(&self) -> Option<ApplyOrder> {
+ Some(ApplyOrder::TopDown)
+ }
+
+ fn rewrite(
+ &self,
+ plan: LogicalPlan,
+ config: &dyn OptimizerConfig,
+ ) -> Result<Transformed<LogicalPlan>> {
+ let original_schema = Arc::clone(plan.schema());
+
let optimized_plan = match plan {
- LogicalPlan::Projection(_)
- | LogicalPlan::Sort(_)
- | LogicalPlan::Filter(_) => Some(self.try_unary_plan(plan,
config)?),
- LogicalPlan::Window(window) => {
- Some(self.try_optimize_window(window, config)?)
- }
- LogicalPlan::Aggregate(aggregate) => {
- Some(self.try_optimize_aggregate(aggregate, config)?)
- }
+ LogicalPlan::Projection(proj) => self.try_optimize_proj(proj,
config)?,
+ LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
+ LogicalPlan::Filter(filter) => self.try_optimize_filter(filter,
config)?,
+ LogicalPlan::Window(window) => self.try_optimize_window(window,
config)?,
+ LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg,
config)?,
LogicalPlan::Join(_)
| LogicalPlan::CrossJoin(_)
| LogicalPlan::Repartition(_)
@@ -420,21 +613,19 @@ impl OptimizerRule for CommonSubexprEliminate {
| LogicalPlan::Unnest(_)
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Prepare(_) => {
- // apply the optimization to all inputs of the plan
- utils::optimize_children(self, plan, config)?
+ // ApplyOrder::TopDown handles recursion
+ Transformed::no(plan)
}
};
- let original_schema = plan.schema();
- match optimized_plan {
- Some(optimized_plan) if optimized_plan.schema() != original_schema
=> {
- // add an additional projection if the output schema changed.
- Ok(Some(build_recover_project_plan(
- original_schema,
- optimized_plan,
- )?))
- }
- plan => Ok(plan),
+ // If we rewrote the plan, ensure the schema stays the same
+ if optimized_plan.transformed && optimized_plan.data.schema() !=
&original_schema
+ {
+ optimized_plan.map_data(|optimized_plan| {
+ build_recover_project_plan(&original_schema, optimized_plan)
+ })
+ } else {
+ Ok(optimized_plan)
}
}
@@ -459,22 +650,29 @@ impl CommonSubexprEliminate {
fn pop_expr(new_expr: &mut Vec<Vec<Expr>>) -> Result<Vec<Expr>> {
new_expr
.pop()
- .ok_or_else(|| DataFusionError::Internal("Failed to pop
expression".to_string()))
+ .ok_or_else(|| internal_datafusion_err!("Failed to pop expression"))
}
+/// Returns the identifier list for each element in `exprs`
+///
+/// Returns and array with 1 element for each input expr in `exprs`
+///
+/// Each element is itself the result of [`expr_to_identifier`] for that expr
+/// (e.g. the identifiers for each node in the tree)
fn to_arrays(
- expr: &[Expr],
+ exprs: &[Expr],
expr_stats: &mut ExprStats,
expr_mask: ExprMask,
) -> Result<Vec<IdArray>> {
- expr.iter()
+ exprs
+ .iter()
.map(|e| {
let mut id_array = vec![];
expr_to_identifier(e, expr_stats, &mut id_array, expr_mask)?;
Ok(id_array)
})
- .collect::<Result<Vec<_>>>()
+ .collect()
}
/// Build the "intermediate" projection plan that evaluates the extracted
common
@@ -506,10 +704,7 @@ fn build_common_expr_project_plan(
}
}
- Ok(LogicalPlan::Projection(Projection::try_new(
- project_exprs,
- Arc::new(input),
- )?))
+ Projection::try_new(project_exprs,
Arc::new(input)).map(LogicalPlan::Projection)
}
/// Build the projection plan to eliminate unnecessary columns produced by
@@ -522,10 +717,7 @@ fn build_recover_project_plan(
input: LogicalPlan,
) -> Result<LogicalPlan> {
let col_exprs = schema.iter().map(Expr::from).collect();
- Ok(LogicalPlan::Projection(Projection::try_new(
- col_exprs,
- Arc::new(input),
- )?))
+ Projection::try_new(col_exprs,
Arc::new(input)).map(LogicalPlan::Projection)
}
fn extract_expressions(
@@ -807,7 +999,7 @@ fn replace_common_expr(
expr_stats: &ExprStats,
common_exprs: &mut CommonExprs,
alias_generator: &AliasGenerator,
-) -> Result<Expr> {
+) -> Result<Transformed<Expr>> {
expr.rewrite(&mut CommonSubexprRewriter {
expr_stats,
id_array,
@@ -816,7 +1008,6 @@ fn replace_common_expr(
alias_counter: 0,
alias_generator,
})
- .data()
}
#[cfg(test)]
@@ -839,18 +1030,36 @@ mod test {
use super::*;
+ fn assert_non_optimized_plan_eq(
+ expected: &str,
+ plan: LogicalPlan,
+ config: Option<&dyn OptimizerConfig>,
+ ) {
+ assert_eq!(expected, format!("{plan:?}"), "Unexpected starting plan");
+ let optimizer = CommonSubexprEliminate {};
+ let default_config = OptimizerContext::new();
+ let config = config.unwrap_or(&default_config);
+ let optimized_plan = optimizer.rewrite(plan, config).unwrap();
+ assert!(!optimized_plan.transformed, "unexpectedly optimize plan");
+ let optimized_plan = optimized_plan.data;
+ assert_eq!(
+ expected,
+ format!("{optimized_plan:?}"),
+ "Unexpected optimized plan"
+ );
+ }
+
fn assert_optimized_plan_eq(
expected: &str,
- plan: &LogicalPlan,
+ plan: LogicalPlan,
config: Option<&dyn OptimizerConfig>,
) {
let optimizer = CommonSubexprEliminate {};
let default_config = OptimizerContext::new();
let config = config.unwrap_or(&default_config);
- let optimized_plan = optimizer
- .try_optimize(plan, config)
- .unwrap()
- .expect("failed to optimize plan");
+ let optimized_plan = optimizer.rewrite(plan, config).unwrap();
+ assert!(optimized_plan.transformed, "failed to optimize plan");
+ let optimized_plan = optimized_plan.data;
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(expected, formatted_plan);
}
@@ -933,7 +1142,7 @@ mod test {
\n Projection: test.a * (Int32(1) - test.b) AS __common_expr_1,
test.a, test.b, test.c\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, None);
+ assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
@@ -953,7 +1162,7 @@ mod test {
\n Projection: test.a + test.b AS __common_expr_1, test.a, test.b,
test.c\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, None);
+ assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
@@ -1006,7 +1215,7 @@ mod test {
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS __common_expr_1,
my_agg(test.a) AS __common_expr_2, AVG(test.b) AS col3, AVG(test.c) AS
__common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, None);
+ assert_optimized_plan_eq(expected, plan, None);
// test: trafo after aggregate
let plan = LogicalPlanBuilder::from(table_scan.clone())
@@ -1025,7 +1234,7 @@ mod test {
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS __common_expr_1,
my_agg(test.a) AS __common_expr_2]]\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, None);
+ assert_optimized_plan_eq(expected, plan, None);
// test: transformation before aggregate
let plan = LogicalPlanBuilder::from(table_scan.clone())
@@ -1042,7 +1251,7 @@ mod test {
\n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b,
test.c\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, None);
+ assert_optimized_plan_eq(expected, plan, None);
// test: common between agg and group
let plan = LogicalPlanBuilder::from(table_scan.clone())
@@ -1059,7 +1268,7 @@ mod test {
\n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b,
test.c\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, None);
+ assert_optimized_plan_eq(expected, plan, None);
// test: all mixed
let plan = LogicalPlanBuilder::from(table_scan)
@@ -1081,7 +1290,7 @@ mod test {
\n Projection: UInt32(1) + test.a AS __common_expr_1, test.a,
test.b, test.c\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, None);
+ assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
@@ -1108,7 +1317,7 @@ mod test {
\n Projection: UInt32(1) + table.test.col.a AS __common_expr_1,
table.test.col.a\
\n TableScan: table.test";
- assert_optimized_plan_eq(expected, &plan, None);
+ assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
@@ -1128,7 +1337,7 @@ mod test {
\n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b,
test.c\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, None);
+ assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
@@ -1144,7 +1353,7 @@ mod test {
let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, None);
+ assert_non_optimized_plan_eq(expected, plan, None);
Ok(())
}
@@ -1162,7 +1371,7 @@ mod test {
\n Projection: Int32(1) + test.a, test.a\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, None);
+ assert_non_optimized_plan_eq(expected, plan, None);
Ok(())
}
@@ -1257,10 +1466,9 @@ mod test {
.build()
.unwrap();
let rule = CommonSubexprEliminate {};
- let optimized_plan = rule
- .try_optimize(&plan, &OptimizerContext::new())
- .unwrap()
- .unwrap();
+ let optimized_plan = rule.rewrite(plan,
&OptimizerContext::new()).unwrap();
+ assert!(!optimized_plan.transformed);
+ let optimized_plan = optimized_plan.data;
let schema = optimized_plan.schema();
let fields_with_datatypes: Vec<_> = schema
@@ -1299,7 +1507,7 @@ mod test {
\n Projection: Int32(1) + test.a AS __common_expr_1, test.a,
test.b, test.c\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, None);
+ assert_optimized_plan_eq(expected, plan, None);
Ok(())
}
@@ -1365,7 +1573,7 @@ mod test {
\n Projection: test.a + test.b AS __common_expr_1, test.c\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, Some(config));
+ assert_optimized_plan_eq(expected, plan, Some(config));
let config = &OptimizerContext::new();
let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
@@ -1388,7 +1596,7 @@ mod test {
\n Projection: test.a + test.b AS __common_expr_2, test.c\
\n TableScan: test";
- assert_optimized_plan_eq(expected, &plan, Some(config));
+ assert_optimized_plan_eq(expected, plan, Some(config));
Ok(())
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]