This is an automated email from the ASF dual-hosted git repository.
berkay pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new aef232b1ac Add `Container` trait and to simplify `Expr` and
`LogicalPlan` apply and map methods (#13467)
aef232b1ac is described below
commit aef232b1ac559fc1597a64d9f5f75c2f29f4c286
Author: Peter Toth <[email protected]>
AuthorDate: Wed Nov 20 08:29:54 2024 +0100
Add `Container` trait and to simplify `Expr` and `LogicalPlan` apply and
map methods (#13467)
* Add `Container` trait and its blanket implementations, remove
`map_until_stop_and_collect` macro, simplify apply and map logic with
`Container`s where possible
* fix clippy
* rename `Container` to `TreeNodeContainer`
* add docs to containers
* clarify when we need a temporary `TreeNodeRefContainer`
* code and docs cleanup
---
datafusion/common/src/tree_node.rs | 363 +++++++++++++++++---
datafusion/expr/src/expr.rs | 36 +-
datafusion/expr/src/logical_plan/ddl.rs | 50 ++-
datafusion/expr/src/logical_plan/plan.rs | 20 +-
datafusion/expr/src/logical_plan/statement.rs | 51 +--
datafusion/expr/src/logical_plan/tree_node.rs | 347 ++++++++-----------
datafusion/expr/src/tree_node.rs | 372 +++++++--------------
.../optimizer/src/optimize_projections/mod.rs | 4 +-
datafusion/sql/src/unparser/rewrite.rs | 24 +-
9 files changed, 687 insertions(+), 580 deletions(-)
diff --git a/datafusion/common/src/tree_node.rs
b/datafusion/common/src/tree_node.rs
index c8ec7f1833..0c153583e3 100644
--- a/datafusion/common/src/tree_node.rs
+++ b/datafusion/common/src/tree_node.rs
@@ -17,11 +17,12 @@
//! [`TreeNode`] for visiting and rewriting expression and plan trees
+use crate::Result;
use recursive::recursive;
+use std::collections::HashMap;
+use std::hash::Hash;
use std::sync::Arc;
-use crate::Result;
-
/// These macros are used to determine continuation during transforming
traversals.
macro_rules! handle_transform_recursion {
($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{
@@ -769,6 +770,297 @@ impl<T> Transformed<T> {
}
}
+/// [`TreeNodeContainer`] contains elements that a function can be applied on
or mapped.
+/// The elements of the container are siblings so the continuation rules are
similar to
+/// [`TreeNodeRecursion::visit_sibling`] / [`Transformed::transform_sibling`].
+pub trait TreeNodeContainer<'a, T: 'a>: Sized {
+ /// Applies `f` to all elements of the container.
+ /// This method is usually called from [`TreeNode::apply_children`]
implementations as
+ /// a node is actually a container of the node's children.
+ fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ f: F,
+ ) -> Result<TreeNodeRecursion>;
+
+ /// Maps all elements of the container with `f`.
+ /// This method is usually called from [`TreeNode::map_children`]
implementations as
+ /// a node is actually a container of the node's children.
+ fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+ self,
+ f: F,
+ ) -> Result<Transformed<Self>>;
+}
+
+impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for
Box<C> {
+ fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ f: F,
+ ) -> Result<TreeNodeRecursion> {
+ self.as_ref().apply_elements(f)
+ }
+
+ fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+ self,
+ f: F,
+ ) -> Result<Transformed<Self>> {
+ (*self).map_elements(f)?.map_data(|c| Ok(Self::new(c)))
+ }
+}
+
+impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Clone> TreeNodeContainer<'a, T>
for Arc<C> {
+ fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ f: F,
+ ) -> Result<TreeNodeRecursion> {
+ self.as_ref().apply_elements(f)
+ }
+
+ fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+ self,
+ f: F,
+ ) -> Result<Transformed<Self>> {
+ Arc::unwrap_or_clone(self)
+ .map_elements(f)?
+ .map_data(|c| Ok(Arc::new(c)))
+ }
+}
+
+impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for
Option<C> {
+ fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ f: F,
+ ) -> Result<TreeNodeRecursion> {
+ match self {
+ Some(t) => t.apply_elements(f),
+ None => Ok(TreeNodeRecursion::Continue),
+ }
+ }
+
+ fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+ self,
+ f: F,
+ ) -> Result<Transformed<Self>> {
+ self.map_or(Ok(Transformed::no(None)), |c| {
+ c.map_elements(f)?.map_data(|c| Ok(Some(c)))
+ })
+ }
+}
+
+impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for
Vec<C> {
+ fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ mut f: F,
+ ) -> Result<TreeNodeRecursion> {
+ let mut tnr = TreeNodeRecursion::Continue;
+ for c in self {
+ tnr = c.apply_elements(&mut f)?;
+ match tnr {
+ TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
+ TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
+ }
+ }
+ Ok(tnr)
+ }
+
+ fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+ self,
+ mut f: F,
+ ) -> Result<Transformed<Self>> {
+ let mut tnr = TreeNodeRecursion::Continue;
+ let mut transformed = false;
+ self.into_iter()
+ .map(|c| match tnr {
+ TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
+ c.map_elements(&mut f).map(|result| {
+ tnr = result.tnr;
+ transformed |= result.transformed;
+ result.data
+ })
+ }
+ TreeNodeRecursion::Stop => Ok(c),
+ })
+ .collect::<Result<Vec<_>>>()
+ .map(|data| Transformed::new(data, transformed, tnr))
+ }
+}
+
+impl<'a, T: 'a, K: Eq + Hash, C: TreeNodeContainer<'a, T>>
TreeNodeContainer<'a, T>
+ for HashMap<K, C>
+{
+ fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ mut f: F,
+ ) -> Result<TreeNodeRecursion> {
+ let mut tnr = TreeNodeRecursion::Continue;
+ for c in self.values() {
+ tnr = c.apply_elements(&mut f)?;
+ match tnr {
+ TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
+ TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
+ }
+ }
+ Ok(tnr)
+ }
+
+ fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+ self,
+ mut f: F,
+ ) -> Result<Transformed<Self>> {
+ let mut tnr = TreeNodeRecursion::Continue;
+ let mut transformed = false;
+ self.into_iter()
+ .map(|(k, c)| match tnr {
+ TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
+ c.map_elements(&mut f).map(|result| {
+ tnr = result.tnr;
+ transformed |= result.transformed;
+ (k, result.data)
+ })
+ }
+ TreeNodeRecursion::Stop => Ok((k, c)),
+ })
+ .collect::<Result<HashMap<_, _>>>()
+ .map(|data| Transformed::new(data, transformed, tnr))
+ }
+}
+
+impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>>
+ TreeNodeContainer<'a, T> for (C0, C1)
+{
+ fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ mut f: F,
+ ) -> Result<TreeNodeRecursion> {
+ self.0
+ .apply_elements(&mut f)?
+ .visit_sibling(|| self.1.apply_elements(&mut f))
+ }
+
+ fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+ self,
+ mut f: F,
+ ) -> Result<Transformed<Self>> {
+ self.0
+ .map_elements(&mut f)?
+ .map_data(|new_c0| Ok((new_c0, self.1)))?
+ .transform_sibling(|(new_c0, c1)| {
+ c1.map_elements(&mut f)?
+ .map_data(|new_c1| Ok((new_c0, new_c1)))
+ })
+ }
+}
+
+impl<
+ 'a,
+ T: 'a,
+ C0: TreeNodeContainer<'a, T>,
+ C1: TreeNodeContainer<'a, T>,
+ C2: TreeNodeContainer<'a, T>,
+ > TreeNodeContainer<'a, T> for (C0, C1, C2)
+{
+ fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ mut f: F,
+ ) -> Result<TreeNodeRecursion> {
+ self.0
+ .apply_elements(&mut f)?
+ .visit_sibling(|| self.1.apply_elements(&mut f))?
+ .visit_sibling(|| self.2.apply_elements(&mut f))
+ }
+
+ fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+ self,
+ mut f: F,
+ ) -> Result<Transformed<Self>> {
+ self.0
+ .map_elements(&mut f)?
+ .map_data(|new_c0| Ok((new_c0, self.1, self.2)))?
+ .transform_sibling(|(new_c0, c1, c2)| {
+ c1.map_elements(&mut f)?
+ .map_data(|new_c1| Ok((new_c0, new_c1, c2)))
+ })?
+ .transform_sibling(|(new_c0, new_c1, c2)| {
+ c2.map_elements(&mut f)?
+ .map_data(|new_c2| Ok((new_c0, new_c1, new_c2)))
+ })
+ }
+}
+
+/// [`TreeNodeRefContainer`] contains references to elements that a function
can be
+/// applied on. The elements of the container are siblings so the continuation
rules are
+/// similar to [`TreeNodeRecursion::visit_sibling`].
+///
+/// This container is similar to [`TreeNodeContainer`], but the lifetime of
the reference
+/// elements (`T`) are not derived from the container's lifetime.
+/// A typical usage of this container is in `Expr::apply_children` when we
need to
+/// construct a temporary container to be able to call `apply_ref_elements` on
a
+/// collection of tree node references. But in that case the container's
temporary
+/// lifetime is different to the lifetime of tree nodes that we put into it.
+/// Please find an example usecase in `Expr::apply_children` with the
`Expr::Case` case.
+///
+/// Most of the cases we don't need to create a temporary container with
+/// `TreeNodeRefContainer`, but we can just call
`TreeNodeContainer::apply_elements`.
+/// Please find an example usecase in `Expr::apply_children` with the
`Expr::GroupingSet`
+/// case.
+pub trait TreeNodeRefContainer<'a, T: 'a>: Sized {
+ /// Applies `f` to all elements of the container.
+ /// This method is usually called from [`TreeNode::apply_children`]
implementations as
+ /// a node is actually a container of the node's children.
+ fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+ &self,
+ f: F,
+ ) -> Result<TreeNodeRecursion>;
+}
+
+impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeRefContainer<'a, T> for
Vec<&'a C> {
+ fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+ &self,
+ mut f: F,
+ ) -> Result<TreeNodeRecursion> {
+ let mut tnr = TreeNodeRecursion::Continue;
+ for c in self {
+ tnr = c.apply_elements(&mut f)?;
+ match tnr {
+ TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
+ TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
+ }
+ }
+ Ok(tnr)
+ }
+}
+
+impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>>
+ TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1)
+{
+ fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+ &self,
+ mut f: F,
+ ) -> Result<TreeNodeRecursion> {
+ self.0
+ .apply_elements(&mut f)?
+ .visit_sibling(|| self.1.apply_elements(&mut f))
+ }
+}
+
+impl<
+ 'a,
+ T: 'a,
+ C0: TreeNodeContainer<'a, T>,
+ C1: TreeNodeContainer<'a, T>,
+ C2: TreeNodeContainer<'a, T>,
+ > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2)
+{
+ fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+ &self,
+ mut f: F,
+ ) -> Result<TreeNodeRecursion> {
+ self.0
+ .apply_elements(&mut f)?
+ .visit_sibling(|| self.1.apply_elements(&mut f))?
+ .visit_sibling(|| self.2.apply_elements(&mut f))
+ }
+}
+
/// Transformation helper to process a sequence of iterable tree nodes that
are siblings.
pub trait TreeNodeIterator: Iterator {
/// Apples `f` to each item in this iterator
@@ -843,50 +1135,6 @@ impl<I: Iterator> TreeNodeIterator for I {
}
}
-/// Transformation helper to process a heterogeneous sequence of tree node
containing
-/// expressions.
-///
-/// This macro is very similar to
[TreeNodeIterator::map_until_stop_and_collect] to
-/// process nodes that are siblings, but it accepts an initial transformation
(`F0`) and
-/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its
-/// transformation (`F`).
-///
-/// The macro builds up a tuple that contains `Transformed.data` result of
`F0` as the
-/// first element and further elements from the sequence of pairs. An element
from a pair
-/// is either the value of `EXPR` or the `Transformed.data` result of `F`,
depending on
-/// the `Transformed.tnr` result of previous `F`s (`F0` initially).
-///
-/// # Returns
-/// Error if any of the transformations returns an error
-///
-/// Ok(Transformed<(data0, ..., dataN)>) such that:
-/// 1. `transformed` is true if any of the transformations had transformed true
-/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from
`F0` and
-/// `data1` ... `dataN` are from either `EXPR` or the `Transformed.data`
of `F`
-/// 3. `tnr` from `F0` or the last invocation of `F`
-#[macro_export]
-macro_rules! map_until_stop_and_collect {
- ($F0:expr, $($EXPR:expr, $F:expr),*) => {{
- $F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| {
- let all_datas = (
- data0,
- $(
- if tnr == TreeNodeRecursion::Continue || tnr ==
TreeNodeRecursion::Jump {
- $F.map(|result| {
- tnr = result.tnr;
- transformed |= result.transformed;
- result.data
- })?
- } else {
- $EXPR
- },
- )*
- );
- Ok(Transformed::new(all_datas, transformed, tnr))
- })
- }}
-}
-
/// Transformation helper to access [`Transformed`] fields in a [`Result`]
easily.
///
/// # Example
@@ -1021,7 +1269,7 @@ pub(crate) mod tests {
use std::fmt::Display;
use crate::tree_node::{
- Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion,
TreeNodeRewriter,
+ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
TreeNodeRewriter,
TreeNodeVisitor,
};
use crate::Result;
@@ -1054,7 +1302,7 @@ pub(crate) mod tests {
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
- self.children.iter().apply_until_stop(f)
+ self.children.apply_elements(f)
}
fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
@@ -1063,8 +1311,7 @@ pub(crate) mod tests {
) -> Result<Transformed<Self>> {
Ok(self
.children
- .into_iter()
- .map_until_stop_and_collect(f)?
+ .map_elements(f)?
.update_data(|new_children| Self {
children: new_children,
..self
@@ -1072,6 +1319,22 @@ pub(crate) mod tests {
}
}
+ impl<'a, T: 'a> TreeNodeContainer<'a, Self> for TestTreeNode<T> {
+ fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ mut f: F,
+ ) -> Result<TreeNodeRecursion> {
+ f(self)
+ }
+
+ fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
+ self,
+ mut f: F,
+ ) -> Result<Transformed<Self>> {
+ f(self)
+ }
+ }
+
// J
// |
// I
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 83d35c3d25..8490c08a70 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -32,7 +32,7 @@ use crate::{udaf, ExprSchemable, Operator, Signature,
WindowFrame, WindowUDF};
use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::cse::HashNode;
use datafusion_common::tree_node::{
- Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
+ Transformed, TransformedResult, TreeNode, TreeNodeContainer,
TreeNodeRecursion,
};
use datafusion_common::{
plan_err, Column, DFSchema, HashMap, Result, ScalarValue, TableReference,
@@ -351,6 +351,22 @@ impl<'a> From<(Option<&'a TableReference>, &'a FieldRef)>
for Expr {
}
}
+impl<'a> TreeNodeContainer<'a, Self> for Expr {
+ fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ mut f: F,
+ ) -> Result<TreeNodeRecursion> {
+ f(self)
+ }
+
+ fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
+ self,
+ mut f: F,
+ ) -> Result<Transformed<Self>> {
+ f(self)
+ }
+}
+
/// UNNEST expression.
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct Unnest {
@@ -653,6 +669,24 @@ impl Display for Sort {
}
}
+impl<'a> TreeNodeContainer<'a, Expr> for Sort {
+ fn apply_elements<F: FnMut(&'a Expr) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ f: F,
+ ) -> Result<TreeNodeRecursion> {
+ self.expr.apply_elements(f)
+ }
+
+ fn map_elements<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
+ self,
+ f: F,
+ ) -> Result<Transformed<Self>> {
+ self.expr
+ .map_elements(f)?
+ .map_data(|expr| Ok(Self { expr, ..self }))
+ }
+}
+
/// Aggregate function
///
/// See also [`ExprFunctionExt`] to set these fields on `Expr`
diff --git a/datafusion/expr/src/logical_plan/ddl.rs
b/datafusion/expr/src/logical_plan/ddl.rs
index 93e8b5fd04..8c64a01798 100644
--- a/datafusion/expr/src/logical_plan/ddl.rs
+++ b/datafusion/expr/src/logical_plan/ddl.rs
@@ -26,7 +26,10 @@ use std::{
use crate::expr::Sort;
use arrow::datatypes::DataType;
-use datafusion_common::{Constraints, DFSchemaRef, SchemaReference,
TableReference};
+use datafusion_common::tree_node::{Transformed, TreeNodeContainer,
TreeNodeRecursion};
+use datafusion_common::{
+ Constraints, DFSchemaRef, Result, SchemaReference, TableReference,
+};
use sqlparser::ast::Ident;
/// Various types of DDL (CREATE / DROP) catalog manipulation
@@ -487,6 +490,28 @@ pub struct OperateFunctionArg {
pub data_type: DataType,
pub default_expr: Option<Expr>,
}
+
+impl<'a> TreeNodeContainer<'a, Expr> for OperateFunctionArg {
+ fn apply_elements<F: FnMut(&'a Expr) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ f: F,
+ ) -> Result<TreeNodeRecursion> {
+ self.default_expr.apply_elements(f)
+ }
+
+ fn map_elements<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
+ self,
+ f: F,
+ ) -> Result<Transformed<Self>> {
+ self.default_expr.map_elements(f)?.map_data(|default_expr| {
+ Ok(Self {
+ default_expr,
+ ..self
+ })
+ })
+ }
+}
+
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct CreateFunctionBody {
/// LANGUAGE lang_name
@@ -497,6 +522,29 @@ pub struct CreateFunctionBody {
pub function_body: Option<Expr>,
}
+impl<'a> TreeNodeContainer<'a, Expr> for CreateFunctionBody {
+ fn apply_elements<F: FnMut(&'a Expr) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ f: F,
+ ) -> Result<TreeNodeRecursion> {
+ self.function_body.apply_elements(f)
+ }
+
+ fn map_elements<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
+ self,
+ f: F,
+ ) -> Result<Transformed<Self>> {
+ self.function_body
+ .map_elements(f)?
+ .map_data(|function_body| {
+ Ok(Self {
+ function_body,
+ ..self
+ })
+ })
+ }
+}
+
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct DropFunction {
pub name: String,
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index 6ee99b22c7..e9f4f1f809 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -45,7 +45,9 @@ use crate::{
};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
-use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
+use datafusion_common::tree_node::{
+ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
+};
use datafusion_common::{
aggregate_functional_dependencies, internal_err, plan_err, Column,
Constraints,
DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence,
@@ -287,6 +289,22 @@ impl Default for LogicalPlan {
}
}
+impl<'a> TreeNodeContainer<'a, Self> for LogicalPlan {
+ fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
+ &'a self,
+ mut f: F,
+ ) -> Result<TreeNodeRecursion> {
+ f(self)
+ }
+
+ fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
+ self,
+ mut f: F,
+ ) -> Result<Transformed<Self>> {
+ f(self)
+ }
+}
+
impl LogicalPlan {
/// Get a reference to the logical plan's schema
pub fn schema(&self) -> &DFSchemaRef {
diff --git a/datafusion/expr/src/logical_plan/statement.rs
b/datafusion/expr/src/logical_plan/statement.rs
index 05e2b1af14..26df379f5e 100644
--- a/datafusion/expr/src/logical_plan/statement.rs
+++ b/datafusion/expr/src/logical_plan/statement.rs
@@ -16,12 +16,10 @@
// under the License.
use arrow::datatypes::DataType;
-use datafusion_common::tree_node::{Transformed, TreeNodeIterator};
-use datafusion_common::{DFSchema, DFSchemaRef, Result};
+use datafusion_common::{DFSchema, DFSchemaRef};
use std::fmt::{self, Display};
use std::sync::{Arc, OnceLock};
-use super::tree_node::rewrite_arc;
use crate::{expr_vec_fmt, Expr, LogicalPlan};
/// Statements have a unchanging empty schema.
@@ -80,53 +78,6 @@ impl Statement {
}
}
- /// Rewrites input LogicalPlans in the current `Statement` using `f`.
- pub(super) fn map_inputs<
- F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
- >(
- self,
- f: F,
- ) -> Result<Transformed<Self>> {
- match self {
- Statement::Prepare(Prepare {
- input,
- name,
- data_types,
- }) => Ok(rewrite_arc(input, f)?.update_data(|input| {
- Statement::Prepare(Prepare {
- input,
- name,
- data_types,
- })
- })),
- _ => Ok(Transformed::no(self)),
- }
- }
-
- /// Returns a iterator over all expressions in the current `Statement`.
- pub(super) fn expression_iter(&self) -> impl Iterator<Item = &Expr> {
- match self {
- Statement::Execute(Execute { parameters, .. }) =>
parameters.iter(),
- _ => [].iter(),
- }
- }
-
- /// Rewrites all expressions in the current `Statement` using `f`.
- pub(super) fn map_expressions<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
- self,
- f: F,
- ) -> Result<Transformed<Self>> {
- match self {
- Statement::Execute(Execute { name, parameters }) => Ok(parameters
- .into_iter()
- .map_until_stop_and_collect(f)?
- .update_data(|parameters| {
- Statement::Execute(Execute { parameters, name })
- })),
- _ => Ok(Transformed::no(self)),
- }
- }
-
/// Return a `format`able structure with the a human readable
/// description of this LogicalPlan node per node, not including
/// children.
diff --git a/datafusion/expr/src/logical_plan/tree_node.rs
b/datafusion/expr/src/logical_plan/tree_node.rs
index e7dfe87919..6850c30f4f 100644
--- a/datafusion/expr/src/logical_plan/tree_node.rs
+++ b/datafusion/expr/src/logical_plan/tree_node.rs
@@ -36,32 +36,30 @@
//! (Re)creation APIs (these require substantial cloning and thus are slow):
//! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different
expressions
//! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions
+
use crate::{
dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView,
DdlStatement,
- Distinct, DistinctOn, DmlStatement, Explain, Expr, Extension, Filter,
Join, Limit,
- LogicalPlan, Partitioning, Projection, RecursiveQuery, Repartition, Sort,
Subquery,
- SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values,
Window,
+ Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension,
Filter, Join,
+ Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery,
Repartition,
+ Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest,
+ UserDefinedLogicalNode, Values, Window,
};
+use datafusion_common::tree_node::TreeNodeRefContainer;
use recursive::recursive;
-use std::ops::Deref;
-use std::sync::Arc;
use crate::expr::{Exists, InSubquery};
-use crate::tree_node::{transform_sort_option_vec, transform_sort_vec};
use datafusion_common::tree_node::{
- Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion,
TreeNodeRewriter,
- TreeNodeVisitor,
-};
-use datafusion_common::{
- internal_err, map_until_stop_and_collect, DataFusionError, Result,
+ Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator,
TreeNodeRecursion,
+ TreeNodeRewriter, TreeNodeVisitor,
};
+use datafusion_common::{internal_err, Result};
impl TreeNode for LogicalPlan {
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
- self.inputs().into_iter().apply_until_stop(f)
+ self.inputs().apply_ref_elements(f)
}
/// Applies `f` to each child (input) of this plan node, rewriting them
*in place.*
@@ -74,14 +72,14 @@ impl TreeNode for LogicalPlan {
/// [`Expr::Exists`]: crate::Expr::Exists
fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
- mut f: F,
+ f: F,
) -> Result<Transformed<Self>> {
Ok(match self {
LogicalPlan::Projection(Projection {
expr,
input,
schema,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
LogicalPlan::Projection(Projection {
expr,
input,
@@ -92,7 +90,7 @@ impl TreeNode for LogicalPlan {
predicate,
input,
having,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
LogicalPlan::Filter(Filter {
predicate,
input,
@@ -102,7 +100,7 @@ impl TreeNode for LogicalPlan {
LogicalPlan::Repartition(Repartition {
input,
partitioning_scheme,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
LogicalPlan::Repartition(Repartition {
input,
partitioning_scheme,
@@ -112,7 +110,7 @@ impl TreeNode for LogicalPlan {
input,
window_expr,
schema,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
LogicalPlan::Window(Window {
input,
window_expr,
@@ -124,7 +122,7 @@ impl TreeNode for LogicalPlan {
group_expr,
aggr_expr,
schema,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
LogicalPlan::Aggregate(Aggregate {
input,
group_expr,
@@ -132,7 +130,8 @@ impl TreeNode for LogicalPlan {
schema,
})
}),
- LogicalPlan::Sort(Sort { expr, input, fetch }) =>
rewrite_arc(input, f)?
+ LogicalPlan::Sort(Sort { expr, input, fetch }) => input
+ .map_elements(f)?
.update_data(|input| LogicalPlan::Sort(Sort { expr, input,
fetch })),
LogicalPlan::Join(Join {
left,
@@ -143,12 +142,7 @@ impl TreeNode for LogicalPlan {
join_constraint,
schema,
null_equals_null,
- }) => map_until_stop_and_collect!(
- rewrite_arc(left, &mut f),
- right,
- rewrite_arc(right, &mut f)
- )?
- .update_data(|(left, right)| {
+ }) => (left, right).map_elements(f)?.update_data(|(left, right)| {
LogicalPlan::Join(Join {
left,
right,
@@ -160,12 +154,13 @@ impl TreeNode for LogicalPlan {
null_equals_null,
})
}),
- LogicalPlan::Limit(Limit { skip, fetch, input }) =>
rewrite_arc(input, f)?
+ LogicalPlan::Limit(Limit { skip, fetch, input }) => input
+ .map_elements(f)?
.update_data(|input| LogicalPlan::Limit(Limit { skip, fetch,
input })),
LogicalPlan::Subquery(Subquery {
subquery,
outer_ref_columns,
- }) => rewrite_arc(subquery, f)?.update_data(|subquery| {
+ }) => subquery.map_elements(f)?.update_data(|subquery| {
LogicalPlan::Subquery(Subquery {
subquery,
outer_ref_columns,
@@ -175,7 +170,7 @@ impl TreeNode for LogicalPlan {
input,
alias,
schema,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
LogicalPlan::SubqueryAlias(SubqueryAlias {
input,
alias,
@@ -184,17 +179,18 @@ impl TreeNode for LogicalPlan {
}),
LogicalPlan::Extension(extension) =>
rewrite_extension_inputs(extension, f)?
.update_data(LogicalPlan::Extension),
- LogicalPlan::Union(Union { inputs, schema }) =>
rewrite_arcs(inputs, f)?
+ LogicalPlan::Union(Union { inputs, schema }) => inputs
+ .map_elements(f)?
.update_data(|inputs| LogicalPlan::Union(Union { inputs,
schema })),
LogicalPlan::Distinct(distinct) => match distinct {
- Distinct::All(input) => rewrite_arc(input,
f)?.update_data(Distinct::All),
+ Distinct::All(input) =>
input.map_elements(f)?.update_data(Distinct::All),
Distinct::On(DistinctOn {
on_expr,
select_expr,
sort_expr,
input,
schema,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
Distinct::On(DistinctOn {
on_expr,
select_expr,
@@ -211,7 +207,7 @@ impl TreeNode for LogicalPlan {
stringified_plans,
schema,
logical_optimization_succeeded,
- }) => rewrite_arc(plan, f)?.update_data(|plan| {
+ }) => plan.map_elements(f)?.update_data(|plan| {
LogicalPlan::Explain(Explain {
verbose,
plan,
@@ -224,7 +220,7 @@ impl TreeNode for LogicalPlan {
verbose,
input,
schema,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
LogicalPlan::Analyze(Analyze {
verbose,
input,
@@ -237,7 +233,7 @@ impl TreeNode for LogicalPlan {
op,
input,
output_schema,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
LogicalPlan::Dml(DmlStatement {
table_name,
table_schema,
@@ -252,7 +248,7 @@ impl TreeNode for LogicalPlan {
partition_by,
file_type,
options,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
LogicalPlan::Copy(CopyTo {
input,
output_url,
@@ -271,7 +267,7 @@ impl TreeNode for LogicalPlan {
or_replace,
column_defaults,
temporary,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
DdlStatement::CreateMemoryTable(CreateMemoryTable {
name,
constraints,
@@ -288,7 +284,7 @@ impl TreeNode for LogicalPlan {
or_replace,
definition,
temporary,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
DdlStatement::CreateView(CreateView {
name,
input,
@@ -318,7 +314,7 @@ impl TreeNode for LogicalPlan {
dependency_indices,
schema,
options,
- }) => rewrite_arc(input, f)?.update_data(|input| {
+ }) => input.map_elements(f)?.update_data(|input| {
LogicalPlan::Unnest(Unnest {
input,
exec_columns: input_columns,
@@ -334,22 +330,24 @@ impl TreeNode for LogicalPlan {
static_term,
recursive_term,
is_distinct,
- }) => map_until_stop_and_collect!(
- rewrite_arc(static_term, &mut f),
- recursive_term,
- rewrite_arc(recursive_term, &mut f)
- )?
- .update_data(|(static_term, recursive_term)| {
- LogicalPlan::RecursiveQuery(RecursiveQuery {
- name,
- static_term,
- recursive_term,
- is_distinct,
- })
- }),
- LogicalPlan::Statement(stmt) => {
- stmt.map_inputs(f)?.update_data(LogicalPlan::Statement)
+ }) => (static_term, recursive_term).map_elements(f)?.update_data(
+ |(static_term, recursive_term)| {
+ LogicalPlan::RecursiveQuery(RecursiveQuery {
+ name,
+ static_term,
+ recursive_term,
+ is_distinct,
+ })
+ },
+ ),
+ LogicalPlan::Statement(stmt) => match stmt {
+ Statement::Prepare(p) => p
+ .input
+ .map_elements(f)?
+ .update_data(|input| Statement::Prepare(Prepare { input,
..p })),
+ _ => Transformed::no(stmt),
}
+ .update_data(LogicalPlan::Statement),
// plans without inputs
LogicalPlan::TableScan { .. }
| LogicalPlan::EmptyRelation { .. }
@@ -359,24 +357,6 @@ impl TreeNode for LogicalPlan {
}
}
-/// Applies `f` to rewrite a `Arc<LogicalPlan>` without copying, if possible
-pub(super) fn rewrite_arc<F: FnMut(LogicalPlan) ->
Result<Transformed<LogicalPlan>>>(
- plan: Arc<LogicalPlan>,
- mut f: F,
-) -> Result<Transformed<Arc<LogicalPlan>>> {
- f(Arc::unwrap_or_clone(plan))?.map_data(|new_plan| Ok(Arc::new(new_plan)))
-}
-
-/// rewrite a `Vec` of `Arc<LogicalPlan>` without copying, if possible
-fn rewrite_arcs<F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>>(
- input_plans: Vec<Arc<LogicalPlan>>,
- mut f: F,
-) -> Result<Transformed<Vec<Arc<LogicalPlan>>>> {
- input_plans
- .into_iter()
- .map_until_stop_and_collect(|plan| rewrite_arc(plan, &mut f))
-}
-
/// Rewrites all inputs for an Extension node "in place"
/// (it currently has to copy values because there are no APIs for in place
modification)
///
@@ -423,54 +403,40 @@ impl LogicalPlan {
mut f: F,
) -> Result<TreeNodeRecursion> {
match self {
- LogicalPlan::Projection(Projection { expr, .. }) => {
- expr.iter().apply_until_stop(f)
- }
- LogicalPlan::Values(Values { values, .. }) => values
- .iter()
- .apply_until_stop(|value| value.iter().apply_until_stop(&mut
f)),
+ LogicalPlan::Projection(Projection { expr, .. }) =>
expr.apply_elements(f),
+ LogicalPlan::Values(Values { values, .. }) =>
values.apply_elements(f),
LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate),
LogicalPlan::Repartition(Repartition {
partitioning_scheme,
..
}) => match partitioning_scheme {
Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr)
=> {
- expr.iter().apply_until_stop(f)
+ expr.apply_elements(f)
}
Partitioning::RoundRobinBatch(_) =>
Ok(TreeNodeRecursion::Continue),
},
LogicalPlan::Window(Window { window_expr, .. }) => {
- window_expr.iter().apply_until_stop(f)
+ window_expr.apply_elements(f)
}
LogicalPlan::Aggregate(Aggregate {
group_expr,
aggr_expr,
..
- }) => group_expr
- .iter()
- .chain(aggr_expr.iter())
- .apply_until_stop(f),
+ }) => (group_expr, aggr_expr).apply_ref_elements(f),
// There are two part of expression for join, equijoin(on) and
non-equijoin(filter).
// 1. the first part is `on.len()` equijoin expressions, and the
struct of each expr is `left-on = right-on`.
// 2. the second part is non-equijoin(filter).
LogicalPlan::Join(Join { on, filter, .. }) => {
- on.iter()
- // TODO: why we need to create an `Expr::eq`? Cloning
`Expr` is costly...
- // it not ideal to create an expr here to analyze them,
but could cache it on the Join itself
- .map(|(l, r)| Expr::eq(l.clone(), r.clone()))
- .apply_until_stop(|e| f(&e))?
- .visit_sibling(|| filter.iter().apply_until_stop(f))
- }
- LogicalPlan::Sort(Sort { expr, .. }) => {
- expr.iter().apply_until_stop(|sort| f(&sort.expr))
+ (on, filter).apply_ref_elements(f)
}
+ LogicalPlan::Sort(Sort { expr, .. }) => expr.apply_elements(f),
LogicalPlan::Extension(extension) => {
// would be nice to avoid this copy -- maybe can
// update extension to just observer Exprs
- extension.node.expressions().iter().apply_until_stop(f)
+ extension.node.expressions().apply_elements(f)
}
LogicalPlan::TableScan(TableScan { filters, .. }) => {
- filters.iter().apply_until_stop(f)
+ filters.apply_elements(f)
}
LogicalPlan::Unnest(unnest) => {
let columns = unnest.exec_columns.clone();
@@ -479,24 +445,23 @@ impl LogicalPlan {
.iter()
.map(|c| Expr::Column(c.clone()))
.collect::<Vec<_>>();
- exprs.iter().apply_until_stop(f)
+ exprs.apply_elements(f)
}
LogicalPlan::Distinct(Distinct::On(DistinctOn {
on_expr,
select_expr,
sort_expr,
..
- })) => on_expr
- .iter()
- .chain(select_expr.iter())
- .chain(sort_expr.iter().flatten().map(|sort| &sort.expr))
- .apply_until_stop(f),
- LogicalPlan::Limit(Limit { skip, fetch, .. }) => skip
- .iter()
- .chain(fetch.iter())
- .map(|e| e.deref())
- .apply_until_stop(f),
- LogicalPlan::Statement(stmt) =>
stmt.expression_iter().apply_until_stop(f),
+ })) => (on_expr, select_expr, sort_expr).apply_ref_elements(f),
+ LogicalPlan::Limit(Limit { skip, fetch, .. }) => {
+ (skip, fetch).apply_ref_elements(f)
+ }
+ LogicalPlan::Statement(stmt) => match stmt {
+ Statement::Execute(Execute { parameters, .. }) => {
+ parameters.apply_elements(f)
+ }
+ _ => Ok(TreeNodeRecursion::Continue),
+ },
// plans without expressions
LogicalPlan::EmptyRelation(_)
| LogicalPlan::RecursiveQuery(_)
@@ -529,21 +494,15 @@ impl LogicalPlan {
expr,
input,
schema,
- }) => expr
- .into_iter()
- .map_until_stop_and_collect(f)?
- .update_data(|expr| {
- LogicalPlan::Projection(Projection {
- expr,
- input,
- schema,
- })
- }),
+ }) => expr.map_elements(f)?.update_data(|expr| {
+ LogicalPlan::Projection(Projection {
+ expr,
+ input,
+ schema,
+ })
+ }),
LogicalPlan::Values(Values { schema, values }) => values
- .into_iter()
- .map_until_stop_and_collect(|value| {
- value.into_iter().map_until_stop_and_collect(&mut f)
- })?
+ .map_elements(f)?
.update_data(|values| LogicalPlan::Values(Values { schema,
values })),
LogicalPlan::Filter(Filter {
predicate,
@@ -561,12 +520,10 @@ impl LogicalPlan {
partitioning_scheme,
}) => match partitioning_scheme {
Partitioning::Hash(expr, usize) => expr
- .into_iter()
- .map_until_stop_and_collect(f)?
+ .map_elements(f)?
.update_data(|expr| Partitioning::Hash(expr, usize)),
Partitioning::DistributeBy(expr) => expr
- .into_iter()
- .map_until_stop_and_collect(f)?
+ .map_elements(f)?
.update_data(Partitioning::DistributeBy),
Partitioning::RoundRobinBatch(_) =>
Transformed::no(partitioning_scheme),
}
@@ -580,34 +537,28 @@ impl LogicalPlan {
input,
window_expr,
schema,
- }) => window_expr
- .into_iter()
- .map_until_stop_and_collect(f)?
- .update_data(|window_expr| {
- LogicalPlan::Window(Window {
- input,
- window_expr,
- schema,
- })
- }),
+ }) => window_expr.map_elements(f)?.update_data(|window_expr| {
+ LogicalPlan::Window(Window {
+ input,
+ window_expr,
+ schema,
+ })
+ }),
LogicalPlan::Aggregate(Aggregate {
input,
group_expr,
aggr_expr,
schema,
- }) => map_until_stop_and_collect!(
- group_expr.into_iter().map_until_stop_and_collect(&mut f),
- aggr_expr,
- aggr_expr.into_iter().map_until_stop_and_collect(&mut f)
- )?
- .update_data(|(group_expr, aggr_expr)| {
- LogicalPlan::Aggregate(Aggregate {
- input,
- group_expr,
- aggr_expr,
- schema,
- })
- }),
+ }) => (group_expr, aggr_expr).map_elements(f)?.update_data(
+ |(group_expr, aggr_expr)| {
+ LogicalPlan::Aggregate(Aggregate {
+ input,
+ group_expr,
+ aggr_expr,
+ schema,
+ })
+ },
+ ),
// There are two part of expression for join, equijoin(on) and
non-equijoin(filter).
// 1. the first part is `on.len()` equijoin expressions, and the
struct of each expr is `left-on = right-on`.
@@ -621,16 +572,7 @@ impl LogicalPlan {
join_constraint,
schema,
null_equals_null,
- }) => map_until_stop_and_collect!(
- on.into_iter().map_until_stop_and_collect(
- |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1))
- ),
- filter,
- filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)),
|e| {
- Ok(f(e)?.update_data(Some))
- })
- )?
- .update_data(|(on, filter)| {
+ }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| {
LogicalPlan::Join(Join {
left,
right,
@@ -642,17 +584,13 @@ impl LogicalPlan {
null_equals_null,
})
}),
- LogicalPlan::Sort(Sort { expr, input, fetch }) => {
- transform_sort_vec(expr, &mut f)?
- .update_data(|expr| LogicalPlan::Sort(Sort { expr, input,
fetch }))
- }
+ LogicalPlan::Sort(Sort { expr, input, fetch }) => expr
+ .map_elements(f)?
+ .update_data(|expr| LogicalPlan::Sort(Sort { expr, input,
fetch })),
LogicalPlan::Extension(Extension { node }) => {
// would be nice to avoid this copy -- maybe can
// update extension to just observer Exprs
- let exprs = node
- .expressions()
- .into_iter()
- .map_until_stop_and_collect(f)?;
+ let exprs = node.expressions().map_elements(f)?;
let plan = LogicalPlan::Extension(Extension {
node: UserDefinedLogicalNode::with_exprs_and_inputs(
node.as_ref(),
@@ -669,64 +607,47 @@ impl LogicalPlan {
projected_schema,
filters,
fetch,
- }) => filters
- .into_iter()
- .map_until_stop_and_collect(f)?
- .update_data(|filters| {
- LogicalPlan::TableScan(TableScan {
- table_name,
- source,
- projection,
- projected_schema,
- filters,
- fetch,
- })
- }),
+ }) => filters.map_elements(f)?.update_data(|filters| {
+ LogicalPlan::TableScan(TableScan {
+ table_name,
+ source,
+ projection,
+ projected_schema,
+ filters,
+ fetch,
+ })
+ }),
LogicalPlan::Distinct(Distinct::On(DistinctOn {
on_expr,
select_expr,
sort_expr,
input,
schema,
- })) => map_until_stop_and_collect!(
- on_expr.into_iter().map_until_stop_and_collect(&mut f),
- select_expr,
- select_expr.into_iter().map_until_stop_and_collect(&mut f),
- sort_expr,
- transform_sort_option_vec(sort_expr, &mut f)
- )?
- .update_data(|(on_expr, select_expr, sort_expr)| {
- LogicalPlan::Distinct(Distinct::On(DistinctOn {
- on_expr,
- select_expr,
- sort_expr,
- input,
- schema,
- }))
- }),
- LogicalPlan::Limit(Limit { skip, fetch, input }) => {
- let skip = skip.map(|e| *e);
- let fetch = fetch.map(|e| *e);
- map_until_stop_and_collect!(
- skip.map_or(Ok::<_,
DataFusionError>(Transformed::no(None)), |e| {
- Ok(f(e)?.update_data(Some))
- }),
- fetch,
- fetch.map_or(Ok::<_,
DataFusionError>(Transformed::no(None)), |e| {
- Ok(f(e)?.update_data(Some))
- })
- )?
- .update_data(|(skip, fetch)| {
- LogicalPlan::Limit(Limit {
- skip: skip.map(Box::new),
- fetch: fetch.map(Box::new),
+ })) => (on_expr, select_expr, sort_expr)
+ .map_elements(f)?
+ .update_data(|(on_expr, select_expr, sort_expr)| {
+ LogicalPlan::Distinct(Distinct::On(DistinctOn {
+ on_expr,
+ select_expr,
+ sort_expr,
input,
- })
+ schema,
+ }))
+ }),
+ LogicalPlan::Limit(Limit { skip, fetch, input }) => {
+ (skip, fetch).map_elements(f)?.update_data(|(skip, fetch)| {
+ LogicalPlan::Limit(Limit { skip, fetch, input })
})
}
- LogicalPlan::Statement(stmt) => {
- stmt.map_expressions(f)?.update_data(LogicalPlan::Statement)
+ LogicalPlan::Statement(stmt) => match stmt {
+ Statement::Execute(e) => {
+ e.parameters.map_elements(f)?.update_data(|parameters| {
+ Statement::Execute(Execute { parameters, ..e })
+ })
+ }
+ _ => Transformed::no(stmt),
}
+ .update_data(LogicalPlan::Statement),
// plans without expressions
LogicalPlan::EmptyRelation(_)
| LogicalPlan::Unnest(_)
diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs
index e964091aae..eacace5ed0 100644
--- a/datafusion/expr/src/tree_node.rs
+++ b/datafusion/expr/src/tree_node.rs
@@ -19,14 +19,14 @@
use crate::expr::{
AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet,
InList,
- InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest,
WindowFunction,
+ InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest,
WindowFunction,
};
use crate::{Expr, ExprFunctionExt};
use datafusion_common::tree_node::{
- Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion,
+ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
TreeNodeRefContainer,
};
-use datafusion_common::{map_until_stop_and_collect, Result};
+use datafusion_common::Result;
/// Implementation of the [`TreeNode`] trait
///
@@ -42,9 +42,9 @@ impl TreeNode for Expr {
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
- let children = match self {
- Expr::Alias(Alias{expr,..})
- | Expr::Unnest(Unnest{expr})
+ match self {
+ Expr::Alias(Alias { expr, .. })
+ | Expr::Unnest(Unnest { expr })
| Expr::Not(expr)
| Expr::IsNotNull(expr)
| Expr::IsTrue(expr)
@@ -57,78 +57,50 @@ impl TreeNode for Expr {
| Expr::Negative(expr)
| Expr::Cast(Cast { expr, .. })
| Expr::TryCast(TryCast { expr, .. })
- | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()],
+ | Expr::InSubquery(InSubquery { expr, .. }) =>
expr.apply_elements(f),
Expr::GroupingSet(GroupingSet::Rollup(exprs))
- | Expr::GroupingSet(GroupingSet::Cube(exprs)) =>
exprs.iter().collect(),
- Expr::ScalarFunction (ScalarFunction{ args, .. } ) => {
- args.iter().collect()
+ | Expr::GroupingSet(GroupingSet::Cube(exprs)) =>
exprs.apply_elements(f),
+ Expr::ScalarFunction(ScalarFunction { args, .. }) => {
+ args.apply_elements(f)
}
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
- lists_of_exprs.iter().flatten().collect()
+ lists_of_exprs.apply_elements(f)
}
Expr::Column(_)
// Treat OuterReferenceColumn as a leaf expression
| Expr::OuterReferenceColumn(_, _)
| Expr::ScalarVariable(_, _)
| Expr::Literal(_)
- | Expr::Exists {..}
+ | Expr::Exists { .. }
| Expr::ScalarSubquery(_)
- | Expr::Wildcard {..}
- | Expr::Placeholder (_) => vec![],
+ | Expr::Wildcard { .. }
+ | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue),
Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
- vec![left.as_ref(), right.as_ref()]
+ (left, right).apply_ref_elements(f)
}
Expr::Like(Like { expr, pattern, .. })
| Expr::SimilarTo(Like { expr, pattern, .. }) => {
- vec![expr.as_ref(), pattern.as_ref()]
+ (expr, pattern).apply_ref_elements(f)
}
Expr::Between(Between {
- expr, low, high, ..
- }) => vec![expr.as_ref(), low.as_ref(), high.as_ref()],
- Expr::Case(case) => {
- let mut expr_vec = vec![];
- if let Some(expr) = case.expr.as_ref() {
- expr_vec.push(expr.as_ref());
- };
- for (when, then) in case.when_then_expr.iter() {
- expr_vec.push(when.as_ref());
- expr_vec.push(then.as_ref());
- }
- if let Some(else_expr) = case.else_expr.as_ref() {
- expr_vec.push(else_expr.as_ref());
- }
- expr_vec
- }
- Expr::AggregateFunction(AggregateFunction { args, filter,
order_by, .. })
- => {
- let mut expr_vec = args.iter().collect::<Vec<_>>();
- if let Some(f) = filter {
- expr_vec.push(f.as_ref());
- }
- if let Some(order_by) = order_by {
- expr_vec.extend(order_by.iter().map(|sort| &sort.expr));
- }
- expr_vec
- }
+ expr, low, high, ..
+ }) => (expr, low, high).apply_ref_elements(f),
+ Expr::Case(Case { expr, when_then_expr, else_expr }) =>
+ (expr, when_then_expr, else_expr).apply_ref_elements(f),
+ Expr::AggregateFunction(AggregateFunction { args, filter,
order_by, .. }) =>
+ (args, filter, order_by).apply_ref_elements(f),
Expr::WindowFunction(WindowFunction {
- args,
- partition_by,
- order_by,
- ..
- }) => {
- let mut expr_vec = args.iter().collect::<Vec<_>>();
- expr_vec.extend(partition_by);
- expr_vec.extend(order_by.iter().map(|sort| &sort.expr));
- expr_vec
+ args,
+ partition_by,
+ order_by,
+ ..
+ }) => {
+ (args, partition_by, order_by).apply_ref_elements(f)
}
Expr::InList(InList { expr, list, .. }) => {
- let mut expr_vec = vec![expr.as_ref()];
- expr_vec.extend(list);
- expr_vec
+ (expr, list).apply_ref_elements(f)
}
- };
-
- children.into_iter().apply_until_stop(f)
+ }
}
/// Maps each child of `self` using the provided closure `f`.
@@ -148,137 +120,103 @@ impl TreeNode for Expr {
| Expr::ScalarSubquery(_)
| Expr::ScalarVariable(_, _)
| Expr::Literal(_) => Transformed::no(self),
- Expr::Unnest(Unnest { expr, .. }) => transform_box(expr, &mut f)?
- .update_data(|be| Expr::Unnest(Unnest::new_boxed(be))),
+ Expr::Unnest(Unnest { expr, .. }) => expr
+ .map_elements(f)?
+ .update_data(|expr| Expr::Unnest(Unnest { expr })),
Expr::Alias(Alias {
expr,
relation,
name,
- }) => f(*expr)?.update_data(|e| Expr::Alias(Alias::new(e,
relation, name))),
+ }) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)),
Expr::InSubquery(InSubquery {
expr,
subquery,
negated,
- }) => transform_box(expr, &mut f)?.update_data(|be| {
+ }) => expr.map_elements(f)?.update_data(|be| {
Expr::InSubquery(InSubquery::new(be, subquery, negated))
}),
- Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
- map_until_stop_and_collect!(
- transform_box(left, &mut f),
- right,
- transform_box(right, &mut f)
- )?
+ Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
+ .map_elements(f)?
.update_data(|(new_left, new_right)| {
Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
- })
- }
+ }),
Expr::Like(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
- }) => map_until_stop_and_collect!(
- transform_box(expr, &mut f),
- pattern,
- transform_box(pattern, &mut f)
- )?
- .update_data(|(new_expr, new_pattern)| {
- Expr::Like(Like::new(
- negated,
- new_expr,
- new_pattern,
- escape_char,
- case_insensitive,
- ))
- }),
+ }) => {
+ (expr, pattern)
+ .map_elements(f)?
+ .update_data(|(new_expr, new_pattern)| {
+ Expr::Like(Like::new(
+ negated,
+ new_expr,
+ new_pattern,
+ escape_char,
+ case_insensitive,
+ ))
+ })
+ }
Expr::SimilarTo(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
- }) => map_until_stop_and_collect!(
- transform_box(expr, &mut f),
- pattern,
- transform_box(pattern, &mut f)
- )?
- .update_data(|(new_expr, new_pattern)| {
- Expr::SimilarTo(Like::new(
- negated,
- new_expr,
- new_pattern,
- escape_char,
- case_insensitive,
- ))
- }),
- Expr::Not(expr) => transform_box(expr, &mut
f)?.update_data(Expr::Not),
- Expr::IsNotNull(expr) => {
- transform_box(expr, &mut f)?.update_data(Expr::IsNotNull)
- }
- Expr::IsNull(expr) => transform_box(expr, &mut
f)?.update_data(Expr::IsNull),
- Expr::IsTrue(expr) => transform_box(expr, &mut
f)?.update_data(Expr::IsTrue),
- Expr::IsFalse(expr) => {
- transform_box(expr, &mut f)?.update_data(Expr::IsFalse)
- }
- Expr::IsUnknown(expr) => {
- transform_box(expr, &mut f)?.update_data(Expr::IsUnknown)
- }
- Expr::IsNotTrue(expr) => {
- transform_box(expr, &mut f)?.update_data(Expr::IsNotTrue)
- }
- Expr::IsNotFalse(expr) => {
- transform_box(expr, &mut f)?.update_data(Expr::IsNotFalse)
+ }) => {
+ (expr, pattern)
+ .map_elements(f)?
+ .update_data(|(new_expr, new_pattern)| {
+ Expr::SimilarTo(Like::new(
+ negated,
+ new_expr,
+ new_pattern,
+ escape_char,
+ case_insensitive,
+ ))
+ })
}
+ Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
+ Expr::IsNotNull(expr) =>
expr.map_elements(f)?.update_data(Expr::IsNotNull),
+ Expr::IsNull(expr) =>
expr.map_elements(f)?.update_data(Expr::IsNull),
+ Expr::IsTrue(expr) =>
expr.map_elements(f)?.update_data(Expr::IsTrue),
+ Expr::IsFalse(expr) =>
expr.map_elements(f)?.update_data(Expr::IsFalse),
+ Expr::IsUnknown(expr) =>
expr.map_elements(f)?.update_data(Expr::IsUnknown),
+ Expr::IsNotTrue(expr) =>
expr.map_elements(f)?.update_data(Expr::IsNotTrue),
+ Expr::IsNotFalse(expr) =>
expr.map_elements(f)?.update_data(Expr::IsNotFalse),
Expr::IsNotUnknown(expr) => {
- transform_box(expr, &mut f)?.update_data(Expr::IsNotUnknown)
- }
- Expr::Negative(expr) => {
- transform_box(expr, &mut f)?.update_data(Expr::Negative)
+ expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
}
+ Expr::Negative(expr) =>
expr.map_elements(f)?.update_data(Expr::Negative),
Expr::Between(Between {
expr,
negated,
low,
high,
- }) => map_until_stop_and_collect!(
- transform_box(expr, &mut f),
- low,
- transform_box(low, &mut f),
- high,
- transform_box(high, &mut f)
- )?
- .update_data(|(new_expr, new_low, new_high)| {
- Expr::Between(Between::new(new_expr, negated, new_low,
new_high))
- }),
+ }) => (expr, low, high).map_elements(f)?.update_data(
+ |(new_expr, new_low, new_high)| {
+ Expr::Between(Between::new(new_expr, negated, new_low,
new_high))
+ },
+ ),
Expr::Case(Case {
expr,
when_then_expr,
else_expr,
- }) => map_until_stop_and_collect!(
- transform_option_box(expr, &mut f),
- when_then_expr,
- when_then_expr
- .into_iter()
- .map_until_stop_and_collect(|(when, then)| {
- map_until_stop_and_collect!(
- transform_box(when, &mut f),
- then,
- transform_box(then, &mut f)
- )
- }),
- else_expr,
- transform_option_box(else_expr, &mut f)
- )?
- .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
- Expr::Case(Case::new(new_expr, new_when_then_expr,
new_else_expr))
- }),
- Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut
f)?
+ }) => (expr, when_then_expr, else_expr)
+ .map_elements(f)?
+ .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
+ Expr::Case(Case::new(new_expr, new_when_then_expr,
new_else_expr))
+ }),
+ Expr::Cast(Cast { expr, data_type }) => expr
+ .map_elements(f)?
.update_data(|be| Expr::Cast(Cast::new(be, data_type))),
- Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr,
&mut f)?
+ Expr::TryCast(TryCast { expr, data_type }) => expr
+ .map_elements(f)?
.update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
Expr::ScalarFunction(ScalarFunction { func, args }) => {
- transform_vec(args, &mut f)?.map_data(|new_args| {
+ args.map_elements(f)?.map_data(|new_args| {
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
func, new_args,
)))
@@ -291,22 +229,17 @@ impl TreeNode for Expr {
order_by,
window_frame,
null_treatment,
- }) => map_until_stop_and_collect!(
- transform_vec(args, &mut f),
- partition_by,
- transform_vec(partition_by, &mut f),
- order_by,
- transform_sort_vec(order_by, &mut f)
- )?
- .update_data(|(new_args, new_partition_by, new_order_by)| {
- Expr::WindowFunction(WindowFunction::new(fun, new_args))
- .partition_by(new_partition_by)
- .order_by(new_order_by)
- .window_frame(window_frame)
- .null_treatment(null_treatment)
- .build()
- .unwrap()
- }),
+ }) => (args, partition_by, order_by).map_elements(f)?.update_data(
+ |(new_args, new_partition_by, new_order_by)| {
+ Expr::WindowFunction(WindowFunction::new(fun, new_args))
+ .partition_by(new_partition_by)
+ .order_by(new_order_by)
+ .window_frame(window_frame)
+ .null_treatment(null_treatment)
+ .build()
+ .unwrap()
+ },
+ ),
Expr::AggregateFunction(AggregateFunction {
args,
func,
@@ -314,31 +247,27 @@ impl TreeNode for Expr {
filter,
order_by,
null_treatment,
- }) => map_until_stop_and_collect!(
- transform_vec(args, &mut f),
- filter,
- transform_option_box(filter, &mut f),
- order_by,
- transform_sort_option_vec(order_by, &mut f)
- )?
- .map_data(|(new_args, new_filter, new_order_by)| {
- Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
- func,
- new_args,
- distinct,
- new_filter,
- new_order_by,
- null_treatment,
- )))
- })?,
+ }) => (args, filter, order_by).map_elements(f)?.map_data(
+ |(new_args, new_filter, new_order_by)| {
+ Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
+ func,
+ new_args,
+ distinct,
+ new_filter,
+ new_order_by,
+ null_treatment,
+ )))
+ },
+ )?,
Expr::GroupingSet(grouping_set) => match grouping_set {
- GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)?
+ GroupingSet::Rollup(exprs) => exprs
+ .map_elements(f)?
.update_data(|ve|
Expr::GroupingSet(GroupingSet::Rollup(ve))),
- GroupingSet::Cube(exprs) => transform_vec(exprs, &mut f)?
+ GroupingSet::Cube(exprs) => exprs
+ .map_elements(f)?
.update_data(|ve|
Expr::GroupingSet(GroupingSet::Cube(ve))),
GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
- .into_iter()
- .map_until_stop_and_collect(|exprs| transform_vec(exprs,
&mut f))?
+ .map_elements(f)?
.update_data(|new_lists_of_exprs| {
Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
}),
@@ -347,70 +276,11 @@ impl TreeNode for Expr {
expr,
list,
negated,
- }) => map_until_stop_and_collect!(
- transform_box(expr, &mut f),
- list,
- transform_vec(list, &mut f)
- )?
- .update_data(|(new_expr, new_list)| {
- Expr::InList(InList::new(new_expr, new_list, negated))
- }),
+ }) => (expr, list)
+ .map_elements(f)?
+ .update_data(|(new_expr, new_list)| {
+ Expr::InList(InList::new(new_expr, new_list, negated))
+ }),
})
}
}
-
-/// Transforms a boxed expression by applying the provided closure `f`.
-fn transform_box<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
- be: Box<Expr>,
- f: &mut F,
-) -> Result<Transformed<Box<Expr>>> {
- Ok(f(*be)?.update_data(Box::new))
-}
-
-/// Transforms an optional boxed expression by applying the provided closure
`f`.
-fn transform_option_box<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
- obe: Option<Box<Expr>>,
- f: &mut F,
-) -> Result<Transformed<Option<Box<Expr>>>> {
- obe.map_or(Ok(Transformed::no(None)), |be| {
- Ok(transform_box(be, f)?.update_data(Some))
- })
-}
-
-/// &mut transform a Option<`Vec` of `Expr`s>
-pub fn transform_option_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
- ove: Option<Vec<Expr>>,
- f: &mut F,
-) -> Result<Transformed<Option<Vec<Expr>>>> {
- ove.map_or(Ok(Transformed::no(None)), |ve| {
- Ok(transform_vec(ve, f)?.update_data(Some))
- })
-}
-
-/// &mut transform a `Vec` of `Expr`s
-fn transform_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
- ve: Vec<Expr>,
- f: &mut F,
-) -> Result<Transformed<Vec<Expr>>> {
- ve.into_iter().map_until_stop_and_collect(f)
-}
-
-/// Transforms an optional vector of sort expressions by applying the provided
closure `f`.
-pub fn transform_sort_option_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
- sorts_option: Option<Vec<Sort>>,
- f: &mut F,
-) -> Result<Transformed<Option<Vec<Sort>>>> {
- sorts_option.map_or(Ok(Transformed::no(None)), |sorts| {
- Ok(transform_sort_vec(sorts, f)?.update_data(Some))
- })
-}
-
-/// Transforms an vector of sort expressions by applying the provided closure
`f`.
-pub fn transform_sort_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
- sorts: Vec<Sort>,
- f: &mut F,
-) -> Result<Transformed<Vec<Sort>>> {
- sorts.into_iter().map_until_stop_and_collect(|s| {
- Ok(f(s.expr)?.update_data(|e| Sort { expr: e, ..s }))
- })
-}
diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs
b/datafusion/optimizer/src/optimize_projections/mod.rs
index b659e477f6..1519c54dbf 100644
--- a/datafusion/optimizer/src/optimize_projections/mod.rs
+++ b/datafusion/optimizer/src/optimize_projections/mod.rs
@@ -39,7 +39,7 @@ use datafusion_expr::{
use crate::optimize_projections::required_indices::RequiredIndicies;
use crate::utils::NamePreserver;
use datafusion_common::tree_node::{
- Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion,
+ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
};
/// Optimizer rule to prune unnecessary columns from intermediate schemas
@@ -484,7 +484,7 @@ fn merge_consecutive_projections(proj: Projection) ->
Result<Transformed<Project
// previous projection as input:
let name_preserver = NamePreserver::new_for_projection();
let mut original_names = vec![];
- let new_exprs = expr.into_iter().map_until_stop_and_collect(|expr| {
+ let new_exprs = expr.map_elements(|expr| {
original_names.push(name_preserver.save(&expr));
// do not rewrite top level Aliases (rewriter will remove all aliases
within exprs)
diff --git a/datafusion/sql/src/unparser/rewrite.rs
b/datafusion/sql/src/unparser/rewrite.rs
index 6b3b999ba0..68af121a41 100644
--- a/datafusion/sql/src/unparser/rewrite.rs
+++ b/datafusion/sql/src/unparser/rewrite.rs
@@ -18,11 +18,12 @@
use std::{collections::HashSet, sync::Arc};
use arrow_schema::Schema;
+use datafusion_common::tree_node::TreeNodeContainer;
use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
Column, HashMap, Result, TableReference,
};
-use datafusion_expr::{expr::Alias, tree_node::transform_sort_vec};
+use datafusion_expr::expr::Alias;
use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr};
use sqlparser::ast::Ident;
@@ -83,17 +84,18 @@ pub(super) fn normalize_union_schema(plan: &LogicalPlan) ->
Result<LogicalPlan>
/// Rewrite sort expressions that have a UNION plan as their input to remove
the table reference.
fn rewrite_sort_expr_for_union(exprs: Vec<SortExpr>) -> Result<Vec<SortExpr>> {
- let sort_exprs = transform_sort_vec(exprs, &mut |expr| {
- expr.transform_up(|expr| {
- if let Expr::Column(mut col) = expr {
- col.relation = None;
- Ok(Transformed::yes(Expr::Column(col)))
- } else {
- Ok(Transformed::no(expr))
- }
+ let sort_exprs = exprs
+ .map_elements(&mut |expr: Expr| {
+ expr.transform_up(|expr| {
+ if let Expr::Column(mut col) = expr {
+ col.relation = None;
+ Ok(Transformed::yes(Expr::Column(col)))
+ } else {
+ Ok(Transformed::no(expr))
+ }
+ })
})
- })
- .data()?;
+ .data()?;
Ok(sort_exprs)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]