alamb commented on code in PR #8891:
URL: https://github.com/apache/arrow-datafusion/pull/8891#discussion_r1504351732
##########
datafusion/common/src/tree_node.rs:
##########
@@ -224,100 +319,202 @@ pub trait TreeNode: Sized {
/// tree and makes it easier to add new types of tree node and
/// algorithms.
///
-/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::pre_visit`]
-/// and [`TreeNodeVisitor::post_visit`] are invoked recursively
+/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::f_down`]
+/// and [`TreeNodeVisitor::f_up`] are invoked recursively
/// on an node tree.
///
/// If an [`Err`] result is returned, recursion is stopped
/// immediately.
///
-/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no
+/// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no
/// children of that tree node are visited, nor is post_visit
/// called on that tree node
///
-/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no
+/// If [`TreeNodeRecursion::Stop`] is returned on a call to post_visit, no
/// siblings of that tree node are visited, nor is post_visit
/// called on its parent tree node
///
-/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no
+/// If [`TreeNodeRecursion::Jump`] is returned on a call to pre_visit, no
/// children of that tree node are visited.
pub trait TreeNodeVisitor: Sized {
/// The node type which is visitable.
- type N: TreeNode;
+ type Node: TreeNode;
/// Invoked before any children of `node` are visited.
- fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion>;
+ fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion>;
/// Invoked after all children of `node` are visited. Default
/// implementation does nothing.
- fn post_visit(&mut self, _node: &Self::N) -> Result<VisitRecursion> {
- Ok(VisitRecursion::Continue)
+ fn f_up(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
+ Ok(TreeNodeRecursion::Continue)
}
}
-/// Trait for potentially recursively transform an [`TreeNode`] node
-/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is
-/// invoked recursively on all nodes of a tree.
+/// Trait for potentially recursively transform a [`TreeNode`] node tree.
pub trait TreeNodeRewriter: Sized {
/// The node type which is rewritable.
- type N: TreeNode;
+ type Node: TreeNode;
- /// Invoked before (Preorder) any children of `node` are rewritten /
- /// visited. Default implementation returns `Ok(Recursion::Continue)`
- fn pre_visit(&mut self, _node: &Self::N) -> Result<RewriteRecursion> {
- Ok(RewriteRecursion::Continue)
+ /// Invoked while traversing down the tree before any children are
rewritten /
+ /// visited.
+ /// Default implementation returns the node unmodified and continues
recursion.
+ fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
+ Ok(Transformed::no(node))
}
- /// Invoked after (Postorder) all children of `node` have been mutated and
- /// returns a potentially modified node.
- fn mutate(&mut self, node: Self::N) -> Result<Self::N>;
+ /// Invoked while traversing up the tree after all children have been
rewritten /
+ /// visited.
+ /// Default implementation returns the node unmodified.
+ fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
+ Ok(Transformed::no(node))
+ }
}
-/// Controls how the [`TreeNode`] recursion should proceed for
[`TreeNode::rewrite`].
-#[derive(Debug)]
-pub enum RewriteRecursion {
- /// Continue rewrite this node tree.
+/// Controls how [`TreeNode`] recursions should proceed.
+#[derive(Debug, PartialEq, Clone, Copy)]
+pub enum TreeNodeRecursion {
+ /// Continue recursion with the next node.
Continue,
- /// Call 'op' immediately and return.
- Mutate,
- /// Do not rewrite the children of this node.
- Stop,
- /// Keep recursive but skip apply op on this node
- Skip,
-}
-/// Controls how the [`TreeNode`] recursion should proceed for
[`TreeNode::visit`].
-#[derive(Debug)]
-pub enum VisitRecursion {
- /// Continue the visit to this node tree.
- Continue,
- /// Keep recursive but skip applying op on the children
- Skip,
- /// Stop the visit to this node tree.
+ /// In top-down traversals skip recursing into children but continue with
the next
+ /// node, which actually means pruning of the subtree.
+ /// In bottom-up traversals bypass calling bottom-up closures till the
next leaf node.
+ /// In combined traversals bypass calling bottom-up closures till the next
top-down
+ /// closure.
+ Jump,
+
+ /// Stop recursion.
Stop,
}
-pub enum Transformed<T> {
- /// The item was transformed / rewritten somehow
- Yes(T),
- /// The item was not transformed
- No(T),
+#[derive(PartialEq, Debug)]
+pub struct Transformed<T> {
+ pub data: T,
+ pub transformed: bool,
+ pub tnr: TreeNodeRecursion,
}
impl<T> Transformed<T> {
- pub fn into(self) -> T {
- match self {
- Transformed::Yes(t) => t,
- Transformed::No(t) => t,
+ pub fn new(data: T, transformed: bool, tnr: TreeNodeRecursion) -> Self {
+ Self {
+ data,
+ transformed,
+ tnr,
+ }
+ }
+
+ pub fn yes(data: T) -> Self {
+ Self {
+ data,
+ transformed: true,
+ tnr: TreeNodeRecursion::Continue,
+ }
+ }
+
+ pub fn no(data: T) -> Self {
+ Self {
+ data,
+ transformed: false,
+ tnr: TreeNodeRecursion::Continue,
}
}
- pub fn into_pair(self) -> (T, bool) {
- match self {
- Transformed::Yes(t) => (t, true),
- Transformed::No(t) => (t, false),
+ pub fn map_data<U, F: FnOnce(T) -> U>(self, f: F) -> Transformed<U> {
+ Transformed {
+ data: f(self.data),
+ transformed: self.transformed,
+ tnr: self.tnr,
}
}
+
+ pub fn flat_map_data<U, F: FnOnce(T) -> Result<U>>(
+ self,
+ f: F,
+ ) -> Result<Transformed<U>> {
+ Ok(Transformed {
+ data: f(self.data)?,
+ transformed: self.transformed,
+ tnr: self.tnr,
+ })
+ }
+
+ /// This is an important function to decide about recursion continuation
and
+ /// [`TreeNodeRecursion`] state propagation. Handling
[`TreeNodeRecursion::Continue`]
+ /// and [`TreeNodeRecursion::Stop`] is always straightforward, but
+ /// [`TreeNodeRecursion::Jump`] can behave differently when we are
traversing down or
+ /// up on a tree.
+ fn and_then<F: FnOnce(T) -> Result<Transformed<T>>>(
+ self,
+ f: F,
+ return_on_jump: Option<TreeNodeRecursion>,
+ ) -> Result<Transformed<T>> {
+ match self.tnr {
+ TreeNodeRecursion::Continue => {}
+ TreeNodeRecursion::Jump => {
+ if let Some(tnr) = return_on_jump {
+ return Ok(Transformed { tnr, ..self });
+ }
+ }
+ TreeNodeRecursion::Stop => return Ok(self),
+ };
+ let t = f(self.data)?;
+ Ok(Transformed {
+ transformed: t.transformed || self.transformed,
+ ..t
+ })
+ }
+
+ pub fn and_then_transform<F: FnOnce(T) -> Result<Transformed<T>>>(
+ self,
+ f: F,
+ ) -> Result<Transformed<T>> {
+ self.and_then(f, None)
+ }
+}
+
+pub trait TransformedIterator: Iterator {
Review Comment:
Perhaps we can document this trait as well
##########
datafusion/optimizer/src/analyzer/count_wildcard_rule.rs:
##########
@@ -43,7 +43,7 @@ impl CountWildcardRule {
impl AnalyzerRule for CountWildcardRule {
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) ->
Result<LogicalPlan> {
- plan.transform_down(&analyze_internal)
+ plan.transform_down(&analyze_internal).map(|t| t.data)
Review Comment:
this is a super common pattern (to extract the data out of the transform
result).
What do you think about renaming the existing `transform_down` to
`transform_down_inner()` that returns a `TransformedResult` and add a new
`transform_down()` convenience function. Something like
```rust
fn transform_down(self) -> Result<Node> {
self.transform_down_inner(self).map(|t| t.data)
}
```
We could also provide some convenience function to access the data like (via
a Trait on Result<..>)
```rust
plan.transform_down(&analyze_internal).data()
```
Though that maybe would just mask what is going on and be more confusing
than helpful
##########
datafusion/common/src/tree_node.rs:
##########
@@ -88,132 +177,138 @@ pub trait TreeNode: Sized {
///
/// If an Err result is returned, recursion is stopped immediately
///
- /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no
+ /// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no
/// children of that node will be visited, nor is post_visit
/// called on that node. Details see [`TreeNodeVisitor`]
///
- /// If using the default [`TreeNodeVisitor::post_visit`] that does
+ /// If using the default [`TreeNodeVisitor::f_up`] that does
/// nothing, [`Self::apply`] should be preferred.
- fn visit<V: TreeNodeVisitor<N = Self>>(
+ fn visit<V: TreeNodeVisitor<Node = Self>>(
&self,
visitor: &mut V,
- ) -> Result<VisitRecursion> {
- handle_tree_recursion!(visitor.pre_visit(self)?);
- handle_tree_recursion!(self.apply_children(&mut |node|
node.visit(visitor))?);
- visitor.post_visit(self)
+ ) -> Result<TreeNodeRecursion> {
+ handle_visit_recursion_down!(visitor.f_down(self)?);
+ handle_visit_recursion_up!(self.apply_children(&mut |n|
n.visit(visitor))?);
+ visitor.f_up(self)
Review Comment:
I think either the behavior of the current PR or @berkaysynnada both make
sense to me and I have no particular preference.
However, I recommend we document the behavior clearly in the doc comment,
```rust
/// If [`TreeNodeRecursion::Jump`] is returned by `f_down` then all children
are skipped and `f_up` is not applied
##########
datafusion/common/src/tree_node.rs:
##########
@@ -224,100 +319,202 @@ pub trait TreeNode: Sized {
/// tree and makes it easier to add new types of tree node and
/// algorithms.
///
-/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::pre_visit`]
-/// and [`TreeNodeVisitor::post_visit`] are invoked recursively
+/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::f_down`]
+/// and [`TreeNodeVisitor::f_up`] are invoked recursively
/// on an node tree.
///
/// If an [`Err`] result is returned, recursion is stopped
/// immediately.
///
-/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no
+/// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no
/// children of that tree node are visited, nor is post_visit
/// called on that tree node
///
-/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no
+/// If [`TreeNodeRecursion::Stop`] is returned on a call to post_visit, no
/// siblings of that tree node are visited, nor is post_visit
/// called on its parent tree node
///
-/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no
+/// If [`TreeNodeRecursion::Jump`] is returned on a call to pre_visit, no
/// children of that tree node are visited.
pub trait TreeNodeVisitor: Sized {
/// The node type which is visitable.
- type N: TreeNode;
+ type Node: TreeNode;
/// Invoked before any children of `node` are visited.
- fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion>;
+ fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion>;
/// Invoked after all children of `node` are visited. Default
/// implementation does nothing.
- fn post_visit(&mut self, _node: &Self::N) -> Result<VisitRecursion> {
- Ok(VisitRecursion::Continue)
+ fn f_up(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
+ Ok(TreeNodeRecursion::Continue)
}
}
-/// Trait for potentially recursively transform an [`TreeNode`] node
-/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is
-/// invoked recursively on all nodes of a tree.
+/// Trait for potentially recursively transform a [`TreeNode`] node tree.
pub trait TreeNodeRewriter: Sized {
/// The node type which is rewritable.
- type N: TreeNode;
+ type Node: TreeNode;
- /// Invoked before (Preorder) any children of `node` are rewritten /
- /// visited. Default implementation returns `Ok(Recursion::Continue)`
- fn pre_visit(&mut self, _node: &Self::N) -> Result<RewriteRecursion> {
- Ok(RewriteRecursion::Continue)
+ /// Invoked while traversing down the tree before any children are
rewritten /
+ /// visited.
+ /// Default implementation returns the node unmodified and continues
recursion.
+ fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
+ Ok(Transformed::no(node))
}
- /// Invoked after (Postorder) all children of `node` have been mutated and
- /// returns a potentially modified node.
- fn mutate(&mut self, node: Self::N) -> Result<Self::N>;
+ /// Invoked while traversing up the tree after all children have been
rewritten /
+ /// visited.
+ /// Default implementation returns the node unmodified.
+ fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
+ Ok(Transformed::no(node))
+ }
}
-/// Controls how the [`TreeNode`] recursion should proceed for
[`TreeNode::rewrite`].
-#[derive(Debug)]
-pub enum RewriteRecursion {
- /// Continue rewrite this node tree.
+/// Controls how [`TreeNode`] recursions should proceed.
+#[derive(Debug, PartialEq, Clone, Copy)]
+pub enum TreeNodeRecursion {
+ /// Continue recursion with the next node.
Continue,
- /// Call 'op' immediately and return.
- Mutate,
- /// Do not rewrite the children of this node.
- Stop,
- /// Keep recursive but skip apply op on this node
- Skip,
-}
-/// Controls how the [`TreeNode`] recursion should proceed for
[`TreeNode::visit`].
-#[derive(Debug)]
-pub enum VisitRecursion {
- /// Continue the visit to this node tree.
- Continue,
- /// Keep recursive but skip applying op on the children
- Skip,
- /// Stop the visit to this node tree.
+ /// In top-down traversals skip recursing into children but continue with
the next
+ /// node, which actually means pruning of the subtree.
+ /// In bottom-up traversals bypass calling bottom-up closures till the
next leaf node.
+ /// In combined traversals bypass calling bottom-up closures till the next
top-down
+ /// closure.
+ Jump,
+
+ /// Stop recursion.
Stop,
}
-pub enum Transformed<T> {
- /// The item was transformed / rewritten somehow
- Yes(T),
- /// The item was not transformed
- No(T),
+#[derive(PartialEq, Debug)]
+pub struct Transformed<T> {
Review Comment:
Can we perhaps add some doc comments to this struct and its fields
explaining what it does (aka the return value for
`TreeNodeRewriter::transform`)?
##########
datafusion/common/src/tree_node.rs:
##########
@@ -333,35 +530,45 @@ pub trait DynTreeNode {
&self,
arc_self: Arc<Self>,
new_children: Vec<Arc<Self>>,
- ) -> Result<Arc<Self>>;
+ ) -> Result<Transformed<Arc<Self>>>;
}
/// Blanket implementation for Arc for any tye that implements
/// [`DynTreeNode`] (such as [`Arc<dyn PhysicalExpr>`])
impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
/// Apply the closure `F` to the node's children
- fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
+ fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
where
- F: FnMut(&Self) -> Result<VisitRecursion>,
+ F: FnMut(&Self) -> Result<TreeNodeRecursion>,
{
+ let mut tnr = TreeNodeRecursion::Continue;
for child in self.arc_children() {
- handle_tree_recursion!(op(&child)?)
+ tnr = f(&child)?;
+ handle_visit_recursion!(tnr)
}
- Ok(VisitRecursion::Continue)
+ Ok(tnr)
}
- fn map_children<F>(self, transform: F) -> Result<Self>
+ fn map_children<F>(self, f: F) -> Result<Transformed<Self>>
where
- F: FnMut(Self) -> Result<Self>,
+ F: FnMut(Self) -> Result<Transformed<Self>>,
{
let children = self.arc_children();
if !children.is_empty() {
- let new_children =
- children.into_iter().map(transform).collect::<Result<_>>()?;
+ let t = children.into_iter().map_till_continue_and_collect(f)?;
+ // TODO: Currently `assert_eq!(t.transformed, t2.transformed)`
fails as
Review Comment:
I don't understand the end user implications of this comment -- like would
this mask a bug for example? If so, what kind of bug would it be 🤔
I am trying to understand if we should file a follow on ticket to track this
##########
datafusion/common/src/tree_node.rs:
##########
@@ -88,132 +177,138 @@ pub trait TreeNode: Sized {
///
/// If an Err result is returned, recursion is stopped immediately
///
- /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no
+ /// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no
Review Comment:
Shall we change the example above to use `f_down` and `f_up`?
##########
datafusion/expr/src/expr_rewriter/mod.rs:
##########
@@ -283,16 +289,16 @@ mod test {
}
impl TreeNodeRewriter for RecordingRewriter {
- type N = Expr;
+ type Node = Expr;
Review Comment:
I think this is a pretty good example of what type of migration users will
need to do for their rewrites
##########
datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs:
##########
@@ -147,13 +148,20 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
// simplifications can enable new constant evaluation)
// https://github.com/apache/arrow-datafusion/issues/1160
expr.rewrite(&mut const_evaluator)?
+ .data
Review Comment:
This is similarly understandable but strange -- almost all callers of
`rewrite` will want the data, not the details of the rewrite result.
I think the API would be nicer to use if we didn't have to extract the
`data` explicitly
##########
datafusion/common/src/tree_node.rs:
##########
@@ -377,32 +584,1045 @@ pub trait ConcreteTreeNode: Sized {
fn take_children(self) -> (Self, Vec<Self>);
/// Reattaches updated child nodes to the node, returning the updated node.
- fn with_new_children(self, children: Vec<Self>) -> Result<Self>;
+ fn with_new_children(self, children: Vec<Self>) ->
Result<Transformed<Self>>;
}
impl<T: ConcreteTreeNode> TreeNode for T {
/// Apply the closure `F` to the node's children
- fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
+ fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
where
- F: FnMut(&Self) -> Result<VisitRecursion>,
+ F: FnMut(&Self) -> Result<TreeNodeRecursion>,
{
+ let mut tnr = TreeNodeRecursion::Continue;
for child in self.children() {
- handle_tree_recursion!(op(child)?)
+ tnr = f(child)?;
+ handle_visit_recursion!(tnr)
}
- Ok(VisitRecursion::Continue)
+ Ok(tnr)
}
- fn map_children<F>(self, transform: F) -> Result<Self>
+ fn map_children<F>(self, f: F) -> Result<Transformed<Self>>
where
- F: FnMut(Self) -> Result<Self>,
+ F: FnMut(Self) -> Result<Transformed<Self>>,
{
let (new_self, children) = self.take_children();
if !children.is_empty() {
- let new_children =
- children.into_iter().map(transform).collect::<Result<_>>()?;
- new_self.with_new_children(new_children)
+ let t = children.into_iter().map_till_continue_and_collect(f)?;
+ // TODO: Currently `assert_eq!(t.transformed, t2.transformed)`
fails as
+ // `t.transformed` quality comes from if the transformation
closures fill the
+ // field correctly.
+ // Once we trust `t.transformed` we can remove the additional
check in
+ // `with_new_children()`.
+ let t2 = new_self.with_new_children(t.data)?;
+
+ // Propagate up `t.transformed` and `t.tnr` along with the node
containing
+ // transformed children.
+ Ok(Transformed::new(t2.data, t.transformed, t.tnr))
} else {
- Ok(new_self)
+ Ok(Transformed::no(new_self))
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::tree_node::{
+ Transformed, TransformedIterator, TreeNode, TreeNodeRecursion,
TreeNodeRewriter,
+ TreeNodeVisitor,
+ };
+ use crate::Result;
+ use std::fmt::Display;
+
+ #[derive(PartialEq, Debug)]
+ struct TestTreeNode<T> {
+ children: Vec<TestTreeNode<T>>,
+ data: T,
+ }
+
+ impl<T> TestTreeNode<T> {
+ fn new(children: Vec<TestTreeNode<T>>, data: T) -> Self {
+ Self { children, data }
}
}
+
+ impl<T> TreeNode for TestTreeNode<T> {
+ fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
+ where
+ F: FnMut(&Self) -> Result<TreeNodeRecursion>,
+ {
+ let mut tnr = TreeNodeRecursion::Continue;
+ for child in &self.children {
+ tnr = f(child)?;
+ handle_visit_recursion!(tnr);
+ }
+ Ok(tnr)
+ }
+
+ fn map_children<F>(self, f: F) -> Result<Transformed<Self>>
+ where
+ F: FnMut(Self) -> Result<Transformed<Self>>,
+ {
+ Ok(self
+ .children
+ .into_iter()
+ .map_till_continue_and_collect(f)?
+ .map_data(|new_children| Self {
+ children: new_children,
+ ..self
+ }))
+ }
+ }
+
+ // J
+ // |
+ // I
+ // |
+ // F
+ // / \
+ // E G
+ // | |
+ // C H
+ // / \
+ // B D
+ // |
+ // A
+ fn test_tree() -> TestTreeNode<String> {
+ let node_a = TestTreeNode::new(vec![], "a".to_string());
+ let node_b = TestTreeNode::new(vec![], "b".to_string());
+ let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
+ let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
+ let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
+ let node_h = TestTreeNode::new(vec![], "h".to_string());
+ let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
+ let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
+ let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
+ TestTreeNode::new(vec![node_i], "j".to_string())
+ }
+
+ // Continue on all nodes
+
+ // Expected visits in a combined traversal
Review Comment:
😍
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]