alamb commented on code in PR #4447:
URL: https://github.com/apache/arrow-datafusion/pull/4447#discussion_r1038761919


##########
datafusion/optimizer/src/push_down_filter.rs:
##########
@@ -881,40 +880,30 @@ mod tests {
     }
 
     #[test]
-    fn filter_keep_agg() -> Result<()> {
-        let table_scan = test_table_scan()?;
-        let plan = LogicalPlanBuilder::from(table_scan)
-            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
-            .filter(col("b").gt(lit(10i64)))?
+    fn push_agg_need_replace_expr() -> Result<()> {
+        let plan = LogicalPlanBuilder::from(test_table_scan()?)
+            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), 
col("b")])?
+            .filter(col("test.b + test.a").gt(lit(10i64)))?
             .build()?;
-        // filter of aggregate is after aggregation since they are 
non-commutative
-        let expected = "\
-            Filter: b > Int64(10)\
-            \n  Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
-            \n    TableScan: test";
+        let expected =
+            "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), 
test.b]]\
+        \n  Filter: test.b + test.a > Int64(10)\
+        \n    TableScan: test";
         assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
-    fn filter_keep_partial_agg() -> Result<()> {
+    fn filter_keep_agg() -> Result<()> {
         let table_scan = test_table_scan()?;
-        let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64)));
-        let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64)));
-        let filter = f1.or(f2);
         let plan = LogicalPlanBuilder::from(table_scan)
             .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
-            .filter(filter)?
+            .filter(col("b").gt(lit(10i64)))?
             .build()?;
         // filter of aggregate is after aggregation since they are 
non-commutative
-        // (c =1 AND b > 2) OR (c = 1 AND b > 3)
-        // rewrite to CNF
-        // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR 
C = 1) AND (b > 2 OR b > 3)
-
         let expected = "\
-        Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR 
test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\
-        \n  Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
-        \n    Filter: test.c = Int64(1) OR test.c = Int64(1)\
-        \n      TableScan: test";
+            Filter: b > Int64(10)\

Review Comment:
   I agree this new plan is correct 



##########
datafusion/optimizer/src/push_down_filter.rs:
##########
@@ -641,20 +633,27 @@ impl OptimizerRule for PushDownFilter {
                 let mut keep_predicates = vec![];
                 let mut push_predicates = vec![];
                 for expr in predicates {
-                    let columns = expr.to_columns()?;
-                    if columns.is_empty()
-                        || !columns
-                            .intersection(&used_columns)
-                            .collect::<HashSet<_>>()
-                            .is_empty()
-                    {
-                        keep_predicates.push(expr);
-                    } else {
+                    let cols = expr.to_columns()?;
+                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
                         push_predicates.push(expr);
+                    } else {
+                        keep_predicates.push(expr);
                     }
                 }
 
-                let child = match conjunction(push_predicates) {
+                // As for plan Filter: Column(a+b) > 0 -- Agg: 
groupby:[Column(a)+Column(b)]

Review Comment:
   Nice -- this is getting quite sophisticated. 



##########
datafusion/optimizer/src/push_down_filter.rs:
##########
@@ -881,40 +880,30 @@ mod tests {
     }
 
     #[test]
-    fn filter_keep_agg() -> Result<()> {
-        let table_scan = test_table_scan()?;
-        let plan = LogicalPlanBuilder::from(table_scan)
-            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
-            .filter(col("b").gt(lit(10i64)))?
+    fn push_agg_need_replace_expr() -> Result<()> {
+        let plan = LogicalPlanBuilder::from(test_table_scan()?)
+            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), 
col("b")])?
+            .filter(col("test.b + test.a").gt(lit(10i64)))?
             .build()?;
-        // filter of aggregate is after aggregation since they are 
non-commutative
-        let expected = "\
-            Filter: b > Int64(10)\
-            \n  Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
-            \n    TableScan: test";
+        let expected =
+            "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), 
test.b]]\
+        \n  Filter: test.b + test.a > Int64(10)\

Review Comment:
   👍  very nice



##########
datafusion/optimizer/src/push_down_filter.rs:
##########
@@ -641,20 +633,27 @@ impl OptimizerRule for PushDownFilter {
                 let mut keep_predicates = vec![];
                 let mut push_predicates = vec![];
                 for expr in predicates {
-                    let columns = expr.to_columns()?;
-                    if columns.is_empty()
-                        || !columns
-                            .intersection(&used_columns)
-                            .collect::<HashSet<_>>()
-                            .is_empty()
-                    {
-                        keep_predicates.push(expr);
-                    } else {
+                    let cols = expr.to_columns()?;
+                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
                         push_predicates.push(expr);
+                    } else {
+                        keep_predicates.push(expr);
                     }
                 }
 
-                let child = match conjunction(push_predicates) {
+                // As for plan Filter: Column(a+b) > 0 -- Agg: 
groupby:[Column(a)+Column(b)]
+                // After push, we need to replace `a+b` with 
Column(a)+Column(b)
+                // So we need create a replace_map, add {`a+b` --> 
Expr(Column(a)+Column(b))}
+                let mut replace_map = HashMap::new();
+                for expr in &agg.group_expr {
+                    replace_map.insert(expr.display_name()?, expr.clone());

Review Comment:
   Double checked that `display_name` is the right one: 
https://docs.rs/datafusion/14.0.0/datafusion/prelude/enum.Expr.html#method.display_name
 👍 



##########
datafusion/optimizer/tests/integration-test.rs:
##########
@@ -304,6 +304,18 @@ fn join_keys_in_subquery_alias_1() {
     assert_eq!(expected, format!("{:?}", plan));
 }
 
+#[test]
+fn push_down_filter_groupby_expr_contains_alias() {
+    let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) 
FROM test GROUP BY 1) where c > 3";
+    let plan = test_sql(sql).unwrap();
+    let expected = "Projection: c, COUNT(UInt8(1))\
+    \n  Projection: test.col_int32 + test.col_uint32 AS c, COUNT(UInt8(1))\
+    \n    Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS 
Int32)]], aggr=[[COUNT(UInt8(1))]]\
+    \n      Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\

Review Comment:
   👍 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to