andygrove commented on code in PR #3422:
URL: https://github.com/apache/arrow-datafusion/pull/3422#discussion_r967425770
##########
datafusion/core/src/physical_optimizer/pruning.rs:
##########
@@ -1508,6 +1510,78 @@ mod tests {
Ok(())
}
+ #[test]
+ fn row_group_predicate_cast() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("c1", DataType::Int32,
false)]);
+ let expected_expr =
+ "CAST(#c1_min AS Int64) <= Int64(1) AND Int64(1) <= CAST(#c1_max
AS Int64)";
+
+ // test column on the left
+ let expr = cast(col("c1"),
DataType::Int64).eq(lit(ScalarValue::Int64(Some(1))));
+ let predicate_expr =
+ build_predicate_expression(&expr, &schema, &mut
RequiredStatColumns::new())?;
+ assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+ // test column on the right
+ let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"),
DataType::Int64));
+ let predicate_expr =
+ build_predicate_expression(&expr, &schema, &mut
RequiredStatColumns::new())?;
+ assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+ let expected_expr = "TRY_CAST(#c1_max AS Int64) > Int64(1)";
+
+ // test column on the left
+ let expr =
+ try_cast(col("c1"),
DataType::Int64).gt(lit(ScalarValue::Int64(Some(1))));
+ let predicate_expr =
+ build_predicate_expression(&expr, &schema, &mut
RequiredStatColumns::new())?;
+ assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+ // test column on the right
+ let expr =
+ lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"),
DataType::Int64));
+ let predicate_expr =
+ build_predicate_expression(&expr, &schema, &mut
RequiredStatColumns::new())?;
+ assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+ Ok(())
+ }
+
+ #[test]
+ fn row_group_predicate_cast_list() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("c1", DataType::Int32,
false)]);
+ // test cast(c1 as int64) in int64(1, 2, 3)
+ let expr = Expr::InList {
+ expr: Box::new(cast(col("c1"), DataType::Int64)),
+ list: vec![
+ lit(ScalarValue::Int64(Some(1))),
+ lit(ScalarValue::Int64(Some(2))),
+ lit(ScalarValue::Int64(Some(3))),
+ ],
+ negated: false,
+ };
+ let expected_expr = "CAST(#c1_min AS Int64) <= Int64(1) AND Int64(1)
<= CAST(#c1_max AS Int64) OR CAST(#c1_min AS Int64) <= Int64(2) AND Int64(2) <=
CAST(#c1_max AS Int64) OR CAST(#c1_min AS Int64) <= Int64(3) AND Int64(3) <=
CAST(#c1_max AS Int64)";
+ let predicate_expr =
+ build_predicate_expression(&expr, &schema, &mut
RequiredStatColumns::new())?;
+ assert_eq!(format!("{:?}", predicate_expr), expected_expr);
+
+ let expr = Expr::InList {
+ expr: Box::new(cast(col("c1"), DataType::Int64)),
+ list: vec![
+ lit(ScalarValue::Int64(Some(1))),
+ lit(ScalarValue::Int64(Some(2))),
+ lit(ScalarValue::Int64(Some(3))),
+ ],
+ negated: true,
+ };
+ let expected_expr = "CAST(#c1_min AS Int64) != Int64(1) OR Int64(1) !=
CAST(#c1_max AS Int64) AND CAST(#c1_min AS Int64) != Int64(2) OR Int64(2) !=
CAST(#c1_max AS Int64) AND CAST(#c1_min AS Int64) != Int64(3) OR Int64(3) !=
CAST(#c1_max AS Int64)";
Review Comment:
I'm not sure I understand what is happening with the negated case. Could you
add a comment here explaining this? Why is there `!=` in here rather than `<=`
or `>=`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]