This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 7c3ea0540c feat: add AggregateMode::PartialReduce for tree-reduce
aggregation (#20019)
7c3ea0540c is described below
commit 7c3ea0540ca8793eba44b16d2a62e9cff02e3a8f
Author: Nathaniel J. Smith <[email protected]>
AuthorDate: Fri Jan 30 01:54:50 2026 -0800
feat: add AggregateMode::PartialReduce for tree-reduce aggregation (#20019)
DataFusion's current `AggregateMode` enum has four variants covering
three of the four cells in the input/output matrix:
| | Input: raw data | Input: partial state |
| - | - | - |
| Output: final values | `Single` / `SinglePartitioned` | `Final` /
`FinalPartitioned` |
| Output: partial state | `Partial` | ??? |
This PR adds `AggregateMode::PartialReduce` to fill in the missing cell:
it takes partially-reduced values as input, and reduces them further,
but without finalizing.
This is useful because it's the key component needed to implement
distributed tree-reduction (as seen in e.g. the Scuba or Honeycomb
papers): a set of worker nodes each perform multithreaded `Partial`
aggregations, feed those into a `PartialReduce` to reduce all of this
node's values into a single row, and then a head node collects the
outputs from all nodes' `PartialReduce` to feed into a `Final`
reduction.
PR can be reviewed commit by commit: first commit is pure
refactor/simplification; most places we were matching on `AggregateMode`
we were actually just trying to either check which row of the above
table we were in, or else which column. So now we have `is_first_stage`
(tells you which column) and `is_last_stage` (tells you which row) and
we use them everywhere.
Second commit adds `PartialReduce`, and is pretty small because
`is_first_stage`/`is_last_stage` do most of the heavy lifting. It also
adds a test demonstrating a minimal Partial -> PartialReduce -> Final
tree-reduction.
---
.../physical-optimizer/src/aggregate_statistics.rs | 6 +-
.../physical-optimizer/src/update_aggr_exprs.rs | 6 +-
datafusion/physical-plan/src/aggregates/mod.rs | 294 ++++++++++++++++++---
.../physical-plan/src/aggregates/no_grouping.rs | 25 +-
.../physical-plan/src/aggregates/row_hash.rs | 72 ++---
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/generated/pbjson.rs | 3 +
datafusion/proto/src/generated/prost.rs | 3 +
datafusion/proto/src/physical_plan/mod.rs | 2 +
9 files changed, 304 insertions(+), 108 deletions(-)
diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs
b/datafusion/physical-optimizer/src/aggregate_statistics.rs
index cf3c15509c..5caee8b047 100644
--- a/datafusion/physical-optimizer/src/aggregate_statistics.rs
+++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs
@@ -20,7 +20,7 @@ use datafusion_common::Result;
use datafusion_common::config::ConfigOptions;
use datafusion_common::scalar::ScalarValue;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
-use datafusion_physical_plan::aggregates::AggregateExec;
+use datafusion_physical_plan::aggregates::{AggregateExec, AggregateInputMode};
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr};
use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
@@ -116,13 +116,13 @@ impl PhysicalOptimizerRule for AggregateStatistics {
/// the `ExecutionPlan.children()` method that returns an owned reference.
fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn
ExecutionPlan>> {
if let Some(final_agg_exec) = node.as_any().downcast_ref::<AggregateExec>()
- && !final_agg_exec.mode().is_first_stage()
+ && final_agg_exec.mode().input_mode() == AggregateInputMode::Partial
&& final_agg_exec.group_expr().is_empty()
{
let mut child = Arc::clone(final_agg_exec.input());
loop {
if let Some(partial_agg_exec) =
child.as_any().downcast_ref::<AggregateExec>()
- && partial_agg_exec.mode().is_first_stage()
+ && partial_agg_exec.mode().input_mode() ==
AggregateInputMode::Raw
&& partial_agg_exec.group_expr().is_empty()
&& partial_agg_exec.filter_expr().iter().all(|e| e.is_none())
{
diff --git a/datafusion/physical-optimizer/src/update_aggr_exprs.rs
b/datafusion/physical-optimizer/src/update_aggr_exprs.rs
index c0aab4080d..67127c2a23 100644
--- a/datafusion/physical-optimizer/src/update_aggr_exprs.rs
+++ b/datafusion/physical-optimizer/src/update_aggr_exprs.rs
@@ -25,7 +25,9 @@ use datafusion_common::tree_node::{Transformed,
TransformedResult, TreeNode};
use datafusion_common::{Result, plan_datafusion_err};
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement};
-use datafusion_physical_plan::aggregates::{AggregateExec, concat_slices};
+use datafusion_physical_plan::aggregates::{
+ AggregateExec, AggregateInputMode, concat_slices,
+};
use datafusion_physical_plan::windows::get_ordered_partition_by_indices;
use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
@@ -81,7 +83,7 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder {
// ordering fields may be pruned out by first stage aggregates.
// Hence, necessary information for proper merge is added
during
// the first stage to the state field, which the final stage
uses.
- if !aggr_exec.mode().is_first_stage() {
+ if aggr_exec.mode().input_mode() ==
AggregateInputMode::Partial {
return Ok(Transformed::no(plan));
}
let input = aggr_exec.input();
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs
b/datafusion/physical-plan/src/aggregates/mod.rs
index d645f5c55d..aa0f5a236c 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -89,10 +89,54 @@ pub fn topk_types_supported(key_type: &DataType,
value_type: &DataType) -> bool
const AGGREGATION_HASH_SEED: ahash::RandomState =
ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as
u64);
+/// Whether an aggregate stage consumes raw input data or intermediate
+/// accumulator state from a previous aggregation stage.
+///
+/// See the [table on
`AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes)
+/// for how this relates to aggregate modes.
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum AggregateInputMode {
+ /// The stage consumes raw, unaggregated input data and calls
+ /// [`Accumulator::update_batch`].
+ Raw,
+ /// The stage consumes intermediate accumulator state from a previous
+ /// aggregation stage and calls [`Accumulator::merge_batch`].
+ Partial,
+}
+
+/// Whether an aggregate stage produces intermediate accumulator state
+/// or final output values.
+///
+/// See the [table on
`AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes)
+/// for how this relates to aggregate modes.
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum AggregateOutputMode {
+ /// The stage produces intermediate accumulator state, serialized via
+ /// [`Accumulator::state`].
+ Partial,
+ /// The stage produces final output values via
+ /// [`Accumulator::evaluate`].
+ Final,
+}
+
/// Aggregation modes
///
/// See [`Accumulator::state`] for background information on multi-phase
/// aggregation and how these modes are used.
+///
+/// # Variants and their input/output modes
+///
+/// Each variant can be characterized by its [`AggregateInputMode`] and
+/// [`AggregateOutputMode`]:
+///
+/// ```text
+/// | Input: Raw data | Input: Partial state
+/// Output: Final values | Single, SinglePartitioned | Final, FinalPartitioned
+/// Output: Partial state | Partial | PartialReduce
+/// ```
+///
+/// Use [`AggregateMode::input_mode`] and [`AggregateMode::output_mode`]
+/// to query these properties.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum AggregateMode {
/// One of multiple layers of aggregation, any input partitioning
@@ -144,18 +188,56 @@ pub enum AggregateMode {
/// This mode requires that the input has more than one partition, and is
/// partitioned by group key (like FinalPartitioned).
SinglePartitioned,
+ /// Combine multiple partial aggregations to produce a new partial
+ /// aggregation.
+ ///
+ /// Input is intermediate accumulator state (like Final), but output is
+ /// also intermediate accumulator state (like Partial). This enables
+ /// tree-reduce aggregation strategies where partial results from
+ /// multiple workers are combined in multiple stages before a final
+ /// evaluation.
+ ///
+ /// ```text
+ /// Final
+ /// / \
+ /// PartialReduce PartialReduce
+ /// / \ / \
+ /// Partial Partial Partial Partial
+ /// ```
+ PartialReduce,
}
impl AggregateMode {
- /// Checks whether this aggregation step describes a "first stage"
calculation.
- /// In other words, its input is not another aggregation result and the
- /// `merge_batch` method will not be called for these modes.
- pub fn is_first_stage(&self) -> bool {
+ /// Returns the [`AggregateInputMode`] for this mode: whether this
+ /// stage consumes raw input data or intermediate accumulator state.
+ ///
+ /// See the [table
above](AggregateMode#variants-and-their-inputoutput-modes)
+ /// for details.
+ pub fn input_mode(&self) -> AggregateInputMode {
match self {
AggregateMode::Partial
| AggregateMode::Single
- | AggregateMode::SinglePartitioned => true,
- AggregateMode::Final | AggregateMode::FinalPartitioned => false,
+ | AggregateMode::SinglePartitioned => AggregateInputMode::Raw,
+ AggregateMode::Final
+ | AggregateMode::FinalPartitioned
+ | AggregateMode::PartialReduce => AggregateInputMode::Partial,
+ }
+ }
+
+ /// Returns the [`AggregateOutputMode`] for this mode: whether this
+ /// stage produces intermediate accumulator state or final output values.
+ ///
+ /// See the [table
above](AggregateMode#variants-and-their-inputoutput-modes)
+ /// for details.
+ pub fn output_mode(&self) -> AggregateOutputMode {
+ match self {
+ AggregateMode::Final
+ | AggregateMode::FinalPartitioned
+ | AggregateMode::Single
+ | AggregateMode::SinglePartitioned => AggregateOutputMode::Final,
+ AggregateMode::Partial | AggregateMode::PartialReduce => {
+ AggregateOutputMode::Partial
+ }
}
}
}
@@ -917,14 +999,15 @@ impl AggregateExec {
// Get output partitioning:
let input_partitioning = input.output_partitioning().clone();
- let output_partitioning = if mode.is_first_stage() {
- // First stage aggregation will not change the output partitioning,
- // but needs to respect aliases (e.g. mapping in the GROUP BY
- // expression).
- let input_eq_properties = input.equivalence_properties();
- input_partitioning.project(group_expr_mapping, input_eq_properties)
- } else {
- input_partitioning.clone()
+ let output_partitioning = match mode.input_mode() {
+ AggregateInputMode::Raw => {
+ // First stage aggregation will not change the output
partitioning,
+ // but needs to respect aliases (e.g. mapping in the GROUP BY
+ // expression).
+ let input_eq_properties = input.equivalence_properties();
+ input_partitioning.project(group_expr_mapping,
input_eq_properties)
+ }
+ AggregateInputMode::Partial => input_partitioning.clone(),
};
// TODO: Emission type and boundedness information can be enhanced here
@@ -1248,7 +1331,7 @@ impl ExecutionPlan for AggregateExec {
fn required_input_distribution(&self) -> Vec<Distribution> {
match &self.mode {
- AggregateMode::Partial => {
+ AggregateMode::Partial | AggregateMode::PartialReduce => {
vec![Distribution::UnspecifiedDistribution]
}
AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned
=> {
@@ -1477,20 +1560,17 @@ fn create_schema(
let mut fields = Vec::with_capacity(group_by.num_output_exprs() +
aggr_expr.len());
fields.extend(group_by.output_fields(input_schema)?);
- match mode {
- AggregateMode::Partial => {
- // in partial mode, the fields of the accumulator's state
+ match mode.output_mode() {
+ AggregateOutputMode::Final => {
+ // in final mode, the field with the final result of the
accumulator
for expr in aggr_expr {
- fields.extend(expr.state_fields()?.iter().cloned());
+ fields.push(expr.field())
}
}
- AggregateMode::Final
- | AggregateMode::FinalPartitioned
- | AggregateMode::Single
- | AggregateMode::SinglePartitioned => {
- // in final mode, the field with the final result of the
accumulator
+ AggregateOutputMode::Partial => {
+ // in partial mode, the fields of the accumulator's state
for expr in aggr_expr {
- fields.push(expr.field())
+ fields.extend(expr.state_fields()?.iter().cloned());
}
}
}
@@ -1530,7 +1610,7 @@ fn get_aggregate_expr_req(
// If the aggregation is performing a "second stage" calculation,
// then ignore the ordering requirement. Ordering requirement applies
// only to the aggregation input data.
- if !agg_mode.is_first_stage() {
+ if agg_mode.input_mode() == AggregateInputMode::Partial {
return None;
}
@@ -1696,10 +1776,8 @@ pub fn aggregate_expressions(
mode: &AggregateMode,
col_idx_base: usize,
) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
- match mode {
- AggregateMode::Partial
- | AggregateMode::Single
- | AggregateMode::SinglePartitioned => Ok(aggr_expr
+ match mode.input_mode() {
+ AggregateInputMode::Raw => Ok(aggr_expr
.iter()
.map(|agg| {
let mut result = agg.expressions();
@@ -1710,8 +1788,8 @@ pub fn aggregate_expressions(
result
})
.collect()),
- // In this mode, we build the merge expressions of the aggregation.
- AggregateMode::Final | AggregateMode::FinalPartitioned => {
+ AggregateInputMode::Partial => {
+ // In merge mode, we build the merge expressions of the
aggregation.
let mut col_idx_base = col_idx_base;
aggr_expr
.iter()
@@ -1759,8 +1837,15 @@ pub fn finalize_aggregation(
accumulators: &mut [AccumulatorItem],
mode: &AggregateMode,
) -> Result<Vec<ArrayRef>> {
- match mode {
- AggregateMode::Partial => {
+ match mode.output_mode() {
+ AggregateOutputMode::Final => {
+ // Merge the state to the final value
+ accumulators
+ .iter_mut()
+ .map(|accumulator| accumulator.evaluate().and_then(|v|
v.to_array()))
+ .collect()
+ }
+ AggregateOutputMode::Partial => {
// Build the vector of states
accumulators
.iter_mut()
@@ -1774,16 +1859,6 @@ pub fn finalize_aggregation(
.flatten_ok()
.collect()
}
- AggregateMode::Final
- | AggregateMode::FinalPartitioned
- | AggregateMode::Single
- | AggregateMode::SinglePartitioned => {
- // Merge the state to the final value
- accumulators
- .iter_mut()
- .map(|accumulator| accumulator.evaluate().and_then(|v|
v.to_array()))
- .collect()
- }
}
}
@@ -3745,4 +3820,135 @@ mod tests {
}
Ok(())
}
+
+ /// Tests that PartialReduce mode:
+ /// 1. Accepts state as input (like Final)
+ /// 2. Produces state as output (like Partial)
+ /// 3. Can be followed by a Final stage to get the correct result
+ ///
+ /// This simulates a tree-reduce pattern:
+ /// Partial -> PartialReduce -> Final
+ #[tokio::test]
+ async fn test_partial_reduce_mode() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::UInt32, false),
+ Field::new("b", DataType::Float64, false),
+ ]));
+
+ // Produce two partitions of input data
+ let batch1 = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![
+ Arc::new(UInt32Array::from(vec![1, 2, 3])),
+ Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
+ ],
+ )?;
+ let batch2 = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![
+ Arc::new(UInt32Array::from(vec![1, 2, 3])),
+ Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])),
+ ],
+ )?;
+
+ let groups =
+ PhysicalGroupBy::new_single(vec![(col("a", &schema)?,
"a".to_string())]);
+ let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
+ AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
+ .schema(Arc::clone(&schema))
+ .alias("SUM(b)")
+ .build()?,
+ )];
+
+ // Step 1: Partial aggregation on partition 1
+ let input1 =
+ TestMemoryExec::try_new_exec(&[vec![batch1]], Arc::clone(&schema),
None)?;
+ let partial1 = Arc::new(AggregateExec::try_new(
+ AggregateMode::Partial,
+ groups.clone(),
+ aggregates.clone(),
+ vec![None],
+ input1,
+ Arc::clone(&schema),
+ )?);
+
+ // Step 2: Partial aggregation on partition 2
+ let input2 =
+ TestMemoryExec::try_new_exec(&[vec![batch2]], Arc::clone(&schema),
None)?;
+ let partial2 = Arc::new(AggregateExec::try_new(
+ AggregateMode::Partial,
+ groups.clone(),
+ aggregates.clone(),
+ vec![None],
+ input2,
+ Arc::clone(&schema),
+ )?);
+
+ // Collect partial results
+ let task_ctx = Arc::new(TaskContext::default());
+ let partial_result1 =
+ crate::collect(Arc::clone(&partial1) as _,
Arc::clone(&task_ctx)).await?;
+ let partial_result2 =
+ crate::collect(Arc::clone(&partial2) as _,
Arc::clone(&task_ctx)).await?;
+
+ // The partial results have state schema (group cols + accumulator
state)
+ let partial_schema = partial1.schema();
+
+ // Step 3: PartialReduce — combine partial results, still producing
state
+ let combined_input = TestMemoryExec::try_new_exec(
+ &[partial_result1, partial_result2],
+ Arc::clone(&partial_schema),
+ None,
+ )?;
+ // Coalesce into a single partition for the PartialReduce
+ let coalesced = Arc::new(CoalescePartitionsExec::new(combined_input));
+
+ let partial_reduce = Arc::new(AggregateExec::try_new(
+ AggregateMode::PartialReduce,
+ groups.clone(),
+ aggregates.clone(),
+ vec![None],
+ coalesced,
+ Arc::clone(&partial_schema),
+ )?);
+
+ // Verify PartialReduce output schema matches Partial output schema
+ // (both produce state, not final values)
+ assert_eq!(partial_reduce.schema(), partial_schema);
+
+ // Collect PartialReduce results
+ let reduce_result =
+ crate::collect(Arc::clone(&partial_reduce) as _,
Arc::clone(&task_ctx))
+ .await?;
+
+ // Step 4: Final aggregation on the PartialReduce output
+ let final_input = TestMemoryExec::try_new_exec(
+ &[reduce_result],
+ Arc::clone(&partial_schema),
+ None,
+ )?;
+ let final_agg = Arc::new(AggregateExec::try_new(
+ AggregateMode::Final,
+ groups.clone(),
+ aggregates.clone(),
+ vec![None],
+ final_input,
+ Arc::clone(&partial_schema),
+ )?);
+
+ let result = crate::collect(final_agg, Arc::clone(&task_ctx)).await?;
+
+ // Expected: group 1 -> 10+40=50, group 2 -> 20+50=70, group 3 ->
30+60=90
+ assert_snapshot!(batches_to_sort_string(&result), @r"
+ +---+--------+
+ | a | SUM(b) |
+ +---+--------+
+ | 1 | 50.0 |
+ | 2 | 70.0 |
+ | 3 | 90.0 |
+ +---+--------+
+ ");
+
+ Ok(())
+ }
}
diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs
b/datafusion/physical-plan/src/aggregates/no_grouping.rs
index a55d70ca6f..eb9b6766ab 100644
--- a/datafusion/physical-plan/src/aggregates/no_grouping.rs
+++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs
@@ -18,8 +18,9 @@
//! Aggregate without grouping columns
use crate::aggregates::{
- AccumulatorItem, AggrDynFilter, AggregateMode, DynamicFilterAggregateType,
- aggregate_expressions, create_accumulators, finalize_aggregation,
+ AccumulatorItem, AggrDynFilter, AggregateInputMode, AggregateMode,
+ DynamicFilterAggregateType, aggregate_expressions, create_accumulators,
+ finalize_aggregation,
};
use crate::metrics::{BaselineMetrics, RecordOutput};
use crate::{RecordBatchStream, SendableRecordBatchStream};
@@ -282,13 +283,9 @@ impl AggregateStream {
let input = agg.input.execute(partition, Arc::clone(context))?;
let aggregate_expressions = aggregate_expressions(&agg.aggr_expr,
&agg.mode, 0)?;
- let filter_expressions = match agg.mode {
- AggregateMode::Partial
- | AggregateMode::Single
- | AggregateMode::SinglePartitioned => agg_filter_expr,
- AggregateMode::Final | AggregateMode::FinalPartitioned => {
- vec![None; agg.aggr_expr.len()]
- }
+ let filter_expressions = match agg.mode.input_mode() {
+ AggregateInputMode::Raw => agg_filter_expr,
+ AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()],
};
let accumulators = create_accumulators(&agg.aggr_expr)?;
@@ -455,13 +452,9 @@ fn aggregate_batch(
// 1.4
let size_pre = accum.size();
- let res = match mode {
- AggregateMode::Partial
- | AggregateMode::Single
- | AggregateMode::SinglePartitioned =>
accum.update_batch(&values),
- AggregateMode::Final | AggregateMode::FinalPartitioned => {
- accum.merge_batch(&values)
- }
+ let res = match mode.input_mode() {
+ AggregateInputMode::Raw => accum.update_batch(&values),
+ AggregateInputMode::Partial => accum.merge_batch(&values),
};
let size_post = accum.size();
allocated += size_post.saturating_sub(size_pre);
diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs
b/datafusion/physical-plan/src/aggregates/row_hash.rs
index 49ce125e73..b2cf396b15 100644
--- a/datafusion/physical-plan/src/aggregates/row_hash.rs
+++ b/datafusion/physical-plan/src/aggregates/row_hash.rs
@@ -26,8 +26,8 @@ use super::order::GroupOrdering;
use crate::aggregates::group_values::{GroupByMetrics, GroupValues,
new_group_values};
use crate::aggregates::order::GroupOrderingFull;
use crate::aggregates::{
- AggregateMode, PhysicalGroupBy, create_schema, evaluate_group_by,
evaluate_many,
- evaluate_optional,
+ AggregateInputMode, AggregateMode, AggregateOutputMode, PhysicalGroupBy,
+ create_schema, evaluate_group_by, evaluate_many, evaluate_optional,
};
use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
use crate::sorts::sort::sort_batch;
@@ -491,13 +491,9 @@ impl GroupedHashAggregateStream {
agg_group_by.num_group_exprs(),
)?;
- let filter_expressions = match agg.mode {
- AggregateMode::Partial
- | AggregateMode::Single
- | AggregateMode::SinglePartitioned => agg_filter_expr,
- AggregateMode::Final | AggregateMode::FinalPartitioned => {
- vec![None; agg.aggr_expr.len()]
- }
+ let filter_expressions = match agg.mode.input_mode() {
+ AggregateInputMode::Raw => agg_filter_expr,
+ AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()],
};
// Instantiate the accumulators
@@ -982,29 +978,24 @@ impl GroupedHashAggregateStream {
// Call the appropriate method on each aggregator with
// the entire input row and the relevant group indexes
- match self.mode {
- AggregateMode::Partial
- | AggregateMode::Single
- | AggregateMode::SinglePartitioned
- if !self.spill_state.is_stream_merging =>
- {
- acc.update_batch(
- values,
- group_indices,
- opt_filter,
- total_num_groups,
- )?;
- }
- _ => {
- assert_or_internal_err!(
- opt_filter.is_none(),
- "aggregate filter should be applied in partial
stage, there should be no filter in final stage"
- );
-
- // if aggregation is over intermediate states,
- // use merge
- acc.merge_batch(values, group_indices, None,
total_num_groups)?;
- }
+ if self.mode.input_mode() == AggregateInputMode::Raw
+ && !self.spill_state.is_stream_merging
+ {
+ acc.update_batch(
+ values,
+ group_indices,
+ opt_filter,
+ total_num_groups,
+ )?;
+ } else {
+ assert_or_internal_err!(
+ opt_filter.is_none(),
+ "aggregate filter should be applied in partial stage,
there should be no filter in final stage"
+ );
+
+ // if aggregation is over intermediate states,
+ // use merge
+ acc.merge_batch(values, group_indices, None,
total_num_groups)?;
}
self.group_by_metrics
.aggregation_time
@@ -1092,17 +1083,12 @@ impl GroupedHashAggregateStream {
// Next output each aggregate value
for acc in self.accumulators.iter_mut() {
- match self.mode {
- AggregateMode::Partial => output.extend(acc.state(emit_to)?),
- _ if spilling => {
- // If spilling, output partial state because the spilled
data will be
- // merged and re-evaluated later.
- output.extend(acc.state(emit_to)?)
- }
- AggregateMode::Final
- | AggregateMode::FinalPartitioned
- | AggregateMode::Single
- | AggregateMode::SinglePartitioned =>
output.push(acc.evaluate(emit_to)?),
+ if self.mode.output_mode() == AggregateOutputMode::Final &&
!spilling {
+ output.push(acc.evaluate(emit_to)?)
+ } else {
+ // Output partial state: either because we're in a non-final
mode,
+ // or because we're spilling and will merge/re-evaluate later.
+ output.extend(acc.state(emit_to)?)
}
}
drop(timer);
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 810ec6d1f1..59be5c5787 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1205,6 +1205,7 @@ enum AggregateMode {
FINAL_PARTITIONED = 2;
SINGLE = 3;
SINGLE_PARTITIONED = 4;
+ PARTIAL_REDUCE = 5;
}
message PartiallySortedInputOrderMode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 7ed20785ab..3873afcdce 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -410,6 +410,7 @@ impl serde::Serialize for AggregateMode {
Self::FinalPartitioned => "FINAL_PARTITIONED",
Self::Single => "SINGLE",
Self::SinglePartitioned => "SINGLE_PARTITIONED",
+ Self::PartialReduce => "PARTIAL_REDUCE",
};
serializer.serialize_str(variant)
}
@@ -426,6 +427,7 @@ impl<'de> serde::Deserialize<'de> for AggregateMode {
"FINAL_PARTITIONED",
"SINGLE",
"SINGLE_PARTITIONED",
+ "PARTIAL_REDUCE",
];
struct GeneratedVisitor;
@@ -471,6 +473,7 @@ impl<'de> serde::Deserialize<'de> for AggregateMode {
"FINAL_PARTITIONED" => Ok(AggregateMode::FinalPartitioned),
"SINGLE" => Ok(AggregateMode::Single),
"SINGLE_PARTITIONED" =>
Ok(AggregateMode::SinglePartitioned),
+ "PARTIAL_REDUCE" => Ok(AggregateMode::PartialReduce),
_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 0c9320c778..3806e31a46 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2382,6 +2382,7 @@ pub enum AggregateMode {
FinalPartitioned = 2,
Single = 3,
SinglePartitioned = 4,
+ PartialReduce = 5,
}
impl AggregateMode {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -2395,6 +2396,7 @@ impl AggregateMode {
Self::FinalPartitioned => "FINAL_PARTITIONED",
Self::Single => "SINGLE",
Self::SinglePartitioned => "SINGLE_PARTITIONED",
+ Self::PartialReduce => "PARTIAL_REDUCE",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -2405,6 +2407,7 @@ impl AggregateMode {
"FINAL_PARTITIONED" => Some(Self::FinalPartitioned),
"SINGLE" => Some(Self::Single),
"SINGLE_PARTITIONED" => Some(Self::SinglePartitioned),
+ "PARTIAL_REDUCE" => Some(Self::PartialReduce),
_ => None,
}
}
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index ca213bc722..e1f6381d1f 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -1086,6 +1086,7 @@ impl protobuf::PhysicalPlanNode {
protobuf::AggregateMode::SinglePartitioned => {
AggregateMode::SinglePartitioned
}
+ protobuf::AggregateMode::PartialReduce =>
AggregateMode::PartialReduce,
};
let num_expr = hash_agg.group_expr.len();
@@ -2677,6 +2678,7 @@ impl protobuf::PhysicalPlanNode {
AggregateMode::SinglePartitioned => {
protobuf::AggregateMode::SinglePartitioned
}
+ AggregateMode::PartialReduce =>
protobuf::AggregateMode::PartialReduce,
};
let input_schema = exec.input_schema();
let input =
protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]