This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 69d05aa0c support cast/try_cast in prune with signed integer and
decimal (#3422)
69d05aa0c is described below
commit 69d05aa0c563a478e28502c0a4e7822095859b28
Author: Kun Liu <[email protected]>
AuthorDate: Tue Sep 13 01:40:41 2022 +0800
support cast/try_cast in prune with signed integer and decimal (#3422)
* support cast/try_cast in prune
* add bound for supported data type in the cast/try_cast prune
---
datafusion/core/src/physical_optimizer/pruning.rs | 287 ++++++++++++++++++++--
datafusion/expr/src/expr_fn.rs | 13 +
2 files changed, 278 insertions(+), 22 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/pruning.rs
b/datafusion/core/src/physical_optimizer/pruning.rs
index 896b5ad9a..73f6c795c 100644
--- a/datafusion/core/src/physical_optimizer/pruning.rs
+++ b/datafusion/core/src/physical_optimizer/pruning.rs
@@ -43,9 +43,9 @@ use arrow::{
datatypes::{DataType, Field, Schema, SchemaRef},
record_batch::RecordBatch,
};
-use datafusion_expr::binary_expr;
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::utils::expr_to_columns;
+use datafusion_expr::{binary_expr, cast, try_cast, ExprSchemable};
use datafusion_physical_expr::create_physical_expr;
/// Interface to pass statistics information to [`PruningPredicate`]
@@ -429,11 +429,13 @@ impl<'a> PruningExpressionBuilder<'a> {
}
};
- let (column_expr, correct_operator, scalar_expr) =
- match rewrite_expr_to_prunable(column_expr, correct_operator,
scalar_expr) {
- Ok(ret) => ret,
- Err(e) => return Err(e),
- };
+ let df_schema = DFSchema::try_from(schema.clone())?;
+ let (column_expr, correct_operator, scalar_expr) =
rewrite_expr_to_prunable(
+ column_expr,
+ correct_operator,
+ scalar_expr,
+ df_schema,
+ )?;
let column = columns.iter().next().unwrap().clone();
let field = match schema.column_with_name(&column.flat_name()) {
Some((_, f)) => f,
@@ -481,12 +483,15 @@ impl<'a> PruningExpressionBuilder<'a> {
/// 2. `-col > 10` should be rewritten to `col < -10`
/// 3. `!col = true` would be rewritten to `col = !true`
/// 4. `abs(a - 10) > 0` not supported
+/// 5. `cast(can_prunable_expr) > 10`
+/// 6. `try_cast(can_prunable_expr) > 10`
///
/// More rewrite rules are still in progress.
fn rewrite_expr_to_prunable(
column_expr: &Expr,
op: Operator,
scalar_expr: &Expr,
+ schema: DFSchema,
) -> Result<(Expr, Operator, Expr)> {
if !is_compare_op(op) {
return Err(DataFusionError::Plan(
@@ -495,22 +500,29 @@ fn rewrite_expr_to_prunable(
}
match column_expr {
- // `col > lit()`
+ // `col op lit()`
Expr::Column(_) => Ok((column_expr.clone(), op, scalar_expr.clone())),
-
+ // `cast(col) op lit()`
+ Expr::Cast { expr, data_type } => {
+ let from_type = expr.get_type(&schema)?;
+ verify_support_type_for_prune(&from_type, data_type)?;
+ let (left, op, right) =
+ rewrite_expr_to_prunable(expr, op, scalar_expr, schema)?;
+ Ok((cast(left, data_type.clone()), op, right))
+ }
+ // `try_cast(col) op lit()`
+ Expr::TryCast { expr, data_type } => {
+ let from_type = expr.get_type(&schema)?;
+ verify_support_type_for_prune(&from_type, data_type)?;
+ let (left, op, right) =
+ rewrite_expr_to_prunable(expr, op, scalar_expr, schema)?;
+ Ok((try_cast(left, data_type.clone()), op, right))
+ }
// `-col > lit()` --> `col < -lit()`
- Expr::Negative(c) => match c.as_ref() {
- Expr::Column(_) => Ok((
- c.as_ref().clone(),
- reverse_operator(op),
- Expr::Negative(Box::new(scalar_expr.clone())),
- )),
- _ => Err(DataFusionError::Plan(format!(
- "negative with complex expression {:?} is not supported",
- column_expr
- ))),
- },
-
+ Expr::Negative(c) => {
+ let (left, op, right) = rewrite_expr_to_prunable(c, op,
scalar_expr, schema)?;
+ Ok((left, reverse_operator(op), Expr::Negative(Box::new(right))))
+ }
// `!col = true` --> `col = !true`
Expr::Not(c) => {
if op != Operator::Eq && op != Operator::NotEq {
@@ -551,6 +563,32 @@ fn is_compare_op(op: Operator) -> bool {
)
}
+// The pruning logic is based on the comparing the min/max bounds.
+// Must make sure the two type has order.
+// For example, casts from string to numbers is not correct.
+// Because the "13" is less than "3" with UTF8 comparison order.
+fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) ->
Result<()> {
+ // TODO: support other data type for prunable cast or try cast
+ if matches!(
+ from_type,
+ DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ | DataType::Decimal128(_, _)
+ ) && matches!(
+ to_type,
+ DataType::Int8 | DataType::Int32 | DataType::Int64 |
DataType::Decimal128(_, _)
+ ) {
+ Ok(())
+ } else {
+ Err(DataFusionError::Plan(format!(
+ "Try Cast/Cast with from type {} to type {} is not supported",
+ from_type, to_type
+ )))
+ }
+}
+
/// replaces a column with an old name with a new name in an expression
fn rewrite_column_expr(
e: Expr,
@@ -804,10 +842,10 @@ mod tests {
datatypes::{DataType, TimeUnit},
};
use datafusion_common::ScalarValue;
+ use datafusion_expr::{cast, is_null};
use std::collections::HashMap;
#[derive(Debug)]
-
/// Mock statistic provider for tests
///
/// Each row represents the statistics for a "container" (which
@@ -1508,6 +1546,78 @@ mod tests {
Ok(())
}
+ #[test]
+ fn row_group_predicate_cast() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("c1", DataType::Int32,
false)]);
+ let expected_expr =
+ "CAST(#c1_min AS Int64) <= Int64(1) AND Int64(1) <= CAST(#c1_max
AS Int64)";
+
+ // test column on the left
+ let expr = cast(col("c1"),
DataType::Int64).eq(lit(ScalarValue::Int64(Some(1))));
+ let predicate_expr =
+ build_predicate_expression(&expr, &schema, &mut
RequiredStatColumns::new())?;
+ assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+ // test column on the right
+ let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"),
DataType::Int64));
+ let predicate_expr =
+ build_predicate_expression(&expr, &schema, &mut
RequiredStatColumns::new())?;
+ assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+ let expected_expr = "TRY_CAST(#c1_max AS Int64) > Int64(1)";
+
+ // test column on the left
+ let expr =
+ try_cast(col("c1"),
DataType::Int64).gt(lit(ScalarValue::Int64(Some(1))));
+ let predicate_expr =
+ build_predicate_expression(&expr, &schema, &mut
RequiredStatColumns::new())?;
+ assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+ // test column on the right
+ let expr =
+ lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"),
DataType::Int64));
+ let predicate_expr =
+ build_predicate_expression(&expr, &schema, &mut
RequiredStatColumns::new())?;
+ assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+ Ok(())
+ }
+
+ #[test]
+ fn row_group_predicate_cast_list() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("c1", DataType::Int32,
false)]);
+ // test cast(c1 as int64) in int64(1, 2, 3)
+ let expr = Expr::InList {
+ expr: Box::new(cast(col("c1"), DataType::Int64)),
+ list: vec![
+ lit(ScalarValue::Int64(Some(1))),
+ lit(ScalarValue::Int64(Some(2))),
+ lit(ScalarValue::Int64(Some(3))),
+ ],
+ negated: false,
+ };
+ let expected_expr = "CAST(#c1_min AS Int64) <= Int64(1) AND Int64(1)
<= CAST(#c1_max AS Int64) OR CAST(#c1_min AS Int64) <= Int64(2) AND Int64(2) <=
CAST(#c1_max AS Int64) OR CAST(#c1_min AS Int64) <= Int64(3) AND Int64(3) <=
CAST(#c1_max AS Int64)";
+ let predicate_expr =
+ build_predicate_expression(&expr, &schema, &mut
RequiredStatColumns::new())?;
+ assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+ let expr = Expr::InList {
+ expr: Box::new(cast(col("c1"), DataType::Int64)),
+ list: vec![
+ lit(ScalarValue::Int64(Some(1))),
+ lit(ScalarValue::Int64(Some(2))),
+ lit(ScalarValue::Int64(Some(3))),
+ ],
+ negated: true,
+ };
+ let expected_expr = "CAST(#c1_min AS Int64) != Int64(1) OR Int64(1) !=
CAST(#c1_max AS Int64) AND CAST(#c1_min AS Int64) != Int64(2) OR Int64(2) !=
CAST(#c1_max AS Int64) AND CAST(#c1_min AS Int64) != Int64(3) OR Int64(3) !=
CAST(#c1_max AS Int64)";
+ let predicate_expr =
+ build_predicate_expression(&expr, &schema, &mut
RequiredStatColumns::new())?;
+ assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+ Ok(())
+ }
+
#[test]
fn prune_decimal_data() {
// decimal(9,2)
@@ -1527,6 +1637,36 @@ mod tests {
vec![Some(5), Some(6), Some(4), None], // max
),
);
+ let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
+ let result = p.prune(&statistics).unwrap();
+ let expected = vec![false, true, false, true];
+ assert_eq!(result, expected);
+
+ // with cast column to other type
+ let expr = cast(col("s1"), DataType::Decimal128(14, 3))
+ .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3)));
+ let statistics = TestStatistics::new().with(
+ "s1",
+ ContainerStats::new_i32(
+ vec![Some(0), Some(4), None, Some(3)], // min
+ vec![Some(5), Some(6), Some(4), None], // max
+ ),
+ );
+ let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
+ let result = p.prune(&statistics).unwrap();
+ let expected = vec![false, true, false, true];
+ assert_eq!(result, expected);
+
+ // with try cast column to other type
+ let expr = try_cast(col("s1"), DataType::Decimal128(14, 3))
+ .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3)));
+ let statistics = TestStatistics::new().with(
+ "s1",
+ ContainerStats::new_i32(
+ vec![Some(0), Some(4), None, Some(3)], // min
+ vec![Some(5), Some(6), Some(4), None], // max
+ ),
+ );
let p = PruningPredicate::try_new(expr, schema).unwrap();
let result = p.prune(&statistics).unwrap();
let expected = vec![false, true, false, true];
@@ -1576,6 +1716,7 @@ mod tests {
let expected = vec![false, true, false, true];
assert_eq!(result, expected);
}
+
#[test]
fn prune_api() {
let schema = Arc::new(Schema::new(vec![
@@ -1599,10 +1740,16 @@ mod tests {
// No stats for s2 ==> some rows could pass
// s2 [3, None] (null max) ==> some rows could pass
- let p = PruningPredicate::try_new(expr, schema).unwrap();
+ let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
let result = p.prune(&statistics).unwrap();
let expected = vec![false, true, true, true];
+ assert_eq!(result, expected);
+ // filter with cast
+ let expr = cast(col("s2"),
DataType::Int64).gt(lit(ScalarValue::Int64(Some(5))));
+ let p = PruningPredicate::try_new(expr, schema).unwrap();
+ let result = p.prune(&statistics).unwrap();
+ let expected = vec![false, true, true, true];
assert_eq!(result, expected);
}
@@ -1852,4 +1999,100 @@ mod tests {
let result = p.prune(&statistics).unwrap();
assert_eq!(result, expected_ret);
}
+
+ #[test]
+ fn prune_cast_column_scalar() {
+ // The data type of column i is INT32
+ let (schema, statistics) = int32_setup();
+ let expected_ret = vec![true, true, false, true, true];
+
+ // i > int64(0)
+ let expr = col("i").gt(cast(lit(ScalarValue::Int64(Some(0))),
DataType::Int32));
+ let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
+ let result = p.prune(&statistics).unwrap();
+ assert_eq!(result, expected_ret);
+
+ // cast(i as int64) > int64(0)
+ let expr = cast(col("i"),
DataType::Int64).gt(lit(ScalarValue::Int64(Some(0))));
+ let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
+ let result = p.prune(&statistics).unwrap();
+ assert_eq!(result, expected_ret);
+
+ // try_cast(i as int64) > int64(0)
+ let expr =
+ try_cast(col("i"),
DataType::Int64).gt(lit(ScalarValue::Int64(Some(0))));
+ let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
+ let result = p.prune(&statistics).unwrap();
+ assert_eq!(result, expected_ret);
+
+ // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0`
+ let expr = Expr::Negative(Box::new(cast(col("i"), DataType::Int64)))
+ .lt(lit(ScalarValue::Int64(Some(0))));
+ let p = PruningPredicate::try_new(expr, schema).unwrap();
+ let result = p.prune(&statistics).unwrap();
+ assert_eq!(result, expected_ret);
+ }
+
+ #[test]
+ fn test_rewrite_expr_to_prunable() {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
+ let df_schema = DFSchema::try_from(schema).unwrap();
+ // column op lit
+ let left_input = col("a");
+ let right_input = lit(ScalarValue::Int32(Some(12)));
+ let (result_left, _, result_right) = rewrite_expr_to_prunable(
+ &left_input,
+ Operator::Eq,
+ &right_input,
+ df_schema.clone(),
+ )
+ .unwrap();
+ assert_eq!(result_left, left_input);
+ assert_eq!(result_right, right_input);
+ // cast op lit
+ let left_input = cast(col("a"), DataType::Decimal128(20, 3));
+ let right_input = lit(ScalarValue::Decimal128(Some(12), 20, 3));
+ let (result_left, _, result_right) = rewrite_expr_to_prunable(
+ &left_input,
+ Operator::Gt,
+ &right_input,
+ df_schema.clone(),
+ )
+ .unwrap();
+ assert_eq!(result_left, left_input);
+ assert_eq!(result_right, right_input);
+ // try_cast op lit
+ let left_input = try_cast(col("a"), DataType::Int64);
+ let right_input = lit(ScalarValue::Int64(Some(12)));
+ let (result_left, _, result_right) =
+ rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input,
df_schema)
+ .unwrap();
+ assert_eq!(result_left, left_input);
+ assert_eq!(result_right, right_input);
+ // TODO: add test for other case and op
+ }
+
+ #[test]
+ fn test_rewrite_expr_to_prunable_error() {
+ // cast string value to numeric value
+ // this cast is not supported
+ let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
+ let df_schema = DFSchema::try_from(schema).unwrap();
+ let left_input = cast(col("a"), DataType::Int64);
+ let right_input = lit(ScalarValue::Int64(Some(12)));
+ let result = rewrite_expr_to_prunable(
+ &left_input,
+ Operator::Gt,
+ &right_input,
+ df_schema.clone(),
+ );
+ assert!(result.is_err());
+ // other expr
+ let left_input = is_null(col("a"));
+ let right_input = lit(ScalarValue::Int64(Some(12)));
+ let result =
+ rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input,
df_schema);
+ assert!(result.is_err());
+ // TODO: add other negative test for other case and op
+ }
}
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 8b0f16466..f7eaec39b 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -259,6 +259,19 @@ pub fn cast(expr: Expr, data_type: DataType) -> Expr {
}
}
+/// Create a try cast expression
+pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
+ Expr::TryCast {
+ expr: Box::new(expr),
+ data_type,
+ }
+}
+
+/// Create is null expression
+pub fn is_null(expr: Expr) -> Expr {
+ Expr::IsNull(Box::new(expr))
+}
+
/// Create an convenience function representing a unary scalar function
macro_rules! unary_scalar_expr {
($ENUM:ident, $FUNC:ident, $DOC:expr) => {