This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 783c45b892 Add case expr simplifiers for literal comparisons (#17743)
783c45b892 is described below

commit 783c45b8925ac72c7264f02fadbbd5a9d1490c06
Author: Jack Kleeman <[email protected]>
AuthorDate: Fri Sep 26 19:10:33 2025 +0100

    Add case expr simplifiers for literal comparisons (#17743)
    
    * Add case expr simplifiers for literal comparisons
    
    * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Avoid expr clones
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../src/simplify_expressions/expr_simplifier.rs    | 226 ++++++++++++++++++++-
 .../optimizer/src/simplify_expressions/utils.rs    |  27 ++-
 2 files changed, 251 insertions(+), 2 deletions(-)

diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 2d10c7fe22..b491a3529f 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -1399,6 +1399,41 @@ impl<S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'_, S> {
             // Rules for Case
             //
 
+            // Inline a comparison to a literal with the case statement into 
the `THEN` clauses.
+            // which can enable further simplifications
+            // CASE WHEN X THEN "a" WHEN Y THEN "b" ... END = "a" --> CASE 
WHEN X THEN "a" = "a" WHEN Y THEN "b" = "a" END
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: op @ (Eq | NotEq),
+                right,
+            }) if is_case_with_literal_outputs(&left) && is_lit(&right) => {
+                let case = into_case(*left)?;
+                Transformed::yes(Expr::Case(Case {
+                    expr: None,
+                    when_then_expr: case
+                        .when_then_expr
+                        .into_iter()
+                        .map(|(when, then)| {
+                            (
+                                when,
+                                Box::new(Expr::BinaryExpr(BinaryExpr {
+                                    left: then,
+                                    op,
+                                    right: right.clone(),
+                                })),
+                            )
+                        })
+                        .collect(),
+                    else_expr: case.else_expr.map(|els| {
+                        Box::new(Expr::BinaryExpr(BinaryExpr {
+                            left: els,
+                            op,
+                            right,
+                        }))
+                    }),
+                }))
+            }
+
             // CASE WHEN true THEN A ... END --> A
             // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X 
THEN A ELSE B END
             Expr::Case(Case {
@@ -1447,7 +1482,11 @@ impl<S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'_, S> {
                 when_then_expr,
                 else_expr,
             }) if !when_then_expr.is_empty()
-                && when_then_expr.len() < 3 // The rewrite is O(n²) so limit 
to small number
+                // The rewrite is O(n²) in general so limit to small number of 
when-thens that can be true
+                && (when_then_expr.len() < 3 // small number of input whens
+                    // or all thens are literal bools and a small number of 
them are true
+                    || (when_then_expr.iter().all(|(_, then)| 
is_bool_lit(then))
+                        && when_then_expr.iter().filter(|(_, then)| 
is_true(then)).count() < 3))
                 && info.is_boolean_type(&when_then_expr[0].1)? =>
             {
                 // String disjunction of all the when predicates encountered 
so far. Not nullable.
@@ -1471,6 +1510,55 @@ impl<S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'_, S> {
                 // Do a first pass at simplification
                 out_expr.rewrite(self)?
             }
+            // CASE
+            //   WHEN X THEN true
+            //   WHEN Y THEN true
+            //   WHEN Z THEN false
+            //   ...
+            //   ELSE true
+            // END
+            //
+            // --->
+            //
+            // NOT(CASE
+            //   WHEN X THEN false
+            //   WHEN Y THEN false
+            //   WHEN Z THEN true
+            //   ...
+            //   ELSE false
+            // END)
+            //
+            // Note: the rationale for this rewrite is that the case can then 
be further
+            // simplified into a small number of ANDs and ORs
+            Expr::Case(Case {
+                expr: None,
+                when_then_expr,
+                else_expr,
+            }) if !when_then_expr.is_empty()
+                && when_then_expr
+                    .iter()
+                    .all(|(_, then)| is_bool_lit(then)) // all thens are 
literal bools
+                // This simplification is only helpful if we end up with a 
small number of true thens
+                && when_then_expr
+                    .iter()
+                    .filter(|(_, then)| is_false(then))
+                    .count()
+                    < 3
+                && else_expr.as_deref().is_none_or(is_bool_lit) =>
+            {
+                Transformed::yes(
+                    Expr::Case(Case {
+                        expr: None,
+                        when_then_expr: when_then_expr
+                            .into_iter()
+                            .map(|(when, then)| (when, 
Box::new(Expr::Not(then))))
+                            .collect(),
+                        else_expr: else_expr
+                            .map(|else_expr| Box::new(Expr::Not(else_expr))),
+                    })
+                    .not(),
+                )
+            }
             Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
                 match udf.simplify(args, info)? {
                     ExprSimplifyResult::Original(args) => {
@@ -3465,6 +3553,142 @@ mod tests {
         );
     }
 
+    #[test]
+    fn simplify_literal_case_equality() {
+        // CASE WHEN c2 != false THEN "ok" ELSE "not_ok"
+        let simple_case = Expr::Case(Case::new(
+            None,
+            vec![(
+                Box::new(col("c2_non_null").not_eq(lit(false))),
+                Box::new(lit("ok")),
+            )],
+            Some(Box::new(lit("not_ok"))),
+        ));
+
+        // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" == "ok"
+        // -->
+        // CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok"
+        // -->
+        // CASE WHEN c2 != false THEN true ELSE false
+        // -->
+        // c2
+        assert_eq!(
+            simplify(binary_expr(simple_case.clone(), Operator::Eq, 
lit("ok"),)),
+            col("c2_non_null"),
+        );
+
+        // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" != "ok"
+        // -->
+        // NOT(CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok")
+        // -->
+        // NOT(CASE WHEN c2 != false THEN true ELSE false)
+        // -->
+        // NOT(c2)
+        assert_eq!(
+            simplify(binary_expr(simple_case, Operator::NotEq, lit("ok"),)),
+            not(col("c2_non_null")),
+        );
+
+        let complex_case = Expr::Case(Case::new(
+            None,
+            vec![
+                (
+                    Box::new(col("c1").eq(lit("inboxed"))),
+                    Box::new(lit("pending")),
+                ),
+                (
+                    Box::new(col("c1").eq(lit("scheduled"))),
+                    Box::new(lit("pending")),
+                ),
+                (
+                    Box::new(col("c1").eq(lit("completed"))),
+                    Box::new(lit("completed")),
+                ),
+                (
+                    Box::new(col("c1").eq(lit("paused"))),
+                    Box::new(lit("paused")),
+                ),
+                (Box::new(col("c2")), Box::new(lit("running"))),
+                (
+                    
Box::new(col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0)))),
+                    Box::new(lit("backing-off")),
+                ),
+            ],
+            Some(Box::new(lit("ready"))),
+        ));
+
+        assert_eq!(
+            simplify(binary_expr(
+                complex_case.clone(),
+                Operator::Eq,
+                lit("completed"),
+            )),
+            not_distinct_from(col("c1").eq(lit("completed")), lit(true)).and(
+                distinct_from(col("c1").eq(lit("inboxed")), lit(true))
+                    .and(distinct_from(col("c1").eq(lit("scheduled")), 
lit(true)))
+            )
+        );
+
+        assert_eq!(
+            simplify(binary_expr(
+                complex_case.clone(),
+                Operator::NotEq,
+                lit("completed"),
+            )),
+            distinct_from(col("c1").eq(lit("completed")), lit(true))
+                .or(not_distinct_from(col("c1").eq(lit("inboxed")), lit(true))
+                    .or(not_distinct_from(col("c1").eq(lit("scheduled")), 
lit(true))))
+        );
+
+        assert_eq!(
+            simplify(binary_expr(
+                complex_case.clone(),
+                Operator::Eq,
+                lit("running"),
+            )),
+            not_distinct_from(col("c2"), lit(true)).and(
+                distinct_from(col("c1").eq(lit("inboxed")), lit(true))
+                    .and(distinct_from(col("c1").eq(lit("scheduled")), 
lit(true)))
+                    .and(distinct_from(col("c1").eq(lit("completed")), 
lit(true)))
+                    .and(distinct_from(col("c1").eq(lit("paused")), lit(true)))
+            )
+        );
+
+        assert_eq!(
+            simplify(binary_expr(
+                complex_case.clone(),
+                Operator::Eq,
+                lit("ready"),
+            )),
+            distinct_from(col("c1").eq(lit("inboxed")), lit(true))
+                .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true)))
+                .and(distinct_from(col("c1").eq(lit("completed")), lit(true)))
+                .and(distinct_from(col("c1").eq(lit("paused")), lit(true)))
+                .and(distinct_from(col("c2"), lit(true)))
+                .and(distinct_from(
+                    col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))),
+                    lit(true)
+                ))
+        );
+
+        assert_eq!(
+            simplify(binary_expr(
+                complex_case.clone(),
+                Operator::NotEq,
+                lit("ready"),
+            )),
+            not_distinct_from(col("c1").eq(lit("inboxed")), lit(true))
+                .or(not_distinct_from(col("c1").eq(lit("scheduled")), 
lit(true)))
+                .or(not_distinct_from(col("c1").eq(lit("completed")), 
lit(true)))
+                .or(not_distinct_from(col("c1").eq(lit("paused")), lit(true)))
+                .or(not_distinct_from(col("c2"), lit(true)))
+                .or(not_distinct_from(
+                    col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))),
+                    lit(true)
+                ))
+        );
+    }
+
     #[test]
     fn simplify_expr_case_when_then_else() {
         // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true
diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs 
b/datafusion/optimizer/src/simplify_expressions/utils.rs
index 2f7dadceba..35e256f306 100644
--- a/datafusion/optimizer/src/simplify_expressions/utils.rs
+++ b/datafusion/optimizer/src/simplify_expressions/utils.rs
@@ -22,7 +22,7 @@ use datafusion_common::{internal_err, Result, ScalarValue};
 use datafusion_expr::{
     expr::{Between, BinaryExpr, InList},
     expr_fn::{and, bitwise_and, bitwise_or, or},
-    Expr, Like, Operator,
+    Case, Expr, Like, Operator,
 };
 
 pub static POWS_OF_TEN: [i128; 38] = [
@@ -265,6 +265,31 @@ pub fn as_bool_lit(expr: &Expr) -> Result<Option<bool>> {
     }
 }
 
+pub fn is_case_with_literal_outputs(expr: &Expr) -> bool {
+    match expr {
+        Expr::Case(Case {
+            expr: None,
+            when_then_expr,
+            else_expr,
+        }) => {
+            when_then_expr.iter().all(|(_, then)| is_lit(then))
+                && else_expr.as_deref().is_none_or(is_lit)
+        }
+        _ => false,
+    }
+}
+
+pub fn into_case(expr: Expr) -> Result<Case> {
+    match expr {
+        Expr::Case(case) => Ok(case),
+        _ => internal_err!("Expected case, got {expr:?}"),
+    }
+}
+
+pub fn is_lit(expr: &Expr) -> bool {
+    matches!(expr, Expr::Literal(_, _))
+}
+
 /// negate a Not clause
 /// input is the clause to be negated.(args of Not clause)
 /// For BinaryExpr, use the negation of op instead.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to