liukun4515 commented on a change in pull request #1073: URL: https://github.com/apache/arrow-rs/pull/1073#discussion_r773657049
########## File path: arrow/src/compute/kernels/cast.rs ########## @@ -1906,26 +1989,179 @@ where mod tests { use super::*; use crate::{buffer::Buffer, util::display::array_value_to_string}; - use num::traits::Pow; + + macro_rules! generate_cast_test_case { + ($INPUT_ARRAY: expr, $OUTPUT_TYPE_ARRAY: ident, $OUTPUT_TYPE: expr, $OUTPUT_VALUES: expr) => { + // assert cast type + let input_array_type = $INPUT_ARRAY.data_type(); + assert!(can_cast_types(input_array_type, $OUTPUT_TYPE)); + let casted_array = cast($INPUT_ARRAY, $OUTPUT_TYPE).unwrap(); + let result_array = casted_array + .as_any() + .downcast_ref::<$OUTPUT_TYPE_ARRAY>() + .unwrap(); + assert_eq!($OUTPUT_TYPE, result_array.data_type()); + assert_eq!(result_array.len(), $OUTPUT_VALUES.len()); + for (i, x) in $OUTPUT_VALUES.iter().enumerate() { + match x { + Some(x) => { + assert_eq!(result_array.value(i), *x); + } + None => { + assert!(result_array.is_null(i)); + } + } + } + }; + } + + // TODO remove this function if the decimal array has the creator function + fn create_decimal_array( + array: &[Option<i128>], + precision: usize, + scale: usize, + ) -> Result<DecimalArray> { + let mut decimal_builder = DecimalBuilder::new(array.len(), precision, scale); + for value in array { + match value { + None => { + decimal_builder.append_null()?; + } + Some(v) => { + decimal_builder.append_value(*v)?; + } + } + } + Ok(decimal_builder.finish()) + } #[test] - fn test_cast_numeric_to_decimal() { - // test cast type - let data_types = vec![ - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, + fn test_cast_decimal_to_numeric() { + let decimal_type = DataType::Decimal(38, 2); + // negative test + assert!(!can_cast_types(&decimal_type, &DataType::UInt8)); + let value_array: Vec<Option<i128>> = + vec![Some(125), Some(225), Some(325), None, Some(525)]; + let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let array = Arc::new(decimal_array) as ArrayRef; + // i8 + generate_cast_test_case!( + &array, + Int8Array, + &DataType::Int8, + vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)] + ); + // i16 + generate_cast_test_case!( + &array, + Int16Array, + &DataType::Int16, + vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)] + ); + // i32 + generate_cast_test_case!( + &array, + Int32Array, + &DataType::Int32, + vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)] + ); + // i64 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + // f32 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + // f64 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + + // overflow test: out of range of max i8 + let value_array: Vec<Option<i128>> = vec![Some(24400)]; + let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let array = Arc::new(decimal_array) as ArrayRef; + let casted_array = cast(&array, &DataType::Int8); + assert_eq!( + "Cast error: value of 244 is out of range Int8".to_string(), + casted_array.unwrap_err().to_string() + ); + + // loss the precision: convert decimal to f32、f64 + // f32 + // 112345678_f32 and 112345679_f32 are same, so the 112345679_f32 will lose precision. + let value_array: Vec<Option<i128>> = vec![ + Some(125), + Some(225), + Some(325), + None, + Some(525), + Some(112345678), + Some(112345679), + ]; + let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let array = Arc::new(decimal_array) as ArrayRef; + generate_cast_test_case!( + &array, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32), + Some(1_123_456.7_f32), + Some(1_123_456.7_f32) + ] + ); + + // f64 + // 112345678901234568_f64 and 112345678901234560_f64 are same, so the 112345678901234568_f64 will lose precision. + let value_array: Vec<Option<i128>> = vec![ + Some(125), + Some(225), + Some(325), + None, + Some(525), + Some(112345678901234568), Review comment: This case is used to test `float64 excessive precision` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org