This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8190cb9721 Optimized push down filter #10291 (#10366)
8190cb9721 is described below

commit 8190cb97216e4f46faccbeddae57f6773587955f
Author: Dmitry Bugakov <[email protected]>
AuthorDate: Fri May 3 18:22:36 2024 +0200

    Optimized push down filter #10291 (#10366)
---
 datafusion/optimizer/src/push_down_filter.rs | 139 ++++++++++++++++-----------
 1 file changed, 81 insertions(+), 58 deletions(-)

diff --git a/datafusion/optimizer/src/push_down_filter.rs 
b/datafusion/optimizer/src/push_down_filter.rs
index 8462cf86f1..2355ee604e 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -17,8 +17,7 @@
 use std::collections::{HashMap, HashSet};
 use std::sync::Arc;
 
-use crate::optimizer::ApplyOrder;
-use crate::{OptimizerConfig, OptimizerRule};
+use itertools::Itertools;
 
 use datafusion_common::tree_node::{
     Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
@@ -29,6 +28,7 @@ use datafusion_common::{
 };
 use datafusion_expr::expr::Alias;
 use datafusion_expr::expr_rewriter::replace_col;
+use datafusion_expr::logical_plan::tree_node::unwrap_arc;
 use datafusion_expr::logical_plan::{
     CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union,
 };
@@ -38,7 +38,8 @@ use datafusion_expr::{
     ScalarFunctionDefinition, TableProviderFilterPushDown,
 };
 
-use itertools::Itertools;
+use crate::optimizer::ApplyOrder;
+use crate::{OptimizerConfig, OptimizerRule};
 
 /// Optimizer rule for pushing (moving) filter expressions down in a plan so
 /// they are applied as early as possible.
@@ -407,7 +408,7 @@ fn push_down_all_join(
     right: &LogicalPlan,
     on_filter: Vec<Expr>,
     is_inner_join: bool,
-) -> Result<LogicalPlan> {
+) -> Result<Transformed<LogicalPlan>> {
     let on_filter_empty = on_filter.is_empty();
     // Get pushable predicates from current optimizer state
     let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?;
@@ -505,9 +506,10 @@ fn push_down_all_join(
     // wrap the join on the filter whose predicates must be kept
     match conjunction(keep_predicates) {
         Some(predicate) => {
-            Filter::try_new(predicate, Arc::new(plan)).map(LogicalPlan::Filter)
+            let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?;
+            Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan)))
         }
-        None => Ok(plan),
+        None => Ok(Transformed::no(plan)),
     }
 }
 
@@ -515,31 +517,32 @@ fn push_down_join(
     plan: &LogicalPlan,
     join: &Join,
     parent_predicate: Option<&Expr>,
-) -> Result<Option<LogicalPlan>> {
-    let predicates = match parent_predicate {
-        Some(parent_predicate) => 
split_conjunction_owned(parent_predicate.clone()),
-        None => vec![],
-    };
+) -> Result<Transformed<LogicalPlan>> {
+    // Split the parent predicate into individual conjunctive parts.
+    let predicates = parent_predicate
+        .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
 
-    // Convert JOIN ON predicate to Predicates
+    // Extract conjunctions from the JOIN's ON filter, if present.
     let on_filters = join
         .filter
         .as_ref()
-        .map(|e| split_conjunction_owned(e.clone()))
-        .unwrap_or_default();
+        .map_or_else(Vec::new, |filter| 
split_conjunction_owned(filter.clone()));
 
     let mut is_inner_join = false;
     let infer_predicates = if join.join_type == JoinType::Inner {
         is_inner_join = true;
+
         // Only allow both side key is column.
         let join_col_keys = join
             .on
             .iter()
-            .flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) {
-                (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)),
-                _ => None,
+            .filter_map(|(l, r)| {
+                let left_col = l.try_into_col().ok()?;
+                let right_col = r.try_into_col().ok()?;
+                Some((left_col, right_col))
             })
             .collect::<Vec<_>>();
+
         // TODO refine the logic, introduce EquivalenceProperties to logical 
plan and infer additional filters to push down
         // For inner joins, duplicate filters for joined columns so filters 
can be pushed down
         // to both sides. Take the following query as an example:
@@ -559,6 +562,7 @@ fn push_down_join(
             .chain(on_filters.iter())
             .filter_map(|predicate| {
                 let mut join_cols_to_replace = HashMap::new();
+
                 let columns = match predicate.to_columns() {
                     Ok(columns) => columns,
                     Err(e) => return Some(Err(e)),
@@ -596,9 +600,10 @@ fn push_down_join(
     };
 
     if on_filters.is_empty() && predicates.is_empty() && 
infer_predicates.is_empty() {
-        return Ok(None);
+        return Ok(Transformed::no(plan.clone()));
     }
-    Ok(Some(push_down_all_join(
+
+    match push_down_all_join(
         predicates,
         infer_predicates,
         plan,
@@ -606,10 +611,21 @@ fn push_down_join(
         &join.right,
         on_filters,
         is_inner_join,
-    )?))
+    ) {
+        Ok(plan) => Ok(Transformed::yes(plan.data)),
+        Err(e) => Err(e),
+    }
 }
 
 impl OptimizerRule for PushDownFilter {
+    fn try_optimize(
+        &self,
+        _plan: &LogicalPlan,
+        _config: &dyn OptimizerConfig,
+    ) -> Result<Option<LogicalPlan>> {
+        internal_err!("Should have called PushDownFilter::rewrite")
+    }
+
     fn name(&self) -> &str {
         "push_down_filter"
     }
@@ -618,21 +634,24 @@ impl OptimizerRule for PushDownFilter {
         Some(ApplyOrder::TopDown)
     }
 
-    fn try_optimize(
+    fn supports_rewrite(&self) -> bool {
+        true
+    }
+
+    fn rewrite(
         &self,
-        plan: &LogicalPlan,
+        plan: LogicalPlan,
         _config: &dyn OptimizerConfig,
-    ) -> Result<Option<LogicalPlan>> {
+    ) -> Result<Transformed<LogicalPlan>> {
         let filter = match plan {
-            LogicalPlan::Filter(filter) => filter,
-            // we also need to pushdown filter in Join.
-            LogicalPlan::Join(join) => return push_down_join(plan, join, None),
-            _ => return Ok(None),
+            LogicalPlan::Filter(ref filter) => filter,
+            LogicalPlan::Join(ref join) => return push_down_join(&plan, join, 
None),
+            _ => return Ok(Transformed::no(plan)),
         };
 
         let child_plan = filter.input.as_ref();
         let new_plan = match child_plan {
-            LogicalPlan::Filter(child_filter) => {
+            LogicalPlan::Filter(ref child_filter) => {
                 let parents_predicates = split_conjunction(&filter.predicate);
                 let set: HashSet<&&Expr> = parents_predicates.iter().collect();
 
@@ -652,20 +671,18 @@ impl OptimizerRule for PushDownFilter {
                     new_predicate,
                     child_filter.input.clone(),
                 )?);
-                self.try_optimize(&new_filter, _config)?
-                    .unwrap_or(new_filter)
+                self.rewrite(new_filter, _config)?.data
             }
             LogicalPlan::Repartition(_)
             | LogicalPlan::Distinct(_)
             | LogicalPlan::Sort(_) => {
-                // commutable
                 let new_filter = plan.with_new_exprs(
                     plan.expressions(),
                     vec![child_plan.inputs()[0].clone()],
                 )?;
                 child_plan.with_new_exprs(child_plan.expressions(), 
vec![new_filter])?
             }
-            LogicalPlan::SubqueryAlias(subquery_alias) => {
+            LogicalPlan::SubqueryAlias(ref subquery_alias) => {
                 let mut replace_map = HashMap::new();
                 for (i, (qualifier, field)) in
                     subquery_alias.input.schema().iter().enumerate()
@@ -685,7 +702,7 @@ impl OptimizerRule for PushDownFilter {
                 )?);
                 child_plan.with_new_exprs(child_plan.expressions(), 
vec![new_filter])?
             }
-            LogicalPlan::Projection(projection) => {
+            LogicalPlan::Projection(ref projection) => {
                 // A projection is filter-commutable if it do not contain 
volatile predicates or contain volatile
                 // predicates that are not used in the filter. However, we 
should re-writes all predicate expressions.
                 // collect projection.
@@ -742,10 +759,10 @@ impl OptimizerRule for PushDownFilter {
                             }
                         }
                     }
-                    None => return Ok(None),
+                    None => return Ok(Transformed::no(plan)),
                 }
             }
-            LogicalPlan::Union(union) => {
+            LogicalPlan::Union(ref union) => {
                 let mut inputs = Vec::with_capacity(union.inputs.len());
                 for input in &union.inputs {
                     let mut replace_map = HashMap::new();
@@ -770,7 +787,7 @@ impl OptimizerRule for PushDownFilter {
                     schema: plan.schema().clone(),
                 })
             }
-            LogicalPlan::Aggregate(agg) => {
+            LogicalPlan::Aggregate(ref agg) => {
                 // We can push down Predicate which in groupby_expr.
                 let group_expr_columns = agg
                     .group_expr
@@ -821,13 +838,15 @@ impl OptimizerRule for PushDownFilter {
                     None => new_agg,
                 }
             }
-            LogicalPlan::Join(join) => {
-                match push_down_join(&filter.input, join, 
Some(&filter.predicate))? {
-                    Some(optimized_plan) => optimized_plan,
-                    None => return Ok(None),
-                }
+            LogicalPlan::Join(ref join) => {
+                push_down_join(
+                    &unwrap_arc(filter.clone().input),
+                    join,
+                    Some(&filter.predicate),
+                )?
+                .data
             }
-            LogicalPlan::CrossJoin(cross_join) => {
+            LogicalPlan::CrossJoin(ref cross_join) => {
                 let predicates = 
split_conjunction_owned(filter.predicate.clone());
                 let join = 
convert_cross_join_to_inner_join(cross_join.clone())?;
                 let join_plan = LogicalPlan::Join(join);
@@ -843,9 +862,9 @@ impl OptimizerRule for PushDownFilter {
                     vec![],
                     true,
                 )?;
-                convert_to_cross_join_if_beneficial(plan)?
+                convert_to_cross_join_if_beneficial(plan.data)?
             }
-            LogicalPlan::TableScan(scan) => {
+            LogicalPlan::TableScan(ref scan) => {
                 let filter_predicates = split_conjunction(&filter.predicate);
                 let results = scan
                     .source
@@ -892,7 +911,7 @@ impl OptimizerRule for PushDownFilter {
                     None => new_scan,
                 }
             }
-            LogicalPlan::Extension(extension_plan) => {
+            LogicalPlan::Extension(ref extension_plan) => {
                 let prevent_cols =
                     extension_plan.node.prevent_predicate_push_down_columns();
 
@@ -935,9 +954,10 @@ impl OptimizerRule for PushDownFilter {
                     None => new_extension,
                 }
             }
-            _ => return Ok(None),
+            _ => return Ok(Transformed::no(plan)),
         };
-        Ok(Some(new_plan))
+
+        Ok(Transformed::yes(new_plan))
     }
 }
 
@@ -1024,16 +1044,12 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>) 
-> bool {
 
 #[cfg(test)]
 mod tests {
-    use super::*;
     use std::any::Any;
     use std::fmt::{Debug, Formatter};
 
-    use crate::optimizer::Optimizer;
-    use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
-    use crate::test::*;
-    use crate::OptimizerContext;
-
     use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+    use async_trait::async_trait;
+
     use datafusion_common::ScalarValue;
     use datafusion_expr::expr::ScalarFunction;
     use datafusion_expr::logical_plan::table_scan;
@@ -1043,7 +1059,13 @@ mod tests {
         Volatility,
     };
 
-    use async_trait::async_trait;
+    use crate::optimizer::Optimizer;
+    use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
+    use crate::test::*;
+    use crate::OptimizerContext;
+
+    use super::*;
+
     fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
 
     fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> 
Result<()> {
@@ -2298,9 +2320,9 @@ mod tests {
             
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
 
         let optimized_plan = PushDownFilter::new()
-            .try_optimize(&plan, &OptimizerContext::new())
+            .rewrite(plan, &OptimizerContext::new())
             .expect("failed to optimize plan")
-            .unwrap();
+            .data;
 
         let expected = "\
         Filter: a = Int64(1)\
@@ -2667,8 +2689,9 @@ Projection: a, b
         // Originally global state which can help to avoid duplicate Filters 
been generated and pushed down.
         // Now the global state is removed. Need to double confirm that avoid 
duplicate Filters.
         let optimized_plan = PushDownFilter::new()
-            .try_optimize(&plan, &OptimizerContext::new())?
-            .expect("failed to optimize plan");
+            .rewrite(plan, &OptimizerContext::new())
+            .expect("failed to optimize plan")
+            .data;
         assert_optimized_plan_eq(optimized_plan, expected)
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to