This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d698d9deb6 fix: Support Dict types in `in_list` physical plans (#10031)
d698d9deb6 is described below

commit d698d9deb6d06214ea36cc59739993ddf6441b6a
Author: advancedxy <[email protected]>
AuthorDate: Sun Apr 14 04:05:38 2024 +0800

    fix: Support Dict types in `in_list` physical plans (#10031)
    
    * fix: Relax type check with dict types in in_list
    
    * refine comments
    
    * fix style, refine comments and address reviewer's comments
    
    * refine comments
    
    * address comments
---
 .../physical-expr/src/expressions/in_list.rs       | 126 ++++++++++++++++++++-
 datafusion/sqllogictest/test_files/dictionary.slt  |  39 +++++++
 2 files changed, 161 insertions(+), 4 deletions(-)

diff --git a/datafusion/physical-expr/src/expressions/in_list.rs 
b/datafusion/physical-expr/src/expressions/in_list.rs
index ecdb03e97e..07185b4d65 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -415,6 +415,18 @@ impl PartialEq<dyn Any> for InListExpr {
     }
 }
 
+/// Checks if two types are logically equal, dictionary types are compared by 
their value types.
+fn is_logically_eq(lhs: &DataType, rhs: &DataType) -> bool {
+    match (lhs, rhs) {
+        (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => {
+            v1.as_ref().eq(v2.as_ref())
+        }
+        (DataType::Dictionary(_, l), _) => l.as_ref().eq(rhs),
+        (_, DataType::Dictionary(_, r)) => lhs.eq(r.as_ref()),
+        _ => lhs.eq(rhs),
+    }
+}
+
 /// Creates a unary expression InList
 pub fn in_list(
     expr: Arc<dyn PhysicalExpr>,
@@ -426,7 +438,7 @@ pub fn in_list(
     let expr_data_type = expr.data_type(schema)?;
     for list_expr in list.iter() {
         let list_expr_data_type = list_expr.data_type(schema)?;
-        if !expr_data_type.eq(&list_expr_data_type) {
+        if !is_logically_eq(&expr_data_type, &list_expr_data_type) {
             return internal_err!(
                 "The data type inlist should be same, the value type is 
{expr_data_type}, one of list expr type is {list_expr_data_type}"
             );
@@ -499,7 +511,21 @@ mod tests {
     macro_rules! in_list {
         ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, 
$SCHEMA:expr) => {{
             let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, 
$SCHEMA)?;
-            let expr = in_list(cast_expr, cast_list_exprs, $NEGATED, 
$SCHEMA).unwrap();
+            in_list_raw!(
+                $BATCH,
+                cast_list_exprs,
+                $NEGATED,
+                $EXPECTED,
+                cast_expr,
+                $SCHEMA
+            );
+        }};
+    }
+
+    // applies the in_list expr to an input batch and list without cast
+    macro_rules! in_list_raw {
+        ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, 
$SCHEMA:expr) => {{
+            let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap();
             let result = expr
                 .evaluate(&$BATCH)?
                 .into_array($BATCH.num_rows())
@@ -540,7 +566,7 @@ mod tests {
             &schema
         );
 
-        // expression: "a not in ("a", "b")"
+        // expression: "a in ("a", "b", null)"
         let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
         in_list!(
             batch,
@@ -551,7 +577,7 @@ mod tests {
             &schema
         );
 
-        // expression: "a not in ("a", "b")"
+        // expression: "a not in ("a", "b", null)"
         let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
         in_list!(
             batch,
@@ -1314,4 +1340,96 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn in_list_utf8_with_dict_types() -> Result<()> {
+        fn dict_lit(key_type: DataType, value: &str) -> Arc<dyn PhysicalExpr> {
+            lit(ScalarValue::Dictionary(
+                Box::new(key_type),
+                Box::new(ScalarValue::new_utf8(value.to_string())),
+            ))
+        }
+
+        fn null_dict_lit(key_type: DataType) -> Arc<dyn PhysicalExpr> {
+            lit(ScalarValue::Dictionary(
+                Box::new(key_type),
+                Box::new(ScalarValue::Utf8(None)),
+            ))
+        }
+
+        let schema = Schema::new(vec![Field::new(
+            "a",
+            DataType::Dictionary(Box::new(DataType::UInt16), 
Box::new(DataType::Utf8)),
+            true,
+        )]);
+        let a: UInt16DictionaryArray =
+            vec![Some("a"), Some("d"), None].into_iter().collect();
+        let col_a = col("a", &schema)?;
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(a)])?;
+
+        // expression: "a in ("a", "b")"
+        let lists = [
+            vec![lit("a"), lit("b")],
+            vec![
+                dict_lit(DataType::Int8, "a"),
+                dict_lit(DataType::UInt16, "b"),
+            ],
+        ];
+        for list in lists.iter() {
+            in_list_raw!(
+                batch,
+                list.clone(),
+                &false,
+                vec![Some(true), Some(false), None],
+                col_a.clone(),
+                &schema
+            );
+        }
+
+        // expression: "a not in ("a", "b")"
+        for list in lists.iter() {
+            in_list_raw!(
+                batch,
+                list.clone(),
+                &true,
+                vec![Some(false), Some(true), None],
+                col_a.clone(),
+                &schema
+            );
+        }
+
+        // expression: "a in ("a", "b", null)"
+        let lists = [
+            vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))],
+            vec![
+                dict_lit(DataType::Int8, "a"),
+                dict_lit(DataType::UInt16, "b"),
+                null_dict_lit(DataType::UInt16),
+            ],
+        ];
+        for list in lists.iter() {
+            in_list_raw!(
+                batch,
+                list.clone(),
+                &false,
+                vec![Some(true), None, None],
+                col_a.clone(),
+                &schema
+            );
+        }
+
+        // expression: "a not in ("a", "b", null)"
+        for list in lists.iter() {
+            in_list_raw!(
+                batch,
+                list.clone(),
+                &true,
+                vec![Some(false), None, None],
+                col_a.clone(),
+                &schema
+            );
+        }
+
+        Ok(())
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/dictionary.slt 
b/datafusion/sqllogictest/test_files/dictionary.slt
index af7bf5cb16..891a09fbc1 100644
--- a/datafusion/sqllogictest/test_files/dictionary.slt
+++ b/datafusion/sqllogictest/test_files/dictionary.slt
@@ -87,6 +87,22 @@ f3 Utf8 YES
 f4 Float64 YES
 time Timestamp(Nanosecond, None) YES
 
+# in list with dictionary input
+query BBB
+SELECT
+    tag_id in ('1000'), '1000' in (tag_id, null), 
arrow_cast('999','Dictionary(Int32, Utf8)') in (tag_id, null)
+FROM m1
+----
+true true NULL
+true true NULL
+true true NULL
+true true NULL
+true true NULL
+true true NULL
+true true NULL
+true true NULL
+true true NULL
+true true NULL
 
 # Table m2 with a tag columns `tag_id` and `type`, a field column `f5`, and 
`time`
 statement ok
@@ -165,6 +181,29 @@ order by date_bin('30 minutes', time) DESC
 3 400 600 500 2023-12-04T00:30:00
 3 100 300 200 2023-12-04T00:00:00
 
+# query with in list
+query BBBBBBBB
+SELECT
+    type in ('active', 'passive')
+    , 'active' in (type)
+    , 'active' in (type, null)
+    , arrow_cast('passive','Dictionary(Int8, Utf8)') in (type, null)
+    , tag_id in ('1000', '2000')
+    , tag_id in ('999')
+    , '1000' in (tag_id, null)
+    , arrow_cast('999','Dictionary(Int16, Utf8)') in (tag_id, null)
+FROM m2
+----
+true true true NULL true false true NULL
+true true true NULL true false true NULL
+true true true NULL true false true NULL
+true true true NULL true false true NULL
+true true true NULL true false true NULL
+true true true NULL true false true NULL
+true false NULL true true false true NULL
+true false NULL true true false true NULL
+true false NULL true true false true NULL
+true false NULL true true false true NULL
 
 
 # Reproducer for https://github.com/apache/arrow-datafusion/issues/8738

Reply via email to