This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 45fb91992 Add sum_dyn to calculate sum for dictionary array (#2566)
45fb91992 is described below

commit 45fb919928e1903be0eb4de3af1966c56c6f6c71
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Aug 24 14:42:08 2022 -0700

    Add sum_dyn to calculate sum for dictionary array (#2566)
    
    * Add sum_dyn
    
    * Add null values test case
---
 arrow/src/compute/kernels/aggregate.rs | 63 +++++++++++++++++++++++++++++++---
 1 file changed, 59 insertions(+), 4 deletions(-)

diff --git a/arrow/src/compute/kernels/aggregate.rs 
b/arrow/src/compute/kernels/aggregate.rs
index 12ead669f..c8d0443c4 100644
--- a/arrow/src/compute/kernels/aggregate.rs
+++ b/arrow/src/compute/kernels/aggregate.rs
@@ -21,10 +21,10 @@ use multiversion::multiversion;
 use std::ops::Add;
 
 use crate::array::{
-    Array, BooleanArray, GenericBinaryArray, GenericStringArray, 
OffsetSizeTrait,
-    PrimitiveArray,
+    as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray,
+    GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
 };
-use crate::datatypes::{ArrowNativeType, ArrowNumericType};
+use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType};
 
 /// Generic test for NaN, the optimizer should be able to remove this for 
integer types.
 #[inline]
@@ -185,6 +185,37 @@ pub fn min_string<T: OffsetSizeTrait>(array: 
&GenericStringArray<T>) -> Option<&
 }
 
 /// Returns the sum of values in the array.
+pub fn sum_dyn<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> 
Option<T::Native>
+where
+    T: ArrowNumericType,
+    T::Native: Add<Output = T::Native>,
+{
+    match array.data_type() {
+        DataType::Dictionary(_, _) => {
+            let null_count = array.null_count();
+
+            if null_count == array.len() {
+                return None;
+            }
+
+            let iter = ArrayIter::new(array);
+            let sum = iter
+                .into_iter()
+                .fold(T::default_value(), |accumulator, value| {
+                    if let Some(value) = value {
+                        accumulator + value
+                    } else {
+                        accumulator
+                    }
+                });
+
+            Some(sum)
+        }
+        _ => sum::<T>(as_primitive_array(&array)),
+    }
+}
+
+/// Returns the sum of values in the primitive array.
 ///
 /// Returns `None` if the array is empty or only contains null values.
 #[cfg(not(feature = "simd"))]
@@ -583,7 +614,7 @@ mod simd {
     }
 }
 
-/// Returns the sum of values in the array.
+/// Returns the sum of values in the primitive array.
 ///
 /// Returns `None` if the array is empty or only contains null values.
 #[cfg(feature = "simd")]
@@ -625,6 +656,7 @@ mod tests {
     use super::*;
     use crate::array::*;
     use crate::compute::add;
+    use crate::datatypes::{Int32Type, Int8Type};
 
     #[test]
     fn test_primitive_array_sum() {
@@ -1003,4 +1035,27 @@ mod tests {
         assert_eq!(Some(true), min_boolean(&a));
         assert_eq!(Some(true), max_boolean(&a));
     }
+
+    #[test]
+    fn test_sum_dyn() {
+        let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 
16, 17]);
+        let keys = Int8Array::from_iter_values([2_i8, 3, 4]);
+
+        let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
+        let array = dict_array.downcast_dict::<Int8Array>().unwrap();
+        assert_eq!(39, sum_dyn::<Int8Type, _>(array).unwrap());
+
+        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+        assert_eq!(15, sum_dyn::<Int32Type, _>(&a).unwrap());
+
+        let keys = Int8Array::from(vec![Some(2_i8), None, Some(4)]);
+        let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
+        let array = dict_array.downcast_dict::<Int8Array>().unwrap();
+        assert_eq!(26, sum_dyn::<Int8Type, _>(array).unwrap());
+
+        let keys = Int8Array::from(vec![None, None, None]);
+        let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
+        let array = dict_array.downcast_dict::<Int8Array>().unwrap();
+        assert!(sum_dyn::<Int8Type, _>(array).is_none());
+    }
 }

Reply via email to