This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch hash_agg_spike
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit bcd9e8ed568f41ff6a7555ee8fbedaa890fbefb5
Author: DaniĆ«l Heres <[email protected]>
AuthorDate: Mon Jul 3 09:57:24 2023 +0200

    WIP count
---
 datafusion/physical-expr/src/aggregate/count.rs | 238 +++++++++++++++++++++++-
 1 file changed, 236 insertions(+), 2 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/count.rs 
b/datafusion/physical-expr/src/aggregate/count.rs
index 22cb2512fc..c3ad7767b1 100644
--- a/datafusion/physical-expr/src/aggregate/count.rs
+++ b/datafusion/physical-expr/src/aggregate/count.rs
@@ -19,17 +19,23 @@
 
 use std::any::Any;
 use std::fmt::Debug;
+use std::marker::PhantomData;
 use std::ops::BitAnd;
 use std::sync::Arc;
 
 use crate::aggregate::row_accumulator::RowAccumulator;
 use crate::aggregate::utils::down_cast_any_ref;
-use crate::{AggregateExpr, PhysicalExpr};
+use crate::{AggregateExpr, PhysicalExpr, GroupsAccumulator};
 use arrow::array::{Array, Int64Array};
 use arrow::compute;
+use arrow::compute::kernels::cast;
 use arrow::datatypes::DataType;
 use arrow::{array::ArrayRef, datatypes::Field};
-use arrow_buffer::BooleanBuffer;
+use arrow_array::builder::PrimitiveBuilder;
+use arrow_array::cast::AsArray;
+use arrow_array::types::{UInt64Type, Int64Type, UInt32Type, Int32Type};
+use arrow_array::{PrimitiveArray, UInt64Array, ArrowNumericType};
+use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
 use datafusion_common::{downcast_value, ScalarValue};
 use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::Accumulator;
@@ -37,6 +43,8 @@ use datafusion_row::accessor::RowAccessor;
 
 use crate::expressions::format_state_name;
 
+use super::groups_accumulator::accumulate::{accumulate_all, 
accumulate_all_nullable};
+
 /// COUNT aggregate expression
 /// Returns the amount of non-null values of the given expression.
 #[derive(Debug, Clone)]
@@ -76,6 +84,200 @@ impl Count {
     }
 }
 
+/// An accumulator to compute the average of PrimitiveArray<T>.
+/// Stores values as native types, and does overflow checking
+///
+/// F: Function that calcuates the average value from a sum of
+/// T::Native and a total count
+#[derive(Debug)]
+struct CountGroupsAccumulator<T>
+where T: ArrowNumericType + Send,
+{
+    /// The type of the returned count
+    return_data_type: DataType,
+
+    /// Count per group (use u64 to make UInt64Array)
+    counts: Vec<u64>,
+
+    /// If we have seen a null input value for this group_index
+    null_inputs: BooleanBufferBuilder,
+
+    // Bind it to struct
+    phantom: PhantomData<T>
+}
+
+
+impl<T> CountGroupsAccumulator<T>
+where T: ArrowNumericType + Send,
+{
+    pub fn new(return_data_type: &DataType) -> Self {
+        Self {
+            return_data_type: return_data_type.clone(),
+            counts: vec![],
+            null_inputs: BooleanBufferBuilder::new(0),
+            phantom: PhantomData {}
+        }
+    }
+
+        /// Adds one to each group's counter
+        fn increment_counts(
+            &mut self,
+            group_indices: &[usize],
+            values: &PrimitiveArray<T>,
+            opt_filter: Option<&arrow_array::BooleanArray>,
+            total_num_groups: usize,
+        ) {
+            self.counts.resize(total_num_groups, 0);
+    
+            if values.null_count() == 0 {
+                accumulate_all(
+                    group_indices,
+                    values,
+                    opt_filter,
+                    |group_index, _new_value| {
+                        self.counts[group_index] += 1;
+                    }
+                )
+            }else {
+                accumulate_all_nullable(
+                    group_indices,
+                    values,
+                    opt_filter,
+                    |group_index, _new_value, is_valid| {
+                        if is_valid {
+                            self.counts[group_index] += 1;
+                        }
+                    },
+                )
+            }
+        }
+
+        /// Adds the counts with the partial counts
+        fn update_counts_with_partial_counts(
+            &mut self,
+            group_indices: &[usize],
+            partial_counts: &UInt64Array,
+            opt_filter: Option<&arrow_array::BooleanArray>,
+            total_num_groups: usize,
+        ) {
+            self.counts.resize(total_num_groups, 0);
+    
+            if partial_counts.null_count() == 0 {
+                accumulate_all(
+                    group_indices,
+                    partial_counts,
+                    opt_filter,
+                    |group_index, partial_count| {
+                        self.counts[group_index] += partial_count;
+                    },
+                )
+            } else {
+                accumulate_all_nullable(
+                    group_indices,
+                    partial_counts,
+                    opt_filter,
+                    |group_index, partial_count, is_valid| {
+                        if is_valid {
+                            self.counts[group_index] += partial_count;
+                        }
+                    },
+                )
+            }
+        }
+
+        /// Returns a NullBuffer representing which group_indices have
+        /// null values (if they saw a null input)
+        /// Resets `self.null_inputs`;
+        fn build_nulls(&mut self) -> Option<NullBuffer> {
+            let nulls = NullBuffer::new(self.null_inputs.finish());
+            if nulls.null_count() > 0 {
+                Some(nulls)
+            } else {
+                None
+            }
+        }
+}
+
+impl <T> GroupsAccumulator for CountGroupsAccumulator<T>
+where T: ArrowNumericType + Send
+{
+    fn update_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        assert_eq!(values.len(), 1, "single argument to update_batch");
+        let values = values.get(0).unwrap().as_primitive::<T>();
+
+        self.increment_counts(group_indices, values, opt_filter, 
total_num_groups);
+
+        Ok(())
+    }
+
+    fn merge_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        assert_eq!(values.len(), 1, "one argument to merge_batch");
+        // first batch is counts, second is partial sums
+        let partial_counts = 
values.get(0).unwrap().as_primitive::<UInt64Type>();
+        self.update_counts_with_partial_counts(
+            group_indices,
+            partial_counts,
+            opt_filter,
+            total_num_groups,
+        );
+
+        Ok(())
+    }
+
+    fn evaluate(&mut self) -> Result<ArrayRef> {
+        let counts = std::mem::take(&mut self.counts);
+        let nulls = self.build_nulls();
+
+        // don't evaluate averages with null inputs to avoid errors on null 
vaues
+        let array: PrimitiveArray<UInt64Type> = if let Some(nulls) = 
nulls.as_ref() {
+            let mut builder = 
PrimitiveBuilder::<UInt64Type>::with_capacity(nulls.len());
+            let iter = counts.into_iter().zip(nulls.iter());
+
+            for (count, is_valid) in iter {
+                if is_valid {
+                    builder.append_value(count)
+                } else {
+                    builder.append_null();
+                }
+            }
+            builder.finish()
+        } else {
+            PrimitiveArray::<UInt64Type>::new(counts.into(), nulls) // no copy
+        };
+        // TODO remove cast
+        let array = cast(&array, &self.return_data_type)?;
+
+        Ok(array)
+    }
+
+    // return arrays for sums and counts
+    fn state(&mut self) -> Result<Vec<ArrayRef>> {
+        // TODO nulls
+        let nulls = self.build_nulls();
+        let counts = std::mem::take(&mut self.counts);
+        let counts = UInt64Array::from(counts); // zero copy
+        Ok(vec![
+            Arc::new(counts) as ArrayRef,
+        ])
+    }
+
+    fn size(&self) -> usize {
+        self.counts.capacity() * std::mem::size_of::<usize>()
+    }
+}
+
 /// count null values for multiple columns
 /// for each row if one column value is null, then null_count + 1
 fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
@@ -147,6 +349,38 @@ impl AggregateExpr for Count {
     fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
         Ok(Box::new(CountAccumulator::new()))
     }
+
+    fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+        // instantiate specialized accumulator
+        match &self.data_type {
+            DataType::UInt64 => {
+                Ok(Box::new(CountGroupsAccumulator::<UInt64Type>::new(
+                    &self.data_type,
+                )))
+            },
+                DataType::Int64 => {
+                Ok(Box::new(CountGroupsAccumulator::<Int64Type>::new(
+                    &self.data_type,
+                )))
+            },
+                DataType::UInt32 => {
+                Ok(Box::new(CountGroupsAccumulator::<UInt32Type>::new(
+                    &self.data_type,
+                )))
+            },
+                DataType::Int32 => {
+                Ok(Box::new(CountGroupsAccumulator::<Int32Type>::new(
+                    &self.data_type,
+                )))
+            }
+
+            _ => Err(DataFusionError::NotImplemented(format!(
+                "CountGroupsAccumulator not supported for {}",
+                self.data_type
+            ))),
+        }
+
+    }
 }
 
 impl PartialEq<dyn Any> for Count {

Reply via email to