This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c95281974 feat: improve LiteralGuarantee for the case like `(a=1 AND 
b=1) OR (a=2 AND b=3)` (#16762)
3c95281974 is described below

commit 3c95281974b4ef1677ff63fd379453c4b2df4bb5
Author: Huaijin <[email protected]>
AuthorDate: Wed Jul 23 22:08:32 2025 +0800

    feat: improve LiteralGuarantee for the case like `(a=1 AND b=1) OR (a=2 AND 
b=3)` (#16762)
    
    * feat: imporve LiteralGuarantee for the case like (a=1 AND b=1) OR (a=2 
AND b=3)
    
    * support inlist
    
    * fmt and clippy
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/physical-expr/src/utils/guarantee.rs | 339 ++++++++++++++++++++----
 1 file changed, 290 insertions(+), 49 deletions(-)

diff --git a/datafusion/physical-expr/src/utils/guarantee.rs 
b/datafusion/physical-expr/src/utils/guarantee.rs
index 8092dc3c1a..8a57cc7b7c 100644
--- a/datafusion/physical-expr/src/utils/guarantee.rs
+++ b/datafusion/physical-expr/src/utils/guarantee.rs
@@ -129,35 +129,15 @@ impl LiteralGuarantee {
                     .as_any()
                     .downcast_ref::<crate::expressions::InListExpr>()
                 {
-                    // Only support single-column inlist currently, 
multi-column inlist is not supported
-                    let col = inlist
-                        .expr()
-                        .as_any()
-                        .downcast_ref::<crate::expressions::Column>();
-                    let Some(col) = col else {
-                        return builder;
-                    };
-
-                    let literals = inlist
-                        .list()
-                        .iter()
-                        .map(|e| 
e.as_any().downcast_ref::<crate::expressions::Literal>())
-                        .collect::<Option<Vec<_>>>();
-                    let Some(literals) = literals else {
-                        return builder;
-                    };
-
-                    let guarantee = if inlist.negated() {
-                        Guarantee::NotIn
+                    if let Some(inlist) = ColInList::try_new(inlist) {
+                        builder.aggregate_multi_conjunct(
+                            inlist.col,
+                            inlist.guarantee,
+                            inlist.list.iter().map(|lit| lit.value()),
+                        )
                     } else {
-                        Guarantee::In
-                    };
-
-                    builder.aggregate_multi_conjunct(
-                        col,
-                        guarantee,
-                        literals.iter().map(|e| e.value()),
-                    )
+                        builder
+                    }
                 } else {
                     // split disjunction: <expr> OR <expr> OR ...
                     let disjunctions = split_disjunction(expr);
@@ -184,16 +164,6 @@ impl LiteralGuarantee {
                         .filter_map(|expr| ColOpLit::try_new(expr))
                         .collect::<Vec<_>>();
 
-                    if terms.is_empty() {
-                        return builder;
-                    }
-
-                    // if not all terms are of the form (col <op> literal),
-                    // can't infer any guarantees
-                    if terms.len() != disjunctions.len() {
-                        return builder;
-                    }
-
                     // if all terms are 'col <op> literal' with the same column
                     // and operation we can infer any guarantees
                     //
@@ -203,18 +173,70 @@ impl LiteralGuarantee {
                     // foo is required for the expression to be true.
                     // So we can only create a multi value guarantee for `=`
                     // (or a single value). (e.g. ignore `a != foo OR a != 
bar`)
-                    let first_term = &terms[0];
-                    if terms.iter().all(|term| {
-                        term.col.name() == first_term.col.name()
-                            && term.guarantee == Guarantee::In
-                    }) {
+                    let first_term = terms.first();
+                    if !terms.is_empty()
+                        && terms.len() == disjunctions.len()
+                        && terms.iter().all(|term| {
+                            term.col.name() == first_term.unwrap().col.name()
+                                && term.guarantee == Guarantee::In
+                        })
+                    {
                         builder.aggregate_multi_conjunct(
-                            first_term.col,
+                            first_term.unwrap().col,
                             Guarantee::In,
                             terms.iter().map(|term| term.lit.value()),
                         )
                     } else {
-                        // can't infer anything
+                        // Handle disjunctions with conjunctions like (a = 1 
AND b = 2) OR (a = 2 AND b = 3)
+                        // Extract termsets from each disjunction
+                        // if in each termset, they have same column, and the 
guarantee is In,
+                        // we can infer a guarantee for the column
+                        // e.g. (a = 1 AND b = 2) OR (a = 2 AND b = 3) is `a 
IN (1, 2) AND b IN (2, 3)`
+                        // otherwise, we can't infer a guarantee
+                        let termsets: Vec<Vec<ColOpLitOrInList>> = disjunctions
+                            .iter()
+                            .map(|expr| {
+                                split_conjunction(expr)
+                                    .into_iter()
+                                    .filter_map(ColOpLitOrInList::try_new)
+                                    .filter(|term| term.guarantee() == 
Guarantee::In)
+                                    .collect()
+                            })
+                            .collect();
+
+                        // Early return if any termset is empty (can't infer 
guarantees)
+                        if termsets.iter().any(|terms| terms.is_empty()) {
+                            return builder;
+                        }
+
+                        // Find columns that appear in all termsets
+                        let common_cols = find_common_columns(&termsets);
+                        if common_cols.is_empty() {
+                            return builder;
+                        }
+
+                        // Build guarantees for common columns
+                        let mut builder = builder;
+                        for col in common_cols {
+                            let literals: Vec<_> = termsets
+                                .iter()
+                                .filter_map(|terms| {
+                                    terms.iter().find(|term| term.col() == 
col).map(
+                                        |term| {
+                                            term.lits().into_iter().map(|lit| 
lit.value())
+                                        },
+                                    )
+                                })
+                                .flatten()
+                                .collect();
+
+                            builder = builder.aggregate_multi_conjunct(
+                                col,
+                                Guarantee::In,
+                                literals.into_iter(),
+                            );
+                        }
+
                         builder
                     }
                 }
@@ -362,7 +384,7 @@ struct ColOpLit<'a> {
 }
 
 impl<'a> ColOpLit<'a> {
-    /// Returns Some(ColEqLit) if the expression is either:
+    /// Returns Some(ColOpLit) if the expression is either:
     /// 1. `col <op> literal`
     /// 2. `literal <op> col`
     /// 3. operator is `=` or `!=`
@@ -410,6 +432,115 @@ impl<'a> ColOpLit<'a> {
     }
 }
 
+/// Represents a single `col [not]in literal` expression
+struct ColInList<'a> {
+    col: &'a crate::expressions::Column,
+    guarantee: Guarantee,
+    list: Vec<&'a crate::expressions::Literal>,
+}
+
+impl<'a> ColInList<'a> {
+    /// Returns Some(ColInList) if the expression is either:
+    /// 1. `col <op> (literal1, literal2, ...)`
+    /// 3. operator is `in` or `not in`
+    ///
+    /// Returns None otherwise
+    fn try_new(inlist: &'a crate::expressions::InListExpr) -> Option<Self> {
+        // Only support single-column inlist currently, multi-column inlist is 
not supported
+        let col = inlist
+            .expr()
+            .as_any()
+            .downcast_ref::<crate::expressions::Column>()?;
+
+        let literals = inlist
+            .list()
+            .iter()
+            .map(|e| e.as_any().downcast_ref::<crate::expressions::Literal>())
+            .collect::<Option<Vec<_>>>()?;
+
+        let guarantee = if inlist.negated() {
+            Guarantee::NotIn
+        } else {
+            Guarantee::In
+        };
+
+        Some(Self {
+            col,
+            guarantee,
+            list: literals,
+        })
+    }
+}
+
+/// Represents a single `col [not]in literal` expression or a single `col <op> 
literal` expression
+enum ColOpLitOrInList<'a> {
+    ColOpLit(ColOpLit<'a>),
+    ColInList(ColInList<'a>),
+}
+
+impl<'a> ColOpLitOrInList<'a> {
+    fn try_new(expr: &'a Arc<dyn PhysicalExpr>) -> Option<Self> {
+        match expr
+            .as_any()
+            .downcast_ref::<crate::expressions::InListExpr>()
+        {
+            Some(inlist) => Some(Self::ColInList(ColInList::try_new(inlist)?)),
+            None => ColOpLit::try_new(expr).map(Self::ColOpLit),
+        }
+    }
+
+    fn guarantee(&self) -> Guarantee {
+        match self {
+            Self::ColOpLit(col_op_lit) => col_op_lit.guarantee,
+            Self::ColInList(col_in_list) => col_in_list.guarantee,
+        }
+    }
+
+    fn col(&self) -> &'a crate::expressions::Column {
+        match self {
+            Self::ColOpLit(col_op_lit) => col_op_lit.col,
+            Self::ColInList(col_in_list) => col_in_list.col,
+        }
+    }
+
+    fn lits(&self) -> Vec<&'a crate::expressions::Literal> {
+        match self {
+            Self::ColOpLit(col_op_lit) => vec![col_op_lit.lit],
+            Self::ColInList(col_in_list) => col_in_list.list.clone(),
+        }
+    }
+}
+
+/// Find columns that appear in all termsets
+fn find_common_columns<'a>(
+    termsets: &[Vec<ColOpLitOrInList<'a>>],
+) -> Vec<&'a crate::expressions::Column> {
+    if termsets.is_empty() {
+        return Vec::new();
+    }
+
+    // Start with columns from the first termset
+    let mut common_cols: HashSet<_> = termsets[0].iter().map(|term| 
term.col()).collect();
+
+    // check if any common_col in one termset occur many times
+    // e.g. (a = 1 AND a = 2) OR (a = 2 AND b = 3), should not infer a 
guarantee
+    // TODO: for above case, we can infer a IN (2) AND b IN (3)
+    if common_cols.len() != termsets[0].len() {
+        return Vec::new();
+    }
+
+    // Intersect with columns from remaining termsets
+    for termset in termsets.iter().skip(1) {
+        let termset_cols: HashSet<_> = termset.iter().map(|term| 
term.col()).collect();
+        if termset_cols.len() != termset.len() {
+            return Vec::new();
+        }
+        common_cols = 
common_cols.intersection(&termset_cols).cloned().collect();
+    }
+
+    common_cols.into_iter().collect()
+}
+
 #[cfg(test)]
 mod test {
     use std::sync::LazyLock;
@@ -808,12 +939,11 @@ mod test {
             vec![not_in_guarantee("b", [1, 2, 3]), in_guarantee("b", [3, 4])],
         );
         // b IN (1, 2, 3) OR b = 2
-        // TODO this should be in_guarantee("b", [1, 2, 3]) but currently we 
don't support to analyze this kind of disjunction. Only `ColOpLit OR ColOpLit` 
is supported.
         test_analyze(
             col("b")
                 .in_list(vec![lit(1), lit(2), lit(3)], false)
                 .or(col("b").eq(lit(2))),
-            vec![],
+            vec![in_guarantee("b", [1, 2, 3])],
         );
         // b IN (1, 2, 3) OR b != 3
         test_analyze(
@@ -824,13 +954,123 @@ mod test {
         );
     }
 
+    #[test]
+    fn test_disjunction_and_conjunction_multi_column() {
+        // (a = "foo" AND b = 1) OR (a = "bar" AND b = 2)
+        test_analyze(
+            (col("a").eq(lit("foo")).and(col("b").eq(lit(1))))
+                .or(col("a").eq(lit("bar")).and(col("b").eq(lit(2)))),
+            vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [1, 2])],
+        );
+        // (a = "foo" AND b = 1) OR (a = "bar" AND b = 2) OR (b = 3)
+        test_analyze(
+            (col("a").eq(lit("foo")).and(col("b").eq(lit(1))))
+                .or(col("a").eq(lit("bar")).and(col("b").eq(lit(2))))
+                .or(col("b").eq(lit(3))),
+            vec![in_guarantee("b", [1, 2, 3])],
+        );
+        // (a = "foo" AND b = 1) OR (a = "bar" AND b = 2) OR (c = 3)
+        test_analyze(
+            (col("a").eq(lit("foo")).and(col("b").eq(lit(1))))
+                .or(col("a").eq(lit("bar")).and(col("b").eq(lit(2))))
+                .or(col("c").eq(lit(3))),
+            vec![],
+        );
+        // (a = "foo" AND b > 1) OR (a = "bar" AND b = 2)
+        test_analyze(
+            (col("a").eq(lit("foo")).and(col("b").gt(lit(1))))
+                .or(col("a").eq(lit("bar")).and(col("b").eq(lit(2)))),
+            vec![in_guarantee("a", ["foo", "bar"])],
+        );
+        // (a = "foo" AND b = 1) OR (b = 1 AND c = 2) OR (c = 3 AND a = "bar")
+        test_analyze(
+            (col("a").eq(lit("foo")).and(col("b").eq(lit(1))))
+                .or(col("b").eq(lit(1)).and(col("c").eq(lit(2))))
+                .or(col("c").eq(lit(3)).and(col("a").eq(lit("bar")))),
+            vec![],
+        );
+        // (a = "foo" AND a = "bar") OR (a = "good" AND b = 1)
+        // TODO: this should be `a IN ("good") AND b IN (1)`
+        test_analyze(
+            (col("a").eq(lit("foo")).and(col("a").eq(lit("bar"))))
+                .or(col("a").eq(lit("good")).and(col("b").eq(lit(1)))),
+            vec![],
+        );
+        // (a = "foo" AND a = "foo") OR (a = "good" AND b = 1)
+        // TODO: this should be `a IN ("foo", "good")`
+        test_analyze(
+            (col("a").eq(lit("foo")).and(col("a").eq(lit("foo"))))
+                .or(col("a").eq(lit("good")).and(col("b").eq(lit(1)))),
+            vec![],
+        );
+        // (a = "foo" AND b = 3) OR (b = 4 AND b = 1) OR (b = 2 AND a = "bar")
+        test_analyze(
+            (col("a").eq(lit("foo")).and(col("b").eq(lit(3))))
+                .or(col("b").eq(lit(4)).and(col("b").eq(lit(1))))
+                .or(col("b").eq(lit(2)).and(col("a").eq(lit("bar")))),
+            vec![],
+        );
+        // (b = 1 AND b > 3) OR (a = "foo" AND b = 4)
+        test_analyze(
+            (col("b").eq(lit(1)).and(col("b").gt(lit(3))))
+                .or(col("a").eq(lit("foo")).and(col("b").eq(lit(4)))),
+            // if b isn't 1 or 4, it can not be true (though the expression 
actually can never be true)
+            vec![in_guarantee("b", [1, 4])],
+        );
+        // (a = "foo" AND b = 1) OR (a != "bar" AND b = 2)
+        test_analyze(
+            (col("a").eq(lit("foo")).and(col("b").eq(lit(1))))
+                .or(col("a").not_eq(lit("bar")).and(col("b").eq(lit(2)))),
+            vec![in_guarantee("b", [1, 2])],
+        );
+        // (a = "foo" AND b = 1) OR (a LIKE "%bar" AND b = 2)
+        test_analyze(
+            (col("a").eq(lit("foo")).and(col("b").eq(lit(1))))
+                .or(col("a").like(lit("%bar")).and(col("b").eq(lit(2)))),
+            vec![in_guarantee("b", [1, 2])],
+        );
+        // (a IN ("foo", "bar") AND b = 5) OR (a IN ("foo", "bar") AND b = 6)
+        test_analyze(
+            (col("a")
+                .in_list(vec![lit("foo"), lit("bar")], false)
+                .and(col("b").eq(lit(5))))
+            .or(col("a")
+                .in_list(vec![lit("foo"), lit("bar")], false)
+                .and(col("b").eq(lit(6)))),
+            vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [5, 6])],
+        );
+        // (a IN ("foo", "bar") AND b = 5) OR (a IN ("foo") AND b = 6)
+        test_analyze(
+            (col("a")
+                .in_list(vec![lit("foo"), lit("bar")], false)
+                .and(col("b").eq(lit(5))))
+            .or(col("a")
+                .in_list(vec![lit("foo")], false)
+                .and(col("b").eq(lit(6)))),
+            vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [5, 6])],
+        );
+        // (a NOT IN ("foo", "bar") AND b = 5) OR (a NOT IN ("foo") AND b = 6)
+        test_analyze(
+            (col("a")
+                .in_list(vec![lit("foo"), lit("bar")], true)
+                .and(col("b").eq(lit(5))))
+            .or(col("a")
+                .in_list(vec![lit("foo")], true)
+                .and(col("b").eq(lit(6)))),
+            vec![in_guarantee("b", [5, 6])],
+        );
+    }
+
     /// Tests that analyzing expr results in the expected guarantees
     fn test_analyze(expr: Expr, expected: Vec<LiteralGuarantee>) {
         println!("Begin analyze of {expr}");
         let schema = schema();
         let physical_expr = logical2physical(&expr, &schema);
 
-        let actual = LiteralGuarantee::analyze(&physical_expr);
+        let actual = LiteralGuarantee::analyze(&physical_expr)
+            .into_iter()
+            .sorted_by_key(|g| g.column.name().to_string())
+            .collect::<Vec<_>>();
         assert_eq!(
             expected, actual,
             "expr: {expr}\
@@ -867,6 +1107,7 @@ mod test {
             Arc::new(Schema::new(vec![
                 Field::new("a", DataType::Utf8, false),
                 Field::new("b", DataType::Int32, false),
+                Field::new("c", DataType::Int32, false),
             ]))
         });
         Arc::clone(&SCHEMA)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to