This is an automated email from the ASF dual-hosted git repository. dheres pushed a commit to branch hash_agg_spike in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit 892e440538b984877b82edcc96b4e2e5f80e70dd Author: Andrew Lamb <[email protected]> AuthorDate: Sat Jul 1 05:46:52 2023 -0400 split nullable/non nullable handling --- datafusion/physical-expr/src/aggregate/average.rs | 139 ++++++++++++---------- 1 file changed, 79 insertions(+), 60 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 0dcff7ec9b..20ccadd7e8 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -417,9 +417,13 @@ impl RowAccumulator for AvgRowAccumulator { } } -/// This function is called once per row to update the accumulator, -/// for a `PrimitiveArray<T>` and is the inner loop for many -/// GroupsAccumulators and thus performance critical. +/// This function is called to update the accumulator state per row, +/// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for +/// many GroupsAccumulators and thus performance critical. +/// +/// I couldn't find any way to combine this with +/// accumulate_all_nullable without having to pass in a is_null on +/// every row. /// /// * `values`: the input arguments to the accumulator /// * `group_indices`: To which groups do the rows in `values` belong, group id) @@ -427,80 +431,95 @@ impl RowAccumulator for AvgRowAccumulator { /// /// `F`: The function to invoke for a non null input row to update the /// accumulator state. Called like `value_fn(group_index, value) -/// -/// `FN`: The function to call for each null input row. Called like -/// `null_fn(group_index) fn accumulate_all<T, F, FN>( values: &PrimitiveArray<T>, group_indicies: &[usize], opt_filter: Option<&arrow_array::BooleanArray>, value_fn: F, - null_fn: FN, ) where T: ArrowNumericType + Send, F: Fn(usize, T::Native) + Send, - FN: Fn(usize) + Send, { + assert_eq!( + values.null_count(), 0, + "Called accumulate_all with nullable array (call accumulate_all_nullable instead)" + ); + // AAL TODO handle filter values + + let data: &[T::Native] = values.values(); + let iter = group_indicies.iter().zip(data.iter()); + for (&group_index, &new_value) in iter { + value_fn(group_index, new_value) + } +} + + +/// This function is called to update the accumulator state per row, +/// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for +/// many GroupsAccumulators and thus performance critical. +/// +/// * `values`: the input arguments to the accumulator +/// * `group_indices`: To which groups do the rows in `values` belong, group id) +/// * `opt_filter`: if present, only update aggregate state using values[i] if opt_filter[i] is true +/// +/// `F`: The function to invoke for an input row to update the +/// accumulator state. Called like `value_fn(group_index, value, +/// is_valid). NOTE the parameter is true when the value is VALID. +fn accumulate_all_nullable<T, F, FN>( + values: &PrimitiveArray<T>, + group_indicies: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + value_fn: F, +) where + T: ArrowNumericType + Send, + F: Fn(usize, T::Native, bool) + Send, +{ + // AAL TODO handle filter values // TODO combine the null mask from values and opt_filter - let valids = values.nulls(); + let valids = values + .nulls() + .expect("Called accumulate_all_nullable with non-nullable array (call accumulate_all instead)"); // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum let data: &[T::Native] = values.values(); - match valids { - // no nulls - None => { - let iter = group_indicies.iter().zip(data.iter()); - for (&group_index, &new_value) in iter { - value_fn(group_index, new_value) - } - } - // there are nulls, so handle them specially - Some(valids) => { - let group_indices_chunks = group_indicies.chunks_exact(64); - let data_chunks = data.chunks_exact(64); - let bit_chunks = valids.inner().bit_chunks(); - - let group_indices_remainder = group_indices_chunks.remainder(); - let data_remainder = data_chunks.remainder(); - - group_indices_chunks - .zip(data_chunks) - .zip(bit_chunks.iter()) - .for_each(|((group_index_chunk, data_chunk), mask)| { - // index_mask has value 1 << i in the loop - let mut index_mask = 1; - group_index_chunk.iter().zip(data_chunk.iter()).for_each( - |(&group_index, &new_value)| { - // valid bit was set, real vale - if (mask & index_mask) != 0 { - value_fn(group_index, new_value); - } else { - null_fn(group_index) - } - index_mask <<= 1; - }, - ) - }); - - // handle any remaining bits (after the intial 64) - let remainder_bits = bit_chunks.remainder_bits(); - group_indices_remainder - .iter() - .zip(data_remainder.iter()) - .enumerate() - .for_each(|(i, (&group_index, &new_value))| { - if remainder_bits & (1 << i) != 0 { - value_fn(group_index, new_value) - } else { - null_fn(group_index) - } - }); - } - } + let group_indices_chunks = group_indicies.chunks_exact(64); + let data_chunks = data.chunks_exact(64); + let bit_chunks = valids.inner().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + let data_remainder = data_chunks.remainder(); + + group_indices_chunks + .zip(data_chunks) + .zip(bit_chunks.iter()) + .for_each(|((group_index_chunk, data_chunk), mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().zip(data_chunk.iter()).for_each( + |(&group_index, &new_value)| { + // valid bit was set, real vale + let is_valid = (mask & index_mask) != 0; + value_fn(group_index, new_value, is_valid); + index_mask <<= 1; + }, + ) + }); + + // handle any remaining bits (after the intial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .zip(data_remainder.iter()) + .enumerate() + .for_each(|(i, (&group_index, &new_value))| { + let is_valid = remainder_bits & (1 << i) != 0; + value_fn(group_index, new_value, is_valid) + }); } + /// An accumulator to compute the average of PrimitiveArray<T>. /// Stores values as native types, and does overflow checking ///
