This is an automated email from the ASF dual-hosted git repository. dheres pushed a commit to branch hash_agg_spike in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit 79b1bc9041798d7a0ec97e6e4c32f75bdc9b6eca Author: Andrew Lamb <[email protected]> AuthorDate: Sat Jul 1 06:27:34 2023 -0400 Move accumulator to their own function --- datafusion/physical-expr/src/aggregate/average.rs | 102 +---------------- .../src/aggregate/groups_accumulator/accumulate.rs | 121 +++++++++++++++++++++ .../mod.rs} | 2 + 3 files changed, 124 insertions(+), 101 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 2d9a627a5f..3f3c7820be 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -45,6 +45,7 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use datafusion_row::accessor::RowAccessor; +use super::groups_accumulator::accumulate::{accumulate_all, accumulate_all_nullable}; use super::utils::Decimal128Averager; /// AVG aggregate expression @@ -417,107 +418,6 @@ impl RowAccumulator for AvgRowAccumulator { } } -/// This function is called to update the accumulator state per row, -/// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for -/// many GroupsAccumulators and thus performance critical. -/// -/// I couldn't find any way to combine this with -/// accumulate_all_nullable without having to pass in a is_null on -/// every row. -/// -/// * `values`: the input arguments to the accumulator -/// * `group_indices`: To which groups do the rows in `values` belong, group id) -/// * `opt_filter`: if present, only update aggregate state using values[i] if opt_filter[i] is true -/// -/// `F`: The function to invoke for a non null input row to update the -/// accumulator state. Called like `value_fn(group_index, value) -fn accumulate_all<T, F>( - values: &PrimitiveArray<T>, - group_indicies: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - mut value_fn: F, -) where - T: ArrowNumericType + Send, - F: FnMut(usize, T::Native) + Send, -{ - assert_eq!( - values.null_count(), 0, - "Called accumulate_all with nullable array (call accumulate_all_nullable instead)" - ); - - // AAL TODO handle filter values - - let data: &[T::Native] = values.values(); - let iter = group_indicies.iter().zip(data.iter()); - for (&group_index, &new_value) in iter { - value_fn(group_index, new_value) - } -} - -/// This function is called to update the accumulator state per row, -/// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for -/// many GroupsAccumulators and thus performance critical. -/// -/// * `values`: the input arguments to the accumulator -/// * `group_indices`: To which groups do the rows in `values` belong, group id) -/// * `opt_filter`: if present, only update aggregate state using values[i] if opt_filter[i] is true -/// -/// `F`: The function to invoke for an input row to update the -/// accumulator state. Called like `value_fn(group_index, value, -/// is_valid). NOTE the parameter is true when the value is VALID. -fn accumulate_all_nullable<T, F>( - values: &PrimitiveArray<T>, - group_indicies: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - mut value_fn: F, -) where - T: ArrowNumericType + Send, - F: FnMut(usize, T::Native, bool) + Send, -{ - // AAL TODO handle filter values - // TODO combine the null mask from values and opt_filter - let valids = values - .nulls() - .expect("Called accumulate_all_nullable with non-nullable array (call accumulate_all instead)"); - - // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum - let data: &[T::Native] = values.values(); - - let group_indices_chunks = group_indicies.chunks_exact(64); - let data_chunks = data.chunks_exact(64); - let bit_chunks = valids.inner().bit_chunks(); - - let group_indices_remainder = group_indices_chunks.remainder(); - let data_remainder = data_chunks.remainder(); - - group_indices_chunks - .zip(data_chunks) - .zip(bit_chunks.iter()) - .for_each(|((group_index_chunk, data_chunk), mask)| { - // index_mask has value 1 << i in the loop - let mut index_mask = 1; - group_index_chunk.iter().zip(data_chunk.iter()).for_each( - |(&group_index, &new_value)| { - // valid bit was set, real vale - let is_valid = (mask & index_mask) != 0; - value_fn(group_index, new_value, is_valid); - index_mask <<= 1; - }, - ) - }); - - // handle any remaining bits (after the intial 64) - let remainder_bits = bit_chunks.remainder_bits(); - group_indices_remainder - .iter() - .zip(data_remainder.iter()) - .enumerate() - .for_each(|(i, (&group_index, &new_value))| { - let is_valid = remainder_bits & (1 << i) != 0; - value_fn(group_index, new_value, is_valid) - }); -} - /// An accumulator to compute the average of PrimitiveArray<T>. /// Stores values as native types, and does overflow checking /// diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs new file mode 100644 index 0000000000..5d72328763 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Vectorized [`accumulate`] and [`accumulate_nullable`] functions + +use arrow_array::{Array, ArrowNumericType, PrimitiveArray}; + +/// This function is called to update the accumulator state per row, +/// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for +/// many GroupsAccumulators and thus performance critical. +/// +/// I couldn't find any way to combine this with +/// accumulate_all_nullable without having to pass in a is_null on +/// every row. +/// +/// * `values`: the input arguments to the accumulator +/// * `group_indices`: To which groups do the rows in `values` belong, group id) +/// * `opt_filter`: if present, only update aggregate state using values[i] if opt_filter[i] is true +/// +/// `F`: The function to invoke for a non null input row to update the +/// accumulator state. Called like `value_fn(group_index, value) +pub fn accumulate_all<T, F>( + values: &PrimitiveArray<T>, + group_indicies: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + mut value_fn: F, +) where + T: ArrowNumericType + Send, + F: FnMut(usize, T::Native) + Send, +{ + assert_eq!( + values.null_count(), 0, + "Called accumulate_all with nullable array (call accumulate_all_nullable instead)" + ); + + // AAL TODO handle filter values + + let data: &[T::Native] = values.values(); + let iter = group_indicies.iter().zip(data.iter()); + for (&group_index, &new_value) in iter { + value_fn(group_index, new_value) + } +} + +/// This function is called to update the accumulator state per row, +/// for a `PrimitiveArray<T>` with no nulls. It is the inner loop for +/// many GroupsAccumulators and thus performance critical. +/// +/// * `values`: the input arguments to the accumulator +/// * `group_indices`: To which groups do the rows in `values` belong, group id) +/// * `opt_filter`: if present, only update aggregate state using values[i] if opt_filter[i] is true +/// +/// `F`: The function to invoke for an input row to update the +/// accumulator state. Called like `value_fn(group_index, value, +/// is_valid). NOTE the parameter is true when the value is VALID. +pub fn accumulate_all_nullable<T, F>( + values: &PrimitiveArray<T>, + group_indicies: &[usize], + opt_filter: Option<&arrow_array::BooleanArray>, + mut value_fn: F, +) where + T: ArrowNumericType + Send, + F: FnMut(usize, T::Native, bool) + Send, +{ + // AAL TODO handle filter values + // TODO combine the null mask from values and opt_filter + let valids = values + .nulls() + .expect("Called accumulate_all_nullable with non-nullable array (call accumulate_all instead)"); + + // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum + let data: &[T::Native] = values.values(); + + let group_indices_chunks = group_indicies.chunks_exact(64); + let data_chunks = data.chunks_exact(64); + let bit_chunks = valids.inner().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + let data_remainder = data_chunks.remainder(); + + group_indices_chunks + .zip(data_chunks) + .zip(bit_chunks.iter()) + .for_each(|((group_index_chunk, data_chunk), mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().zip(data_chunk.iter()).for_each( + |(&group_index, &new_value)| { + // valid bit was set, real vale + let is_valid = (mask & index_mask) != 0; + value_fn(group_index, new_value, is_valid); + index_mask <<= 1; + }, + ) + }); + + // handle any remaining bits (after the intial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .zip(data_remainder.iter()) + .enumerate() + .for_each(|(i, (&group_index, &new_value))| { + let is_valid = remainder_bits & (1 << i) != 0; + value_fn(group_index, new_value, is_valid) + }); +} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs similarity index 99% rename from datafusion/physical-expr/src/aggregate/groups_accumulator.rs rename to datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs index 82cfbfaa31..680eb927a1 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs @@ -17,6 +17,8 @@ //! Vectorized [`GroupsAccumulator`] +pub mod accumulate; + use arrow_array::{ArrayRef, BooleanArray}; use datafusion_common::Result;
