This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e0c7607d5 feat: unwrap casts of string and dictionary columns (#10323)
1e0c7607d5 is described below

commit 1e0c7607d51faf698eef43a07de22a9578d83712
Author: Adam Curtis <[email protected]>
AuthorDate: Wed May 1 15:57:49 2024 -0400

    feat: unwrap casts of string and dictionary columns (#10323)
    
    * feat: unwrap casts of string and dictionary columns
    
    * feat: allow unwrapping casts for any dictionary type
    
    * docs: fix
    
    * add LargeUtf8
    
    * add explain test for integer cast
    
    * remove unnecessary equality check
    this should prevent returning Transformed in cases where nothing was
    changed
    
    * update comments
---
 .../optimizer/src/unwrap_cast_in_comparison.rs     | 233 ++++++++++++++++-----
 datafusion/sqllogictest/test_files/dictionary.slt  |  45 ++++
 2 files changed, 229 insertions(+), 49 deletions(-)

diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs 
b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 138769674d..293a694d68 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -152,8 +152,8 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter {
                     let Ok(right_type) = right.get_type(&self.schema) else {
                         return Ok(Transformed::no(expr));
                     };
-                    is_support_data_type(&left_type)
-                        && is_support_data_type(&right_type)
+                    is_supported_type(&left_type)
+                        && is_supported_type(&right_type)
                         && is_comparison_op(op)
                 } =>
             {
@@ -172,7 +172,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter {
                         let Ok(expr_type) = right_expr.get_type(&self.schema) 
else {
                             return Ok(Transformed::no(expr));
                         };
-                        let Ok(Some(value)) =
+                        let Some(value) =
                             try_cast_literal_to_type(left_lit_value, 
&expr_type)
                         else {
                             return Ok(Transformed::no(expr));
@@ -196,7 +196,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter {
                         let Ok(expr_type) = left_expr.get_type(&self.schema) 
else {
                             return Ok(Transformed::no(expr));
                         };
-                        let Ok(Some(value)) =
+                        let Some(value) =
                             try_cast_literal_to_type(right_lit_value, 
&expr_type)
                         else {
                             return Ok(Transformed::no(expr));
@@ -226,14 +226,14 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter {
                 let Ok(expr_type) = left_expr.get_type(&self.schema) else {
                     return Ok(Transformed::no(expr));
                 };
-                if !is_support_data_type(&expr_type) {
+                if !is_supported_type(&expr_type) {
                     return Ok(Transformed::no(expr));
                 }
                 let Ok(right_exprs) = list
                     .iter()
                     .map(|right| {
                         let right_type = right.get_type(&self.schema)?;
-                        if !is_support_data_type(&right_type) {
+                        if !is_supported_type(&right_type) {
                             internal_err!(
                                 "The type of list expr {} is not supported",
                                 &right_type
@@ -243,7 +243,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter {
                             Expr::Literal(right_lit_value) => {
                                 // if the right_lit_value can be casted to the 
type of internal_left_expr
                                 // we need to unwrap the cast for 
cast/try_cast expr, and add cast to the literal
-                                let Ok(Some(value)) = 
try_cast_literal_to_type(right_lit_value, &expr_type) else {
+                                let Some(value) = 
try_cast_literal_to_type(right_lit_value, &expr_type) else {
                                     internal_err!(
                                         "Can't cast the list expr {:?} to type 
{:?}",
                                         right_lit_value, &expr_type
@@ -282,7 +282,15 @@ fn is_comparison_op(op: &Operator) -> bool {
     )
 }
 
-fn is_support_data_type(data_type: &DataType) -> bool {
+/// Returns true if [UnwrapCastExprRewriter] supports this data type
+fn is_supported_type(data_type: &DataType) -> bool {
+    is_supported_numeric_type(data_type)
+        || is_supported_string_type(data_type)
+        || is_supported_dictionary_type(data_type)
+}
+
+/// Returns true if [[UnwrapCastExprRewriter]] suppors this numeric type
+fn is_supported_numeric_type(data_type: &DataType) -> bool {
     matches!(
         data_type,
         DataType::UInt8
@@ -298,19 +306,47 @@ fn is_support_data_type(data_type: &DataType) -> bool {
     )
 }
 
+/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a 
string
+fn is_supported_string_type(data_type: &DataType) -> bool {
+    matches!(data_type, DataType::Utf8 | DataType::LargeUtf8)
+}
+
+/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a 
dictionary
+fn is_supported_dictionary_type(data_type: &DataType) -> bool {
+    matches!(data_type,
+                    DataType::Dictionary(_, inner) if is_supported_type(inner))
+}
+
+/// Convert a literal value from one data type to another
 fn try_cast_literal_to_type(
     lit_value: &ScalarValue,
     target_type: &DataType,
-) -> Result<Option<ScalarValue>> {
+) -> Option<ScalarValue> {
     let lit_data_type = lit_value.data_type();
-    // the rule just support the signed numeric data type now
-    if !is_support_data_type(&lit_data_type) || 
!is_support_data_type(target_type) {
-        return Ok(None);
+    if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) {
+        return None;
     }
     if lit_value.is_null() {
         // null value can be cast to any type of null value
-        return Ok(Some(ScalarValue::try_from(target_type)?));
+        return ScalarValue::try_from(target_type).ok();
+    }
+    try_cast_numeric_literal(lit_value, target_type)
+        .or_else(|| try_cast_string_literal(lit_value, target_type))
+        .or_else(|| try_cast_dictionary(lit_value, target_type))
+}
+
+/// Convert a numeric value from one numeric data type to another
+fn try_cast_numeric_literal(
+    lit_value: &ScalarValue,
+    target_type: &DataType,
+) -> Option<ScalarValue> {
+    let lit_data_type = lit_value.data_type();
+    if !is_supported_numeric_type(&lit_data_type)
+        || !is_supported_numeric_type(target_type)
+    {
+        return None;
     }
+
     let mul = match target_type {
         DataType::UInt8
         | DataType::UInt16
@@ -322,9 +358,7 @@ fn try_cast_literal_to_type(
         | DataType::Int64 => 1_i128,
         DataType::Timestamp(_, _) => 1_i128,
         DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
-        other_type => {
-            return internal_err!("Error target data type {other_type:?}");
-        }
+        _ => return None,
     };
     let (target_min, target_max) = match target_type {
         DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
@@ -343,9 +377,7 @@ fn try_cast_literal_to_type(
             MIN_DECIMAL_FOR_EACH_PRECISION[*precision as usize - 1],
             MAX_DECIMAL_FOR_EACH_PRECISION[*precision as usize - 1],
         ),
-        other_type => {
-            return internal_err!("Error target data type {other_type:?}");
-        }
+        _ => return None,
     };
     let lit_value_target_type = match lit_value {
         ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
@@ -379,13 +411,11 @@ fn try_cast_literal_to_type(
                 None
             }
         }
-        other_value => {
-            return internal_err!("Invalid literal value {other_value:?}");
-        }
+        _ => None,
     };
 
     match lit_value_target_type {
-        None => Ok(None),
+        None => None,
         Some(value) => {
             if value >= target_min && value <= target_max {
                 // the value casted from lit to the target type is in the 
range of target type.
@@ -434,18 +464,60 @@ fn try_cast_literal_to_type(
                     DataType::Decimal128(p, s) => {
                         ScalarValue::Decimal128(Some(value), *p, *s)
                     }
-                    other_type => {
-                        return internal_err!("Error target data type 
{other_type:?}");
+                    _ => {
+                        return None;
                     }
                 };
-                Ok(Some(result_scalar))
+                Some(result_scalar)
             } else {
-                Ok(None)
+                None
             }
         }
     }
 }
 
+fn try_cast_string_literal(
+    lit_value: &ScalarValue,
+    target_type: &DataType,
+) -> Option<ScalarValue> {
+    let string_value = match lit_value {
+        ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) => s.clone(),
+        _ => return None,
+    };
+    let scalar_value = match target_type {
+        DataType::Utf8 => ScalarValue::Utf8(string_value),
+        DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
+        _ => return None,
+    };
+    Some(scalar_value)
+}
+
+/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the 
dictionary
+fn try_cast_dictionary(
+    lit_value: &ScalarValue,
+    target_type: &DataType,
+) -> Option<ScalarValue> {
+    let lit_value_type = lit_value.data_type();
+    let result_scalar = match (lit_value, target_type) {
+        // Unwrap dictionary when inner type matches target type
+        (ScalarValue::Dictionary(_, inner_value), _)
+            if inner_value.data_type() == *target_type =>
+        {
+            (**inner_value).clone()
+        }
+        // Wrap type when target type is dictionary
+        (_, DataType::Dictionary(index_type, inner_type))
+            if **inner_type == lit_value_type =>
+        {
+            ScalarValue::Dictionary(index_type.clone(), 
Box::new(lit_value.clone()))
+        }
+        _ => {
+            return None;
+        }
+    };
+    Some(result_scalar)
+}
+
 /// Cast a timestamp value from one unit to another
 fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> 
Option<i64> {
     let value = value as i64;
@@ -536,6 +608,35 @@ mod tests {
         assert_eq!(optimize_test(expr_input, &schema), expected);
     }
 
+    #[test]
+    fn test_unwrap_cast_comparison_string() {
+        let schema = expr_test_schema();
+        let dict = ScalarValue::Dictionary(
+            Box::new(DataType::Int32),
+            Box::new(ScalarValue::from("value")),
+        );
+
+        // cast(str1 as Dictionary<Int32, Utf8>) = arrow_cast('value', 
'Dictionary<Int32, Utf8>') => str1 = Utf8('value1')
+        let expr_input = cast(col("str1"), 
dict.data_type()).eq(lit(dict.clone()));
+        let expected = col("str1").eq(lit("value"));
+        assert_eq!(optimize_test(expr_input, &schema), expected);
+
+        // cast(tag as Utf8) = Utf8('value') => tag = arrow_cast('value', 
'Dictionary<Int32, Utf8>')
+        let expr_input = cast(col("tag"), DataType::Utf8).eq(lit("value"));
+        let expected = col("tag").eq(lit(dict));
+        assert_eq!(optimize_test(expr_input, &schema), expected);
+
+        // cast(largestr as Dictionary<Int32, LargeUtf8>) = 
arrow_cast('value', 'Dictionary<Int32, LargeUtf8>') => str1 = 
LargeUtf8('value1')
+        let dict = ScalarValue::Dictionary(
+            Box::new(DataType::Int32),
+            Box::new(ScalarValue::LargeUtf8(Some("value".to_owned()))),
+        );
+        let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict));
+        let expected =
+            
col("largestr").eq(lit(ScalarValue::LargeUtf8(Some("value".to_owned()))));
+        assert_eq!(optimize_test(expr_input, &schema), expected);
+    }
+
     #[test]
     fn test_not_unwrap_cast_with_decimal_comparison() {
         let schema = expr_test_schema();
@@ -746,6 +847,9 @@ mod tests {
                     Field::new("c6", DataType::UInt32, false),
                     Field::new("ts_nano_none", timestamp_nano_none_type(), 
false),
                     Field::new("ts_nano_utf", timestamp_nano_utc_type(), 
false),
+                    Field::new("str1", DataType::Utf8, false),
+                    Field::new("largestr", DataType::LargeUtf8, false),
+                    Field::new("tag", dictionary_tag_type(), false),
                 ]
                 .into(),
                 HashMap::new(),
@@ -793,6 +897,11 @@ mod tests {
         DataType::Timestamp(TimeUnit::Nanosecond, utc)
     }
 
+    // a dictonary type for storing string tags
+    fn dictionary_tag_type() -> DataType {
+        DataType::Dictionary(Box::new(DataType::Int32), 
Box::new(DataType::Utf8))
+    }
+
     #[test]
     fn test_try_cast_to_type_nulls() {
         // test that nulls can be cast to/from all integer types
@@ -807,6 +916,8 @@ mod tests {
             ScalarValue::UInt64(None),
             ScalarValue::Decimal128(None, 3, 0),
             ScalarValue::Decimal128(None, 8, 2),
+            ScalarValue::Utf8(None),
+            ScalarValue::LargeUtf8(None),
         ];
 
         for s1 in &scalars {
@@ -1061,18 +1172,17 @@ mod tests {
         target_type: DataType,
         expected_result: ExpectedCast,
     ) {
-        let actual_result = try_cast_literal_to_type(&literal, &target_type);
+        let actual_value = try_cast_literal_to_type(&literal, &target_type);
 
         println!("expect_cast: ");
         println!("  {literal:?} --> {target_type:?}");
         println!("  expected_result: {expected_result:?}");
-        println!("  actual_result:   {actual_result:?}");
+        println!("  actual_result:   {actual_value:?}");
 
         match expected_result {
             ExpectedCast::Value(expected_value) => {
-                let actual_value = actual_result
-                    .expect("Expected success but got error")
-                    .expect("Expected cast value but got None");
+                let actual_value =
+                    actual_value.expect("Expected cast value but got None");
 
                 assert_eq!(actual_value, expected_value);
 
@@ -1094,7 +1204,7 @@ mod tests {
 
                 assert_eq!(
                     &expected_array, &cast_array,
-                    "Result of casing {literal:?} with arrow was\n 
{cast_array:#?}\nbut expected\n{expected_array:#?}"
+                    "Result of casting {literal:?} with arrow was\n 
{cast_array:#?}\nbut expected\n{expected_array:#?}"
                 );
 
                 // Verify that for timestamp types the timezones are the same
@@ -1109,8 +1219,6 @@ mod tests {
                 }
             }
             ExpectedCast::NoValue => {
-                let actual_value = actual_result.expect("Expected success but 
got error");
-
                 assert!(
                     actual_value.is_none(),
                     "Expected no cast value, but got {actual_value:?}"
@@ -1126,7 +1234,6 @@ mod tests {
             &ScalarValue::TimestampNanosecond(Some(123456), None),
             &DataType::Timestamp(TimeUnit::Nanosecond, None),
         )
-        .unwrap()
         .unwrap();
 
         assert_eq!(
@@ -1139,7 +1246,6 @@ mod tests {
             &ScalarValue::TimestampNanosecond(Some(123456), None),
             &DataType::Timestamp(TimeUnit::Microsecond, None),
         )
-        .unwrap()
         .unwrap();
 
         assert_eq!(
@@ -1152,7 +1258,6 @@ mod tests {
             &ScalarValue::TimestampNanosecond(Some(123456), None),
             &DataType::Timestamp(TimeUnit::Millisecond, None),
         )
-        .unwrap()
         .unwrap();
 
         assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), 
None));
@@ -1162,7 +1267,6 @@ mod tests {
             &ScalarValue::TimestampNanosecond(Some(123456), None),
             &DataType::Timestamp(TimeUnit::Second, None),
         )
-        .unwrap()
         .unwrap();
 
         assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
@@ -1172,7 +1276,6 @@ mod tests {
             &ScalarValue::TimestampMicrosecond(Some(123), None),
             &DataType::Timestamp(TimeUnit::Nanosecond, None),
         )
-        .unwrap()
         .unwrap();
 
         assert_eq!(
@@ -1185,7 +1288,6 @@ mod tests {
             &ScalarValue::TimestampMicrosecond(Some(123), None),
             &DataType::Timestamp(TimeUnit::Millisecond, None),
         )
-        .unwrap()
         .unwrap();
 
         assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), 
None));
@@ -1195,7 +1297,6 @@ mod tests {
             &ScalarValue::TimestampMicrosecond(Some(123456789), None),
             &DataType::Timestamp(TimeUnit::Second, None),
         )
-        .unwrap()
         .unwrap();
         assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
 
@@ -1204,7 +1305,6 @@ mod tests {
             &ScalarValue::TimestampMillisecond(Some(123), None),
             &DataType::Timestamp(TimeUnit::Nanosecond, None),
         )
-        .unwrap()
         .unwrap();
         assert_eq!(
             new_scalar,
@@ -1216,7 +1316,6 @@ mod tests {
             &ScalarValue::TimestampMillisecond(Some(123), None),
             &DataType::Timestamp(TimeUnit::Microsecond, None),
         )
-        .unwrap()
         .unwrap();
         assert_eq!(
             new_scalar,
@@ -1227,7 +1326,6 @@ mod tests {
             &ScalarValue::TimestampMillisecond(Some(123456789), None),
             &DataType::Timestamp(TimeUnit::Second, None),
         )
-        .unwrap()
         .unwrap();
         assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), 
None));
 
@@ -1236,7 +1334,6 @@ mod tests {
             &ScalarValue::TimestampSecond(Some(123), None),
             &DataType::Timestamp(TimeUnit::Nanosecond, None),
         )
-        .unwrap()
         .unwrap();
         assert_eq!(
             new_scalar,
@@ -1248,7 +1345,6 @@ mod tests {
             &ScalarValue::TimestampSecond(Some(123), None),
             &DataType::Timestamp(TimeUnit::Microsecond, None),
         )
-        .unwrap()
         .unwrap();
         assert_eq!(
             new_scalar,
@@ -1260,7 +1356,6 @@ mod tests {
             &ScalarValue::TimestampSecond(Some(123), None),
             &DataType::Timestamp(TimeUnit::Millisecond, None),
         )
-        .unwrap()
         .unwrap();
         assert_eq!(
             new_scalar,
@@ -1272,8 +1367,48 @@ mod tests {
             &ScalarValue::TimestampSecond(Some(i64::MAX), None),
             &DataType::Timestamp(TimeUnit::Millisecond, None),
         )
-        .unwrap()
         .unwrap();
         assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
     }
+
+    #[test]
+    fn test_try_cast_to_string_type() {
+        let scalars = vec![
+            ScalarValue::from("string"),
+            ScalarValue::LargeUtf8(Some("string".to_owned())),
+        ];
+
+        for s1 in &scalars {
+            for s2 in &scalars {
+                let expected_value = ExpectedCast::Value(s2.clone());
+
+                expect_cast(s1.clone(), s2.data_type(), expected_value);
+            }
+        }
+    }
+    #[test]
+    fn test_try_cast_to_dictionary_type() {
+        fn dictionary_type(t: DataType) -> DataType {
+            DataType::Dictionary(Box::new(DataType::Int32), Box::new(t))
+        }
+        fn dictionary_value(value: ScalarValue) -> ScalarValue {
+            ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value))
+        }
+        let scalars = vec![
+            ScalarValue::from("string"),
+            ScalarValue::LargeUtf8(Some("string".to_owned())),
+        ];
+        for s in &scalars {
+            expect_cast(
+                s.clone(),
+                dictionary_type(s.data_type()),
+                ExpectedCast::Value(dictionary_value(s.clone())),
+            );
+            expect_cast(
+                dictionary_value(s.clone()),
+                s.data_type(),
+                ExpectedCast::Value(s.clone()),
+            )
+        }
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/dictionary.slt 
b/datafusion/sqllogictest/test_files/dictionary.slt
index 06b7265026..7e45f5e444 100644
--- a/datafusion/sqllogictest/test_files/dictionary.slt
+++ b/datafusion/sqllogictest/test_files/dictionary.slt
@@ -386,3 +386,48 @@ drop table m3;
 
 statement ok
 drop table m3_source;
+
+
+## Test that filtering on dictionary columns coerces the filter value to the 
dictionary type
+statement ok
+create table test as values
+  ('row1', arrow_cast('1', 'Dictionary(Int32, Utf8)')),
+  ('row2', arrow_cast('2', 'Dictionary(Int32, Utf8)')),
+  ('row3', arrow_cast('3', 'Dictionary(Int32, Utf8)'))
+;
+
+# query using an string '1' which must be coerced into a dictionary string
+query T?
+SELECT * from test where column2 = '1';
+----
+row1 1
+
+# filter should not have a cast on column2
+query TT
+explain SELECT * from test where column2 = '1';
+----
+logical_plan
+01)Filter: test.column2 = Dictionary(Int32, Utf8("1"))
+02)--TableScan: test projection=[column1, column2]
+physical_plan
+01)CoalesceBatchesExec: target_batch_size=8192
+02)--FilterExec: column2@1 = 1
+03)----MemoryExec: partitions=1, partition_sizes=[1]
+
+
+# Now query using an integer which must be coerced into a dictionary string
+query T?
+SELECT * from test where column2 = 1;
+----
+row1 1
+
+query TT
+explain SELECT * from test where column2 = 1;
+----
+logical_plan
+01)Filter: test.column2 = Dictionary(Int32, Utf8("1"))
+02)--TableScan: test projection=[column1, column2]
+physical_plan
+01)CoalesceBatchesExec: target_batch_size=8192
+02)--FilterExec: column2@1 = 1
+03)----MemoryExec: partitions=1, partition_sizes=[1]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to