This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push: new 4653df465 Support unsigned integers in `unwrap_cast_in_comparison` Optimizer rule (#4149) 4653df465 is described below commit 4653df4652c8af5d4e8841c489c1ab2e85a54e69 Author: Andrew Lamb <and...@nerdnetworks.org> AuthorDate: Tue Nov 15 14:00:29 2022 -0500 Support unsigned integers in `unwrap_cast_in_comparison` Optimizer rule (#4149) * Support unsigned integers in `unwrap_cast_in_comparison` Optimizer rule * Update comment --- datafusion/core/tests/sql/joins.rs | 12 +- .../optimizer/src/unwrap_cast_in_comparison.rs | 126 ++++++++++++++++++--- datafusion/optimizer/tests/integration-test.rs | 6 +- 3 files changed, 120 insertions(+), 24 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 10d024025..324ccb4c7 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1428,9 +1428,9 @@ async fn reduce_left_join_1() -> Result<()> { "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: CAST(t1.t1_id AS Int64) < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: CAST(t2.t2_id AS Int64) < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; let formatted = plan.display_indent_schema().to_string(); @@ -1476,10 +1476,10 @@ async fn reduce_left_join_2() -> Result<()> { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; let formatted = plan.display_indent_schema().to_string(); @@ -1524,9 +1524,9 @@ async fn reduce_left_join_3() -> Result<()> { " Projection: t1.t1_id, t1.t1_name, t1.t1_int, alias=t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: CAST(t1.t1_id AS Int64) < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: CAST(t2.t2_int AS Int64) < Int64(3) AND CAST(t2.t2_id AS Int64) < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: t2.t2_int < UInt32(3) AND t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 28b085684..5f542d749 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -283,7 +283,11 @@ fn is_comparison_op(op: &Operator) -> bool { fn is_support_data_type(data_type: &DataType) -> bool { matches!( data_type, - DataType::Int8 + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 @@ -292,6 +296,25 @@ fn is_support_data_type(data_type: &DataType) -> bool { ) } +fn is_decimal_type(dt: &DataType) -> bool { + matches!(dt, DataType::Decimal128(_, _)) +} + +fn is_unsigned_type(dt: &DataType) -> bool { + matches!( + dt, + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 + ) +} + +/// Until https://github.com/apache/arrow-rs/issues/1043 is done +/// (support for unsigned <--> decimal casts) we also don't do that +/// kind of cast in this optimizer +fn is_unsupported_cast(dt1: &DataType, dt2: &DataType) -> bool { + (is_decimal_type(dt1) && is_unsigned_type(dt2)) + || (is_decimal_type(dt2) && is_unsigned_type(dt1)) +} + fn try_cast_literal_to_type( lit_value: &ScalarValue, target_type: &DataType, @@ -301,12 +324,22 @@ fn try_cast_literal_to_type( if !is_support_data_type(&lit_data_type) || !is_support_data_type(target_type) { return Ok(None); } + if is_unsupported_cast(&lit_data_type, target_type) { + return Ok(None); + } if lit_value.is_null() { // null value can be cast to any type of null value return Ok(Some(ScalarValue::try_from(target_type)?)); } let mul = match target_type { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => 1_i128, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 => 1_i128, DataType::Timestamp(_, _) => 1_i128, DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), other_type => { @@ -317,6 +350,10 @@ fn try_cast_literal_to_type( } }; let (target_min, target_max) = match target_type { + DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128), + DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128), + DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128), + DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128), DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), @@ -341,6 +378,10 @@ fn try_cast_literal_to_type( ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul), ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul), ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul), ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul), @@ -383,6 +424,10 @@ fn try_cast_literal_to_type( DataType::Int16 => ScalarValue::Int16(Some(value as i16)), DataType::Int32 => ScalarValue::Int32(Some(value as i32)), DataType::Int64 => ScalarValue::Int64(Some(value as i64)), + DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)), + DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)), + DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), + DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), DataType::Timestamp(TimeUnit::Second, tz) => { ScalarValue::TimestampSecond(Some(value as i64), tz.clone()) } @@ -469,6 +514,15 @@ mod tests { assert_eq!(optimize_test(lit_lt_lit, &schema), expected); } + #[test] + fn test_unwrap_cast_comparison_unsigned() { + // "cast(c6, UINT64) = 0u64 => c6 = 0u32 + let schema = expr_test_schema(); + let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64)); + let expected = col("c6").eq(lit(0u32)); + assert_eq!(optimize_test(expr_input, &schema), expected); + } + #[test] fn test_not_unwrap_cast_with_decimal_comparison() { let schema = expr_test_schema(); @@ -635,16 +689,16 @@ mod tests { #[test] fn test_not_support_data_type() { - // "c6 > 0" will be cast to `cast(c6 as int64) > 0 + // "c6 > 0" will be cast to `cast(c6 as float) > 0 // but the type of c6 is uint32 // the rewriter will not throw error and just return the original expr let schema = expr_test_schema(); - let expr_input = cast(col("c6"), DataType::Int64).eq(lit(0i64)); + let expr_input = cast(col("c6"), DataType::Float64).eq(lit(0f64)); assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); // inlist for unsupported data type let expr_input = - in_list(cast(col("c6"), DataType::Int64), vec![lit(0i64)], false); + in_list(cast(col("c6"), DataType::Float64), vec![lit(0f64)], false); assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); } @@ -733,17 +787,24 @@ mod tests { ScalarValue::Int16(None), ScalarValue::Int32(None), ScalarValue::Int64(None), + ScalarValue::UInt8(None), + ScalarValue::UInt16(None), + ScalarValue::UInt32(None), + ScalarValue::UInt64(None), ScalarValue::Decimal128(None, 3, 0), ScalarValue::Decimal128(None, 8, 2), ]; for s1 in &scalars { for s2 in &scalars { - expect_cast( - s1.clone(), - s2.get_datatype(), - ExpectedCast::Value(s2.clone()), - ); + let expected_value = + if is_unsupported_cast(&s1.get_datatype(), &s2.get_datatype()) { + ExpectedCast::NoValue + } else { + ExpectedCast::Value(s2.clone()) + }; + + expect_cast(s1.clone(), s2.get_datatype(), expected_value); } } } @@ -756,25 +817,56 @@ mod tests { ScalarValue::Int16(Some(123)), ScalarValue::Int32(Some(123)), ScalarValue::Int64(Some(123)), + ScalarValue::UInt8(Some(123)), + ScalarValue::UInt16(Some(123)), + ScalarValue::UInt32(Some(123)), + ScalarValue::UInt64(Some(123)), ScalarValue::Decimal128(Some(123), 3, 0), ScalarValue::Decimal128(Some(12300), 8, 2), ]; for s1 in &scalars { for s2 in &scalars { - expect_cast( - s1.clone(), - s2.get_datatype(), - ExpectedCast::Value(s2.clone()), - ); + let expected_value = + if is_unsupported_cast(&s1.get_datatype(), &s2.get_datatype()) { + ExpectedCast::NoValue + } else { + ExpectedCast::Value(s2.clone()) + }; + + expect_cast(s1.clone(), s2.get_datatype(), expected_value); } } + + let max_i32 = ScalarValue::Int32(Some(i32::MAX)); + expect_cast( + max_i32, + DataType::UInt64, + ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))), + ); + + let min_i32 = ScalarValue::Int32(Some(i32::MIN)); + expect_cast( + min_i32, + DataType::Int64, + ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))), + ); + + let max_i64 = ScalarValue::Int64(Some(i64::MAX)); + expect_cast( + max_i64, + DataType::UInt64, + ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))), + ); } #[test] fn test_try_cast_to_type_int_out_of_range() { + let min_i32 = ScalarValue::Int32(Some(i32::MIN)); + let min_i64 = ScalarValue::Int64(Some(i64::MIN)); let max_i64 = ScalarValue::Int64(Some(i64::MAX)); let max_u64 = ScalarValue::UInt64(Some(u64::MAX)); + expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue); expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue); @@ -783,6 +875,10 @@ mod tests { expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue); + expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue); + + expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue); + // decimal out of range expect_cast( ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0), diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 779f156c0..2fdec1f2a 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -47,8 +47,8 @@ fn case_when() -> Result<()> { let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test"; let plan = test_sql(sql)?; - let expected = "Projection: CASE WHEN CAST(test.col_uint32 AS Int64) > Int64(0) THEN Int64(1) ELSE Int64(0) END\ - \n TableScan: test projection=[col_uint32]"; + let expected = "Projection: CASE WHEN test.col_uint32 > UInt32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_uint32 > Int64(0) THEN Int64(1) ELSE Int64(0) END\ + \n TableScan: test projection=[col_uint32]"; assert_eq!(expected, format!("{:?}", plan)); Ok(()) } @@ -91,7 +91,7 @@ fn unsigned_target_type() -> Result<()> { let sql = "SELECT col_utf8 FROM test WHERE col_uint32 > 0"; let plan = test_sql(sql)?; let expected = "Projection: test.col_utf8\ - \n Filter: CAST(test.col_uint32 AS Int64) > Int64(0)\ + \n Filter: test.col_uint32 > UInt32(0)\ \n TableScan: test projection=[col_uint32, col_utf8]"; assert_eq!(expected, format!("{:?}", plan)); Ok(())