This is an automated email from the ASF dual-hosted git repository. dheres pushed a commit to branch hash_agg_spike in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit bcd9e8ed568f41ff6a7555ee8fbedaa890fbefb5 Author: Daniƫl Heres <[email protected]> AuthorDate: Mon Jul 3 09:57:24 2023 +0200 WIP count --- datafusion/physical-expr/src/aggregate/count.rs | 238 +++++++++++++++++++++++- 1 file changed, 236 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 22cb2512fc..c3ad7767b1 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -19,17 +19,23 @@ use std::any::Any; use std::fmt::Debug; +use std::marker::PhantomData; use std::ops::BitAnd; use std::sync::Arc; use crate::aggregate::row_accumulator::RowAccumulator; use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::{AggregateExpr, PhysicalExpr, GroupsAccumulator}; use arrow::array::{Array, Int64Array}; use arrow::compute; +use arrow::compute::kernels::cast; use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::Field}; -use arrow_buffer::BooleanBuffer; +use arrow_array::builder::PrimitiveBuilder; +use arrow_array::cast::AsArray; +use arrow_array::types::{UInt64Type, Int64Type, UInt32Type, Int32Type}; +use arrow_array::{PrimitiveArray, UInt64Array, ArrowNumericType}; +use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use datafusion_common::{downcast_value, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; @@ -37,6 +43,8 @@ use datafusion_row::accessor::RowAccessor; use crate::expressions::format_state_name; +use super::groups_accumulator::accumulate::{accumulate_all, accumulate_all_nullable}; + /// COUNT aggregate expression /// Returns the amount of non-null values of the given expression. #[derive(Debug, Clone)] @@ -76,6 +84,200 @@ impl Count { } } +/// An accumulator to compute the average of PrimitiveArray<T>. +/// Stores values as native types, and does overflow checking +/// +/// F: Function that calcuates the average value from a sum of +/// T::Native and a total count +#[derive(Debug)] +struct CountGroupsAccumulator<T> +where T: ArrowNumericType + Send, +{ + /// The type of the returned count + return_data_type: DataType, + + /// Count per group (use u64 to make UInt64Array) + counts: Vec<u64>, + + /// If we have seen a null input value for this group_index + null_inputs: BooleanBufferBuilder, + + // Bind it to struct + phantom: PhantomData<T> +} + + +impl<T> CountGroupsAccumulator<T> +where T: ArrowNumericType + Send, +{ + pub fn new(return_data_type: &DataType) -> Self { + Self { + return_data_type: return_data_type.clone(), + counts: vec![], + null_inputs: BooleanBufferBuilder::new(0), + phantom: PhantomData {} + } + } + + /// Adds one to each group's counter + fn increment_counts( + &mut self, + group_indices: &[usize], + values: &PrimitiveArray<T>, + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) { + self.counts.resize(total_num_groups, 0); + + if values.null_count() == 0 { + accumulate_all( + group_indices, + values, + opt_filter, + |group_index, _new_value| { + self.counts[group_index] += 1; + } + ) + }else { + accumulate_all_nullable( + group_indices, + values, + opt_filter, + |group_index, _new_value, is_valid| { + if is_valid { + self.counts[group_index] += 1; + } + }, + ) + } + } + + /// Adds the counts with the partial counts + fn update_counts_with_partial_counts( + &mut self, + group_indices: &[usize], + partial_counts: &UInt64Array, + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) { + self.counts.resize(total_num_groups, 0); + + if partial_counts.null_count() == 0 { + accumulate_all( + group_indices, + partial_counts, + opt_filter, + |group_index, partial_count| { + self.counts[group_index] += partial_count; + }, + ) + } else { + accumulate_all_nullable( + group_indices, + partial_counts, + opt_filter, + |group_index, partial_count, is_valid| { + if is_valid { + self.counts[group_index] += partial_count; + } + }, + ) + } + } + + /// Returns a NullBuffer representing which group_indices have + /// null values (if they saw a null input) + /// Resets `self.null_inputs`; + fn build_nulls(&mut self) -> Option<NullBuffer> { + let nulls = NullBuffer::new(self.null_inputs.finish()); + if nulls.null_count() > 0 { + Some(nulls) + } else { + None + } + } +} + +impl <T> GroupsAccumulator for CountGroupsAccumulator<T> +where T: ArrowNumericType + Send +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values.get(0).unwrap().as_primitive::<T>(); + + self.increment_counts(group_indices, values, opt_filter, total_num_groups); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values.get(0).unwrap().as_primitive::<UInt64Type>(); + self.update_counts_with_partial_counts( + group_indices, + partial_counts, + opt_filter, + total_num_groups, + ); + + Ok(()) + } + + fn evaluate(&mut self) -> Result<ArrayRef> { + let counts = std::mem::take(&mut self.counts); + let nulls = self.build_nulls(); + + // don't evaluate averages with null inputs to avoid errors on null vaues + let array: PrimitiveArray<UInt64Type> = if let Some(nulls) = nulls.as_ref() { + let mut builder = PrimitiveBuilder::<UInt64Type>::with_capacity(nulls.len()); + let iter = counts.into_iter().zip(nulls.iter()); + + for (count, is_valid) in iter { + if is_valid { + builder.append_value(count) + } else { + builder.append_null(); + } + } + builder.finish() + } else { + PrimitiveArray::<UInt64Type>::new(counts.into(), nulls) // no copy + }; + // TODO remove cast + let array = cast(&array, &self.return_data_type)?; + + Ok(array) + } + + // return arrays for sums and counts + fn state(&mut self) -> Result<Vec<ArrayRef>> { + // TODO nulls + let nulls = self.build_nulls(); + let counts = std::mem::take(&mut self.counts); + let counts = UInt64Array::from(counts); // zero copy + Ok(vec![ + Arc::new(counts) as ArrayRef, + ]) + } + + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::<usize>() + } +} + /// count null values for multiple columns /// for each row if one column value is null, then null_count + 1 fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { @@ -147,6 +349,38 @@ impl AggregateExpr for Count { fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> { Ok(Box::new(CountAccumulator::new())) } + + fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> { + // instantiate specialized accumulator + match &self.data_type { + DataType::UInt64 => { + Ok(Box::new(CountGroupsAccumulator::<UInt64Type>::new( + &self.data_type, + ))) + }, + DataType::Int64 => { + Ok(Box::new(CountGroupsAccumulator::<Int64Type>::new( + &self.data_type, + ))) + }, + DataType::UInt32 => { + Ok(Box::new(CountGroupsAccumulator::<UInt32Type>::new( + &self.data_type, + ))) + }, + DataType::Int32 => { + Ok(Box::new(CountGroupsAccumulator::<Int32Type>::new( + &self.data_type, + ))) + } + + _ => Err(DataFusionError::NotImplemented(format!( + "CountGroupsAccumulator not supported for {}", + self.data_type + ))), + } + + } } impl PartialEq<dyn Any> for Count {
