This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 69d05aa0c support cast/try_cast in prune with signed integer and 
decimal (#3422)
69d05aa0c is described below

commit 69d05aa0c563a478e28502c0a4e7822095859b28
Author: Kun Liu <[email protected]>
AuthorDate: Tue Sep 13 01:40:41 2022 +0800

    support cast/try_cast in prune with signed integer and decimal (#3422)
    
    * support cast/try_cast in prune
    
    * add bound for supported data type in the cast/try_cast prune
---
 datafusion/core/src/physical_optimizer/pruning.rs | 287 ++++++++++++++++++++--
 datafusion/expr/src/expr_fn.rs                    |  13 +
 2 files changed, 278 insertions(+), 22 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/pruning.rs 
b/datafusion/core/src/physical_optimizer/pruning.rs
index 896b5ad9a..73f6c795c 100644
--- a/datafusion/core/src/physical_optimizer/pruning.rs
+++ b/datafusion/core/src/physical_optimizer/pruning.rs
@@ -43,9 +43,9 @@ use arrow::{
     datatypes::{DataType, Field, Schema, SchemaRef},
     record_batch::RecordBatch,
 };
-use datafusion_expr::binary_expr;
 use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
 use datafusion_expr::utils::expr_to_columns;
+use datafusion_expr::{binary_expr, cast, try_cast, ExprSchemable};
 use datafusion_physical_expr::create_physical_expr;
 
 /// Interface to pass statistics information to [`PruningPredicate`]
@@ -429,11 +429,13 @@ impl<'a> PruningExpressionBuilder<'a> {
                 }
             };
 
-        let (column_expr, correct_operator, scalar_expr) =
-            match rewrite_expr_to_prunable(column_expr, correct_operator, 
scalar_expr) {
-                Ok(ret) => ret,
-                Err(e) => return Err(e),
-            };
+        let df_schema = DFSchema::try_from(schema.clone())?;
+        let (column_expr, correct_operator, scalar_expr) = 
rewrite_expr_to_prunable(
+            column_expr,
+            correct_operator,
+            scalar_expr,
+            df_schema,
+        )?;
         let column = columns.iter().next().unwrap().clone();
         let field = match schema.column_with_name(&column.flat_name()) {
             Some((_, f)) => f,
@@ -481,12 +483,15 @@ impl<'a> PruningExpressionBuilder<'a> {
 /// 2. `-col > 10` should be rewritten to `col < -10`
 /// 3. `!col = true` would be rewritten to `col = !true`
 /// 4. `abs(a - 10) > 0` not supported
+/// 5. `cast(can_prunable_expr) > 10`
+/// 6. `try_cast(can_prunable_expr) > 10`
 ///
 /// More rewrite rules are still in progress.
 fn rewrite_expr_to_prunable(
     column_expr: &Expr,
     op: Operator,
     scalar_expr: &Expr,
+    schema: DFSchema,
 ) -> Result<(Expr, Operator, Expr)> {
     if !is_compare_op(op) {
         return Err(DataFusionError::Plan(
@@ -495,22 +500,29 @@ fn rewrite_expr_to_prunable(
     }
 
     match column_expr {
-        // `col > lit()`
+        // `col op lit()`
         Expr::Column(_) => Ok((column_expr.clone(), op, scalar_expr.clone())),
-
+        // `cast(col) op lit()`
+        Expr::Cast { expr, data_type } => {
+            let from_type = expr.get_type(&schema)?;
+            verify_support_type_for_prune(&from_type, data_type)?;
+            let (left, op, right) =
+                rewrite_expr_to_prunable(expr, op, scalar_expr, schema)?;
+            Ok((cast(left, data_type.clone()), op, right))
+        }
+        // `try_cast(col) op lit()`
+        Expr::TryCast { expr, data_type } => {
+            let from_type = expr.get_type(&schema)?;
+            verify_support_type_for_prune(&from_type, data_type)?;
+            let (left, op, right) =
+                rewrite_expr_to_prunable(expr, op, scalar_expr, schema)?;
+            Ok((try_cast(left, data_type.clone()), op, right))
+        }
         // `-col > lit()`  --> `col < -lit()`
-        Expr::Negative(c) => match c.as_ref() {
-            Expr::Column(_) => Ok((
-                c.as_ref().clone(),
-                reverse_operator(op),
-                Expr::Negative(Box::new(scalar_expr.clone())),
-            )),
-            _ => Err(DataFusionError::Plan(format!(
-                "negative with complex expression {:?} is not supported",
-                column_expr
-            ))),
-        },
-
+        Expr::Negative(c) => {
+            let (left, op, right) = rewrite_expr_to_prunable(c, op, 
scalar_expr, schema)?;
+            Ok((left, reverse_operator(op), Expr::Negative(Box::new(right))))
+        }
         // `!col = true` --> `col = !true`
         Expr::Not(c) => {
             if op != Operator::Eq && op != Operator::NotEq {
@@ -551,6 +563,32 @@ fn is_compare_op(op: Operator) -> bool {
     )
 }
 
+// The pruning logic is based on the comparing the min/max bounds.
+// Must make sure the two type has order.
+// For example, casts from string to numbers is not correct.
+// Because the "13" is less than "3" with UTF8 comparison order.
+fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> 
Result<()> {
+    // TODO: support other data type for prunable cast or try cast
+    if matches!(
+        from_type,
+        DataType::Int8
+            | DataType::Int16
+            | DataType::Int32
+            | DataType::Int64
+            | DataType::Decimal128(_, _)
+    ) && matches!(
+        to_type,
+        DataType::Int8 | DataType::Int32 | DataType::Int64 | 
DataType::Decimal128(_, _)
+    ) {
+        Ok(())
+    } else {
+        Err(DataFusionError::Plan(format!(
+            "Try Cast/Cast with from type {} to type {} is not supported",
+            from_type, to_type
+        )))
+    }
+}
+
 /// replaces a column with an old name with a new name in an expression
 fn rewrite_column_expr(
     e: Expr,
@@ -804,10 +842,10 @@ mod tests {
         datatypes::{DataType, TimeUnit},
     };
     use datafusion_common::ScalarValue;
+    use datafusion_expr::{cast, is_null};
     use std::collections::HashMap;
 
     #[derive(Debug)]
-
     /// Mock statistic provider for tests
     ///
     /// Each row represents the statistics for a "container" (which
@@ -1508,6 +1546,78 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn row_group_predicate_cast() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("c1", DataType::Int32, 
false)]);
+        let expected_expr =
+            "CAST(#c1_min AS Int64) <= Int64(1) AND Int64(1) <= CAST(#c1_max 
AS Int64)";
+
+        // test column on the left
+        let expr = cast(col("c1"), 
DataType::Int64).eq(lit(ScalarValue::Int64(Some(1))));
+        let predicate_expr =
+            build_predicate_expression(&expr, &schema, &mut 
RequiredStatColumns::new())?;
+        assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+        // test column on the right
+        let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"), 
DataType::Int64));
+        let predicate_expr =
+            build_predicate_expression(&expr, &schema, &mut 
RequiredStatColumns::new())?;
+        assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+        let expected_expr = "TRY_CAST(#c1_max AS Int64) > Int64(1)";
+
+        // test column on the left
+        let expr =
+            try_cast(col("c1"), 
DataType::Int64).gt(lit(ScalarValue::Int64(Some(1))));
+        let predicate_expr =
+            build_predicate_expression(&expr, &schema, &mut 
RequiredStatColumns::new())?;
+        assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+        // test column on the right
+        let expr =
+            lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"), 
DataType::Int64));
+        let predicate_expr =
+            build_predicate_expression(&expr, &schema, &mut 
RequiredStatColumns::new())?;
+        assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+        Ok(())
+    }
+
+    #[test]
+    fn row_group_predicate_cast_list() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("c1", DataType::Int32, 
false)]);
+        // test cast(c1 as int64) in int64(1, 2, 3)
+        let expr = Expr::InList {
+            expr: Box::new(cast(col("c1"), DataType::Int64)),
+            list: vec![
+                lit(ScalarValue::Int64(Some(1))),
+                lit(ScalarValue::Int64(Some(2))),
+                lit(ScalarValue::Int64(Some(3))),
+            ],
+            negated: false,
+        };
+        let expected_expr = "CAST(#c1_min AS Int64) <= Int64(1) AND Int64(1) 
<= CAST(#c1_max AS Int64) OR CAST(#c1_min AS Int64) <= Int64(2) AND Int64(2) <= 
CAST(#c1_max AS Int64) OR CAST(#c1_min AS Int64) <= Int64(3) AND Int64(3) <= 
CAST(#c1_max AS Int64)";
+        let predicate_expr =
+            build_predicate_expression(&expr, &schema, &mut 
RequiredStatColumns::new())?;
+        assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+        let expr = Expr::InList {
+            expr: Box::new(cast(col("c1"), DataType::Int64)),
+            list: vec![
+                lit(ScalarValue::Int64(Some(1))),
+                lit(ScalarValue::Int64(Some(2))),
+                lit(ScalarValue::Int64(Some(3))),
+            ],
+            negated: true,
+        };
+        let expected_expr = "CAST(#c1_min AS Int64) != Int64(1) OR Int64(1) != 
CAST(#c1_max AS Int64) AND CAST(#c1_min AS Int64) != Int64(2) OR Int64(2) != 
CAST(#c1_max AS Int64) AND CAST(#c1_min AS Int64) != Int64(3) OR Int64(3) != 
CAST(#c1_max AS Int64)";
+        let predicate_expr =
+            build_predicate_expression(&expr, &schema, &mut 
RequiredStatColumns::new())?;
+        assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+        Ok(())
+    }
+
     #[test]
     fn prune_decimal_data() {
         // decimal(9,2)
@@ -1527,6 +1637,36 @@ mod tests {
                 vec![Some(5), Some(6), Some(4), None], // max
             ),
         );
+        let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
+        let result = p.prune(&statistics).unwrap();
+        let expected = vec![false, true, false, true];
+        assert_eq!(result, expected);
+
+        // with cast column to other type
+        let expr = cast(col("s1"), DataType::Decimal128(14, 3))
+            .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3)));
+        let statistics = TestStatistics::new().with(
+            "s1",
+            ContainerStats::new_i32(
+                vec![Some(0), Some(4), None, Some(3)], // min
+                vec![Some(5), Some(6), Some(4), None], // max
+            ),
+        );
+        let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
+        let result = p.prune(&statistics).unwrap();
+        let expected = vec![false, true, false, true];
+        assert_eq!(result, expected);
+
+        // with try cast column to other type
+        let expr = try_cast(col("s1"), DataType::Decimal128(14, 3))
+            .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3)));
+        let statistics = TestStatistics::new().with(
+            "s1",
+            ContainerStats::new_i32(
+                vec![Some(0), Some(4), None, Some(3)], // min
+                vec![Some(5), Some(6), Some(4), None], // max
+            ),
+        );
         let p = PruningPredicate::try_new(expr, schema).unwrap();
         let result = p.prune(&statistics).unwrap();
         let expected = vec![false, true, false, true];
@@ -1576,6 +1716,7 @@ mod tests {
         let expected = vec![false, true, false, true];
         assert_eq!(result, expected);
     }
+
     #[test]
     fn prune_api() {
         let schema = Arc::new(Schema::new(vec![
@@ -1599,10 +1740,16 @@ mod tests {
         // No stats for s2 ==> some rows could pass
         // s2 [3, None] (null max) ==> some rows could pass
 
-        let p = PruningPredicate::try_new(expr, schema).unwrap();
+        let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
         let result = p.prune(&statistics).unwrap();
         let expected = vec![false, true, true, true];
+        assert_eq!(result, expected);
 
+        // filter with cast
+        let expr = cast(col("s2"), 
DataType::Int64).gt(lit(ScalarValue::Int64(Some(5))));
+        let p = PruningPredicate::try_new(expr, schema).unwrap();
+        let result = p.prune(&statistics).unwrap();
+        let expected = vec![false, true, true, true];
         assert_eq!(result, expected);
     }
 
@@ -1852,4 +1999,100 @@ mod tests {
         let result = p.prune(&statistics).unwrap();
         assert_eq!(result, expected_ret);
     }
+
+    #[test]
+    fn prune_cast_column_scalar() {
+        // The data type of column i is INT32
+        let (schema, statistics) = int32_setup();
+        let expected_ret = vec![true, true, false, true, true];
+
+        // i > int64(0)
+        let expr = col("i").gt(cast(lit(ScalarValue::Int64(Some(0))), 
DataType::Int32));
+        let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
+        let result = p.prune(&statistics).unwrap();
+        assert_eq!(result, expected_ret);
+
+        // cast(i as int64) > int64(0)
+        let expr = cast(col("i"), 
DataType::Int64).gt(lit(ScalarValue::Int64(Some(0))));
+        let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
+        let result = p.prune(&statistics).unwrap();
+        assert_eq!(result, expected_ret);
+
+        // try_cast(i as int64) > int64(0)
+        let expr =
+            try_cast(col("i"), 
DataType::Int64).gt(lit(ScalarValue::Int64(Some(0))));
+        let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
+        let result = p.prune(&statistics).unwrap();
+        assert_eq!(result, expected_ret);
+
+        // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0`
+        let expr = Expr::Negative(Box::new(cast(col("i"), DataType::Int64)))
+            .lt(lit(ScalarValue::Int64(Some(0))));
+        let p = PruningPredicate::try_new(expr, schema).unwrap();
+        let result = p.prune(&statistics).unwrap();
+        assert_eq!(result, expected_ret);
+    }
+
+    #[test]
+    fn test_rewrite_expr_to_prunable() {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
+        let df_schema = DFSchema::try_from(schema).unwrap();
+        // column op lit
+        let left_input = col("a");
+        let right_input = lit(ScalarValue::Int32(Some(12)));
+        let (result_left, _, result_right) = rewrite_expr_to_prunable(
+            &left_input,
+            Operator::Eq,
+            &right_input,
+            df_schema.clone(),
+        )
+        .unwrap();
+        assert_eq!(result_left, left_input);
+        assert_eq!(result_right, right_input);
+        // cast op lit
+        let left_input = cast(col("a"), DataType::Decimal128(20, 3));
+        let right_input = lit(ScalarValue::Decimal128(Some(12), 20, 3));
+        let (result_left, _, result_right) = rewrite_expr_to_prunable(
+            &left_input,
+            Operator::Gt,
+            &right_input,
+            df_schema.clone(),
+        )
+        .unwrap();
+        assert_eq!(result_left, left_input);
+        assert_eq!(result_right, right_input);
+        // try_cast op lit
+        let left_input = try_cast(col("a"), DataType::Int64);
+        let right_input = lit(ScalarValue::Int64(Some(12)));
+        let (result_left, _, result_right) =
+            rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, 
df_schema)
+                .unwrap();
+        assert_eq!(result_left, left_input);
+        assert_eq!(result_right, right_input);
+        // TODO: add test for other case and op
+    }
+
+    #[test]
+    fn test_rewrite_expr_to_prunable_error() {
+        // cast string value to numeric value
+        // this cast is not supported
+        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
+        let df_schema = DFSchema::try_from(schema).unwrap();
+        let left_input = cast(col("a"), DataType::Int64);
+        let right_input = lit(ScalarValue::Int64(Some(12)));
+        let result = rewrite_expr_to_prunable(
+            &left_input,
+            Operator::Gt,
+            &right_input,
+            df_schema.clone(),
+        );
+        assert!(result.is_err());
+        // other expr
+        let left_input = is_null(col("a"));
+        let right_input = lit(ScalarValue::Int64(Some(12)));
+        let result =
+            rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, 
df_schema);
+        assert!(result.is_err());
+        // TODO: add other negative test for other case and op
+    }
 }
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 8b0f16466..f7eaec39b 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -259,6 +259,19 @@ pub fn cast(expr: Expr, data_type: DataType) -> Expr {
     }
 }
 
+/// Create a try cast expression
+pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
+    Expr::TryCast {
+        expr: Box::new(expr),
+        data_type,
+    }
+}
+
+/// Create is null expression
+pub fn is_null(expr: Expr) -> Expr {
+    Expr::IsNull(Box::new(expr))
+}
+
 /// Create an convenience function representing a unary scalar function
 macro_rules! unary_scalar_expr {
     ($ENUM:ident, $FUNC:ident, $DOC:expr) => {

Reply via email to