This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 7c3ea0540c feat: add AggregateMode::PartialReduce for tree-reduce 
aggregation (#20019)
7c3ea0540c is described below

commit 7c3ea0540ca8793eba44b16d2a62e9cff02e3a8f
Author: Nathaniel J. Smith <[email protected]>
AuthorDate: Fri Jan 30 01:54:50 2026 -0800

    feat: add AggregateMode::PartialReduce for tree-reduce aggregation (#20019)
    
    DataFusion's current `AggregateMode` enum has four variants covering
    three of the four cells in the input/output matrix:
    
    | | Input: raw data          | Input: partial state |
    | - | - | - |
    | Output: final values | `Single` / `SinglePartitioned` | `Final` /
    `FinalPartitioned` |
    | Output: partial state | `Partial`                   | ??? |
    
    This PR adds `AggregateMode::PartialReduce` to fill in the missing cell:
    it takes partially-reduced values as input, and reduces them further,
    but without finalizing.
    
    This is useful because it's the key component needed to implement
    distributed tree-reduction (as seen in e.g. the Scuba or Honeycomb
    papers): a set of worker nodes each perform multithreaded `Partial`
    aggregations, feed those into a `PartialReduce` to reduce all of this
    node's values into a single row, and then a head node collects the
    outputs from all nodes' `PartialReduce` to feed into a `Final`
    reduction.
    
    PR can be reviewed commit by commit: first commit is pure
    refactor/simplification; most places we were matching on `AggregateMode`
    we were actually just trying to either check which row of the above
    table we were in, or else which column. So now we have `is_first_stage`
    (tells you which column) and `is_last_stage` (tells you which row) and
    we use them everywhere.
    
    Second commit adds `PartialReduce`, and is pretty small because
    `is_first_stage`/`is_last_stage` do most of the heavy lifting. It also
    adds a test demonstrating a minimal Partial -> PartialReduce -> Final
    tree-reduction.
---
 .../physical-optimizer/src/aggregate_statistics.rs |   6 +-
 .../physical-optimizer/src/update_aggr_exprs.rs    |   6 +-
 datafusion/physical-plan/src/aggregates/mod.rs     | 294 ++++++++++++++++++---
 .../physical-plan/src/aggregates/no_grouping.rs    |  25 +-
 .../physical-plan/src/aggregates/row_hash.rs       |  72 ++---
 datafusion/proto/proto/datafusion.proto            |   1 +
 datafusion/proto/src/generated/pbjson.rs           |   3 +
 datafusion/proto/src/generated/prost.rs            |   3 +
 datafusion/proto/src/physical_plan/mod.rs          |   2 +
 9 files changed, 304 insertions(+), 108 deletions(-)

diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs 
b/datafusion/physical-optimizer/src/aggregate_statistics.rs
index cf3c15509c..5caee8b047 100644
--- a/datafusion/physical-optimizer/src/aggregate_statistics.rs
+++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs
@@ -20,7 +20,7 @@ use datafusion_common::Result;
 use datafusion_common::config::ConfigOptions;
 use datafusion_common::scalar::ScalarValue;
 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
-use datafusion_physical_plan::aggregates::AggregateExec;
+use datafusion_physical_plan::aggregates::{AggregateExec, AggregateInputMode};
 use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
 use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr};
 use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
@@ -116,13 +116,13 @@ impl PhysicalOptimizerRule for AggregateStatistics {
 /// the `ExecutionPlan.children()` method that returns an owned reference.
 fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn 
ExecutionPlan>> {
     if let Some(final_agg_exec) = node.as_any().downcast_ref::<AggregateExec>()
-        && !final_agg_exec.mode().is_first_stage()
+        && final_agg_exec.mode().input_mode() == AggregateInputMode::Partial
         && final_agg_exec.group_expr().is_empty()
     {
         let mut child = Arc::clone(final_agg_exec.input());
         loop {
             if let Some(partial_agg_exec) = 
child.as_any().downcast_ref::<AggregateExec>()
-                && partial_agg_exec.mode().is_first_stage()
+                && partial_agg_exec.mode().input_mode() == 
AggregateInputMode::Raw
                 && partial_agg_exec.group_expr().is_empty()
                 && partial_agg_exec.filter_expr().iter().all(|e| e.is_none())
             {
diff --git a/datafusion/physical-optimizer/src/update_aggr_exprs.rs 
b/datafusion/physical-optimizer/src/update_aggr_exprs.rs
index c0aab4080d..67127c2a23 100644
--- a/datafusion/physical-optimizer/src/update_aggr_exprs.rs
+++ b/datafusion/physical-optimizer/src/update_aggr_exprs.rs
@@ -25,7 +25,9 @@ use datafusion_common::tree_node::{Transformed, 
TransformedResult, TreeNode};
 use datafusion_common::{Result, plan_datafusion_err};
 use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
 use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement};
-use datafusion_physical_plan::aggregates::{AggregateExec, concat_slices};
+use datafusion_physical_plan::aggregates::{
+    AggregateExec, AggregateInputMode, concat_slices,
+};
 use datafusion_physical_plan::windows::get_ordered_partition_by_indices;
 use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
 
@@ -81,7 +83,7 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder {
                 // ordering fields may be pruned out by first stage aggregates.
                 // Hence, necessary information for proper merge is added 
during
                 // the first stage to the state field, which the final stage 
uses.
-                if !aggr_exec.mode().is_first_stage() {
+                if aggr_exec.mode().input_mode() == 
AggregateInputMode::Partial {
                     return Ok(Transformed::no(plan));
                 }
                 let input = aggr_exec.input();
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs 
b/datafusion/physical-plan/src/aggregates/mod.rs
index d645f5c55d..aa0f5a236c 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -89,10 +89,54 @@ pub fn topk_types_supported(key_type: &DataType, 
value_type: &DataType) -> bool
 const AGGREGATION_HASH_SEED: ahash::RandomState =
     ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as 
u64);
 
+/// Whether an aggregate stage consumes raw input data or intermediate
+/// accumulator state from a previous aggregation stage.
+///
+/// See the [table on 
`AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes)
+/// for how this relates to aggregate modes.
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum AggregateInputMode {
+    /// The stage consumes raw, unaggregated input data and calls
+    /// [`Accumulator::update_batch`].
+    Raw,
+    /// The stage consumes intermediate accumulator state from a previous
+    /// aggregation stage and calls [`Accumulator::merge_batch`].
+    Partial,
+}
+
+/// Whether an aggregate stage produces intermediate accumulator state
+/// or final output values.
+///
+/// See the [table on 
`AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes)
+/// for how this relates to aggregate modes.
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum AggregateOutputMode {
+    /// The stage produces intermediate accumulator state, serialized via
+    /// [`Accumulator::state`].
+    Partial,
+    /// The stage produces final output values via
+    /// [`Accumulator::evaluate`].
+    Final,
+}
+
 /// Aggregation modes
 ///
 /// See [`Accumulator::state`] for background information on multi-phase
 /// aggregation and how these modes are used.
+///
+/// # Variants and their input/output modes
+///
+/// Each variant can be characterized by its [`AggregateInputMode`] and
+/// [`AggregateOutputMode`]:
+///
+/// ```text
+///                       | Input: Raw data           | Input: Partial state
+/// Output: Final values  | Single, SinglePartitioned | Final, FinalPartitioned
+/// Output: Partial state | Partial                   | PartialReduce
+/// ```
+///
+/// Use [`AggregateMode::input_mode`] and [`AggregateMode::output_mode`]
+/// to query these properties.
 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
 pub enum AggregateMode {
     /// One of multiple layers of aggregation, any input partitioning
@@ -144,18 +188,56 @@ pub enum AggregateMode {
     /// This mode requires that the input has more than one partition, and is
     /// partitioned by group key (like FinalPartitioned).
     SinglePartitioned,
+    /// Combine multiple partial aggregations to produce a new partial
+    /// aggregation.
+    ///
+    /// Input is intermediate accumulator state (like Final), but output is
+    /// also intermediate accumulator state (like Partial). This enables
+    /// tree-reduce aggregation strategies where partial results from
+    /// multiple workers are combined in multiple stages before a final
+    /// evaluation.
+    ///
+    /// ```text
+    ///               Final
+    ///            /        \
+    ///     PartialReduce   PartialReduce
+    ///     /         \      /         \
+    ///  Partial   Partial  Partial   Partial
+    /// ```
+    PartialReduce,
 }
 
 impl AggregateMode {
-    /// Checks whether this aggregation step describes a "first stage" 
calculation.
-    /// In other words, its input is not another aggregation result and the
-    /// `merge_batch` method will not be called for these modes.
-    pub fn is_first_stage(&self) -> bool {
+    /// Returns the [`AggregateInputMode`] for this mode: whether this
+    /// stage consumes raw input data or intermediate accumulator state.
+    ///
+    /// See the [table 
above](AggregateMode#variants-and-their-inputoutput-modes)
+    /// for details.
+    pub fn input_mode(&self) -> AggregateInputMode {
         match self {
             AggregateMode::Partial
             | AggregateMode::Single
-            | AggregateMode::SinglePartitioned => true,
-            AggregateMode::Final | AggregateMode::FinalPartitioned => false,
+            | AggregateMode::SinglePartitioned => AggregateInputMode::Raw,
+            AggregateMode::Final
+            | AggregateMode::FinalPartitioned
+            | AggregateMode::PartialReduce => AggregateInputMode::Partial,
+        }
+    }
+
+    /// Returns the [`AggregateOutputMode`] for this mode: whether this
+    /// stage produces intermediate accumulator state or final output values.
+    ///
+    /// See the [table 
above](AggregateMode#variants-and-their-inputoutput-modes)
+    /// for details.
+    pub fn output_mode(&self) -> AggregateOutputMode {
+        match self {
+            AggregateMode::Final
+            | AggregateMode::FinalPartitioned
+            | AggregateMode::Single
+            | AggregateMode::SinglePartitioned => AggregateOutputMode::Final,
+            AggregateMode::Partial | AggregateMode::PartialReduce => {
+                AggregateOutputMode::Partial
+            }
         }
     }
 }
@@ -917,14 +999,15 @@ impl AggregateExec {
 
         // Get output partitioning:
         let input_partitioning = input.output_partitioning().clone();
-        let output_partitioning = if mode.is_first_stage() {
-            // First stage aggregation will not change the output partitioning,
-            // but needs to respect aliases (e.g. mapping in the GROUP BY
-            // expression).
-            let input_eq_properties = input.equivalence_properties();
-            input_partitioning.project(group_expr_mapping, input_eq_properties)
-        } else {
-            input_partitioning.clone()
+        let output_partitioning = match mode.input_mode() {
+            AggregateInputMode::Raw => {
+                // First stage aggregation will not change the output 
partitioning,
+                // but needs to respect aliases (e.g. mapping in the GROUP BY
+                // expression).
+                let input_eq_properties = input.equivalence_properties();
+                input_partitioning.project(group_expr_mapping, 
input_eq_properties)
+            }
+            AggregateInputMode::Partial => input_partitioning.clone(),
         };
 
         // TODO: Emission type and boundedness information can be enhanced here
@@ -1248,7 +1331,7 @@ impl ExecutionPlan for AggregateExec {
 
     fn required_input_distribution(&self) -> Vec<Distribution> {
         match &self.mode {
-            AggregateMode::Partial => {
+            AggregateMode::Partial | AggregateMode::PartialReduce => {
                 vec![Distribution::UnspecifiedDistribution]
             }
             AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned 
=> {
@@ -1477,20 +1560,17 @@ fn create_schema(
     let mut fields = Vec::with_capacity(group_by.num_output_exprs() + 
aggr_expr.len());
     fields.extend(group_by.output_fields(input_schema)?);
 
-    match mode {
-        AggregateMode::Partial => {
-            // in partial mode, the fields of the accumulator's state
+    match mode.output_mode() {
+        AggregateOutputMode::Final => {
+            // in final mode, the field with the final result of the 
accumulator
             for expr in aggr_expr {
-                fields.extend(expr.state_fields()?.iter().cloned());
+                fields.push(expr.field())
             }
         }
-        AggregateMode::Final
-        | AggregateMode::FinalPartitioned
-        | AggregateMode::Single
-        | AggregateMode::SinglePartitioned => {
-            // in final mode, the field with the final result of the 
accumulator
+        AggregateOutputMode::Partial => {
+            // in partial mode, the fields of the accumulator's state
             for expr in aggr_expr {
-                fields.push(expr.field())
+                fields.extend(expr.state_fields()?.iter().cloned());
             }
         }
     }
@@ -1530,7 +1610,7 @@ fn get_aggregate_expr_req(
     // If the aggregation is performing a "second stage" calculation,
     // then ignore the ordering requirement. Ordering requirement applies
     // only to the aggregation input data.
-    if !agg_mode.is_first_stage() {
+    if agg_mode.input_mode() == AggregateInputMode::Partial {
         return None;
     }
 
@@ -1696,10 +1776,8 @@ pub fn aggregate_expressions(
     mode: &AggregateMode,
     col_idx_base: usize,
 ) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
-    match mode {
-        AggregateMode::Partial
-        | AggregateMode::Single
-        | AggregateMode::SinglePartitioned => Ok(aggr_expr
+    match mode.input_mode() {
+        AggregateInputMode::Raw => Ok(aggr_expr
             .iter()
             .map(|agg| {
                 let mut result = agg.expressions();
@@ -1710,8 +1788,8 @@ pub fn aggregate_expressions(
                 result
             })
             .collect()),
-        // In this mode, we build the merge expressions of the aggregation.
-        AggregateMode::Final | AggregateMode::FinalPartitioned => {
+        AggregateInputMode::Partial => {
+            // In merge mode, we build the merge expressions of the 
aggregation.
             let mut col_idx_base = col_idx_base;
             aggr_expr
                 .iter()
@@ -1759,8 +1837,15 @@ pub fn finalize_aggregation(
     accumulators: &mut [AccumulatorItem],
     mode: &AggregateMode,
 ) -> Result<Vec<ArrayRef>> {
-    match mode {
-        AggregateMode::Partial => {
+    match mode.output_mode() {
+        AggregateOutputMode::Final => {
+            // Merge the state to the final value
+            accumulators
+                .iter_mut()
+                .map(|accumulator| accumulator.evaluate().and_then(|v| 
v.to_array()))
+                .collect()
+        }
+        AggregateOutputMode::Partial => {
             // Build the vector of states
             accumulators
                 .iter_mut()
@@ -1774,16 +1859,6 @@ pub fn finalize_aggregation(
                 .flatten_ok()
                 .collect()
         }
-        AggregateMode::Final
-        | AggregateMode::FinalPartitioned
-        | AggregateMode::Single
-        | AggregateMode::SinglePartitioned => {
-            // Merge the state to the final value
-            accumulators
-                .iter_mut()
-                .map(|accumulator| accumulator.evaluate().and_then(|v| 
v.to_array()))
-                .collect()
-        }
     }
 }
 
@@ -3745,4 +3820,135 @@ mod tests {
         }
         Ok(())
     }
+
+    /// Tests that PartialReduce mode:
+    /// 1. Accepts state as input (like Final)
+    /// 2. Produces state as output (like Partial)
+    /// 3. Can be followed by a Final stage to get the correct result
+    ///
+    /// This simulates a tree-reduce pattern:
+    ///   Partial -> PartialReduce -> Final
+    #[tokio::test]
+    async fn test_partial_reduce_mode() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::UInt32, false),
+            Field::new("b", DataType::Float64, false),
+        ]));
+
+        // Produce two partitions of input data
+        let batch1 = RecordBatch::try_new(
+            Arc::clone(&schema),
+            vec![
+                Arc::new(UInt32Array::from(vec![1, 2, 3])),
+                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
+            ],
+        )?;
+        let batch2 = RecordBatch::try_new(
+            Arc::clone(&schema),
+            vec![
+                Arc::new(UInt32Array::from(vec![1, 2, 3])),
+                Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])),
+            ],
+        )?;
+
+        let groups =
+            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, 
"a".to_string())]);
+        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
+            AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
+                .schema(Arc::clone(&schema))
+                .alias("SUM(b)")
+                .build()?,
+        )];
+
+        // Step 1: Partial aggregation on partition 1
+        let input1 =
+            TestMemoryExec::try_new_exec(&[vec![batch1]], Arc::clone(&schema), 
None)?;
+        let partial1 = Arc::new(AggregateExec::try_new(
+            AggregateMode::Partial,
+            groups.clone(),
+            aggregates.clone(),
+            vec![None],
+            input1,
+            Arc::clone(&schema),
+        )?);
+
+        // Step 2: Partial aggregation on partition 2
+        let input2 =
+            TestMemoryExec::try_new_exec(&[vec![batch2]], Arc::clone(&schema), 
None)?;
+        let partial2 = Arc::new(AggregateExec::try_new(
+            AggregateMode::Partial,
+            groups.clone(),
+            aggregates.clone(),
+            vec![None],
+            input2,
+            Arc::clone(&schema),
+        )?);
+
+        // Collect partial results
+        let task_ctx = Arc::new(TaskContext::default());
+        let partial_result1 =
+            crate::collect(Arc::clone(&partial1) as _, 
Arc::clone(&task_ctx)).await?;
+        let partial_result2 =
+            crate::collect(Arc::clone(&partial2) as _, 
Arc::clone(&task_ctx)).await?;
+
+        // The partial results have state schema (group cols + accumulator 
state)
+        let partial_schema = partial1.schema();
+
+        // Step 3: PartialReduce — combine partial results, still producing 
state
+        let combined_input = TestMemoryExec::try_new_exec(
+            &[partial_result1, partial_result2],
+            Arc::clone(&partial_schema),
+            None,
+        )?;
+        // Coalesce into a single partition for the PartialReduce
+        let coalesced = Arc::new(CoalescePartitionsExec::new(combined_input));
+
+        let partial_reduce = Arc::new(AggregateExec::try_new(
+            AggregateMode::PartialReduce,
+            groups.clone(),
+            aggregates.clone(),
+            vec![None],
+            coalesced,
+            Arc::clone(&partial_schema),
+        )?);
+
+        // Verify PartialReduce output schema matches Partial output schema
+        // (both produce state, not final values)
+        assert_eq!(partial_reduce.schema(), partial_schema);
+
+        // Collect PartialReduce results
+        let reduce_result =
+            crate::collect(Arc::clone(&partial_reduce) as _, 
Arc::clone(&task_ctx))
+                .await?;
+
+        // Step 4: Final aggregation on the PartialReduce output
+        let final_input = TestMemoryExec::try_new_exec(
+            &[reduce_result],
+            Arc::clone(&partial_schema),
+            None,
+        )?;
+        let final_agg = Arc::new(AggregateExec::try_new(
+            AggregateMode::Final,
+            groups.clone(),
+            aggregates.clone(),
+            vec![None],
+            final_input,
+            Arc::clone(&partial_schema),
+        )?);
+
+        let result = crate::collect(final_agg, Arc::clone(&task_ctx)).await?;
+
+        // Expected: group 1 -> 10+40=50, group 2 -> 20+50=70, group 3 -> 
30+60=90
+        assert_snapshot!(batches_to_sort_string(&result), @r"
+            +---+--------+
+            | a | SUM(b) |
+            +---+--------+
+            | 1 | 50.0   |
+            | 2 | 70.0   |
+            | 3 | 90.0   |
+            +---+--------+
+        ");
+
+        Ok(())
+    }
 }
diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs 
b/datafusion/physical-plan/src/aggregates/no_grouping.rs
index a55d70ca6f..eb9b6766ab 100644
--- a/datafusion/physical-plan/src/aggregates/no_grouping.rs
+++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs
@@ -18,8 +18,9 @@
 //! Aggregate without grouping columns
 
 use crate::aggregates::{
-    AccumulatorItem, AggrDynFilter, AggregateMode, DynamicFilterAggregateType,
-    aggregate_expressions, create_accumulators, finalize_aggregation,
+    AccumulatorItem, AggrDynFilter, AggregateInputMode, AggregateMode,
+    DynamicFilterAggregateType, aggregate_expressions, create_accumulators,
+    finalize_aggregation,
 };
 use crate::metrics::{BaselineMetrics, RecordOutput};
 use crate::{RecordBatchStream, SendableRecordBatchStream};
@@ -282,13 +283,9 @@ impl AggregateStream {
         let input = agg.input.execute(partition, Arc::clone(context))?;
 
         let aggregate_expressions = aggregate_expressions(&agg.aggr_expr, 
&agg.mode, 0)?;
-        let filter_expressions = match agg.mode {
-            AggregateMode::Partial
-            | AggregateMode::Single
-            | AggregateMode::SinglePartitioned => agg_filter_expr,
-            AggregateMode::Final | AggregateMode::FinalPartitioned => {
-                vec![None; agg.aggr_expr.len()]
-            }
+        let filter_expressions = match agg.mode.input_mode() {
+            AggregateInputMode::Raw => agg_filter_expr,
+            AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()],
         };
         let accumulators = create_accumulators(&agg.aggr_expr)?;
 
@@ -455,13 +452,9 @@ fn aggregate_batch(
 
             // 1.4
             let size_pre = accum.size();
-            let res = match mode {
-                AggregateMode::Partial
-                | AggregateMode::Single
-                | AggregateMode::SinglePartitioned => 
accum.update_batch(&values),
-                AggregateMode::Final | AggregateMode::FinalPartitioned => {
-                    accum.merge_batch(&values)
-                }
+            let res = match mode.input_mode() {
+                AggregateInputMode::Raw => accum.update_batch(&values),
+                AggregateInputMode::Partial => accum.merge_batch(&values),
             };
             let size_post = accum.size();
             allocated += size_post.saturating_sub(size_pre);
diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs 
b/datafusion/physical-plan/src/aggregates/row_hash.rs
index 49ce125e73..b2cf396b15 100644
--- a/datafusion/physical-plan/src/aggregates/row_hash.rs
+++ b/datafusion/physical-plan/src/aggregates/row_hash.rs
@@ -26,8 +26,8 @@ use super::order::GroupOrdering;
 use crate::aggregates::group_values::{GroupByMetrics, GroupValues, 
new_group_values};
 use crate::aggregates::order::GroupOrderingFull;
 use crate::aggregates::{
-    AggregateMode, PhysicalGroupBy, create_schema, evaluate_group_by, 
evaluate_many,
-    evaluate_optional,
+    AggregateInputMode, AggregateMode, AggregateOutputMode, PhysicalGroupBy,
+    create_schema, evaluate_group_by, evaluate_many, evaluate_optional,
 };
 use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
 use crate::sorts::sort::sort_batch;
@@ -491,13 +491,9 @@ impl GroupedHashAggregateStream {
             agg_group_by.num_group_exprs(),
         )?;
 
-        let filter_expressions = match agg.mode {
-            AggregateMode::Partial
-            | AggregateMode::Single
-            | AggregateMode::SinglePartitioned => agg_filter_expr,
-            AggregateMode::Final | AggregateMode::FinalPartitioned => {
-                vec![None; agg.aggr_expr.len()]
-            }
+        let filter_expressions = match agg.mode.input_mode() {
+            AggregateInputMode::Raw => agg_filter_expr,
+            AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()],
         };
 
         // Instantiate the accumulators
@@ -982,29 +978,24 @@ impl GroupedHashAggregateStream {
 
                 // Call the appropriate method on each aggregator with
                 // the entire input row and the relevant group indexes
-                match self.mode {
-                    AggregateMode::Partial
-                    | AggregateMode::Single
-                    | AggregateMode::SinglePartitioned
-                        if !self.spill_state.is_stream_merging =>
-                    {
-                        acc.update_batch(
-                            values,
-                            group_indices,
-                            opt_filter,
-                            total_num_groups,
-                        )?;
-                    }
-                    _ => {
-                        assert_or_internal_err!(
-                            opt_filter.is_none(),
-                            "aggregate filter should be applied in partial 
stage, there should be no filter in final stage"
-                        );
-
-                        // if aggregation is over intermediate states,
-                        // use merge
-                        acc.merge_batch(values, group_indices, None, 
total_num_groups)?;
-                    }
+                if self.mode.input_mode() == AggregateInputMode::Raw
+                    && !self.spill_state.is_stream_merging
+                {
+                    acc.update_batch(
+                        values,
+                        group_indices,
+                        opt_filter,
+                        total_num_groups,
+                    )?;
+                } else {
+                    assert_or_internal_err!(
+                        opt_filter.is_none(),
+                        "aggregate filter should be applied in partial stage, 
there should be no filter in final stage"
+                    );
+
+                    // if aggregation is over intermediate states,
+                    // use merge
+                    acc.merge_batch(values, group_indices, None, 
total_num_groups)?;
                 }
                 self.group_by_metrics
                     .aggregation_time
@@ -1092,17 +1083,12 @@ impl GroupedHashAggregateStream {
 
         // Next output each aggregate value
         for acc in self.accumulators.iter_mut() {
-            match self.mode {
-                AggregateMode::Partial => output.extend(acc.state(emit_to)?),
-                _ if spilling => {
-                    // If spilling, output partial state because the spilled 
data will be
-                    // merged and re-evaluated later.
-                    output.extend(acc.state(emit_to)?)
-                }
-                AggregateMode::Final
-                | AggregateMode::FinalPartitioned
-                | AggregateMode::Single
-                | AggregateMode::SinglePartitioned => 
output.push(acc.evaluate(emit_to)?),
+            if self.mode.output_mode() == AggregateOutputMode::Final && 
!spilling {
+                output.push(acc.evaluate(emit_to)?)
+            } else {
+                // Output partial state: either because we're in a non-final 
mode,
+                // or because we're spilling and will merge/re-evaluate later.
+                output.extend(acc.state(emit_to)?)
             }
         }
         drop(timer);
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 810ec6d1f1..59be5c5787 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1205,6 +1205,7 @@ enum AggregateMode {
   FINAL_PARTITIONED = 2;
   SINGLE = 3;
   SINGLE_PARTITIONED = 4;
+  PARTIAL_REDUCE = 5;
 }
 
 message PartiallySortedInputOrderMode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 7ed20785ab..3873afcdce 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -410,6 +410,7 @@ impl serde::Serialize for AggregateMode {
             Self::FinalPartitioned => "FINAL_PARTITIONED",
             Self::Single => "SINGLE",
             Self::SinglePartitioned => "SINGLE_PARTITIONED",
+            Self::PartialReduce => "PARTIAL_REDUCE",
         };
         serializer.serialize_str(variant)
     }
@@ -426,6 +427,7 @@ impl<'de> serde::Deserialize<'de> for AggregateMode {
             "FINAL_PARTITIONED",
             "SINGLE",
             "SINGLE_PARTITIONED",
+            "PARTIAL_REDUCE",
         ];
 
         struct GeneratedVisitor;
@@ -471,6 +473,7 @@ impl<'de> serde::Deserialize<'de> for AggregateMode {
                     "FINAL_PARTITIONED" => Ok(AggregateMode::FinalPartitioned),
                     "SINGLE" => Ok(AggregateMode::Single),
                     "SINGLE_PARTITIONED" => 
Ok(AggregateMode::SinglePartitioned),
+                    "PARTIAL_REDUCE" => Ok(AggregateMode::PartialReduce),
                     _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
                 }
             }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 0c9320c778..3806e31a46 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2382,6 +2382,7 @@ pub enum AggregateMode {
     FinalPartitioned = 2,
     Single = 3,
     SinglePartitioned = 4,
+    PartialReduce = 5,
 }
 impl AggregateMode {
     /// String value of the enum field names used in the ProtoBuf definition.
@@ -2395,6 +2396,7 @@ impl AggregateMode {
             Self::FinalPartitioned => "FINAL_PARTITIONED",
             Self::Single => "SINGLE",
             Self::SinglePartitioned => "SINGLE_PARTITIONED",
+            Self::PartialReduce => "PARTIAL_REDUCE",
         }
     }
     /// Creates an enum from field names used in the ProtoBuf definition.
@@ -2405,6 +2407,7 @@ impl AggregateMode {
             "FINAL_PARTITIONED" => Some(Self::FinalPartitioned),
             "SINGLE" => Some(Self::Single),
             "SINGLE_PARTITIONED" => Some(Self::SinglePartitioned),
+            "PARTIAL_REDUCE" => Some(Self::PartialReduce),
             _ => None,
         }
     }
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index ca213bc722..e1f6381d1f 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -1086,6 +1086,7 @@ impl protobuf::PhysicalPlanNode {
             protobuf::AggregateMode::SinglePartitioned => {
                 AggregateMode::SinglePartitioned
             }
+            protobuf::AggregateMode::PartialReduce => 
AggregateMode::PartialReduce,
         };
 
         let num_expr = hash_agg.group_expr.len();
@@ -2677,6 +2678,7 @@ impl protobuf::PhysicalPlanNode {
             AggregateMode::SinglePartitioned => {
                 protobuf::AggregateMode::SinglePartitioned
             }
+            AggregateMode::PartialReduce => 
protobuf::AggregateMode::PartialReduce,
         };
         let input_schema = exec.input_schema();
         let input = 
protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to