This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 45fb91992 Add sum_dyn to calculate sum for dictionary array (#2566)
45fb91992 is described below
commit 45fb919928e1903be0eb4de3af1966c56c6f6c71
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Aug 24 14:42:08 2022 -0700
Add sum_dyn to calculate sum for dictionary array (#2566)
* Add sum_dyn
* Add null values test case
---
arrow/src/compute/kernels/aggregate.rs | 63 +++++++++++++++++++++++++++++++---
1 file changed, 59 insertions(+), 4 deletions(-)
diff --git a/arrow/src/compute/kernels/aggregate.rs
b/arrow/src/compute/kernels/aggregate.rs
index 12ead669f..c8d0443c4 100644
--- a/arrow/src/compute/kernels/aggregate.rs
+++ b/arrow/src/compute/kernels/aggregate.rs
@@ -21,10 +21,10 @@ use multiversion::multiversion;
use std::ops::Add;
use crate::array::{
- Array, BooleanArray, GenericBinaryArray, GenericStringArray,
OffsetSizeTrait,
- PrimitiveArray,
+ as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray,
+ GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
};
-use crate::datatypes::{ArrowNativeType, ArrowNumericType};
+use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType};
/// Generic test for NaN, the optimizer should be able to remove this for
integer types.
#[inline]
@@ -185,6 +185,37 @@ pub fn min_string<T: OffsetSizeTrait>(array:
&GenericStringArray<T>) -> Option<&
}
/// Returns the sum of values in the array.
+pub fn sum_dyn<T, A: ArrayAccessor<Item = T::Native>>(array: A) ->
Option<T::Native>
+where
+ T: ArrowNumericType,
+ T::Native: Add<Output = T::Native>,
+{
+ match array.data_type() {
+ DataType::Dictionary(_, _) => {
+ let null_count = array.null_count();
+
+ if null_count == array.len() {
+ return None;
+ }
+
+ let iter = ArrayIter::new(array);
+ let sum = iter
+ .into_iter()
+ .fold(T::default_value(), |accumulator, value| {
+ if let Some(value) = value {
+ accumulator + value
+ } else {
+ accumulator
+ }
+ });
+
+ Some(sum)
+ }
+ _ => sum::<T>(as_primitive_array(&array)),
+ }
+}
+
+/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
#[cfg(not(feature = "simd"))]
@@ -583,7 +614,7 @@ mod simd {
}
}
-/// Returns the sum of values in the array.
+/// Returns the sum of values in the primitive array.
///
/// Returns `None` if the array is empty or only contains null values.
#[cfg(feature = "simd")]
@@ -625,6 +656,7 @@ mod tests {
use super::*;
use crate::array::*;
use crate::compute::add;
+ use crate::datatypes::{Int32Type, Int8Type};
#[test]
fn test_primitive_array_sum() {
@@ -1003,4 +1035,27 @@ mod tests {
assert_eq!(Some(true), min_boolean(&a));
assert_eq!(Some(true), max_boolean(&a));
}
+
+ #[test]
+ fn test_sum_dyn() {
+ let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15,
16, 17]);
+ let keys = Int8Array::from_iter_values([2_i8, 3, 4]);
+
+ let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
+ let array = dict_array.downcast_dict::<Int8Array>().unwrap();
+ assert_eq!(39, sum_dyn::<Int8Type, _>(array).unwrap());
+
+ let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+ assert_eq!(15, sum_dyn::<Int32Type, _>(&a).unwrap());
+
+ let keys = Int8Array::from(vec![Some(2_i8), None, Some(4)]);
+ let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
+ let array = dict_array.downcast_dict::<Int8Array>().unwrap();
+ assert_eq!(26, sum_dyn::<Int8Type, _>(array).unwrap());
+
+ let keys = Int8Array::from(vec![None, None, None]);
+ let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
+ let array = dict_array.downcast_dict::<Int8Array>().unwrap();
+ assert!(sum_dyn::<Int8Type, _>(array).is_none());
+ }
}