This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new f386f7a73 Simplify expressions with `NOT` clause (#2958)
f386f7a73 is described below

commit f386f7a7344d54455fe04d92248e373fac990e6d
Author: AssHero <[email protected]>
AuthorDate: Mon Jul 25 20:41:03 2022 +0800

    Simplify expressions with `NOT` clause (#2958)
    
    * simplify not clause
    
    * refine the code
---
 datafusion/core/tests/sql/joins.rs               |   2 +-
 datafusion/optimizer/src/simplify_expressions.rs | 390 +++++++++++++++++++++--
 2 files changed, 373 insertions(+), 19 deletions(-)

diff --git a/datafusion/core/tests/sql/joins.rs 
b/datafusion/core/tests/sql/joins.rs
index d7d7a0cc6..525a7c9e0 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1567,7 +1567,7 @@ async fn reduce_right_join_2() -> Result<()> {
     let expected = vec![
         "Explain [plan_type:Utf8, plan:Utf8]",
         "  Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id, 
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
-        "    Filter: NOT #t1.t1_int = #t2.t2_int [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
+        "    Filter: #t1.t1_int != #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, 
t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
         "      Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
         "        TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
         "        TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
diff --git a/datafusion/optimizer/src/simplify_expressions.rs 
b/datafusion/optimizer/src/simplify_expressions.rs
index 14b881cc0..0b865238f 100644
--- a/datafusion/optimizer/src/simplify_expressions.rs
+++ b/datafusion/optimizer/src/simplify_expressions.rs
@@ -24,6 +24,7 @@ use arrow::datatypes::{DataType, Field, Schema};
 use arrow::record_batch::RecordBatch;
 use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, 
ScalarValue};
 use datafusion_expr::{
+    expr_fn::{and, or},
     expr_rewriter::RewriteRecursion,
     expr_rewriter::{ExprRewritable, ExprRewriter},
     lit,
@@ -142,11 +143,6 @@ fn lit_bool_null() -> Expr {
     Expr::Literal(ScalarValue::Boolean(None))
 }
 
-/// returns true if expr is a `Not(_)`, false otherwise
-fn is_not(expr: &Expr) -> bool {
-    matches!(expr, Expr::Not(_))
-}
-
 fn is_null(expr: &Expr) -> bool {
     match expr {
         Expr::Literal(v) => v.is_null(),
@@ -185,6 +181,133 @@ fn as_bool_lit(expr: Expr) -> Option<bool> {
     }
 }
 
+/// negate a Not clause
+/// input is the clause to be negated.(args of Not clause)
+/// For BinaryExpr, use the negator of op instead.
+///    not ( A > B) ===> (A <= B)
+/// For BoolExpr, not (A and B) ===> (not A) or (not B)
+///     not (A or B) ===> (not A) and (not B)
+///     not (not A) ===> A
+/// For NullExpr, not (A is not null) ===> A is null
+///     not (A is null) ===> A is not null
+/// For InList, not (A not in (..)) ===> A in (..)
+///     not (A in (..)) ===> A not in (..)
+/// For Between, not (A between B and C) ===> (A not between B and C)
+///     not (A not between B and C) ===> (A between B and C)
+/// For others, use Not clause
+fn negate_clause(expr: Expr) -> Expr {
+    match expr {
+        Expr::BinaryExpr { left, op, right } => {
+            match op {
+                // not (A = B) ===> (A <> B)
+                Operator::Eq => Expr::BinaryExpr {
+                    left,
+                    op: Operator::NotEq,
+                    right,
+                },
+                // not (A <> B) ===> (A = B)
+                Operator::NotEq => Expr::BinaryExpr {
+                    left,
+                    op: Operator::Eq,
+                    right,
+                },
+                // not (A < B) ===> (A >= B)
+                Operator::Lt => Expr::BinaryExpr {
+                    left,
+                    op: Operator::GtEq,
+                    right,
+                },
+                // not (A <= B) ===> (A > B)
+                Operator::LtEq => Expr::BinaryExpr {
+                    left,
+                    op: Operator::Gt,
+                    right,
+                },
+                // not (A > B) ===> (A <= B)
+                Operator::Gt => Expr::BinaryExpr {
+                    left,
+                    op: Operator::LtEq,
+                    right,
+                },
+                // not (A >= B) ===> (A < B)
+                Operator::GtEq => Expr::BinaryExpr {
+                    left,
+                    op: Operator::Lt,
+                    right,
+                },
+                // not (A like 'B') ===> (A not like 'B')
+                Operator::Like => Expr::BinaryExpr {
+                    left,
+                    op: Operator::NotLike,
+                    right,
+                },
+                // not (A not like 'B') ===> (A like 'B')
+                Operator::NotLike => Expr::BinaryExpr {
+                    left,
+                    op: Operator::Like,
+                    right,
+                },
+                // not (A is distinct from B) ===> (A is not distinct from B)
+                Operator::IsDistinctFrom => Expr::BinaryExpr {
+                    left,
+                    op: Operator::IsNotDistinctFrom,
+                    right,
+                },
+                // not (A is not distinct from B) ===> (A is distinct from B)
+                Operator::IsNotDistinctFrom => Expr::BinaryExpr {
+                    left,
+                    op: Operator::IsDistinctFrom,
+                    right,
+                },
+                // not (A and B) ===> (not A) or (not B)
+                Operator::And => {
+                    let left = negate_clause(*left);
+                    let right = negate_clause(*right);
+
+                    or(left, right)
+                }
+                // not (A or B) ===> (not A) and (not B)
+                Operator::Or => {
+                    let left = negate_clause(*left);
+                    let right = negate_clause(*right);
+
+                    and(left, right)
+                }
+                // use not clause
+                _ => Expr::Not(Box::new(Expr::BinaryExpr { left, op, right })),
+            }
+        }
+        // not (not A) ===> A
+        Expr::Not(expr) => *expr,
+        // not (A is not null) ===> A is null
+        Expr::IsNotNull(expr) => expr.is_null(),
+        // not (A is null) ===> A is not null
+        Expr::IsNull(expr) => expr.is_not_null(),
+        // not (A not in (..)) ===> A in (..)
+        // not (A in (..)) ===> A not in (..)
+        Expr::InList {
+            expr,
+            list,
+            negated,
+        } => expr.in_list(list, !negated),
+        // not (A between B and C) ===> (A not between B and C)
+        // not (A not between B and C) ===> (A between B and C)
+        Expr::Between {
+            expr,
+            negated,
+            low,
+            high,
+        } => Expr::Between {
+            expr,
+            negated: !negated,
+            low,
+            high,
+        },
+        // use not clause
+        _ => Expr::Not(Box::new(expr)),
+    }
+}
+
 impl OptimizerRule for SimplifyExpressions {
     fn name(&self) -> &str {
         "simplify_expressions"
@@ -680,12 +803,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, 
S> {
             //
             // Rules for Not
             //
-
-            // !(!A) --> A
-            Not(inner) if is_not(&inner) => match *inner {
-                Not(negated_inner) => *negated_inner,
-                _ => unreachable!(),
-            },
+            Not(inner) => negate_clause(*inner),
 
             //
             // Rules for Case
@@ -899,13 +1017,13 @@ mod tests {
 
     #[test]
     fn test_simplify_negated_and() {
-        // (c > 5) AND !(c > 5) -- can't remove
+        // (c > 5) AND !(c > 5) -- > (c > 5) AND (c <= 5)
         let expr = binary_expr(
             col("c2").gt(lit(5)),
             Operator::And,
             Expr::not(col("c2").gt(lit(5))),
         );
-        let expected = expr.clone();
+        let expected = col("c2").gt(lit(5)).and(col("c2").lt_eq(lit(5)));
 
         assert_eq!(simplify(expr), expected);
     }
@@ -1415,13 +1533,13 @@ mod tests {
             })),
             col("c2")
                 .is_null()
-                .or(col("c2").is_null().not().and(col("c2")))
+                .or(col("c2").is_not_null().and(col("c2")))
         );
 
         // CASE WHERE c1 then true WHERE c2 then false ELSE true
         // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE)
-        // --> c1 OR (NOT(c1 OR c2))
-        // --> NOT(c1) AND c2
+        // --> c1 OR (NOT(c1) AND NOT(c2))
+        // --> c1 OR NOT(c2)
         //
         // Need to call simplify 2x due to
         // https://github.com/apache/arrow-datafusion/issues/1160
@@ -1434,7 +1552,7 @@ mod tests {
                 ],
                 else_expr: Some(Box::new(lit(true))),
             })),
-            col("c1").or(col("c1").or(col("c2")).not())
+            col("c1").or(col("c1").not().and(col("c2").not()))
         );
 
         // CASE WHERE c1 then true WHERE c2 then true ELSE false
@@ -1453,7 +1571,7 @@ mod tests {
                 ],
                 else_expr: Some(Box::new(lit(true))),
             })),
-            col("c1").or(col("c1").or(col("c2")).not())
+            col("c1").or(col("c1").not().and(col("c2").not()))
         );
     }
 
@@ -1964,4 +2082,240 @@ mod tests {
 
         assert_eq!(actual, expected);
     }
+
+    #[test]
+    fn simplify_not_binary() {
+        let table_scan = test_table_scan();
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(col("d").gt(lit(10)).not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.d <= Int32(10) AS NOT test.d > Int32(10)\
+            \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_bool_and() {
+        let table_scan = test_table_scan();
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(col("d").gt(lit(10)).and(col("d").lt(lit(100))).not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.d <= Int32(10) OR #test.d >= Int32(100) 
AS NOT test.d > Int32(10) AND test.d < Int32(100)\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_bool_or() {
+        let table_scan = test_table_scan();
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(col("d").gt(lit(10)).or(col("d").lt(lit(100))).not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.d <= Int32(10) AND #test.d >= Int32(100) 
AS NOT test.d > Int32(10) OR test.d < Int32(100)\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_not() {
+        let table_scan = test_table_scan();
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(col("d").gt(lit(10)).not().not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.d > Int32(10) AS NOT NOT test.d > 
Int32(10)\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_null() {
+        let table_scan = test_table_scan();
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(col("d").is_null().not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.d IS NOT NULL AS NOT test.d IS NULL\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_not_null() {
+        let table_scan = test_table_scan();
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(col("d").is_not_null().not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.d IS NULL AS NOT test.d IS NOT NULL\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_in() {
+        let table_scan = test_table_scan();
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], 
false).not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.d NOT IN ([Int32(1), Int32(2), 
Int32(3)]) AS NOT test.d IN (Map { iter: Iter([Int32(1), Int32(2), Int32(3)]) 
})\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_not_in() {
+        let table_scan = test_table_scan();
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], true).not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.d IN ([Int32(1), Int32(2), Int32(3)]) AS 
NOT test.d NOT IN (Map { iter: Iter([Int32(1), Int32(2), Int32(3)]) })\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_between() {
+        let table_scan = test_table_scan();
+        let qual = Expr::Between {
+            expr: Box::new(col("d")),
+            negated: false,
+            low: Box::new(lit(1)),
+            high: Box::new(lit(10)),
+        };
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(qual.not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.d NOT BETWEEN Int32(1) AND Int32(10) AS 
NOT test.d BETWEEN Int32(1) AND Int32(10)\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_not_between() {
+        let table_scan = test_table_scan();
+        let qual = Expr::Between {
+            expr: Box::new(col("d")),
+            negated: true,
+            low: Box::new(lit(1)),
+            high: Box::new(lit(10)),
+        };
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(qual.not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.d BETWEEN Int32(1) AND Int32(10) AS NOT 
test.d NOT BETWEEN Int32(1) AND Int32(10)\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_like() {
+        let schema = Schema::new(vec![
+            Field::new("a", DataType::Utf8, false),
+            Field::new("b", DataType::Utf8, false),
+        ]);
+        let table_scan = table_scan(Some("test"), &schema, None)
+            .expect("creating scan")
+            .build()
+            .expect("building plan");
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(col("a").like(col("b")).not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.a NOT LIKE #test.b AS NOT test.a LIKE 
test.b\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_not_like() {
+        let schema = Schema::new(vec![
+            Field::new("a", DataType::Utf8, false),
+            Field::new("b", DataType::Utf8, false),
+        ]);
+        let table_scan = table_scan(Some("test"), &schema, None)
+            .expect("creating scan")
+            .build()
+            .expect("building plan");
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(col("a").not_like(col("b")).not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.a LIKE #test.b AS NOT test.a NOT LIKE 
test.b\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_distinct_from() {
+        let table_scan = test_table_scan();
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(binary_expr(col("d"), Operator::IsDistinctFrom, 
lit(10)).not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.d IS NOT DISTINCT FROM Int32(10) AS NOT 
test.d IS DISTINCT FROM Int32(10)\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
+
+    #[test]
+    fn simplify_not_not_distinct_from() {
+        let table_scan = test_table_scan();
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(binary_expr(col("d"), Operator::IsNotDistinctFrom, 
lit(10)).not())
+            .unwrap()
+            .build()
+            .unwrap();
+        let expected = "Filter: #test.d IS DISTINCT FROM Int32(10) AS NOT 
test.d IS NOT DISTINCT FROM Int32(10)\
+        \n  TableScan: test";
+
+        assert_optimized_plan_eq(&plan, expected);
+    }
 }

Reply via email to