This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch hash_agg_spike
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit 668063689250cd9a0a7883af8299c4d3bc17f1f1
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Jul 1 05:24:40 2023 -0400

    factor out accumulate
---
 datafusion/physical-expr/src/aggregate/average.rs | 93 ++++++++++++++++++++++-
 1 file changed, 90 insertions(+), 3 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/average.rs 
b/datafusion/physical-expr/src/aggregate/average.rs
index 7043ed9ce1..0dcff7ec9b 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -417,8 +417,95 @@ impl RowAccumulator for AvgRowAccumulator {
     }
 }
 
+/// This function is called once per row to update the accumulator,
+/// for a `PrimitiveArray<T>` and is the inner loop for many
+/// GroupsAccumulators and thus performance critical.
+///
+/// * `values`: the input arguments to the accumulator
+/// * `group_indices`:  To which groups do the rows in `values` belong, group 
id)
+/// * `opt_filter`: if present, only update aggregate state using values[i] if 
opt_filter[i] is true
+///
+/// `F`: The function to invoke for a non null input row to update the
+/// accumulator state. Called like `value_fn(group_index, value)
+///
+/// `FN`: The function to call for each null input row.  Called like
+/// `null_fn(group_index)
+fn accumulate_all<T, F, FN>(
+    values: &PrimitiveArray<T>,
+    group_indicies: &[usize],
+    opt_filter: Option<&arrow_array::BooleanArray>,
+    value_fn: F,
+    null_fn: FN,
+) where
+    T: ArrowNumericType + Send,
+    F: Fn(usize, T::Native) + Send,
+    FN: Fn(usize) + Send,
+{
+    // AAL TODO handle filter values
+    // TODO combine the null mask from values and opt_filter
+    let valids = values.nulls();
+
+    // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum
+    let data: &[T::Native] = values.values();
+
+    match valids {
+        // no nulls
+        None => {
+            let iter = group_indicies.iter().zip(data.iter());
+            for (&group_index, &new_value) in iter {
+                value_fn(group_index, new_value)
+            }
+        }
+        // there are nulls, so handle them specially
+        Some(valids) => {
+            let group_indices_chunks = group_indicies.chunks_exact(64);
+            let data_chunks = data.chunks_exact(64);
+            let bit_chunks = valids.inner().bit_chunks();
+
+            let group_indices_remainder = group_indices_chunks.remainder();
+            let data_remainder = data_chunks.remainder();
+
+            group_indices_chunks
+                .zip(data_chunks)
+                .zip(bit_chunks.iter())
+                .for_each(|((group_index_chunk, data_chunk), mask)| {
+                    // index_mask has value 1 << i in the loop
+                    let mut index_mask = 1;
+                    group_index_chunk.iter().zip(data_chunk.iter()).for_each(
+                        |(&group_index, &new_value)| {
+                            // valid bit was set, real vale
+                            if (mask & index_mask) != 0 {
+                                value_fn(group_index, new_value);
+                            } else {
+                                null_fn(group_index)
+                            }
+                            index_mask <<= 1;
+                        },
+                    )
+                });
+
+            // handle any remaining bits (after the intial 64)
+            let remainder_bits = bit_chunks.remainder_bits();
+            group_indices_remainder
+                .iter()
+                .zip(data_remainder.iter())
+                .enumerate()
+                .for_each(|(i, (&group_index, &new_value))| {
+                    if remainder_bits & (1 << i) != 0 {
+                        value_fn(group_index, new_value)
+                    } else {
+                        null_fn(group_index)
+                    }
+                });
+        }
+    }
+}
+
 /// An accumulator to compute the average of PrimitiveArray<T>.
 /// Stores values as native types, and does overflow checking
+///
+/// F: Function that calcuates the average value from a sum of
+/// T::Native and a total count
 #[derive(Debug)]
 struct AvgGroupsAccumulator<T, F>
 where
@@ -597,7 +684,7 @@ where
         let array = PrimitiveArray::<T>::from_iter_values(averages);
 
         // fix up decimal precision and scale for decimals
-        let array = set_decimal_precision(&self.return_data_type, 
Arc::new(array))?;
+        let array = adjust_output_array(&self.return_data_type, 
Arc::new(array))?;
 
         Ok(array)
     }
@@ -614,7 +701,7 @@ where
         let sums: PrimitiveArray<T> = PrimitiveArray::from_iter_values(sums);
 
         // fix up decimal precision and scale for decimals
-        let sums = set_decimal_precision(&self.sum_data_type, Arc::new(sums))?;
+        let sums = adjust_output_array(&self.sum_data_type, Arc::new(sums))?;
 
         Ok(vec![
             Arc::new(counts) as ArrayRef,
@@ -631,7 +718,7 @@ where
 ///
 /// Decimal128Arrays are are are created from Vec<NativeType> with default
 /// precision and scale. This function adjusts them down.
-fn set_decimal_precision(sum_data_type: &DataType, array: ArrayRef) -> 
Result<ArrayRef> {
+fn adjust_output_array(sum_data_type: &DataType, array: ArrayRef) -> 
Result<ArrayRef> {
     let array = match sum_data_type {
         DataType::Decimal128(p, s) => Arc::new(
             array

Reply via email to