This is an automated email from the ASF dual-hosted git repository.

remziy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 191eaef09 Add modulus ops into `ArrowNativeTypeOp` (#2756)
191eaef09 is described below

commit 191eaef0906f61dc0dc6ca6cea96a99f86e6c5a4
Author: Remzi Yang <[email protected]>
AuthorDate: Tue Oct 4 07:26:24 2022 +0800

    Add modulus ops into `ArrowNativeTypeOp` (#2756)
    
    * add 3 mod ops and tests
    
    Signed-off-by: remzi <[email protected]>
    
    * fix simd error
    
    Signed-off-by: remzi <[email protected]>
    
    * remove_mod_divide_by_zero
    
    Signed-off-by: remzi <[email protected]>
    
    * overflow panic simd
    
    Signed-off-by: remzi <[email protected]>
    
    * address comment
    
    Signed-off-by: remzi <[email protected]>
    
    Signed-off-by: remzi <[email protected]>
---
 arrow/src/compute/kernels/arithmetic.rs | 64 +++++++++++++++++++++++++++------
 arrow/src/datatypes/native.rs           | 32 ++++++++++++++++-
 2 files changed, 85 insertions(+), 11 deletions(-)

diff --git a/arrow/src/compute/kernels/arithmetic.rs 
b/arrow/src/compute/kernels/arithmetic.rs
index b2e95ad5e..1e6e55248 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -22,7 +22,7 @@
 //! `RUSTFLAGS="-C target-feature=+avx2"` for example.  See the documentation
 //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
 
-use std::ops::{Div, Neg, Rem};
+use std::ops::{Div, Neg};
 
 use num::{One, Zero};
 
@@ -182,7 +182,7 @@ fn simd_checked_modulus<T: ArrowNumericType>(
     right: T::Simd,
 ) -> Result<T::Simd>
 where
-    T::Native: One + Zero,
+    T::Native: ArrowNativeTypeOp + One,
 {
     let zero = T::init(T::Native::zero());
     let one = T::init(T::Native::one());
@@ -305,7 +305,7 @@ fn simd_checked_divide_op<T, SI, SC>(
 ) -> Result<PrimitiveArray<T>>
 where
     T: ArrowNumericType,
-    T::Native: One + Zero,
+    T::Native: ArrowNativeTypeOp,
     SI: Fn(Option<u64>, T::Simd, T::Simd) -> Result<T::Simd>,
     SC: Fn(T::Native, T::Native) -> T::Native,
 {
@@ -1301,7 +1301,7 @@ pub fn modulus<T>(
 ) -> Result<PrimitiveArray<T>>
 where
     T: ArrowNumericType,
-    T::Native: Rem<Output = T::Native> + Zero + One,
+    T::Native: ArrowNativeTypeOp + One,
 {
     #[cfg(feature = "simd")]
     return simd_checked_divide_op(&left, &right, simd_checked_modulus::<T>, 
|a, b| {
@@ -1312,7 +1312,7 @@ where
         if b.is_zero() {
             Err(ArrowError::DivideByZero)
         } else {
-            Ok(a % b)
+            Ok(a.mod_wrapping(b))
         }
     });
 }
@@ -1507,13 +1507,13 @@ pub fn modulus_scalar<T>(
 ) -> Result<PrimitiveArray<T>>
 where
     T: ArrowNumericType,
-    T::Native: Rem<Output = T::Native> + Zero,
+    T::Native: ArrowNativeTypeOp,
 {
     if modulo.is_zero() {
         return Err(ArrowError::DivideByZero);
     }
 
-    Ok(unary(array, |a| a % modulo))
+    Ok(unary(array, |a| a.mod_wrapping(modulo)))
 }
 
 /// Divide every value in an array by a scalar. If any value in the array is 
null then the
@@ -2117,7 +2117,7 @@ mod tests {
     }
 
     #[test]
-    fn test_primitive_array_modulus() {
+    fn test_int_array_modulus() {
         let a = Int32Array::from(vec![15, 15, 8, 1, 9]);
         let b = Int32Array::from(vec![5, 6, 8, 9, 1]);
         let c = modulus(&a, &b).unwrap();
@@ -2128,6 +2128,34 @@ mod tests {
         assert_eq!(0, c.value(4));
     }
 
+    #[test]
+    #[should_panic(
+        expected = "called `Result::unwrap()` on an `Err` value: DivideByZero"
+    )]
+    fn test_int_array_modulus_divide_by_zero() {
+        let a = Int32Array::from(vec![1]);
+        let b = Int32Array::from(vec![0]);
+        modulus(&a, &b).unwrap();
+    }
+
+    #[test]
+    #[cfg(not(feature = "simd"))]
+    fn test_int_array_modulus_overflow_wrapping() {
+        let a = Int32Array::from(vec![i32::MIN]);
+        let b = Int32Array::from(vec![-1]);
+        let result = modulus(&a, &b).unwrap();
+        assert_eq!(0, result.value(0))
+    }
+
+    #[test]
+    #[cfg(feature = "simd")]
+    #[should_panic(expected = "attempt to calculate the remainder with 
overflow")]
+    fn test_int_array_modulus_overflow_panic() {
+        let a = Int32Array::from(vec![i32::MIN]);
+        let b = Int32Array::from(vec![-1]);
+        let _ = modulus(&a, &b).unwrap();
+    }
+
     #[test]
     fn test_primitive_array_divide_scalar() {
         let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
@@ -2190,7 +2218,7 @@ mod tests {
     }
 
     #[test]
-    fn test_primitive_array_modulus_scalar() {
+    fn test_int_array_modulus_scalar() {
         let a = Int32Array::from(vec![15, 14, 9, 8, 1]);
         let b = 3;
         let c = modulus_scalar(&a, b).unwrap();
@@ -2199,7 +2227,7 @@ mod tests {
     }
 
     #[test]
-    fn test_primitive_array_modulus_scalar_sliced() {
+    fn test_int_array_modulus_scalar_sliced() {
         let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]);
         let a = a.slice(1, 4);
         let a = as_primitive_array(&a);
@@ -2208,6 +2236,22 @@ mod tests {
         assert_eq!(actual, expected);
     }
 
+    #[test]
+    #[should_panic(
+        expected = "called `Result::unwrap()` on an `Err` value: DivideByZero"
+    )]
+    fn test_int_array_modulus_scalar_divide_by_zero() {
+        let a = Int32Array::from(vec![1]);
+        modulus_scalar(&a, 0).unwrap();
+    }
+
+    #[test]
+    fn test_int_array_modulus_scalar_overflow_wrapping() {
+        let a = Int32Array::from(vec![i32::MIN]);
+        let result = modulus_scalar(&a, -1).unwrap();
+        assert_eq!(0, result.value(0))
+    }
+
     #[test]
     fn test_primitive_array_divide_sliced() {
         let a = Int32Array::from(vec![0, 0, 0, 15, 15, 8, 1, 9, 0]);
diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs
index 6ab82688e..654b93950 100644
--- a/arrow/src/datatypes/native.rs
+++ b/arrow/src/datatypes/native.rs
@@ -26,7 +26,7 @@ pub(crate) mod native_op {
     use super::ArrowNativeType;
     use crate::error::{ArrowError, Result};
     use num::Zero;
-    use std::ops::{Add, Div, Mul, Sub};
+    use std::ops::{Add, Div, Mul, Rem, Sub};
 
     /// Trait for ArrowNativeType to provide overflow-checking and 
non-overflow-checking
     /// variants for arithmetic operations. For floating point types, this 
provides some
@@ -44,6 +44,7 @@ pub(crate) mod native_op {
         + Sub<Output = Self>
         + Mul<Output = Self>
         + Div<Output = Self>
+        + Rem<Output = Self>
         + Zero
     {
         fn add_checked(self, rhs: Self) -> Result<Self> {
@@ -81,6 +82,18 @@ pub(crate) mod native_op {
         fn div_wrapping(self, rhs: Self) -> Self {
             self / rhs
         }
+
+        fn mod_checked(self, rhs: Self) -> Result<Self> {
+            if rhs.is_zero() {
+                Err(ArrowError::DivideByZero)
+            } else {
+                Ok(self % rhs)
+            }
+        }
+
+        fn mod_wrapping(self, rhs: Self) -> Self {
+            self % rhs
+        }
     }
 }
 
@@ -142,6 +155,23 @@ macro_rules! native_type_op {
             fn div_wrapping(self, rhs: Self) -> Self {
                 self.wrapping_div(rhs)
             }
+
+            fn mod_checked(self, rhs: Self) -> Result<Self> {
+                if rhs.is_zero() {
+                    Err(ArrowError::DivideByZero)
+                } else {
+                    self.checked_rem(rhs).ok_or_else(|| {
+                        ArrowError::ComputeError(format!(
+                            "Overflow happened on: {:?} % {:?}",
+                            self, rhs
+                        ))
+                    })
+                }
+            }
+
+            fn mod_wrapping(self, rhs: Self) -> Self {
+                self.wrapping_rem(rhs)
+            }
         }
     };
 }

Reply via email to