This is an automated email from the ASF dual-hosted git repository.

berkay pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new aef232b1ac Add `Container` trait and to simplify `Expr` and 
`LogicalPlan` apply and map methods (#13467)
aef232b1ac is described below

commit aef232b1ac559fc1597a64d9f5f75c2f29f4c286
Author: Peter Toth <[email protected]>
AuthorDate: Wed Nov 20 08:29:54 2024 +0100

    Add `Container` trait and to simplify `Expr` and `LogicalPlan` apply and 
map methods (#13467)
    
    * Add `Container` trait and its blanket implementations, remove 
`map_until_stop_and_collect` macro, simplify apply and map logic with 
`Container`s where possible
    
    * fix clippy
    
    * rename `Container` to `TreeNodeContainer`
    
    * add docs to containers
    
    * clarify when we need a temporary `TreeNodeRefContainer`
    
    * code and docs cleanup
---
 datafusion/common/src/tree_node.rs                 | 363 +++++++++++++++++---
 datafusion/expr/src/expr.rs                        |  36 +-
 datafusion/expr/src/logical_plan/ddl.rs            |  50 ++-
 datafusion/expr/src/logical_plan/plan.rs           |  20 +-
 datafusion/expr/src/logical_plan/statement.rs      |  51 +--
 datafusion/expr/src/logical_plan/tree_node.rs      | 347 ++++++++-----------
 datafusion/expr/src/tree_node.rs                   | 372 +++++++--------------
 .../optimizer/src/optimize_projections/mod.rs      |   4 +-
 datafusion/sql/src/unparser/rewrite.rs             |  24 +-
 9 files changed, 687 insertions(+), 580 deletions(-)

diff --git a/datafusion/common/src/tree_node.rs 
b/datafusion/common/src/tree_node.rs
index c8ec7f1833..0c153583e3 100644
--- a/datafusion/common/src/tree_node.rs
+++ b/datafusion/common/src/tree_node.rs
@@ -17,11 +17,12 @@
 
 //! [`TreeNode`] for visiting and rewriting expression and plan trees
 
+use crate::Result;
 use recursive::recursive;
+use std::collections::HashMap;
+use std::hash::Hash;
 use std::sync::Arc;
 
-use crate::Result;
-
 /// These macros are used to determine continuation during transforming 
traversals.
 macro_rules! handle_transform_recursion {
     ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{
@@ -769,6 +770,297 @@ impl<T> Transformed<T> {
     }
 }
 
+/// [`TreeNodeContainer`] contains elements that a function can be applied on 
or mapped.
+/// The elements of the container are siblings so the continuation rules are 
similar to
+/// [`TreeNodeRecursion::visit_sibling`] / [`Transformed::transform_sibling`].
+pub trait TreeNodeContainer<'a, T: 'a>: Sized {
+    /// Applies `f` to all elements of the container.
+    /// This method is usually called from [`TreeNode::apply_children`] 
implementations as
+    /// a node is actually a container of the node's children.
+    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        f: F,
+    ) -> Result<TreeNodeRecursion>;
+
+    /// Maps all elements of the container with `f`.
+    /// This method is usually called from [`TreeNode::map_children`] 
implementations as
+    /// a node is actually a container of the node's children.
+    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+        self,
+        f: F,
+    ) -> Result<Transformed<Self>>;
+}
+
+impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for 
Box<C> {
+    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        f: F,
+    ) -> Result<TreeNodeRecursion> {
+        self.as_ref().apply_elements(f)
+    }
+
+    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+        self,
+        f: F,
+    ) -> Result<Transformed<Self>> {
+        (*self).map_elements(f)?.map_data(|c| Ok(Self::new(c)))
+    }
+}
+
+impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Clone> TreeNodeContainer<'a, T> 
for Arc<C> {
+    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        f: F,
+    ) -> Result<TreeNodeRecursion> {
+        self.as_ref().apply_elements(f)
+    }
+
+    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+        self,
+        f: F,
+    ) -> Result<Transformed<Self>> {
+        Arc::unwrap_or_clone(self)
+            .map_elements(f)?
+            .map_data(|c| Ok(Arc::new(c)))
+    }
+}
+
+impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for 
Option<C> {
+    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        f: F,
+    ) -> Result<TreeNodeRecursion> {
+        match self {
+            Some(t) => t.apply_elements(f),
+            None => Ok(TreeNodeRecursion::Continue),
+        }
+    }
+
+    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+        self,
+        f: F,
+    ) -> Result<Transformed<Self>> {
+        self.map_or(Ok(Transformed::no(None)), |c| {
+            c.map_elements(f)?.map_data(|c| Ok(Some(c)))
+        })
+    }
+}
+
+impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for 
Vec<C> {
+    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        mut f: F,
+    ) -> Result<TreeNodeRecursion> {
+        let mut tnr = TreeNodeRecursion::Continue;
+        for c in self {
+            tnr = c.apply_elements(&mut f)?;
+            match tnr {
+                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
+                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
+            }
+        }
+        Ok(tnr)
+    }
+
+    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+        self,
+        mut f: F,
+    ) -> Result<Transformed<Self>> {
+        let mut tnr = TreeNodeRecursion::Continue;
+        let mut transformed = false;
+        self.into_iter()
+            .map(|c| match tnr {
+                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
+                    c.map_elements(&mut f).map(|result| {
+                        tnr = result.tnr;
+                        transformed |= result.transformed;
+                        result.data
+                    })
+                }
+                TreeNodeRecursion::Stop => Ok(c),
+            })
+            .collect::<Result<Vec<_>>>()
+            .map(|data| Transformed::new(data, transformed, tnr))
+    }
+}
+
+impl<'a, T: 'a, K: Eq + Hash, C: TreeNodeContainer<'a, T>> 
TreeNodeContainer<'a, T>
+    for HashMap<K, C>
+{
+    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        mut f: F,
+    ) -> Result<TreeNodeRecursion> {
+        let mut tnr = TreeNodeRecursion::Continue;
+        for c in self.values() {
+            tnr = c.apply_elements(&mut f)?;
+            match tnr {
+                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
+                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
+            }
+        }
+        Ok(tnr)
+    }
+
+    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+        self,
+        mut f: F,
+    ) -> Result<Transformed<Self>> {
+        let mut tnr = TreeNodeRecursion::Continue;
+        let mut transformed = false;
+        self.into_iter()
+            .map(|(k, c)| match tnr {
+                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
+                    c.map_elements(&mut f).map(|result| {
+                        tnr = result.tnr;
+                        transformed |= result.transformed;
+                        (k, result.data)
+                    })
+                }
+                TreeNodeRecursion::Stop => Ok((k, c)),
+            })
+            .collect::<Result<HashMap<_, _>>>()
+            .map(|data| Transformed::new(data, transformed, tnr))
+    }
+}
+
+impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>>
+    TreeNodeContainer<'a, T> for (C0, C1)
+{
+    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        mut f: F,
+    ) -> Result<TreeNodeRecursion> {
+        self.0
+            .apply_elements(&mut f)?
+            .visit_sibling(|| self.1.apply_elements(&mut f))
+    }
+
+    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+        self,
+        mut f: F,
+    ) -> Result<Transformed<Self>> {
+        self.0
+            .map_elements(&mut f)?
+            .map_data(|new_c0| Ok((new_c0, self.1)))?
+            .transform_sibling(|(new_c0, c1)| {
+                c1.map_elements(&mut f)?
+                    .map_data(|new_c1| Ok((new_c0, new_c1)))
+            })
+    }
+}
+
+impl<
+        'a,
+        T: 'a,
+        C0: TreeNodeContainer<'a, T>,
+        C1: TreeNodeContainer<'a, T>,
+        C2: TreeNodeContainer<'a, T>,
+    > TreeNodeContainer<'a, T> for (C0, C1, C2)
+{
+    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        mut f: F,
+    ) -> Result<TreeNodeRecursion> {
+        self.0
+            .apply_elements(&mut f)?
+            .visit_sibling(|| self.1.apply_elements(&mut f))?
+            .visit_sibling(|| self.2.apply_elements(&mut f))
+    }
+
+    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
+        self,
+        mut f: F,
+    ) -> Result<Transformed<Self>> {
+        self.0
+            .map_elements(&mut f)?
+            .map_data(|new_c0| Ok((new_c0, self.1, self.2)))?
+            .transform_sibling(|(new_c0, c1, c2)| {
+                c1.map_elements(&mut f)?
+                    .map_data(|new_c1| Ok((new_c0, new_c1, c2)))
+            })?
+            .transform_sibling(|(new_c0, new_c1, c2)| {
+                c2.map_elements(&mut f)?
+                    .map_data(|new_c2| Ok((new_c0, new_c1, new_c2)))
+            })
+    }
+}
+
+/// [`TreeNodeRefContainer`] contains references to elements that a function 
can be
+/// applied on. The elements of the container are siblings so the continuation 
rules are
+/// similar to [`TreeNodeRecursion::visit_sibling`].
+///
+/// This container is similar to [`TreeNodeContainer`], but the lifetime of 
the reference
+/// elements (`T`) are not derived from the container's lifetime.
+/// A typical usage of this container is in `Expr::apply_children` when we 
need to
+/// construct a temporary container to be able to call `apply_ref_elements` on 
a
+/// collection of tree node references. But in that case the container's 
temporary
+/// lifetime is different to the lifetime of tree nodes that we put into it.
+/// Please find an example usecase in `Expr::apply_children` with the 
`Expr::Case` case.
+///
+/// Most of the cases we don't need to create a temporary container with
+/// `TreeNodeRefContainer`, but we can just call 
`TreeNodeContainer::apply_elements`.
+/// Please find an example usecase in `Expr::apply_children` with the 
`Expr::GroupingSet`
+/// case.
+pub trait TreeNodeRefContainer<'a, T: 'a>: Sized {
+    /// Applies `f` to all elements of the container.
+    /// This method is usually called from [`TreeNode::apply_children`] 
implementations as
+    /// a node is actually a container of the node's children.
+    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+        &self,
+        f: F,
+    ) -> Result<TreeNodeRecursion>;
+}
+
+impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeRefContainer<'a, T> for 
Vec<&'a C> {
+    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+        &self,
+        mut f: F,
+    ) -> Result<TreeNodeRecursion> {
+        let mut tnr = TreeNodeRecursion::Continue;
+        for c in self {
+            tnr = c.apply_elements(&mut f)?;
+            match tnr {
+                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
+                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
+            }
+        }
+        Ok(tnr)
+    }
+}
+
+impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>>
+    TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1)
+{
+    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+        &self,
+        mut f: F,
+    ) -> Result<TreeNodeRecursion> {
+        self.0
+            .apply_elements(&mut f)?
+            .visit_sibling(|| self.1.apply_elements(&mut f))
+    }
+}
+
+impl<
+        'a,
+        T: 'a,
+        C0: TreeNodeContainer<'a, T>,
+        C1: TreeNodeContainer<'a, T>,
+        C2: TreeNodeContainer<'a, T>,
+    > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2)
+{
+    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
+        &self,
+        mut f: F,
+    ) -> Result<TreeNodeRecursion> {
+        self.0
+            .apply_elements(&mut f)?
+            .visit_sibling(|| self.1.apply_elements(&mut f))?
+            .visit_sibling(|| self.2.apply_elements(&mut f))
+    }
+}
+
 /// Transformation helper to process a sequence of iterable tree nodes that 
are siblings.
 pub trait TreeNodeIterator: Iterator {
     /// Apples `f` to each item in this iterator
@@ -843,50 +1135,6 @@ impl<I: Iterator> TreeNodeIterator for I {
     }
 }
 
-/// Transformation helper to process a heterogeneous sequence of tree node 
containing
-/// expressions.
-///
-/// This macro is very similar to 
[TreeNodeIterator::map_until_stop_and_collect] to
-/// process nodes that are siblings, but it accepts an initial transformation 
(`F0`) and
-/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its
-/// transformation (`F`).
-///
-/// The macro builds up a tuple that contains `Transformed.data` result of 
`F0` as the
-/// first element and further elements from the sequence of pairs. An element 
from a pair
-/// is either the value of `EXPR` or the `Transformed.data` result of `F`, 
depending on
-/// the `Transformed.tnr` result of previous `F`s (`F0` initially).
-///
-/// # Returns
-/// Error if any of the transformations returns an error
-///
-/// Ok(Transformed<(data0, ..., dataN)>) such that:
-/// 1. `transformed` is true if any of the transformations had transformed true
-/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from 
`F0` and
-///     `data1` ... `dataN` are from either `EXPR` or the `Transformed.data` 
of `F`
-/// 3. `tnr` from `F0` or the last invocation of `F`
-#[macro_export]
-macro_rules! map_until_stop_and_collect {
-    ($F0:expr, $($EXPR:expr, $F:expr),*) => {{
-        $F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| {
-            let all_datas = (
-                data0,
-                $(
-                    if tnr == TreeNodeRecursion::Continue || tnr == 
TreeNodeRecursion::Jump {
-                        $F.map(|result| {
-                            tnr = result.tnr;
-                            transformed |= result.transformed;
-                            result.data
-                        })?
-                    } else {
-                        $EXPR
-                    },
-                )*
-            );
-            Ok(Transformed::new(all_datas, transformed, tnr))
-        })
-    }}
-}
-
 /// Transformation helper to access [`Transformed`] fields in a [`Result`] 
easily.
 ///
 /// # Example
@@ -1021,7 +1269,7 @@ pub(crate) mod tests {
     use std::fmt::Display;
 
     use crate::tree_node::{
-        Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, 
TreeNodeRewriter,
+        Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, 
TreeNodeRewriter,
         TreeNodeVisitor,
     };
     use crate::Result;
@@ -1054,7 +1302,7 @@ pub(crate) mod tests {
             &'n self,
             f: F,
         ) -> Result<TreeNodeRecursion> {
-            self.children.iter().apply_until_stop(f)
+            self.children.apply_elements(f)
         }
 
         fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
@@ -1063,8 +1311,7 @@ pub(crate) mod tests {
         ) -> Result<Transformed<Self>> {
             Ok(self
                 .children
-                .into_iter()
-                .map_until_stop_and_collect(f)?
+                .map_elements(f)?
                 .update_data(|new_children| Self {
                     children: new_children,
                     ..self
@@ -1072,6 +1319,22 @@ pub(crate) mod tests {
         }
     }
 
+    impl<'a, T: 'a> TreeNodeContainer<'a, Self> for TestTreeNode<T> {
+        fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
+            &'a self,
+            mut f: F,
+        ) -> Result<TreeNodeRecursion> {
+            f(self)
+        }
+
+        fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
+            self,
+            mut f: F,
+        ) -> Result<Transformed<Self>> {
+            f(self)
+        }
+    }
+
     //       J
     //       |
     //       I
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 83d35c3d25..8490c08a70 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -32,7 +32,7 @@ use crate::{udaf, ExprSchemable, Operator, Signature, 
WindowFrame, WindowUDF};
 use arrow::datatypes::{DataType, FieldRef};
 use datafusion_common::cse::HashNode;
 use datafusion_common::tree_node::{
-    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
+    Transformed, TransformedResult, TreeNode, TreeNodeContainer, 
TreeNodeRecursion,
 };
 use datafusion_common::{
     plan_err, Column, DFSchema, HashMap, Result, ScalarValue, TableReference,
@@ -351,6 +351,22 @@ impl<'a> From<(Option<&'a TableReference>, &'a FieldRef)> 
for Expr {
     }
 }
 
+impl<'a> TreeNodeContainer<'a, Self> for Expr {
+    fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        mut f: F,
+    ) -> Result<TreeNodeRecursion> {
+        f(self)
+    }
+
+    fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
+        self,
+        mut f: F,
+    ) -> Result<Transformed<Self>> {
+        f(self)
+    }
+}
+
 /// UNNEST expression.
 #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
 pub struct Unnest {
@@ -653,6 +669,24 @@ impl Display for Sort {
     }
 }
 
+impl<'a> TreeNodeContainer<'a, Expr> for Sort {
+    fn apply_elements<F: FnMut(&'a Expr) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        f: F,
+    ) -> Result<TreeNodeRecursion> {
+        self.expr.apply_elements(f)
+    }
+
+    fn map_elements<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
+        self,
+        f: F,
+    ) -> Result<Transformed<Self>> {
+        self.expr
+            .map_elements(f)?
+            .map_data(|expr| Ok(Self { expr, ..self }))
+    }
+}
+
 /// Aggregate function
 ///
 /// See also  [`ExprFunctionExt`] to set these fields on `Expr`
diff --git a/datafusion/expr/src/logical_plan/ddl.rs 
b/datafusion/expr/src/logical_plan/ddl.rs
index 93e8b5fd04..8c64a01798 100644
--- a/datafusion/expr/src/logical_plan/ddl.rs
+++ b/datafusion/expr/src/logical_plan/ddl.rs
@@ -26,7 +26,10 @@ use std::{
 
 use crate::expr::Sort;
 use arrow::datatypes::DataType;
-use datafusion_common::{Constraints, DFSchemaRef, SchemaReference, 
TableReference};
+use datafusion_common::tree_node::{Transformed, TreeNodeContainer, 
TreeNodeRecursion};
+use datafusion_common::{
+    Constraints, DFSchemaRef, Result, SchemaReference, TableReference,
+};
 use sqlparser::ast::Ident;
 
 /// Various types of DDL  (CREATE / DROP) catalog manipulation
@@ -487,6 +490,28 @@ pub struct OperateFunctionArg {
     pub data_type: DataType,
     pub default_expr: Option<Expr>,
 }
+
+impl<'a> TreeNodeContainer<'a, Expr> for OperateFunctionArg {
+    fn apply_elements<F: FnMut(&'a Expr) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        f: F,
+    ) -> Result<TreeNodeRecursion> {
+        self.default_expr.apply_elements(f)
+    }
+
+    fn map_elements<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
+        self,
+        f: F,
+    ) -> Result<Transformed<Self>> {
+        self.default_expr.map_elements(f)?.map_data(|default_expr| {
+            Ok(Self {
+                default_expr,
+                ..self
+            })
+        })
+    }
+}
+
 #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
 pub struct CreateFunctionBody {
     /// LANGUAGE lang_name
@@ -497,6 +522,29 @@ pub struct CreateFunctionBody {
     pub function_body: Option<Expr>,
 }
 
+impl<'a> TreeNodeContainer<'a, Expr> for CreateFunctionBody {
+    fn apply_elements<F: FnMut(&'a Expr) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        f: F,
+    ) -> Result<TreeNodeRecursion> {
+        self.function_body.apply_elements(f)
+    }
+
+    fn map_elements<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
+        self,
+        f: F,
+    ) -> Result<Transformed<Self>> {
+        self.function_body
+            .map_elements(f)?
+            .map_data(|function_body| {
+                Ok(Self {
+                    function_body,
+                    ..self
+                })
+            })
+    }
+}
+
 #[derive(Clone, PartialEq, Eq, Hash, Debug)]
 pub struct DropFunction {
     pub name: String,
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 6ee99b22c7..e9f4f1f809 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -45,7 +45,9 @@ use crate::{
 };
 
 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
-use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
+use datafusion_common::tree_node::{
+    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
+};
 use datafusion_common::{
     aggregate_functional_dependencies, internal_err, plan_err, Column, 
Constraints,
     DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence,
@@ -287,6 +289,22 @@ impl Default for LogicalPlan {
     }
 }
 
+impl<'a> TreeNodeContainer<'a, Self> for LogicalPlan {
+    fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
+        &'a self,
+        mut f: F,
+    ) -> Result<TreeNodeRecursion> {
+        f(self)
+    }
+
+    fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
+        self,
+        mut f: F,
+    ) -> Result<Transformed<Self>> {
+        f(self)
+    }
+}
+
 impl LogicalPlan {
     /// Get a reference to the logical plan's schema
     pub fn schema(&self) -> &DFSchemaRef {
diff --git a/datafusion/expr/src/logical_plan/statement.rs 
b/datafusion/expr/src/logical_plan/statement.rs
index 05e2b1af14..26df379f5e 100644
--- a/datafusion/expr/src/logical_plan/statement.rs
+++ b/datafusion/expr/src/logical_plan/statement.rs
@@ -16,12 +16,10 @@
 // under the License.
 
 use arrow::datatypes::DataType;
-use datafusion_common::tree_node::{Transformed, TreeNodeIterator};
-use datafusion_common::{DFSchema, DFSchemaRef, Result};
+use datafusion_common::{DFSchema, DFSchemaRef};
 use std::fmt::{self, Display};
 use std::sync::{Arc, OnceLock};
 
-use super::tree_node::rewrite_arc;
 use crate::{expr_vec_fmt, Expr, LogicalPlan};
 
 /// Statements have a unchanging empty schema.
@@ -80,53 +78,6 @@ impl Statement {
         }
     }
 
-    /// Rewrites input LogicalPlans in the current `Statement` using `f`.
-    pub(super) fn map_inputs<
-        F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
-    >(
-        self,
-        f: F,
-    ) -> Result<Transformed<Self>> {
-        match self {
-            Statement::Prepare(Prepare {
-                input,
-                name,
-                data_types,
-            }) => Ok(rewrite_arc(input, f)?.update_data(|input| {
-                Statement::Prepare(Prepare {
-                    input,
-                    name,
-                    data_types,
-                })
-            })),
-            _ => Ok(Transformed::no(self)),
-        }
-    }
-
-    /// Returns a iterator over all expressions in the current `Statement`.
-    pub(super) fn expression_iter(&self) -> impl Iterator<Item = &Expr> {
-        match self {
-            Statement::Execute(Execute { parameters, .. }) => 
parameters.iter(),
-            _ => [].iter(),
-        }
-    }
-
-    /// Rewrites all expressions in the current `Statement` using `f`.
-    pub(super) fn map_expressions<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
-        self,
-        f: F,
-    ) -> Result<Transformed<Self>> {
-        match self {
-            Statement::Execute(Execute { name, parameters }) => Ok(parameters
-                .into_iter()
-                .map_until_stop_and_collect(f)?
-                .update_data(|parameters| {
-                    Statement::Execute(Execute { parameters, name })
-                })),
-            _ => Ok(Transformed::no(self)),
-        }
-    }
-
     /// Return a `format`able structure with the a human readable
     /// description of this LogicalPlan node per node, not including
     /// children.
diff --git a/datafusion/expr/src/logical_plan/tree_node.rs 
b/datafusion/expr/src/logical_plan/tree_node.rs
index e7dfe87919..6850c30f4f 100644
--- a/datafusion/expr/src/logical_plan/tree_node.rs
+++ b/datafusion/expr/src/logical_plan/tree_node.rs
@@ -36,32 +36,30 @@
 //! (Re)creation APIs (these require substantial cloning and thus are slow):
 //! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different 
expressions
 //! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions
+
 use crate::{
     dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, 
DdlStatement,
-    Distinct, DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, 
Join, Limit,
-    LogicalPlan, Partitioning, Projection, RecursiveQuery, Repartition, Sort, 
Subquery,
-    SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, 
Window,
+    Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, 
Filter, Join,
+    Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, 
Repartition,
+    Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest,
+    UserDefinedLogicalNode, Values, Window,
 };
+use datafusion_common::tree_node::TreeNodeRefContainer;
 use recursive::recursive;
-use std::ops::Deref;
-use std::sync::Arc;
 
 use crate::expr::{Exists, InSubquery};
-use crate::tree_node::{transform_sort_option_vec, transform_sort_vec};
 use datafusion_common::tree_node::{
-    Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, 
TreeNodeRewriter,
-    TreeNodeVisitor,
-};
-use datafusion_common::{
-    internal_err, map_until_stop_and_collect, DataFusionError, Result,
+    Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, 
TreeNodeRecursion,
+    TreeNodeRewriter, TreeNodeVisitor,
 };
+use datafusion_common::{internal_err, Result};
 
 impl TreeNode for LogicalPlan {
     fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
         &'n self,
         f: F,
     ) -> Result<TreeNodeRecursion> {
-        self.inputs().into_iter().apply_until_stop(f)
+        self.inputs().apply_ref_elements(f)
     }
 
     /// Applies `f` to each child (input) of this plan node, rewriting them 
*in place.*
@@ -74,14 +72,14 @@ impl TreeNode for LogicalPlan {
     /// [`Expr::Exists`]: crate::Expr::Exists
     fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
         self,
-        mut f: F,
+        f: F,
     ) -> Result<Transformed<Self>> {
         Ok(match self {
             LogicalPlan::Projection(Projection {
                 expr,
                 input,
                 schema,
-            }) => rewrite_arc(input, f)?.update_data(|input| {
+            }) => input.map_elements(f)?.update_data(|input| {
                 LogicalPlan::Projection(Projection {
                     expr,
                     input,
@@ -92,7 +90,7 @@ impl TreeNode for LogicalPlan {
                 predicate,
                 input,
                 having,
-            }) => rewrite_arc(input, f)?.update_data(|input| {
+            }) => input.map_elements(f)?.update_data(|input| {
                 LogicalPlan::Filter(Filter {
                     predicate,
                     input,
@@ -102,7 +100,7 @@ impl TreeNode for LogicalPlan {
             LogicalPlan::Repartition(Repartition {
                 input,
                 partitioning_scheme,
-            }) => rewrite_arc(input, f)?.update_data(|input| {
+            }) => input.map_elements(f)?.update_data(|input| {
                 LogicalPlan::Repartition(Repartition {
                     input,
                     partitioning_scheme,
@@ -112,7 +110,7 @@ impl TreeNode for LogicalPlan {
                 input,
                 window_expr,
                 schema,
-            }) => rewrite_arc(input, f)?.update_data(|input| {
+            }) => input.map_elements(f)?.update_data(|input| {
                 LogicalPlan::Window(Window {
                     input,
                     window_expr,
@@ -124,7 +122,7 @@ impl TreeNode for LogicalPlan {
                 group_expr,
                 aggr_expr,
                 schema,
-            }) => rewrite_arc(input, f)?.update_data(|input| {
+            }) => input.map_elements(f)?.update_data(|input| {
                 LogicalPlan::Aggregate(Aggregate {
                     input,
                     group_expr,
@@ -132,7 +130,8 @@ impl TreeNode for LogicalPlan {
                     schema,
                 })
             }),
-            LogicalPlan::Sort(Sort { expr, input, fetch }) => 
rewrite_arc(input, f)?
+            LogicalPlan::Sort(Sort { expr, input, fetch }) => input
+                .map_elements(f)?
                 .update_data(|input| LogicalPlan::Sort(Sort { expr, input, 
fetch })),
             LogicalPlan::Join(Join {
                 left,
@@ -143,12 +142,7 @@ impl TreeNode for LogicalPlan {
                 join_constraint,
                 schema,
                 null_equals_null,
-            }) => map_until_stop_and_collect!(
-                rewrite_arc(left, &mut f),
-                right,
-                rewrite_arc(right, &mut f)
-            )?
-            .update_data(|(left, right)| {
+            }) => (left, right).map_elements(f)?.update_data(|(left, right)| {
                 LogicalPlan::Join(Join {
                     left,
                     right,
@@ -160,12 +154,13 @@ impl TreeNode for LogicalPlan {
                     null_equals_null,
                 })
             }),
-            LogicalPlan::Limit(Limit { skip, fetch, input }) => 
rewrite_arc(input, f)?
+            LogicalPlan::Limit(Limit { skip, fetch, input }) => input
+                .map_elements(f)?
                 .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, 
input })),
             LogicalPlan::Subquery(Subquery {
                 subquery,
                 outer_ref_columns,
-            }) => rewrite_arc(subquery, f)?.update_data(|subquery| {
+            }) => subquery.map_elements(f)?.update_data(|subquery| {
                 LogicalPlan::Subquery(Subquery {
                     subquery,
                     outer_ref_columns,
@@ -175,7 +170,7 @@ impl TreeNode for LogicalPlan {
                 input,
                 alias,
                 schema,
-            }) => rewrite_arc(input, f)?.update_data(|input| {
+            }) => input.map_elements(f)?.update_data(|input| {
                 LogicalPlan::SubqueryAlias(SubqueryAlias {
                     input,
                     alias,
@@ -184,17 +179,18 @@ impl TreeNode for LogicalPlan {
             }),
             LogicalPlan::Extension(extension) => 
rewrite_extension_inputs(extension, f)?
                 .update_data(LogicalPlan::Extension),
-            LogicalPlan::Union(Union { inputs, schema }) => 
rewrite_arcs(inputs, f)?
+            LogicalPlan::Union(Union { inputs, schema }) => inputs
+                .map_elements(f)?
                 .update_data(|inputs| LogicalPlan::Union(Union { inputs, 
schema })),
             LogicalPlan::Distinct(distinct) => match distinct {
-                Distinct::All(input) => rewrite_arc(input, 
f)?.update_data(Distinct::All),
+                Distinct::All(input) => 
input.map_elements(f)?.update_data(Distinct::All),
                 Distinct::On(DistinctOn {
                     on_expr,
                     select_expr,
                     sort_expr,
                     input,
                     schema,
-                }) => rewrite_arc(input, f)?.update_data(|input| {
+                }) => input.map_elements(f)?.update_data(|input| {
                     Distinct::On(DistinctOn {
                         on_expr,
                         select_expr,
@@ -211,7 +207,7 @@ impl TreeNode for LogicalPlan {
                 stringified_plans,
                 schema,
                 logical_optimization_succeeded,
-            }) => rewrite_arc(plan, f)?.update_data(|plan| {
+            }) => plan.map_elements(f)?.update_data(|plan| {
                 LogicalPlan::Explain(Explain {
                     verbose,
                     plan,
@@ -224,7 +220,7 @@ impl TreeNode for LogicalPlan {
                 verbose,
                 input,
                 schema,
-            }) => rewrite_arc(input, f)?.update_data(|input| {
+            }) => input.map_elements(f)?.update_data(|input| {
                 LogicalPlan::Analyze(Analyze {
                     verbose,
                     input,
@@ -237,7 +233,7 @@ impl TreeNode for LogicalPlan {
                 op,
                 input,
                 output_schema,
-            }) => rewrite_arc(input, f)?.update_data(|input| {
+            }) => input.map_elements(f)?.update_data(|input| {
                 LogicalPlan::Dml(DmlStatement {
                     table_name,
                     table_schema,
@@ -252,7 +248,7 @@ impl TreeNode for LogicalPlan {
                 partition_by,
                 file_type,
                 options,
-            }) => rewrite_arc(input, f)?.update_data(|input| {
+            }) => input.map_elements(f)?.update_data(|input| {
                 LogicalPlan::Copy(CopyTo {
                     input,
                     output_url,
@@ -271,7 +267,7 @@ impl TreeNode for LogicalPlan {
                         or_replace,
                         column_defaults,
                         temporary,
-                    }) => rewrite_arc(input, f)?.update_data(|input| {
+                    }) => input.map_elements(f)?.update_data(|input| {
                         DdlStatement::CreateMemoryTable(CreateMemoryTable {
                             name,
                             constraints,
@@ -288,7 +284,7 @@ impl TreeNode for LogicalPlan {
                         or_replace,
                         definition,
                         temporary,
-                    }) => rewrite_arc(input, f)?.update_data(|input| {
+                    }) => input.map_elements(f)?.update_data(|input| {
                         DdlStatement::CreateView(CreateView {
                             name,
                             input,
@@ -318,7 +314,7 @@ impl TreeNode for LogicalPlan {
                 dependency_indices,
                 schema,
                 options,
-            }) => rewrite_arc(input, f)?.update_data(|input| {
+            }) => input.map_elements(f)?.update_data(|input| {
                 LogicalPlan::Unnest(Unnest {
                     input,
                     exec_columns: input_columns,
@@ -334,22 +330,24 @@ impl TreeNode for LogicalPlan {
                 static_term,
                 recursive_term,
                 is_distinct,
-            }) => map_until_stop_and_collect!(
-                rewrite_arc(static_term, &mut f),
-                recursive_term,
-                rewrite_arc(recursive_term, &mut f)
-            )?
-            .update_data(|(static_term, recursive_term)| {
-                LogicalPlan::RecursiveQuery(RecursiveQuery {
-                    name,
-                    static_term,
-                    recursive_term,
-                    is_distinct,
-                })
-            }),
-            LogicalPlan::Statement(stmt) => {
-                stmt.map_inputs(f)?.update_data(LogicalPlan::Statement)
+            }) => (static_term, recursive_term).map_elements(f)?.update_data(
+                |(static_term, recursive_term)| {
+                    LogicalPlan::RecursiveQuery(RecursiveQuery {
+                        name,
+                        static_term,
+                        recursive_term,
+                        is_distinct,
+                    })
+                },
+            ),
+            LogicalPlan::Statement(stmt) => match stmt {
+                Statement::Prepare(p) => p
+                    .input
+                    .map_elements(f)?
+                    .update_data(|input| Statement::Prepare(Prepare { input, 
..p })),
+                _ => Transformed::no(stmt),
             }
+            .update_data(LogicalPlan::Statement),
             // plans without inputs
             LogicalPlan::TableScan { .. }
             | LogicalPlan::EmptyRelation { .. }
@@ -359,24 +357,6 @@ impl TreeNode for LogicalPlan {
     }
 }
 
-/// Applies `f` to rewrite a `Arc<LogicalPlan>` without copying, if possible
-pub(super) fn rewrite_arc<F: FnMut(LogicalPlan) -> 
Result<Transformed<LogicalPlan>>>(
-    plan: Arc<LogicalPlan>,
-    mut f: F,
-) -> Result<Transformed<Arc<LogicalPlan>>> {
-    f(Arc::unwrap_or_clone(plan))?.map_data(|new_plan| Ok(Arc::new(new_plan)))
-}
-
-/// rewrite a `Vec` of `Arc<LogicalPlan>` without copying, if possible
-fn rewrite_arcs<F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>>(
-    input_plans: Vec<Arc<LogicalPlan>>,
-    mut f: F,
-) -> Result<Transformed<Vec<Arc<LogicalPlan>>>> {
-    input_plans
-        .into_iter()
-        .map_until_stop_and_collect(|plan| rewrite_arc(plan, &mut f))
-}
-
 /// Rewrites all inputs for an Extension node "in place"
 /// (it currently has to copy values because there are no APIs for in place 
modification)
 ///
@@ -423,54 +403,40 @@ impl LogicalPlan {
         mut f: F,
     ) -> Result<TreeNodeRecursion> {
         match self {
-            LogicalPlan::Projection(Projection { expr, .. }) => {
-                expr.iter().apply_until_stop(f)
-            }
-            LogicalPlan::Values(Values { values, .. }) => values
-                .iter()
-                .apply_until_stop(|value| value.iter().apply_until_stop(&mut 
f)),
+            LogicalPlan::Projection(Projection { expr, .. }) => 
expr.apply_elements(f),
+            LogicalPlan::Values(Values { values, .. }) => 
values.apply_elements(f),
             LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate),
             LogicalPlan::Repartition(Repartition {
                 partitioning_scheme,
                 ..
             }) => match partitioning_scheme {
                 Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) 
=> {
-                    expr.iter().apply_until_stop(f)
+                    expr.apply_elements(f)
                 }
                 Partitioning::RoundRobinBatch(_) => 
Ok(TreeNodeRecursion::Continue),
             },
             LogicalPlan::Window(Window { window_expr, .. }) => {
-                window_expr.iter().apply_until_stop(f)
+                window_expr.apply_elements(f)
             }
             LogicalPlan::Aggregate(Aggregate {
                 group_expr,
                 aggr_expr,
                 ..
-            }) => group_expr
-                .iter()
-                .chain(aggr_expr.iter())
-                .apply_until_stop(f),
+            }) => (group_expr, aggr_expr).apply_ref_elements(f),
             // There are two part of expression for join, equijoin(on) and 
non-equijoin(filter).
             // 1. the first part is `on.len()` equijoin expressions, and the 
struct of each expr is `left-on = right-on`.
             // 2. the second part is non-equijoin(filter).
             LogicalPlan::Join(Join { on, filter, .. }) => {
-                on.iter()
-                    // TODO: why we need to create an `Expr::eq`? Cloning 
`Expr` is costly...
-                    // it not ideal to create an expr here to analyze them, 
but could cache it on the Join itself
-                    .map(|(l, r)| Expr::eq(l.clone(), r.clone()))
-                    .apply_until_stop(|e| f(&e))?
-                    .visit_sibling(|| filter.iter().apply_until_stop(f))
-            }
-            LogicalPlan::Sort(Sort { expr, .. }) => {
-                expr.iter().apply_until_stop(|sort| f(&sort.expr))
+                (on, filter).apply_ref_elements(f)
             }
+            LogicalPlan::Sort(Sort { expr, .. }) => expr.apply_elements(f),
             LogicalPlan::Extension(extension) => {
                 // would be nice to avoid this copy -- maybe can
                 // update extension to just observer Exprs
-                extension.node.expressions().iter().apply_until_stop(f)
+                extension.node.expressions().apply_elements(f)
             }
             LogicalPlan::TableScan(TableScan { filters, .. }) => {
-                filters.iter().apply_until_stop(f)
+                filters.apply_elements(f)
             }
             LogicalPlan::Unnest(unnest) => {
                 let columns = unnest.exec_columns.clone();
@@ -479,24 +445,23 @@ impl LogicalPlan {
                     .iter()
                     .map(|c| Expr::Column(c.clone()))
                     .collect::<Vec<_>>();
-                exprs.iter().apply_until_stop(f)
+                exprs.apply_elements(f)
             }
             LogicalPlan::Distinct(Distinct::On(DistinctOn {
                 on_expr,
                 select_expr,
                 sort_expr,
                 ..
-            })) => on_expr
-                .iter()
-                .chain(select_expr.iter())
-                .chain(sort_expr.iter().flatten().map(|sort| &sort.expr))
-                .apply_until_stop(f),
-            LogicalPlan::Limit(Limit { skip, fetch, .. }) => skip
-                .iter()
-                .chain(fetch.iter())
-                .map(|e| e.deref())
-                .apply_until_stop(f),
-            LogicalPlan::Statement(stmt) => 
stmt.expression_iter().apply_until_stop(f),
+            })) => (on_expr, select_expr, sort_expr).apply_ref_elements(f),
+            LogicalPlan::Limit(Limit { skip, fetch, .. }) => {
+                (skip, fetch).apply_ref_elements(f)
+            }
+            LogicalPlan::Statement(stmt) => match stmt {
+                Statement::Execute(Execute { parameters, .. }) => {
+                    parameters.apply_elements(f)
+                }
+                _ => Ok(TreeNodeRecursion::Continue),
+            },
             // plans without expressions
             LogicalPlan::EmptyRelation(_)
             | LogicalPlan::RecursiveQuery(_)
@@ -529,21 +494,15 @@ impl LogicalPlan {
                 expr,
                 input,
                 schema,
-            }) => expr
-                .into_iter()
-                .map_until_stop_and_collect(f)?
-                .update_data(|expr| {
-                    LogicalPlan::Projection(Projection {
-                        expr,
-                        input,
-                        schema,
-                    })
-                }),
+            }) => expr.map_elements(f)?.update_data(|expr| {
+                LogicalPlan::Projection(Projection {
+                    expr,
+                    input,
+                    schema,
+                })
+            }),
             LogicalPlan::Values(Values { schema, values }) => values
-                .into_iter()
-                .map_until_stop_and_collect(|value| {
-                    value.into_iter().map_until_stop_and_collect(&mut f)
-                })?
+                .map_elements(f)?
                 .update_data(|values| LogicalPlan::Values(Values { schema, 
values })),
             LogicalPlan::Filter(Filter {
                 predicate,
@@ -561,12 +520,10 @@ impl LogicalPlan {
                 partitioning_scheme,
             }) => match partitioning_scheme {
                 Partitioning::Hash(expr, usize) => expr
-                    .into_iter()
-                    .map_until_stop_and_collect(f)?
+                    .map_elements(f)?
                     .update_data(|expr| Partitioning::Hash(expr, usize)),
                 Partitioning::DistributeBy(expr) => expr
-                    .into_iter()
-                    .map_until_stop_and_collect(f)?
+                    .map_elements(f)?
                     .update_data(Partitioning::DistributeBy),
                 Partitioning::RoundRobinBatch(_) => 
Transformed::no(partitioning_scheme),
             }
@@ -580,34 +537,28 @@ impl LogicalPlan {
                 input,
                 window_expr,
                 schema,
-            }) => window_expr
-                .into_iter()
-                .map_until_stop_and_collect(f)?
-                .update_data(|window_expr| {
-                    LogicalPlan::Window(Window {
-                        input,
-                        window_expr,
-                        schema,
-                    })
-                }),
+            }) => window_expr.map_elements(f)?.update_data(|window_expr| {
+                LogicalPlan::Window(Window {
+                    input,
+                    window_expr,
+                    schema,
+                })
+            }),
             LogicalPlan::Aggregate(Aggregate {
                 input,
                 group_expr,
                 aggr_expr,
                 schema,
-            }) => map_until_stop_and_collect!(
-                group_expr.into_iter().map_until_stop_and_collect(&mut f),
-                aggr_expr,
-                aggr_expr.into_iter().map_until_stop_and_collect(&mut f)
-            )?
-            .update_data(|(group_expr, aggr_expr)| {
-                LogicalPlan::Aggregate(Aggregate {
-                    input,
-                    group_expr,
-                    aggr_expr,
-                    schema,
-                })
-            }),
+            }) => (group_expr, aggr_expr).map_elements(f)?.update_data(
+                |(group_expr, aggr_expr)| {
+                    LogicalPlan::Aggregate(Aggregate {
+                        input,
+                        group_expr,
+                        aggr_expr,
+                        schema,
+                    })
+                },
+            ),
 
             // There are two part of expression for join, equijoin(on) and 
non-equijoin(filter).
             // 1. the first part is `on.len()` equijoin expressions, and the 
struct of each expr is `left-on = right-on`.
@@ -621,16 +572,7 @@ impl LogicalPlan {
                 join_constraint,
                 schema,
                 null_equals_null,
-            }) => map_until_stop_and_collect!(
-                on.into_iter().map_until_stop_and_collect(
-                    |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1))
-                ),
-                filter,
-                filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), 
|e| {
-                    Ok(f(e)?.update_data(Some))
-                })
-            )?
-            .update_data(|(on, filter)| {
+            }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| {
                 LogicalPlan::Join(Join {
                     left,
                     right,
@@ -642,17 +584,13 @@ impl LogicalPlan {
                     null_equals_null,
                 })
             }),
-            LogicalPlan::Sort(Sort { expr, input, fetch }) => {
-                transform_sort_vec(expr, &mut f)?
-                    .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, 
fetch }))
-            }
+            LogicalPlan::Sort(Sort { expr, input, fetch }) => expr
+                .map_elements(f)?
+                .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, 
fetch })),
             LogicalPlan::Extension(Extension { node }) => {
                 // would be nice to avoid this copy -- maybe can
                 // update extension to just observer Exprs
-                let exprs = node
-                    .expressions()
-                    .into_iter()
-                    .map_until_stop_and_collect(f)?;
+                let exprs = node.expressions().map_elements(f)?;
                 let plan = LogicalPlan::Extension(Extension {
                     node: UserDefinedLogicalNode::with_exprs_and_inputs(
                         node.as_ref(),
@@ -669,64 +607,47 @@ impl LogicalPlan {
                 projected_schema,
                 filters,
                 fetch,
-            }) => filters
-                .into_iter()
-                .map_until_stop_and_collect(f)?
-                .update_data(|filters| {
-                    LogicalPlan::TableScan(TableScan {
-                        table_name,
-                        source,
-                        projection,
-                        projected_schema,
-                        filters,
-                        fetch,
-                    })
-                }),
+            }) => filters.map_elements(f)?.update_data(|filters| {
+                LogicalPlan::TableScan(TableScan {
+                    table_name,
+                    source,
+                    projection,
+                    projected_schema,
+                    filters,
+                    fetch,
+                })
+            }),
             LogicalPlan::Distinct(Distinct::On(DistinctOn {
                 on_expr,
                 select_expr,
                 sort_expr,
                 input,
                 schema,
-            })) => map_until_stop_and_collect!(
-                on_expr.into_iter().map_until_stop_and_collect(&mut f),
-                select_expr,
-                select_expr.into_iter().map_until_stop_and_collect(&mut f),
-                sort_expr,
-                transform_sort_option_vec(sort_expr, &mut f)
-            )?
-            .update_data(|(on_expr, select_expr, sort_expr)| {
-                LogicalPlan::Distinct(Distinct::On(DistinctOn {
-                    on_expr,
-                    select_expr,
-                    sort_expr,
-                    input,
-                    schema,
-                }))
-            }),
-            LogicalPlan::Limit(Limit { skip, fetch, input }) => {
-                let skip = skip.map(|e| *e);
-                let fetch = fetch.map(|e| *e);
-                map_until_stop_and_collect!(
-                    skip.map_or(Ok::<_, 
DataFusionError>(Transformed::no(None)), |e| {
-                        Ok(f(e)?.update_data(Some))
-                    }),
-                    fetch,
-                    fetch.map_or(Ok::<_, 
DataFusionError>(Transformed::no(None)), |e| {
-                        Ok(f(e)?.update_data(Some))
-                    })
-                )?
-                .update_data(|(skip, fetch)| {
-                    LogicalPlan::Limit(Limit {
-                        skip: skip.map(Box::new),
-                        fetch: fetch.map(Box::new),
+            })) => (on_expr, select_expr, sort_expr)
+                .map_elements(f)?
+                .update_data(|(on_expr, select_expr, sort_expr)| {
+                    LogicalPlan::Distinct(Distinct::On(DistinctOn {
+                        on_expr,
+                        select_expr,
+                        sort_expr,
                         input,
-                    })
+                        schema,
+                    }))
+                }),
+            LogicalPlan::Limit(Limit { skip, fetch, input }) => {
+                (skip, fetch).map_elements(f)?.update_data(|(skip, fetch)| {
+                    LogicalPlan::Limit(Limit { skip, fetch, input })
                 })
             }
-            LogicalPlan::Statement(stmt) => {
-                stmt.map_expressions(f)?.update_data(LogicalPlan::Statement)
+            LogicalPlan::Statement(stmt) => match stmt {
+                Statement::Execute(e) => {
+                    e.parameters.map_elements(f)?.update_data(|parameters| {
+                        Statement::Execute(Execute { parameters, ..e })
+                    })
+                }
+                _ => Transformed::no(stmt),
             }
+            .update_data(LogicalPlan::Statement),
             // plans without expressions
             LogicalPlan::EmptyRelation(_)
             | LogicalPlan::Unnest(_)
diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs
index e964091aae..eacace5ed0 100644
--- a/datafusion/expr/src/tree_node.rs
+++ b/datafusion/expr/src/tree_node.rs
@@ -19,14 +19,14 @@
 
 use crate::expr::{
     AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, 
InList,
-    InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, 
WindowFunction,
+    InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, 
WindowFunction,
 };
 use crate::{Expr, ExprFunctionExt};
 
 use datafusion_common::tree_node::{
-    Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion,
+    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, 
TreeNodeRefContainer,
 };
-use datafusion_common::{map_until_stop_and_collect, Result};
+use datafusion_common::Result;
 
 /// Implementation of the [`TreeNode`] trait
 ///
@@ -42,9 +42,9 @@ impl TreeNode for Expr {
         &'n self,
         f: F,
     ) -> Result<TreeNodeRecursion> {
-        let children = match self {
-            Expr::Alias(Alias{expr,..})
-            | Expr::Unnest(Unnest{expr})
+        match self {
+            Expr::Alias(Alias { expr, .. })
+            | Expr::Unnest(Unnest { expr })
             | Expr::Not(expr)
             | Expr::IsNotNull(expr)
             | Expr::IsTrue(expr)
@@ -57,78 +57,50 @@ impl TreeNode for Expr {
             | Expr::Negative(expr)
             | Expr::Cast(Cast { expr, .. })
             | Expr::TryCast(TryCast { expr, .. })
-            | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()],
+            | Expr::InSubquery(InSubquery { expr, .. }) => 
expr.apply_elements(f),
             Expr::GroupingSet(GroupingSet::Rollup(exprs))
-            | Expr::GroupingSet(GroupingSet::Cube(exprs)) => 
exprs.iter().collect(),
-            Expr::ScalarFunction (ScalarFunction{ args, .. } )  => {
-                args.iter().collect()
+            | Expr::GroupingSet(GroupingSet::Cube(exprs)) => 
exprs.apply_elements(f),
+            Expr::ScalarFunction(ScalarFunction { args, .. }) => {
+                args.apply_elements(f)
             }
             Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
-                lists_of_exprs.iter().flatten().collect()
+                lists_of_exprs.apply_elements(f)
             }
             Expr::Column(_)
             // Treat OuterReferenceColumn as a leaf expression
             | Expr::OuterReferenceColumn(_, _)
             | Expr::ScalarVariable(_, _)
             | Expr::Literal(_)
-            | Expr::Exists {..}
+            | Expr::Exists { .. }
             | Expr::ScalarSubquery(_)
-            | Expr::Wildcard {..}
-            | Expr::Placeholder (_) => vec![],
+            | Expr::Wildcard { .. }
+            | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue),
             Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
-                vec![left.as_ref(), right.as_ref()]
+                (left, right).apply_ref_elements(f)
             }
             Expr::Like(Like { expr, pattern, .. })
             | Expr::SimilarTo(Like { expr, pattern, .. }) => {
-                vec![expr.as_ref(), pattern.as_ref()]
+                (expr, pattern).apply_ref_elements(f)
             }
             Expr::Between(Between {
-                expr, low, high, ..
-            }) => vec![expr.as_ref(), low.as_ref(), high.as_ref()],
-            Expr::Case(case) => {
-                let mut expr_vec = vec![];
-                if let Some(expr) = case.expr.as_ref() {
-                    expr_vec.push(expr.as_ref());
-                };
-                for (when, then) in case.when_then_expr.iter() {
-                    expr_vec.push(when.as_ref());
-                    expr_vec.push(then.as_ref());
-                }
-                if let Some(else_expr) = case.else_expr.as_ref() {
-                    expr_vec.push(else_expr.as_ref());
-                }
-                expr_vec
-            }
-            Expr::AggregateFunction(AggregateFunction { args, filter, 
order_by, .. })
-             => {
-                let mut expr_vec = args.iter().collect::<Vec<_>>();
-                if let Some(f) = filter {
-                    expr_vec.push(f.as_ref());
-                }
-                if let Some(order_by) = order_by {
-                    expr_vec.extend(order_by.iter().map(|sort| &sort.expr));
-                }
-                expr_vec
-            }
+                              expr, low, high, ..
+                          }) => (expr, low, high).apply_ref_elements(f),
+            Expr::Case(Case { expr, when_then_expr, else_expr }) =>
+                (expr, when_then_expr, else_expr).apply_ref_elements(f),
+            Expr::AggregateFunction(AggregateFunction { args, filter, 
order_by, .. }) =>
+                (args, filter, order_by).apply_ref_elements(f),
             Expr::WindowFunction(WindowFunction {
-                args,
-                partition_by,
-                order_by,
-                ..
-            }) => {
-                let mut expr_vec = args.iter().collect::<Vec<_>>();
-                expr_vec.extend(partition_by);
-                expr_vec.extend(order_by.iter().map(|sort| &sort.expr));
-                expr_vec
+                                     args,
+                                     partition_by,
+                                     order_by,
+                                     ..
+                                 }) => {
+                (args, partition_by, order_by).apply_ref_elements(f)
             }
             Expr::InList(InList { expr, list, .. }) => {
-                let mut expr_vec = vec![expr.as_ref()];
-                expr_vec.extend(list);
-                expr_vec
+                (expr, list).apply_ref_elements(f)
             }
-        };
-
-        children.into_iter().apply_until_stop(f)
+        }
     }
 
     /// Maps each child of `self` using the provided closure `f`.
@@ -148,137 +120,103 @@ impl TreeNode for Expr {
             | Expr::ScalarSubquery(_)
             | Expr::ScalarVariable(_, _)
             | Expr::Literal(_) => Transformed::no(self),
-            Expr::Unnest(Unnest { expr, .. }) => transform_box(expr, &mut f)?
-                .update_data(|be| Expr::Unnest(Unnest::new_boxed(be))),
+            Expr::Unnest(Unnest { expr, .. }) => expr
+                .map_elements(f)?
+                .update_data(|expr| Expr::Unnest(Unnest { expr })),
             Expr::Alias(Alias {
                 expr,
                 relation,
                 name,
-            }) => f(*expr)?.update_data(|e| Expr::Alias(Alias::new(e, 
relation, name))),
+            }) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)),
             Expr::InSubquery(InSubquery {
                 expr,
                 subquery,
                 negated,
-            }) => transform_box(expr, &mut f)?.update_data(|be| {
+            }) => expr.map_elements(f)?.update_data(|be| {
                 Expr::InSubquery(InSubquery::new(be, subquery, negated))
             }),
-            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
-                map_until_stop_and_collect!(
-                    transform_box(left, &mut f),
-                    right,
-                    transform_box(right, &mut f)
-                )?
+            Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
+                .map_elements(f)?
                 .update_data(|(new_left, new_right)| {
                     Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
-                })
-            }
+                }),
             Expr::Like(Like {
                 negated,
                 expr,
                 pattern,
                 escape_char,
                 case_insensitive,
-            }) => map_until_stop_and_collect!(
-                transform_box(expr, &mut f),
-                pattern,
-                transform_box(pattern, &mut f)
-            )?
-            .update_data(|(new_expr, new_pattern)| {
-                Expr::Like(Like::new(
-                    negated,
-                    new_expr,
-                    new_pattern,
-                    escape_char,
-                    case_insensitive,
-                ))
-            }),
+            }) => {
+                (expr, pattern)
+                    .map_elements(f)?
+                    .update_data(|(new_expr, new_pattern)| {
+                        Expr::Like(Like::new(
+                            negated,
+                            new_expr,
+                            new_pattern,
+                            escape_char,
+                            case_insensitive,
+                        ))
+                    })
+            }
             Expr::SimilarTo(Like {
                 negated,
                 expr,
                 pattern,
                 escape_char,
                 case_insensitive,
-            }) => map_until_stop_and_collect!(
-                transform_box(expr, &mut f),
-                pattern,
-                transform_box(pattern, &mut f)
-            )?
-            .update_data(|(new_expr, new_pattern)| {
-                Expr::SimilarTo(Like::new(
-                    negated,
-                    new_expr,
-                    new_pattern,
-                    escape_char,
-                    case_insensitive,
-                ))
-            }),
-            Expr::Not(expr) => transform_box(expr, &mut 
f)?.update_data(Expr::Not),
-            Expr::IsNotNull(expr) => {
-                transform_box(expr, &mut f)?.update_data(Expr::IsNotNull)
-            }
-            Expr::IsNull(expr) => transform_box(expr, &mut 
f)?.update_data(Expr::IsNull),
-            Expr::IsTrue(expr) => transform_box(expr, &mut 
f)?.update_data(Expr::IsTrue),
-            Expr::IsFalse(expr) => {
-                transform_box(expr, &mut f)?.update_data(Expr::IsFalse)
-            }
-            Expr::IsUnknown(expr) => {
-                transform_box(expr, &mut f)?.update_data(Expr::IsUnknown)
-            }
-            Expr::IsNotTrue(expr) => {
-                transform_box(expr, &mut f)?.update_data(Expr::IsNotTrue)
-            }
-            Expr::IsNotFalse(expr) => {
-                transform_box(expr, &mut f)?.update_data(Expr::IsNotFalse)
+            }) => {
+                (expr, pattern)
+                    .map_elements(f)?
+                    .update_data(|(new_expr, new_pattern)| {
+                        Expr::SimilarTo(Like::new(
+                            negated,
+                            new_expr,
+                            new_pattern,
+                            escape_char,
+                            case_insensitive,
+                        ))
+                    })
             }
+            Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
+            Expr::IsNotNull(expr) => 
expr.map_elements(f)?.update_data(Expr::IsNotNull),
+            Expr::IsNull(expr) => 
expr.map_elements(f)?.update_data(Expr::IsNull),
+            Expr::IsTrue(expr) => 
expr.map_elements(f)?.update_data(Expr::IsTrue),
+            Expr::IsFalse(expr) => 
expr.map_elements(f)?.update_data(Expr::IsFalse),
+            Expr::IsUnknown(expr) => 
expr.map_elements(f)?.update_data(Expr::IsUnknown),
+            Expr::IsNotTrue(expr) => 
expr.map_elements(f)?.update_data(Expr::IsNotTrue),
+            Expr::IsNotFalse(expr) => 
expr.map_elements(f)?.update_data(Expr::IsNotFalse),
             Expr::IsNotUnknown(expr) => {
-                transform_box(expr, &mut f)?.update_data(Expr::IsNotUnknown)
-            }
-            Expr::Negative(expr) => {
-                transform_box(expr, &mut f)?.update_data(Expr::Negative)
+                expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
             }
+            Expr::Negative(expr) => 
expr.map_elements(f)?.update_data(Expr::Negative),
             Expr::Between(Between {
                 expr,
                 negated,
                 low,
                 high,
-            }) => map_until_stop_and_collect!(
-                transform_box(expr, &mut f),
-                low,
-                transform_box(low, &mut f),
-                high,
-                transform_box(high, &mut f)
-            )?
-            .update_data(|(new_expr, new_low, new_high)| {
-                Expr::Between(Between::new(new_expr, negated, new_low, 
new_high))
-            }),
+            }) => (expr, low, high).map_elements(f)?.update_data(
+                |(new_expr, new_low, new_high)| {
+                    Expr::Between(Between::new(new_expr, negated, new_low, 
new_high))
+                },
+            ),
             Expr::Case(Case {
                 expr,
                 when_then_expr,
                 else_expr,
-            }) => map_until_stop_and_collect!(
-                transform_option_box(expr, &mut f),
-                when_then_expr,
-                when_then_expr
-                    .into_iter()
-                    .map_until_stop_and_collect(|(when, then)| {
-                        map_until_stop_and_collect!(
-                            transform_box(when, &mut f),
-                            then,
-                            transform_box(then, &mut f)
-                        )
-                    }),
-                else_expr,
-                transform_option_box(else_expr, &mut f)
-            )?
-            .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
-                Expr::Case(Case::new(new_expr, new_when_then_expr, 
new_else_expr))
-            }),
-            Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut 
f)?
+            }) => (expr, when_then_expr, else_expr)
+                .map_elements(f)?
+                .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
+                    Expr::Case(Case::new(new_expr, new_when_then_expr, 
new_else_expr))
+                }),
+            Expr::Cast(Cast { expr, data_type }) => expr
+                .map_elements(f)?
                 .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
-            Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, 
&mut f)?
+            Expr::TryCast(TryCast { expr, data_type }) => expr
+                .map_elements(f)?
                 .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
             Expr::ScalarFunction(ScalarFunction { func, args }) => {
-                transform_vec(args, &mut f)?.map_data(|new_args| {
+                args.map_elements(f)?.map_data(|new_args| {
                     Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
                         func, new_args,
                     )))
@@ -291,22 +229,17 @@ impl TreeNode for Expr {
                 order_by,
                 window_frame,
                 null_treatment,
-            }) => map_until_stop_and_collect!(
-                transform_vec(args, &mut f),
-                partition_by,
-                transform_vec(partition_by, &mut f),
-                order_by,
-                transform_sort_vec(order_by, &mut f)
-            )?
-            .update_data(|(new_args, new_partition_by, new_order_by)| {
-                Expr::WindowFunction(WindowFunction::new(fun, new_args))
-                    .partition_by(new_partition_by)
-                    .order_by(new_order_by)
-                    .window_frame(window_frame)
-                    .null_treatment(null_treatment)
-                    .build()
-                    .unwrap()
-            }),
+            }) => (args, partition_by, order_by).map_elements(f)?.update_data(
+                |(new_args, new_partition_by, new_order_by)| {
+                    Expr::WindowFunction(WindowFunction::new(fun, new_args))
+                        .partition_by(new_partition_by)
+                        .order_by(new_order_by)
+                        .window_frame(window_frame)
+                        .null_treatment(null_treatment)
+                        .build()
+                        .unwrap()
+                },
+            ),
             Expr::AggregateFunction(AggregateFunction {
                 args,
                 func,
@@ -314,31 +247,27 @@ impl TreeNode for Expr {
                 filter,
                 order_by,
                 null_treatment,
-            }) => map_until_stop_and_collect!(
-                transform_vec(args, &mut f),
-                filter,
-                transform_option_box(filter, &mut f),
-                order_by,
-                transform_sort_option_vec(order_by, &mut f)
-            )?
-            .map_data(|(new_args, new_filter, new_order_by)| {
-                Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
-                    func,
-                    new_args,
-                    distinct,
-                    new_filter,
-                    new_order_by,
-                    null_treatment,
-                )))
-            })?,
+            }) => (args, filter, order_by).map_elements(f)?.map_data(
+                |(new_args, new_filter, new_order_by)| {
+                    Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
+                        func,
+                        new_args,
+                        distinct,
+                        new_filter,
+                        new_order_by,
+                        null_treatment,
+                    )))
+                },
+            )?,
             Expr::GroupingSet(grouping_set) => match grouping_set {
-                GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)?
+                GroupingSet::Rollup(exprs) => exprs
+                    .map_elements(f)?
                     .update_data(|ve| 
Expr::GroupingSet(GroupingSet::Rollup(ve))),
-                GroupingSet::Cube(exprs) => transform_vec(exprs, &mut f)?
+                GroupingSet::Cube(exprs) => exprs
+                    .map_elements(f)?
                     .update_data(|ve| 
Expr::GroupingSet(GroupingSet::Cube(ve))),
                 GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
-                    .into_iter()
-                    .map_until_stop_and_collect(|exprs| transform_vec(exprs, 
&mut f))?
+                    .map_elements(f)?
                     .update_data(|new_lists_of_exprs| {
                         
Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
                     }),
@@ -347,70 +276,11 @@ impl TreeNode for Expr {
                 expr,
                 list,
                 negated,
-            }) => map_until_stop_and_collect!(
-                transform_box(expr, &mut f),
-                list,
-                transform_vec(list, &mut f)
-            )?
-            .update_data(|(new_expr, new_list)| {
-                Expr::InList(InList::new(new_expr, new_list, negated))
-            }),
+            }) => (expr, list)
+                .map_elements(f)?
+                .update_data(|(new_expr, new_list)| {
+                    Expr::InList(InList::new(new_expr, new_list, negated))
+                }),
         })
     }
 }
-
-/// Transforms a boxed expression by applying the provided closure `f`.
-fn transform_box<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
-    be: Box<Expr>,
-    f: &mut F,
-) -> Result<Transformed<Box<Expr>>> {
-    Ok(f(*be)?.update_data(Box::new))
-}
-
-/// Transforms an optional boxed expression by applying the provided closure 
`f`.
-fn transform_option_box<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
-    obe: Option<Box<Expr>>,
-    f: &mut F,
-) -> Result<Transformed<Option<Box<Expr>>>> {
-    obe.map_or(Ok(Transformed::no(None)), |be| {
-        Ok(transform_box(be, f)?.update_data(Some))
-    })
-}
-
-/// &mut transform a Option<`Vec` of `Expr`s>
-pub fn transform_option_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
-    ove: Option<Vec<Expr>>,
-    f: &mut F,
-) -> Result<Transformed<Option<Vec<Expr>>>> {
-    ove.map_or(Ok(Transformed::no(None)), |ve| {
-        Ok(transform_vec(ve, f)?.update_data(Some))
-    })
-}
-
-/// &mut transform a `Vec` of `Expr`s
-fn transform_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
-    ve: Vec<Expr>,
-    f: &mut F,
-) -> Result<Transformed<Vec<Expr>>> {
-    ve.into_iter().map_until_stop_and_collect(f)
-}
-
-/// Transforms an optional vector of sort expressions by applying the provided 
closure `f`.
-pub fn transform_sort_option_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
-    sorts_option: Option<Vec<Sort>>,
-    f: &mut F,
-) -> Result<Transformed<Option<Vec<Sort>>>> {
-    sorts_option.map_or(Ok(Transformed::no(None)), |sorts| {
-        Ok(transform_sort_vec(sorts, f)?.update_data(Some))
-    })
-}
-
-/// Transforms an vector of sort expressions by applying the provided closure 
`f`.
-pub fn transform_sort_vec<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
-    sorts: Vec<Sort>,
-    f: &mut F,
-) -> Result<Transformed<Vec<Sort>>> {
-    sorts.into_iter().map_until_stop_and_collect(|s| {
-        Ok(f(s.expr)?.update_data(|e| Sort { expr: e, ..s }))
-    })
-}
diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs 
b/datafusion/optimizer/src/optimize_projections/mod.rs
index b659e477f6..1519c54dbf 100644
--- a/datafusion/optimizer/src/optimize_projections/mod.rs
+++ b/datafusion/optimizer/src/optimize_projections/mod.rs
@@ -39,7 +39,7 @@ use datafusion_expr::{
 use crate::optimize_projections::required_indices::RequiredIndicies;
 use crate::utils::NamePreserver;
 use datafusion_common::tree_node::{
-    Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion,
+    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
 };
 
 /// Optimizer rule to prune unnecessary columns from intermediate schemas
@@ -484,7 +484,7 @@ fn merge_consecutive_projections(proj: Projection) -> 
Result<Transformed<Project
     // previous projection as input:
     let name_preserver = NamePreserver::new_for_projection();
     let mut original_names = vec![];
-    let new_exprs = expr.into_iter().map_until_stop_and_collect(|expr| {
+    let new_exprs = expr.map_elements(|expr| {
         original_names.push(name_preserver.save(&expr));
 
         // do not rewrite top level Aliases (rewriter will remove all aliases 
within exprs)
diff --git a/datafusion/sql/src/unparser/rewrite.rs 
b/datafusion/sql/src/unparser/rewrite.rs
index 6b3b999ba0..68af121a41 100644
--- a/datafusion/sql/src/unparser/rewrite.rs
+++ b/datafusion/sql/src/unparser/rewrite.rs
@@ -18,11 +18,12 @@
 use std::{collections::HashSet, sync::Arc};
 
 use arrow_schema::Schema;
+use datafusion_common::tree_node::TreeNodeContainer;
 use datafusion_common::{
     tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
     Column, HashMap, Result, TableReference,
 };
-use datafusion_expr::{expr::Alias, tree_node::transform_sort_vec};
+use datafusion_expr::expr::Alias;
 use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr};
 use sqlparser::ast::Ident;
 
@@ -83,17 +84,18 @@ pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> 
Result<LogicalPlan>
 
 /// Rewrite sort expressions that have a UNION plan as their input to remove 
the table reference.
 fn rewrite_sort_expr_for_union(exprs: Vec<SortExpr>) -> Result<Vec<SortExpr>> {
-    let sort_exprs = transform_sort_vec(exprs, &mut |expr| {
-        expr.transform_up(|expr| {
-            if let Expr::Column(mut col) = expr {
-                col.relation = None;
-                Ok(Transformed::yes(Expr::Column(col)))
-            } else {
-                Ok(Transformed::no(expr))
-            }
+    let sort_exprs = exprs
+        .map_elements(&mut |expr: Expr| {
+            expr.transform_up(|expr| {
+                if let Expr::Column(mut col) = expr {
+                    col.relation = None;
+                    Ok(Transformed::yes(Expr::Column(col)))
+                } else {
+                    Ok(Transformed::no(expr))
+                }
+            })
         })
-    })
-    .data()?;
+        .data()?;
 
     Ok(sort_exprs)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to