This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch active_release
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/active_release by this push:
     new 7c748bf  support cast decimal to signed numeric (#1073) (#1089)
7c748bf is described below

commit 7c748bfccbc2eac0c1138378736b70dcb7e26a5b
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Dec 22 15:49:14 2021 -0500

    support cast decimal to signed numeric (#1073) (#1089)
    
    * add cast test macro function; refactor other type to decimal type; add 
decimal to signed numeric type
    support decimal to unsigned numeric
    
    * address the comments and fix the clippy
    
    Co-authored-by: Kun Liu <[email protected]>
---
 arrow/src/compute/kernels/cast.rs | 370 ++++++++++++++++++++++++++++++++------
 1 file changed, 312 insertions(+), 58 deletions(-)

diff --git a/arrow/src/compute/kernels/cast.rs 
b/arrow/src/compute/kernels/cast.rs
index c3d20cc..42e5e7f 100644
--- a/arrow/src/compute/kernels/cast.rs
+++ b/arrow/src/compute/kernels/cast.rs
@@ -68,8 +68,12 @@ pub fn can_cast_types(from_type: &DataType, to_type: 
&DataType) -> bool {
     }
 
     match (from_type, to_type) {
-        // TODO now just support signed numeric to decimal, support decimal to 
numeric later
-        (Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal(_, _))
+        // TODO UTF8/unsigned numeric to decimal
+        // TODO decimal to decimal type
+        // signed numeric to decimal
+        (Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal(_, _)) |
+        // decimal to signed numeric
+        (Decimal(_, _), Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
         | (
             Null,
             Boolean
@@ -108,6 +112,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: 
&DataType) -> bool {
             | Dictionary(_, _),
             Null,
         ) => true,
+        (Decimal(_, _), _) => false,
+        (_, Decimal(_, _)) => false,
         (Struct(_), _) => false,
         (_, Struct(_)) => false,
         (LargeList(list_from), LargeList(list_to)) => {
@@ -345,6 +351,56 @@ macro_rules! cast_floating_point_to_decimal {
     }};
 }
 
+// cast the decimal array to integer array
+macro_rules! cast_decimal_to_integer {
+    ($ARRAY:expr, $SCALE : ident, $VALUE_BUILDER: ident, $NATIVE_TYPE : ident, 
$DATA_TYPE : expr) => {{
+        let array = $ARRAY.as_any().downcast_ref::<DecimalArray>().unwrap();
+        let mut value_builder = $VALUE_BUILDER::new(array.len());
+        let div: i128 = 10_i128.pow(*$SCALE as u32);
+        let min_bound = ($NATIVE_TYPE::MIN) as i128;
+        let max_bound = ($NATIVE_TYPE::MAX) as i128;
+        for i in 0..array.len() {
+            if array.is_null(i) {
+                value_builder.append_null()?;
+            } else {
+                let v = array.value(i) / div;
+                // check the overflow
+                // For example: Decimal(128,10,0) as i8
+                // 128 is out of range i8
+                if v <= max_bound && v >= min_bound {
+                    value_builder.append_value(v as $NATIVE_TYPE)?;
+                } else {
+                    return Err(ArrowError::CastError(format!(
+                        "value of {} is out of range {}",
+                        v, $DATA_TYPE
+                    )));
+                }
+            }
+        }
+        Ok(Arc::new(value_builder.finish()))
+    }};
+}
+
+// cast the decimal array to floating-point array
+macro_rules! cast_decimal_to_float {
+    ($ARRAY:expr, $SCALE : ident, $VALUE_BUILDER: ident, $NATIVE_TYPE : ty) => 
{{
+        let array = $ARRAY.as_any().downcast_ref::<DecimalArray>().unwrap();
+        let div = 10_f64.powi(*$SCALE as i32);
+        let mut value_builder = $VALUE_BUILDER::new(array.len());
+        for i in 0..array.len() {
+            if array.is_null(i) {
+                value_builder.append_null()?;
+            } else {
+                // The range of f32 or f64 is larger than i128, we don't need 
to check overflow.
+                // cast the i128 to f64 will lose precision, for example the 
`112345678901234568` will be as `112345678901234560`.
+                let v = (array.value(i) as f64 / div) as $NATIVE_TYPE;
+                value_builder.append_value(v)?;
+            }
+        }
+        Ok(Arc::new(value_builder.finish()))
+    }};
+}
+
 /// Cast `array` to the provided data type and return a new Array with
 /// type `to_type`, if possible. It accepts `CastOptions` to allow consumers
 /// to configure cast behavior.
@@ -379,6 +435,33 @@ pub fn cast_with_options(
         return Ok(array.clone());
     }
     match (from_type, to_type) {
+        (Decimal(_, scale), _) => {
+            // cast decimal to other type
+            match to_type {
+                Int8 => {
+                    cast_decimal_to_integer!(array, scale, Int8Builder, i8, 
Int8)
+                }
+                Int16 => {
+                    cast_decimal_to_integer!(array, scale, Int16Builder, i16, 
Int16)
+                }
+                Int32 => {
+                    cast_decimal_to_integer!(array, scale, Int32Builder, i32, 
Int32)
+                }
+                Int64 => {
+                    cast_decimal_to_integer!(array, scale, Int64Builder, i64, 
Int64)
+                }
+                Float32 => {
+                    cast_decimal_to_float!(array, scale, Float32Builder, f32)
+                }
+                Float64 => {
+                    cast_decimal_to_float!(array, scale, Float64Builder, f64)
+                }
+                _ => Err(ArrowError::CastError(format!(
+                    "Casting from {:?} to {:?} not supported",
+                    from_type, to_type
+                ))),
+            }
+        }
         (_, Decimal(precision, scale)) => {
             // cast data to decimal
             match from_type {
@@ -1970,26 +2053,179 @@ where
 mod tests {
     use super::*;
     use crate::{buffer::Buffer, util::display::array_value_to_string};
-    use num::traits::Pow;
+
+    macro_rules! generate_cast_test_case {
+        ($INPUT_ARRAY: expr, $OUTPUT_TYPE_ARRAY: ident, $OUTPUT_TYPE: expr, 
$OUTPUT_VALUES: expr) => {
+            // assert cast type
+            let input_array_type = $INPUT_ARRAY.data_type();
+            assert!(can_cast_types(input_array_type, $OUTPUT_TYPE));
+            let casted_array = cast($INPUT_ARRAY, $OUTPUT_TYPE).unwrap();
+            let result_array = casted_array
+                .as_any()
+                .downcast_ref::<$OUTPUT_TYPE_ARRAY>()
+                .unwrap();
+            assert_eq!($OUTPUT_TYPE, result_array.data_type());
+            assert_eq!(result_array.len(), $OUTPUT_VALUES.len());
+            for (i, x) in $OUTPUT_VALUES.iter().enumerate() {
+                match x {
+                    Some(x) => {
+                        assert_eq!(result_array.value(i), *x);
+                    }
+                    None => {
+                        assert!(result_array.is_null(i));
+                    }
+                }
+            }
+        };
+    }
+
+    // TODO remove this function if the decimal array has the creator function
+    fn create_decimal_array(
+        array: &[Option<i128>],
+        precision: usize,
+        scale: usize,
+    ) -> Result<DecimalArray> {
+        let mut decimal_builder = DecimalBuilder::new(array.len(), precision, 
scale);
+        for value in array {
+            match value {
+                None => {
+                    decimal_builder.append_null()?;
+                }
+                Some(v) => {
+                    decimal_builder.append_value(*v)?;
+                }
+            }
+        }
+        Ok(decimal_builder.finish())
+    }
 
     #[test]
-    fn test_cast_numeric_to_decimal() {
-        // test cast type
-        let data_types = vec![
-            DataType::Int8,
-            DataType::Int16,
-            DataType::Int32,
-            DataType::Int64,
-            DataType::Float32,
-            DataType::Float64,
+    fn test_cast_decimal_to_numeric() {
+        let decimal_type = DataType::Decimal(38, 2);
+        // negative test
+        assert!(!can_cast_types(&decimal_type, &DataType::UInt8));
+        let value_array: Vec<Option<i128>> =
+            vec![Some(125), Some(225), Some(325), None, Some(525)];
+        let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap();
+        let array = Arc::new(decimal_array) as ArrayRef;
+        // i8
+        generate_cast_test_case!(
+            &array,
+            Int8Array,
+            &DataType::Int8,
+            vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)]
+        );
+        // i16
+        generate_cast_test_case!(
+            &array,
+            Int16Array,
+            &DataType::Int16,
+            vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)]
+        );
+        // i32
+        generate_cast_test_case!(
+            &array,
+            Int32Array,
+            &DataType::Int32,
+            vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)]
+        );
+        // i64
+        generate_cast_test_case!(
+            &array,
+            Int64Array,
+            &DataType::Int64,
+            vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)]
+        );
+        // f32
+        generate_cast_test_case!(
+            &array,
+            Int64Array,
+            &DataType::Int64,
+            vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)]
+        );
+        // f64
+        generate_cast_test_case!(
+            &array,
+            Int64Array,
+            &DataType::Int64,
+            vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)]
+        );
+
+        // overflow test: out of range of max i8
+        let value_array: Vec<Option<i128>> = vec![Some(24400)];
+        let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap();
+        let array = Arc::new(decimal_array) as ArrayRef;
+        let casted_array = cast(&array, &DataType::Int8);
+        assert_eq!(
+            "Cast error: value of 244 is out of range Int8".to_string(),
+            casted_array.unwrap_err().to_string()
+        );
+
+        // loss the precision: convert decimal to f32、f64
+        // f32
+        // 112345678_f32 and 112345679_f32 are same, so the 112345679_f32 will 
lose precision.
+        let value_array: Vec<Option<i128>> = vec![
+            Some(125),
+            Some(225),
+            Some(325),
+            None,
+            Some(525),
+            Some(112345678),
+            Some(112345679),
+        ];
+        let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap();
+        let array = Arc::new(decimal_array) as ArrayRef;
+        generate_cast_test_case!(
+            &array,
+            Float32Array,
+            &DataType::Float32,
+            vec![
+                Some(1.25_f32),
+                Some(2.25_f32),
+                Some(3.25_f32),
+                None,
+                Some(5.25_f32),
+                Some(1_123_456.7_f32),
+                Some(1_123_456.7_f32)
+            ]
+        );
+
+        // f64
+        // 112345678901234568_f64 and 112345678901234560_f64 are same, so the 
112345678901234568_f64 will lose precision.
+        let value_array: Vec<Option<i128>> = vec![
+            Some(125),
+            Some(225),
+            Some(325),
+            None,
+            Some(525),
+            Some(112345678901234568),
+            Some(112345678901234560),
         ];
+        let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap();
+        let array = Arc::new(decimal_array) as ArrayRef;
+        generate_cast_test_case!(
+            &array,
+            Float64Array,
+            &DataType::Float64,
+            vec![
+                Some(1.25_f64),
+                Some(2.25_f64),
+                Some(3.25_f64),
+                None,
+                Some(5.25_f64),
+                Some(1_123_456_789_012_345.6_f64),
+                Some(1_123_456_789_012_345.6_f64),
+            ]
+        );
+    }
+
+    #[test]
+    fn test_cast_numeric_to_decimal() {
+        // test negative cast type
         let decimal_type = DataType::Decimal(38, 6);
-        for data_type in data_types {
-            assert!(can_cast_types(&data_type, &decimal_type))
-        }
         assert!(!can_cast_types(&DataType::UInt64, &decimal_type));
 
-        // test cast data
+        // i8, i16, i32, i64
         let input_datas = vec![
             Arc::new(Int8Array::from(vec![
                 Some(1),
@@ -2020,25 +2256,19 @@ mod tests {
                 Some(5),
             ])) as ArrayRef, // i64
         ];
-
-        // i8, i16, i32, i64
         for array in input_datas {
-            let casted_array = cast(&array, &decimal_type).unwrap();
-            let decimal_array = casted_array
-                .as_any()
-                .downcast_ref::<DecimalArray>()
-                .unwrap();
-            assert_eq!(&decimal_type, decimal_array.data_type());
-            for i in 0..array.len() {
-                if i == 3 {
-                    assert!(decimal_array.is_null(i as usize));
-                } else {
-                    assert_eq!(
-                        10_i128.pow(6) * (i as i128 + 1),
-                        decimal_array.value(i as usize)
-                    );
-                }
-            }
+            generate_cast_test_case!(
+                &array,
+                DecimalArray,
+                &decimal_type,
+                vec![
+                    Some(1000000_i128),
+                    Some(2000000_i128),
+                    Some(3000000_i128),
+                    None,
+                    Some(5000000_i128)
+                ]
+            );
         }
 
         // test i8 to decimal type with overflow the result type
@@ -2050,34 +2280,54 @@ mod tests {
         assert_eq!("Invalid argument error: The value of 1000 i128 is not 
compatible with Decimal(3,1)", casted_array.unwrap_err().to_string());
 
         // test f32 to decimal type
-        let f_data: Vec<f32> = vec![1.1, 2.2, 4.4, 1.123_456_8];
-        let array = Float32Array::from(f_data.clone());
+        let array = Float32Array::from(vec![
+            Some(1.1),
+            Some(2.2),
+            Some(4.4),
+            None,
+            Some(1.123_456_7),
+            Some(1.123_456_7),
+        ]);
         let array = Arc::new(array) as ArrayRef;
-        let casted_array = cast(&array, &decimal_type).unwrap();
-        let decimal_array = casted_array
-            .as_any()
-            .downcast_ref::<DecimalArray>()
-            .unwrap();
-        assert_eq!(&decimal_type, decimal_array.data_type());
-        for (i, item) in f_data.iter().enumerate().take(array.len()) {
-            let left = (*item as f64) * 10_f64.pow(6);
-            assert_eq!(left as i128, decimal_array.value(i as usize));
-        }
+        generate_cast_test_case!(
+            &array,
+            DecimalArray,
+            &decimal_type,
+            vec![
+                Some(1100000_i128),
+                Some(2200000_i128),
+                Some(4400000_i128),
+                None,
+                Some(1123456_i128),
+                Some(1123456_i128),
+            ]
+        );
 
         // test f64 to decimal type
-        let f_data: Vec<f64> = vec![1.1, 2.2, 4.4, 1.123_456_789_123_4];
-        let array = Float64Array::from(f_data.clone());
+        let array = Float64Array::from(vec![
+            Some(1.1),
+            Some(2.2),
+            Some(4.4),
+            None,
+            Some(1.123_456_789_123_4),
+            Some(1.123_456_789_012_345_6),
+            Some(1.123_456_789_012_345_6),
+        ]);
         let array = Arc::new(array) as ArrayRef;
-        let casted_array = cast(&array, &decimal_type).unwrap();
-        let decimal_array = casted_array
-            .as_any()
-            .downcast_ref::<DecimalArray>()
-            .unwrap();
-        assert_eq!(&decimal_type, decimal_array.data_type());
-        for (i, item) in f_data.iter().enumerate().take(array.len()) {
-            let left = (*item as f64) * 10_f64.pow(6);
-            assert_eq!(left as i128, decimal_array.value(i as usize));
-        }
+        generate_cast_test_case!(
+            &array,
+            DecimalArray,
+            &decimal_type,
+            vec![
+                Some(1100000_i128),
+                Some(2200000_i128),
+                Some(4400000_i128),
+                None,
+                Some(1123456_i128),
+                Some(1123456_i128),
+                Some(1123456_i128),
+            ]
+        );
     }
 
     #[test]
@@ -4031,6 +4281,9 @@ mod tests {
             Arc::new(DurationMillisecondArray::from(vec![1000, 2000])),
             Arc::new(DurationMicrosecondArray::from(vec![1000, 2000])),
             Arc::new(DurationNanosecondArray::from(vec![1000, 2000])),
+            Arc::new(
+                create_decimal_array(&[Some(1), Some(2), Some(3), None], 38, 
0).unwrap(),
+            ),
         ]
     }
 
@@ -4204,6 +4457,7 @@ mod tests {
             Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
             Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)),
             Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
+            Decimal(38, 0),
         ]
     }
 

Reply via email to