This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 1fe856bece predicate pruning: support cast and try_cast for more types 
(#15764)
1fe856bece is described below

commit 1fe856bece733a2ceaba42803e3d3888b437024f
Author: Adrian Garcia Badaracco <1755071+adria...@users.noreply.github.com>
AuthorDate: Thu Apr 24 07:45:07 2025 -0700

    predicate pruning: support cast and try_cast for more types (#15764)
    
    * predicate pruning: support dictionaries
    
    * more types
    
    * clippy
    
    * add tests
    
    * add tests
    
    * simplify to dicts
    
    * revert most changes
    
    * just check for strings, more tests
    
    * more tests
    
    * remove unecessary now confusing clause
---
 datafusion/physical-optimizer/src/pruning.rs | 328 +++++++++++++++++++++++++--
 1 file changed, 314 insertions(+), 14 deletions(-)

diff --git a/datafusion/physical-optimizer/src/pruning.rs 
b/datafusion/physical-optimizer/src/pruning.rs
index 1dd168f181..b62b72ac6d 100644
--- a/datafusion/physical-optimizer/src/pruning.rs
+++ b/datafusion/physical-optimizer/src/pruning.rs
@@ -1210,23 +1210,35 @@ fn is_compare_op(op: Operator) -> bool {
     )
 }
 
+fn is_string_type(data_type: &DataType) -> bool {
+    matches!(
+        data_type,
+        DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View
+    )
+}
+
 // The pruning logic is based on the comparing the min/max bounds.
 // Must make sure the two type has order.
 // For example, casts from string to numbers is not correct.
 // Because the "13" is less than "3" with UTF8 comparison order.
 fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> 
Result<()> {
-    // TODO: support other data type for prunable cast or try cast
-    if matches!(
-        from_type,
-        DataType::Int8
-            | DataType::Int16
-            | DataType::Int32
-            | DataType::Int64
-            | DataType::Decimal128(_, _)
-    ) && matches!(
-        to_type,
-        DataType::Int8 | DataType::Int32 | DataType::Int64 | 
DataType::Decimal128(_, _)
-    ) {
+    // Dictionary casts are always supported as long as the value types are 
supported
+    let from_type = match from_type {
+        DataType::Dictionary(_, t) => {
+            return verify_support_type_for_prune(t.as_ref(), to_type)
+        }
+        _ => from_type,
+    };
+    let to_type = match to_type {
+        DataType::Dictionary(_, t) => {
+            return verify_support_type_for_prune(from_type, t.as_ref())
+        }
+        _ => to_type,
+    };
+    // If both types are strings or both are not strings (number, timestamp, 
etc)
+    // then we can compare them.
+    // PruningPredicate does not support casting of strings to numbers and 
such.
+    if is_string_type(from_type) == is_string_type(to_type) {
         Ok(())
     } else {
         plan_err!(
@@ -1544,7 +1556,10 @@ fn build_predicate_expression(
         Ok(builder) => builder,
         // allow partial failure in predicate expression generation
         // this can still produce a useful predicate when multiple conditions 
are joined using AND
-        Err(_) => return unhandled_hook.handle(expr),
+        Err(e) => {
+            dbg!(format!("Error building pruning expression: {e}"));
+            return unhandled_hook.handle(expr);
+        }
     };
 
     build_statistics_expr(&mut expr_builder)
@@ -3006,7 +3021,7 @@ mod tests {
     }
 
     #[test]
-    fn row_group_predicate_cast() -> Result<()> {
+    fn row_group_predicate_cast_int_int() -> Result<()> {
         let schema = Schema::new(vec![Field::new("c1", DataType::Int32, 
false)]);
         let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 
AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)";
 
@@ -3043,6 +3058,291 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn row_group_predicate_cast_string_string() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, 
false)]);
+        let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 
AS Utf8) <= 1 AND 1 <= CAST(c1_max@1 AS Utf8)";
+
+        // test column on the left
+        let expr = cast(col("c1"), DataType::Utf8)
+            .eq(lit(ScalarValue::Utf8(Some("1".to_string()))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        // test column on the right
+        let expr = lit(ScalarValue::Utf8(Some("1".to_string())))
+            .eq(cast(col("c1"), DataType::Utf8));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        Ok(())
+    }
+
+    #[test]
+    fn row_group_predicate_cast_string_int() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, 
false)]);
+        let expected_expr = "true";
+
+        // test column on the left
+        let expr = cast(col("c1"), 
DataType::Int32).eq(lit(ScalarValue::Int32(Some(1))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        // test column on the right
+        let expr = lit(ScalarValue::Int32(Some(1))).eq(cast(col("c1"), 
DataType::Int32));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        Ok(())
+    }
+
+    #[test]
+    fn row_group_predicate_cast_int_string() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("c1", DataType::Int32, 
false)]);
+        let expected_expr = "true";
+
+        // test column on the left
+        let expr = cast(col("c1"), DataType::Utf8)
+            .eq(lit(ScalarValue::Utf8(Some("1".to_string()))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        // test column on the right
+        let expr = lit(ScalarValue::Utf8(Some("1".to_string())))
+            .eq(cast(col("c1"), DataType::Utf8));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        Ok(())
+    }
+
+    #[test]
+    fn row_group_predicate_date_date() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("c1", DataType::Date32, 
false)]);
+        let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 
AS Date64) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Date64)";
+
+        // test column on the left
+        let expr =
+            cast(col("c1"), 
DataType::Date64).eq(lit(ScalarValue::Date64(Some(123))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        // test column on the right
+        let expr =
+            lit(ScalarValue::Date64(Some(123))).eq(cast(col("c1"), 
DataType::Date64));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        Ok(())
+    }
+
+    #[test]
+    fn row_group_predicate_dict_string_date() -> Result<()> {
+        // Test with Dictionary<UInt8, Utf8> for the literal
+        let schema = Schema::new(vec![Field::new("c1", DataType::Date32, 
false)]);
+        let expected_expr = "true";
+
+        // test column on the left
+        let expr = cast(
+            col("c1"),
+            DataType::Dictionary(Box::new(DataType::UInt8), 
Box::new(DataType::Utf8)),
+        )
+        .eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        // test column on the right
+        let expr = 
lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))).eq(cast(
+            col("c1"),
+            DataType::Dictionary(Box::new(DataType::UInt8), 
Box::new(DataType::Utf8)),
+        ));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        Ok(())
+    }
+
+    #[test]
+    fn row_group_predicate_date_dict_string() -> Result<()> {
+        // Test with Dictionary<UInt8, Utf8> for the column
+        let schema = Schema::new(vec![Field::new(
+            "c1",
+            DataType::Dictionary(Box::new(DataType::UInt8), 
Box::new(DataType::Utf8)),
+            false,
+        )]);
+        let expected_expr = "true";
+
+        // test column on the left
+        let expr =
+            cast(col("c1"), 
DataType::Date32).eq(lit(ScalarValue::Date32(Some(123))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        // test column on the right
+        let expr =
+            lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), 
DataType::Date32));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        Ok(())
+    }
+
+    #[test]
+    fn row_group_predicate_dict_dict_same_value_type() -> Result<()> {
+        // Test with Dictionary types that have the same value type but 
different key types
+        let schema = Schema::new(vec![Field::new(
+            "c1",
+            DataType::Dictionary(Box::new(DataType::UInt8), 
Box::new(DataType::Utf8)),
+            false,
+        )]);
+
+        // Direct comparison with no cast
+        let expr = 
col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string()))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        let expected_expr =
+            "c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= 
c1_max@1";
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        // Test with column cast to a dictionary with different key type
+        let expr = cast(
+            col("c1"),
+            DataType::Dictionary(Box::new(DataType::UInt16), 
Box::new(DataType::Utf8)),
+        )
+        .eq(lit(ScalarValue::Utf8(Some("test".to_string()))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 
AS Dictionary(UInt16, Utf8)) <= test AND test <= CAST(c1_max@1 AS 
Dictionary(UInt16, Utf8))";
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        Ok(())
+    }
+
+    #[test]
+    fn row_group_predicate_dict_dict_different_value_type() -> Result<()> {
+        // Test with Dictionary types that have different value types
+        let schema = Schema::new(vec![Field::new(
+            "c1",
+            DataType::Dictionary(Box::new(DataType::UInt8), 
Box::new(DataType::Int32)),
+            false,
+        )]);
+        let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 
AS Int64) <= 123 AND 123 <= CAST(c1_max@1 AS Int64)";
+
+        // Test with literal of a different type
+        let expr =
+            cast(col("c1"), 
DataType::Int64).eq(lit(ScalarValue::Int64(Some(123))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        Ok(())
+    }
+
+    #[test]
+    fn row_group_predicate_nested_dict() -> Result<()> {
+        // Test with nested Dictionary types
+        let schema = Schema::new(vec![Field::new(
+            "c1",
+            DataType::Dictionary(
+                Box::new(DataType::UInt8),
+                Box::new(DataType::Dictionary(
+                    Box::new(DataType::UInt16),
+                    Box::new(DataType::Utf8),
+                )),
+            ),
+            false,
+        )]);
+        let expected_expr =
+            "c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= 
c1_max@1";
+
+        // Test with a simple literal
+        let expr = 
col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string()))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        Ok(())
+    }
+
+    #[test]
+    fn row_group_predicate_dict_date_dict_date() -> Result<()> {
+        // Test with dictionary-wrapped date types for both sides
+        let schema = Schema::new(vec![Field::new(
+            "c1",
+            DataType::Dictionary(Box::new(DataType::UInt8), 
Box::new(DataType::Date32)),
+            false,
+        )]);
+        let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 
AS Dictionary(UInt16, Date64)) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS 
Dictionary(UInt16, Date64))";
+
+        // Test with a cast to a different date type
+        let expr = cast(
+            col("c1"),
+            DataType::Dictionary(Box::new(DataType::UInt16), 
Box::new(DataType::Date64)),
+        )
+        .eq(lit(ScalarValue::Date64(Some(123))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        Ok(())
+    }
+
+    #[test]
+    fn row_group_predicate_date_string() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("c1", DataType::Utf8, 
false)]);
+        let expected_expr = "true";
+
+        // test column on the left
+        let expr =
+            cast(col("c1"), 
DataType::Date32).eq(lit(ScalarValue::Date32(Some(123))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        // test column on the right
+        let expr =
+            lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), 
DataType::Date32));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        Ok(())
+    }
+
+    #[test]
+    fn row_group_predicate_string_date() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("c1", DataType::Date32, 
false)]);
+        let expected_expr = "true";
+
+        // test column on the left
+        let expr = cast(col("c1"), DataType::Utf8)
+            .eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        // test column on the right
+        let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))
+            .eq(cast(col("c1"), DataType::Utf8));
+        let predicate_expr =
+            test_build_predicate_expression(&expr, &schema, &mut 
RequiredColumns::new());
+        assert_eq!(predicate_expr.to_string(), expected_expr);
+
+        Ok(())
+    }
+
     #[test]
     fn row_group_predicate_cast_list() -> Result<()> {
         let schema = Schema::new(vec![Field::new("c1", DataType::Int32, 
false)]);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to