This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch hash_agg_spike
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit e02c35d649ea999b4a79f44897b31eab8ce606db
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Jun 29 09:42:13 2023 -0400

    POC: Demonstrate new GroupHashAggregate stream approach
---
 .../core/src/physical_plan/aggregates/mod.rs       |  18 +-
 .../core/src/physical_plan/aggregates/row_hash.rs  |   2 +
 .../core/src/physical_plan/aggregates/row_hash2.rs | 449 +++++++++++++++++++++
 datafusion/physical-expr/Cargo.toml                |   1 +
 datafusion/physical-expr/src/aggregate/average.rs  | 218 +++++++++-
 .../src/aggregate/groups_accumulator.rs            | 100 +++++
 datafusion/physical-expr/src/aggregate/mod.rs      |  15 +
 datafusion/physical-expr/src/lib.rs                |   2 +
 8 files changed, 801 insertions(+), 4 deletions(-)

diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs 
b/datafusion/core/src/physical_plan/aggregates/mod.rs
index 343f7628b7..e086b545b8 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -49,6 +49,7 @@ use std::sync::Arc;
 mod bounded_aggregate_stream;
 mod no_grouping;
 mod row_hash;
+mod row_hash2;
 mod utils;
 
 pub use datafusion_expr::AggregateFunction;
@@ -58,6 +59,8 @@ use datafusion_physical_expr::utils::{
     get_finer_ordering, ordering_satisfy_requirement_concrete,
 };
 
+use self::row_hash2::GroupedHashAggregateStream2;
+
 /// Hash aggregate modes
 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
 pub enum AggregateMode {
@@ -196,6 +199,7 @@ impl PartialEq for PhysicalGroupBy {
 enum StreamType {
     AggregateStream(AggregateStream),
     GroupedHashAggregateStream(GroupedHashAggregateStream),
+    GroupedHashAggregateStream2(GroupedHashAggregateStream2),
     BoundedAggregate(BoundedAggregateStream),
 }
 
@@ -204,6 +208,7 @@ impl From<StreamType> for SendableRecordBatchStream {
         match stream {
             StreamType::AggregateStream(stream) => Box::pin(stream),
             StreamType::GroupedHashAggregateStream(stream) => Box::pin(stream),
+            StreamType::GroupedHashAggregateStream2(stream) => 
Box::pin(stream),
             StreamType::BoundedAggregate(stream) => Box::pin(stream),
         }
     }
@@ -711,12 +716,23 @@ impl AggregateExec {
                 partition,
                 aggregation_ordering,
             )?))
+        } else if self.use_poc_group_by() {
+            Ok(StreamType::GroupedHashAggregateStream2(
+                GroupedHashAggregateStream2::new(self, context, partition)?,
+            ))
         } else {
             Ok(StreamType::GroupedHashAggregateStream(
                 GroupedHashAggregateStream::new(self, context, partition)?,
             ))
         }
     }
+
+    /// Returns true if we should use the POC group by stream
+    /// TODO: check for actually supported aggregates, etc
+    fn use_poc_group_by(&self) -> bool {
+        //info!("AAL Checking POC group by: {self:#?}");
+        true
+    }
 }
 
 impl ExecutionPlan for AggregateExec {
@@ -980,7 +996,7 @@ fn group_schema(schema: &Schema, group_count: usize) -> 
SchemaRef {
     Arc::new(Schema::new(group_fields))
 }
 
-/// returns physical expressions to evaluate against a batch
+/// returns physical expressions for arguments to evaluate against a batch
 /// The expressions are different depending on `mode`:
 /// * Partial: AggregateExpr::expressions
 /// * Final: columns of `AggregateExpr::state_fields()`
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs 
b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
index ba02bc096b..5742c17c1d 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
@@ -17,6 +17,7 @@
 
 //! Hash aggregation through row format
 
+use log::info;
 use std::cmp::min;
 use std::ops::Range;
 use std::sync::Arc;
@@ -119,6 +120,7 @@ impl GroupedHashAggregateStream {
         context: Arc<TaskContext>,
         partition: usize,
     ) -> Result<Self> {
+        info!("Creating GroupedHashAggregateStream");
         let agg_schema = Arc::clone(&agg.schema);
         let agg_group_by = agg.group_by.clone();
         let agg_filter_expr = agg.filter_expr.clone();
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash2.rs 
b/datafusion/core/src/physical_plan/aggregates/row_hash2.rs
new file mode 100644
index 0000000000..90e7cd0724
--- /dev/null
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash2.rs
@@ -0,0 +1,449 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Hash aggregation through row format
+//!
+//! POC demonstration of GroupByHashApproach
+
+use datafusion_physical_expr::GroupsAccumulator;
+use log::info;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+use std::vec;
+
+use ahash::RandomState;
+use arrow::row::{OwnedRow, RowConverter, SortField};
+use datafusion_physical_expr::hash_utils::create_hashes;
+use futures::ready;
+use futures::stream::{Stream, StreamExt};
+
+use crate::physical_plan::aggregates::{
+    evaluate_group_by, evaluate_many, evaluate_optional, group_schema, 
AggregateMode,
+    PhysicalGroupBy,
+};
+use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
+use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use arrow::array::*;
+use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
+use datafusion_common::Result;
+use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
+use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
+use datafusion_execution::TaskContext;
+use hashbrown::raw::RawTable;
+
+#[derive(Debug, Clone)]
+/// This object tracks the aggregation phase (input/output)
+pub(crate) enum ExecutionState {
+    ReadingInput,
+    /// When producing output, the remaining rows to output are stored
+    /// here and are sliced off as needed in batch_size chunks
+    ProducingOutput(RecordBatch),
+    Done,
+}
+
+use super::AggregateExec;
+
+/// Grouping aggregate
+///
+/// For each aggregation entry, we use:
+/// - [Arrow-row] represents grouping keys for fast hash computation and 
comparison directly on raw bytes.
+/// - [GroupsAccumulator] to store per group aggregates
+///
+/// The architecture is the following:
+///
+/// TODO
+///
+/// [WordAligned]: datafusion_row::layout
+pub(crate) struct GroupedHashAggregateStream2 {
+    schema: SchemaRef,
+    input: SendableRecordBatchStream,
+    mode: AggregateMode,
+
+    /// Accumulators, one for each `AggregateExpr` in the query
+    accumulators: Vec<Box<dyn GroupsAccumulator>>,
+    /// Arguments expressionf or each accumulator
+    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+    /// Filter expression to evaluate for each aggregate
+    filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
+
+    /// Converter for each row
+    row_converter: RowConverter,
+    group_by: PhysicalGroupBy,
+
+    /// The memory reservation for this grouping
+    reservation: MemoryReservation,
+
+    /// Logically maps group values to a group_index `group_states`
+    ///
+    /// Uses the raw API of hashbrown to avoid actually storing the
+    /// keys in the table
+    ///
+    /// keys: u64 hashes of the GroupValue
+    /// values: (hash, index into `group_states`)
+    map: RawTable<(u64, usize)>,
+
+    /// The actual group by values, stored in arrow Row format
+    /// the index of group_by_values is the index
+    /// https://github.com/apache/arrow-rs/issues/4466
+    group_by_values: Vec<OwnedRow>,
+
+    /// scratch space for the current Batch / Aggregate being
+    /// processed. Saved here to avoid reallocations
+    current_group_indices: Vec<usize>,
+
+    /// generating input/output?
+    exec_state: ExecutionState,
+
+    baseline_metrics: BaselineMetrics,
+
+    random_state: RandomState,
+    /// size to be used for resulting RecordBatches
+    batch_size: usize,
+}
+
+impl GroupedHashAggregateStream2 {
+    /// Create a new GroupedHashAggregateStream
+    pub fn new(
+        agg: &AggregateExec,
+        context: Arc<TaskContext>,
+        partition: usize,
+    ) -> Result<Self> {
+        info!("Creating GroupedHashAggregateStream2");
+        let agg_schema = Arc::clone(&agg.schema);
+        let agg_group_by = agg.group_by.clone();
+        let agg_filter_expr = agg.filter_expr.clone();
+
+        let batch_size = context.session_config().batch_size();
+        let input = agg.input.execute(partition, Arc::clone(&context))?;
+        let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
+
+        let timer = baseline_metrics.elapsed_compute().timer();
+
+        let mut aggregate_exprs = vec![];
+        let mut aggregate_arguments = vec![];
+
+        // The expressions to evaluate the batch, one vec of expressions per 
aggregation.
+        // Assuming create_schema() always puts group columns in front of 
aggregation columns, we set
+        // col_idx_base to the group expression count.
+
+        let all_aggregate_expressions = aggregates::aggregate_expressions(
+            &agg.aggr_expr,
+            &agg.mode,
+            agg_group_by.expr.len(),
+        )?;
+        let filter_expressions = match agg.mode {
+            AggregateMode::Partial | AggregateMode::Single => agg_filter_expr,
+            AggregateMode::Final | AggregateMode::FinalPartitioned => {
+                vec![None; agg.aggr_expr.len()]
+            }
+        };
+
+        for (agg_expr, agg_args) in agg
+            .aggr_expr
+            .iter()
+            .zip(all_aggregate_expressions.into_iter())
+        {
+            aggregate_exprs.push(agg_expr.clone());
+            aggregate_arguments.push(agg_args);
+        }
+
+        let accumulators = create_accumulators(aggregate_exprs)?;
+
+        let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
+        let row_converter = RowConverter::new(
+            group_schema
+                .fields()
+                .iter()
+                .map(|f| SortField::new(f.data_type().clone()))
+                .collect(),
+        )?;
+
+        let name = format!("GroupedHashAggregateStream2[{partition}]");
+        let reservation = 
MemoryConsumer::new(name).register(context.memory_pool());
+        let map = RawTable::with_capacity(0);
+        let group_by_values = vec![];
+        let current_group_indices = vec![];
+
+        timer.done();
+
+        let exec_state = ExecutionState::ReadingInput;
+
+        Ok(GroupedHashAggregateStream2 {
+            schema: agg_schema,
+            input,
+            mode: agg.mode,
+            accumulators,
+            aggregate_arguments,
+            filter_expressions,
+            row_converter,
+            group_by: agg_group_by,
+            reservation,
+            map,
+            group_by_values,
+            current_group_indices,
+            exec_state,
+            baseline_metrics,
+            random_state: Default::default(),
+            batch_size,
+        })
+    }
+}
+
+/// Crate a `GroupsAccumulator` for each of the aggregate_exprs to hold the 
aggregation state
+fn create_accumulators(
+    aggregate_exprs: Vec<Arc<dyn AggregateExpr>>,
+) -> Result<Vec<Box<dyn GroupsAccumulator>>> {
+    info!("Creating accumulator for {aggregate_exprs:#?}");
+    aggregate_exprs
+        .into_iter()
+        .map(|agg_expr| agg_expr.create_groups_accumulator())
+        .collect()
+}
+
+impl Stream for GroupedHashAggregateStream2 {
+    type Item = Result<RecordBatch>;
+
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
+
+        loop {
+            let exec_state = self.exec_state.clone();
+            match exec_state {
+                ExecutionState::ReadingInput => {
+                    match ready!(self.input.poll_next_unpin(cx)) {
+                        // new batch to aggregate
+                        Some(Ok(batch)) => {
+                            let timer = elapsed_compute.timer();
+                            let result = self.group_aggregate_batch(batch);
+                            timer.done();
+
+                            // allocate memory
+                            // This happens AFTER we actually used the memory, 
but simplifies the whole accounting and we are OK with
+                            // overshooting a bit. Also this means we either 
store the whole record batch or not.
+                            let result = result.and_then(|allocated| {
+                                self.reservation.try_grow(allocated)
+                            });
+
+                            if let Err(e) = result {
+                                return Poll::Ready(Some(Err(e)));
+                            }
+                        }
+                        // inner had error, return to caller
+                        Some(Err(e)) => return Poll::Ready(Some(Err(e))),
+                        // inner is done, producing output
+                        None => {
+                            let timer = elapsed_compute.timer();
+                            match self.create_batch_from_map() {
+                                Ok(batch) => {
+                                    self.exec_state =
+                                        ExecutionState::ProducingOutput(batch)
+                                }
+                                Err(e) => return Poll::Ready(Some(Err(e))),
+                            }
+                            timer.done();
+                        }
+                    }
+                }
+
+                ExecutionState::ProducingOutput(batch) => {
+                    // slice off a part of the batch, if needed
+                    let output_batch = if batch.num_rows() <= self.batch_size {
+                        self.exec_state = ExecutionState::Done;
+                        batch
+                    } else {
+                        // output first batch_size rows
+                        let num_remaining = batch.num_rows() - self.batch_size;
+                        let remaining = batch.slice(self.batch_size, 
num_remaining);
+                        self.exec_state = 
ExecutionState::ProducingOutput(remaining);
+                        batch.slice(0, self.batch_size)
+                    };
+                    return Poll::Ready(Some(Ok(
+                        output_batch.record_output(&self.baseline_metrics)
+                    )));
+                }
+
+                ExecutionState::Done => return Poll::Ready(None),
+            }
+        }
+    }
+}
+
+impl RecordBatchStream for GroupedHashAggregateStream2 {
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
+impl GroupedHashAggregateStream2 {
+    /// Update self.aggr_state based on the group_by values (result of 
evalauting the group_by_expressions)
+    ///
+    /// At the return of this function,
+    /// `self.aggr_state.current_group_indices` has the correct
+    /// group_index for each row in the group_values
+    fn update_group_state(
+        &mut self,
+        group_values: &[ArrayRef],
+        allocated: &mut usize,
+    ) -> Result<()> {
+        // Convert the group keys into the row format
+        let group_rows = self.row_converter.convert_columns(group_values)?;
+        let n_rows = group_rows.num_rows();
+        // 1.1 construct the key from the group values
+        // 1.2 construct the mapping key if it does not exist
+
+        // tracks to which group each of the input rows belongs
+        let group_indices = &mut self.current_group_indices;
+        group_indices.clear();
+
+        // 1.1 Calculate the group keys for the group values
+        let mut batch_hashes = vec![0; n_rows];
+        create_hashes(group_values, &self.random_state, &mut batch_hashes)?;
+
+        for (row, hash) in batch_hashes.into_iter().enumerate() {
+            let entry = self.map.get_mut(hash, |(_hash, group_idx)| {
+                // verify that a group that we are inserting with hash is
+                // actually the same key value as the group in
+                // existing_idx  (aka group_values @ row)
+
+                // TODO update *allocated based on size of the row
+                // that was just pushed into
+                // aggr_state.group_by_values
+                group_rows.row(row) == self.group_by_values[*group_idx].row()
+            });
+
+            let group_idx = match entry {
+                // Existing group_index for this group value
+                Some((_hash, group_idx)) => *group_idx,
+                //  1.2 Need to create new entry for the group
+                None => {
+                    // Add new entry to aggr_state and save newly created index
+                    let group_idx = self.group_by_values.len();
+                    self.group_by_values.push(group_rows.row(row).owned());
+
+                    // for hasher function, use precomputed hash value
+                    self.map.insert_accounted(
+                        (hash, group_idx),
+                        |(hash, _group_index)| *hash,
+                        allocated,
+                    );
+                    group_idx
+                }
+            };
+            group_indices.push_accounted(group_idx, allocated);
+        }
+        Ok(())
+    }
+
+    /// Perform group-by aggregation for the given [`RecordBatch`].
+    ///
+    /// If successful, returns the additional amount of memory, in
+    /// bytes, that were allocated during this process.
+    ///
+    fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<usize> {
+        // Evaluate the grouping expressions:
+        let group_by_values = evaluate_group_by(&self.group_by, &batch)?;
+
+        // Keep track of memory allocated:
+        let mut allocated = 0usize;
+
+        // Evaluate the aggregation expressions.
+        let input_values = evaluate_many(&self.aggregate_arguments, &batch)?;
+        // Evalaute the filter expressions, if any, against the inputs
+        let filter_values = evaluate_optional(&self.filter_expressions, 
&batch)?;
+
+        let row_converter_size_pre = self.row_converter.size();
+        for group_values in &group_by_values {
+            // calculate the group indicies for each input row
+            self.update_group_state(group_values, &mut allocated)?;
+            let group_indices = &self.current_group_indices;
+
+            // Gather the inputs to call the actual aggregation
+            let t = self
+                .accumulators
+                .iter_mut()
+                .zip(input_values.iter())
+                .zip(filter_values.iter());
+
+            let total_num_groups = self.group_by_values.len();
+
+            for ((acc, values), opt_filter) in t {
+                let acc_size_pre = acc.size();
+                let opt_filter = opt_filter.as_ref().map(|filter| 
filter.as_boolean());
+
+                match self.mode {
+                    AggregateMode::Partial | AggregateMode::Single => {
+                        acc.update_batch(
+                            values,
+                            &group_indices,
+                            opt_filter,
+                            total_num_groups,
+                        )?;
+                    }
+                    AggregateMode::FinalPartitioned | AggregateMode::Final => {
+                        // if aggregation is over intermediate states,
+                        // use merge
+                        acc.merge_batch(
+                            values,
+                            &group_indices,
+                            opt_filter,
+                            total_num_groups,
+                        )?;
+                    }
+                }
+
+                allocated += acc.size().saturating_sub(acc_size_pre);
+            }
+        }
+        allocated += self
+            .row_converter
+            .size()
+            .saturating_sub(row_converter_size_pre);
+
+        Ok(allocated)
+    }
+}
+
+impl GroupedHashAggregateStream2 {
+    /// Create an output RecordBatch with all group keys and accumulator 
states/values
+    fn create_batch_from_map(&mut self) -> Result<RecordBatch> {
+        if self.group_by_values.is_empty() {
+            let schema = self.schema.clone();
+            return Ok(RecordBatch::new_empty(schema));
+        }
+
+        // First output rows are the groups
+        let groups_rows = self.group_by_values.iter().map(|owned_row| 
owned_row.row());
+
+        let mut output: Vec<ArrayRef> = 
self.row_converter.convert_rows(groups_rows)?;
+
+        // Next output the accumulators
+        for acc in self.accumulators.iter_mut() {
+            match self.mode {
+                AggregateMode::Partial => output.extend(acc.state()?),
+                AggregateMode::Final
+                | AggregateMode::FinalPartitioned
+                | AggregateMode::Single => output.push(acc.evaluate()?),
+            }
+        }
+
+        Ok(RecordBatch::try_new(self.schema.clone(), output)?)
+    }
+}
diff --git a/datafusion/physical-expr/Cargo.toml 
b/datafusion/physical-expr/Cargo.toml
index 04ba2b9e38..a8f82e60e4 100644
--- a/datafusion/physical-expr/Cargo.toml
+++ b/datafusion/physical-expr/Cargo.toml
@@ -59,6 +59,7 @@ indexmap = "2.0.0"
 itertools = { version = "0.11", features = ["use_std"] }
 lazy_static = { version = "^1.4.0" }
 libc = "0.2.140"
+log = "^0.4"
 md-5 = { version = "^0.10.0", optional = true }
 paste = "^1.0"
 petgraph = "0.6.2"
diff --git a/datafusion/physical-expr/src/aggregate/average.rs 
b/datafusion/physical-expr/src/aggregate/average.rs
index 3c76da51a9..f81c704d8b 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -17,6 +17,9 @@
 
 //! Defines physical expressions that can evaluated at runtime during query 
execution
 
+use arrow::array::AsArray;
+use log::info;
+
 use std::any::Any;
 use std::convert::TryFrom;
 use std::sync::Arc;
@@ -29,14 +32,14 @@ use crate::aggregate::sum::sum_batch;
 use crate::aggregate::utils::calculate_result_decimal_for_avg;
 use crate::aggregate::utils::down_cast_any_ref;
 use crate::expressions::format_state_name;
-use crate::{AggregateExpr, PhysicalExpr};
+use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr};
 use arrow::compute;
-use arrow::datatypes::DataType;
+use arrow::datatypes::{DataType, Decimal128Type, UInt64Type};
 use arrow::{
     array::{ArrayRef, UInt64Array},
     datatypes::Field,
 };
-use arrow_array::Array;
+use arrow_array::{Array, ArrowNativeTypeOp, ArrowNumericType, PrimitiveArray};
 use datafusion_common::{downcast_value, ScalarValue};
 use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::Accumulator;
@@ -155,6 +158,22 @@ impl AggregateExpr for Avg {
             &self.rt_data_type,
         )?))
     }
+
+    fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+        // instantiate specialized accumulator
+        match self.sum_data_type {
+            DataType::Decimal128(_, _) => {
+                Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type>::new(
+                    &self.sum_data_type,
+                    &self.rt_data_type,
+                )))
+            }
+            _ => Err(DataFusionError::NotImplemented(format!(
+                "AvgGroupsAccumulator for {}",
+                self.sum_data_type
+            ))),
+        }
+    }
 }
 
 impl PartialEq<dyn Any> for Avg {
@@ -383,6 +402,199 @@ impl RowAccumulator for AvgRowAccumulator {
     }
 }
 
+/// An accumulator to compute the average of PrimitiveArray<T>.
+/// Stores values as native types
+#[derive(Debug)]
+struct AvgGroupsAccumulator<T: ArrowNumericType + Send> {
+    /// The type of the internal sum
+    sum_data_type: DataType,
+
+    /// The type of the returned sum
+    return_data_type: DataType,
+
+    /// Count per group (use u64 to make UInt64Array)
+    counts: Vec<u64>,
+
+    // Sums per group, stored as the native type
+    sums: Vec<T::Native>,
+}
+
+impl<T: ArrowNumericType + Send> AvgGroupsAccumulator<T> {
+    pub fn new(sum_data_type: &DataType, return_data_type: &DataType) -> Self {
+        info!(
+            "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> 
{return_data_type:?}",
+            std::any::type_name::<T>()
+        );
+        Self {
+            return_data_type: return_data_type.clone(),
+            sum_data_type: sum_data_type.clone(),
+            counts: vec![],
+            sums: vec![],
+        }
+    }
+
+    /// Adds the values in `values` to self.sums
+    fn update_sums(
+        &mut self,
+        values: &PrimitiveArray<T>,
+        group_indicies: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        self.sums
+            .resize_with(total_num_groups, || T::default_value());
+
+        // AAL TODO
+        // TODO combine the null mask from values and opt_filter
+        let valids = values.nulls();
+
+        // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum
+        let data: &[T::Native] = values.values();
+
+        match valids {
+            // use all values in group_index
+            None => {
+                let iter = group_indicies.iter().zip(data.iter());
+                for (group_index, new_value) in iter {
+                    self.sums[*group_index].add_wrapping(*new_value);
+                }
+            }
+            //
+            Some(valids) => {
+                let group_indices_chunks = group_indicies.chunks_exact(64);
+                let data_chunks = data.chunks_exact(64);
+                let bit_chunks = valids.inner().bit_chunks();
+
+                let group_indices_remainder = group_indices_chunks.remainder();
+                let data_remainder = data_chunks.remainder();
+
+                group_indices_chunks
+                    .zip(data_chunks)
+                    .zip(bit_chunks.iter())
+                    .for_each(|((group_index_chunk, data_chunk), mask)| {
+                        // index_mask has value 1 << i in the loop
+                        let mut index_mask = 1;
+                        
group_index_chunk.iter().zip(data_chunk.iter()).for_each(
+                            |(group_index, new_value)| {
+                                if (mask & index_mask) != 0 {
+                                    
self.sums[*group_index].add_wrapping(*new_value);
+                                }
+                                index_mask <<= 1;
+                            },
+                        )
+                    });
+
+                let remainder_bits = bit_chunks.remainder_bits();
+                group_indices_remainder
+                    .iter()
+                    .zip(data_remainder.iter())
+                    .enumerate()
+                    .for_each(|(i, (group_index, new_value))| {
+                        if remainder_bits & (1 << i) != 0 {
+                            self.sums[*group_index].add_wrapping(*new_value);
+                        }
+                    });
+            }
+        }
+        Ok(())
+    }
+}
+
+impl<T: ArrowNumericType + Send> GroupsAccumulator for AvgGroupsAccumulator<T> 
{
+    fn update_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indicies: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        assert_eq!(values.len(), 1, "single argument to update_batch");
+        let values = values.get(0).unwrap().as_primitive::<T>();
+
+        // update counts (TOD account for opt_filter)
+        self.counts.resize(total_num_groups, 0);
+        group_indicies.iter().for_each(|&group_idx| {
+            self.counts[group_idx] += 1;
+        });
+
+        // update values
+        self.update_sums(values, group_indicies, opt_filter, 
total_num_groups)?;
+        Ok(())
+    }
+
+    fn merge_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indicies: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        assert_eq!(values.len(), 2, "two arguments to merge_batch");
+        // first batch is counts, second is partial sums
+        let counts = values.get(0).unwrap().as_primitive::<UInt64Type>();
+        let partial_sums = values.get(1).unwrap().as_primitive::<T>();
+
+        // update counts by summing the partial sums (TODO account for 
opt_filter)
+        self.counts.resize(total_num_groups, 0);
+        group_indicies.iter().zip(counts.values().iter()).for_each(
+            |(&group_idx, &count)| {
+                self.counts[group_idx] += count;
+            },
+        );
+
+        // update values
+        self.update_sums(partial_sums, group_indicies, opt_filter, 
total_num_groups)?;
+
+        Ok(())
+    }
+
+    fn evaluate(&mut self) -> Result<ArrayRef> {
+        todo!()
+    }
+
+    // return arrays for sums and counts
+    fn state(&mut self) -> Result<Vec<ArrayRef>> {
+        let counts = std::mem::take(&mut self.counts);
+        // create array from vec is zero copy
+        let counts = UInt64Array::from(counts);
+
+        let sums = std::mem::take(&mut self.sums);
+        // create array from vec is zero copy
+        // TODO figure out how to do this without the iter / copy
+        let sums: PrimitiveArray<T> = PrimitiveArray::from_iter_values(sums);
+
+        // fix up decimal precision and scale
+        let sums = set_decimal_precision(&self.sum_data_type, Arc::new(sums))?;
+
+        Ok(vec![
+            Arc::new(counts) as ArrayRef,
+            Arc::new(sums) as ArrayRef,
+        ])
+    }
+
+    fn size(&self) -> usize {
+        self.counts.capacity() * std::mem::size_of::<usize>()
+    }
+}
+
+/// Adjust array type metadata if needed
+///
+/// Decimal128Arrays are are are created from Vec<NativeType> with default
+/// precision and scale. This function adjusts them down.
+fn set_decimal_precision(sum_data_type: &DataType, array: ArrayRef) -> 
Result<ArrayRef> {
+    let array = match sum_data_type {
+        DataType::Decimal128(p, s) => Arc::new(
+            array
+                .as_primitive::<Decimal128Type>()
+                .clone()
+                .with_precision_and_scale(*p, *s)?,
+        ),
+        // no adjustment needed for other arrays
+        _ => array,
+    };
+    Ok(array)
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator.rs 
b/datafusion/physical-expr/src/aggregate/groups_accumulator.rs
new file mode 100644
index 0000000000..82cfbfaa31
--- /dev/null
+++ b/datafusion/physical-expr/src/aggregate/groups_accumulator.rs
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Vectorized [`GroupsAccumulator`]
+
+use arrow_array::{ArrayRef, BooleanArray};
+use datafusion_common::Result;
+
+/// An implementation of GroupAccumulator is for a single aggregate
+/// (e.g. AVG) and stores the state for *all* groups internally
+///
+/// The logical model is that each group is given a `group_index`
+/// assigned and maintained by the hash table.
+///
+/// group_indexes are contiguous (there aren't gaps), and thus it is
+/// expected that each GroupAccumulator will use something like `Vec<..>`
+/// to store the group states.
+pub trait GroupsAccumulator: Send {
+    /// updates the accumulator's state from a vector of arrays:
+    ///
+    /// * `values`: the input arguments to the accumulator
+    /// * `group_indices`:  To which groups do the rows in `values` belong, 
group id)
+    /// * `opt_filter`: if present, only update aggregate state using 
values[i] if opt_filter[i] is true
+    /// * `total_num_groups`: the number of groups (the largest group_index is 
total_num_groups - 1)
+    fn update_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indicies: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()>;
+
+    /// Returns the final aggregate value for each group as a single
+    /// `RecordBatch`
+    ///
+    /// OPEN QUESTION: Should this method take a "batch_size: usize"
+    /// and produce a Vec<RecordBatch> as output to avoid 1) requiring
+    /// one giant intermediate buffer?
+    ///
+    /// For example, the `SUM` accumulator maintains a running sum,
+    /// and `evaluate` will produce that running sum as its output for
+    /// all groups, in group_index order
+    ///
+    /// This call should be treated as consuming (takes `self`, but it
+    /// can not be due to keeping it object save) the accumulator is
+    /// free to release / reset it is internal state after this call
+    /// and error on any subsequent call.
+    fn evaluate(&mut self) -> Result<ArrayRef>;
+
+    /// Returns any intermediate aggregate state used for multi-phase grouping
+    ///
+    /// For example, AVG returns two arrays:  `SUM` and `COUNT`.
+    ///
+    /// This call should be treated as consuming (takes `self`, but it
+    /// can not be due to keeping it object save) the accumulator is
+    /// free to release / reset it is internal state after this call
+    /// and error on any subsequent call.
+    ///
+    /// TODO: consider returning a single Array (which could be a
+    /// StructArray) instead
+    fn state(&mut self) -> Result<Vec<ArrayRef>>;
+
+    /// merges intermediate state (from `state()`) into this accumulators 
values
+    ///
+    /// For some aggregates (such as `SUM`), merge_batch is the same
+    /// as `update_batch`, but for some aggregrates (such as `COUNT`)
+    /// the operations differ. See [`Self::state`] for more details on how
+    /// state is used and merged.
+    ///
+    /// * `values`: arrays produced from calling `state` previously to the 
accumulator
+    /// * `group_indices`:  To which groups do the rows in `values` belong, 
group id)
+    /// * `opt_filter`: if present, only update aggregate state using 
values[i] if opt_filter[i] is true
+    /// * `total_num_groups`: the number of groups (the largest group_index is 
total_num_groups - 1)
+    fn merge_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indicies: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()>;
+
+    /// Amount of memory used to store the state of this
+    /// accumulator. This function is called once per batch, so it
+    /// should be O(n) to compute
+    fn size(&self) -> usize;
+}
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs 
b/datafusion/physical-expr/src/aggregate/mod.rs
index 9be6d5e1ba..4b613c8e9b 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -25,6 +25,8 @@ use std::any::Any;
 use std::fmt::Debug;
 use std::sync::Arc;
 
+use self::groups_accumulator::GroupsAccumulator;
+
 pub(crate) mod approx_distinct;
 pub(crate) mod approx_median;
 pub(crate) mod approx_percentile_cont;
@@ -45,6 +47,7 @@ pub(crate) mod median;
 #[macro_use]
 pub(crate) mod min_max;
 pub mod build_in;
+pub(crate) mod groups_accumulator;
 mod hyperloglog;
 pub mod moving_min_max;
 pub mod row_accumulator;
@@ -118,6 +121,18 @@ pub trait AggregateExpr: Send + Sync + Debug + 
PartialEq<dyn Any> {
         )))
     }
 
+    /// Return a specialized [`GroupsAccumulator`] that manages state for all 
groups
+    ///
+    /// For maximum performance, [`GroupsAccumulator`] should be
+    /// implemented rather than [`Accumulator`].
+    fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+        // TODO: The default should implement a wrapper over
+        // sef.create_accumulator
+        Err(DataFusionError::NotImplemented(format!(
+            "GroupsAccumulator hasn't been implemented for {self:?} yet"
+        )))
+    }
+
     /// Construct an expression that calculates the aggregate in reverse.
     /// Typically the "reverse" expression is itself (e.g. SUM, COUNT).
     /// For aggregates that do not support calculation in reverse,
diff --git a/datafusion/physical-expr/src/lib.rs 
b/datafusion/physical-expr/src/lib.rs
index 0a2e0e58df..6ea8dc9487 100644
--- a/datafusion/physical-expr/src/lib.rs
+++ b/datafusion/physical-expr/src/lib.rs
@@ -45,7 +45,9 @@ pub mod var_provider;
 pub mod window;
 
 // reexport this to maintain compatibility with anything that used from_slice 
previously
+pub use aggregate::groups_accumulator::GroupsAccumulator;
 pub use aggregate::AggregateExpr;
+
 pub use equivalence::{
     project_equivalence_properties, project_ordering_equivalence_properties,
     EquivalenceProperties, EquivalentClass, OrderingEquivalenceProperties,


Reply via email to