This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new eb30c19b30 Implement disk spilling for all grouping ordering modes in 
GroupedHashAggregateStream (#19287)
eb30c19b30 is described below

commit eb30c19b303e22ab2ab104e0fa67ff8f511261f2
Author: Pepijn Van Eeckhoudt <[email protected]>
AuthorDate: Sat Dec 20 12:22:25 2025 +0100

    Implement disk spilling for all grouping ordering modes in 
GroupedHashAggregateStream (#19287)
    
    ## Which issue does this PR close?
    
    - Closes #19286.
    - Related to #13123
    
    ## Rationale for this change
    
    GroupedHashAggregateStream currently always reports that it can spill to
    the memory tracking subsystem even though this is dependent on the
    aggregation mode and the grouping order.
    The optimistic logic in `group_aggregate_batch` does not correctly take
    the spilling preconditions into account which can lead to excessive
    memory use.
    
    In order to to resolve this, this PR implements disk spilling for all
    grouping modes.
    
    ## What changes are included in this PR?
    
    - Correctly set `MemoryConsumer::can_spill` to reflect actual spilling
    behaviour
    - Ensure optimistic out-of-memory tolerance in `group_aggregate_batch`
    is aligned with disk spilling or early emission logic
    - Implement output order respecting disk spilling for partially and
    fully sorted inputs.
    
    ## Are these changes tested?
    
    Added additional test case to demonstrate problem.
    Added test case to check that output order is respected after spilling.
    
    ## Are there any user-facing changes?
    
    Yes, memory exhaustion may be reported much earlier in the query
    pipeline than is currently the case. In my local tests with a per
    consumer memory limit of 32MiB, grouped aggregation would consume 480MiB
    in practice. This was then reported by ExternalSortExec which choked on
    trying to reserve that much memory.
---
 datafusion/physical-expr-common/src/sort_expr.rs   |  11 +
 .../src/aggregates/group_values/mod.rs             |   4 +-
 .../aggregates/group_values/multi_group_by/mod.rs  |   9 +-
 .../src/aggregates/group_values/row.rs             |   9 +-
 .../group_values/single_group_by/boolean.rs        |   3 +-
 .../group_values/single_group_by/bytes.rs          |   4 +-
 .../group_values/single_group_by/bytes_view.rs     |   4 +-
 .../group_values/single_group_by/primitive.rs      |   8 +-
 datafusion/physical-plan/src/aggregates/mod.rs     | 214 ++++++++++++-
 .../physical-plan/src/aggregates/row_hash.rs       | 353 +++++++++++----------
 datafusion/physical-plan/src/test.rs               |   1 +
 11 files changed, 423 insertions(+), 197 deletions(-)

diff --git a/datafusion/physical-expr-common/src/sort_expr.rs 
b/datafusion/physical-expr-common/src/sort_expr.rs
index e8558c7643..db30dd6ed2 100644
--- a/datafusion/physical-expr-common/src/sort_expr.rs
+++ b/datafusion/physical-expr-common/src/sort_expr.rs
@@ -457,6 +457,17 @@ impl LexOrdering {
             req.expr.eq(&cur.expr) && is_reversed_sort_options(&req.options, 
&cur.options)
         })
     }
+
+    /// Returns the sort options for the given expression if one is defined in 
this `LexOrdering`.
+    pub fn get_sort_options(&self, expr: &dyn PhysicalExpr) -> 
Option<SortOptions> {
+        for e in self {
+            if e.expr.as_ref().dyn_eq(expr) {
+                return Some(e.options);
+            }
+        }
+
+        None
+    }
 }
 
 /// Check if two SortOptions represent reversed orderings.
diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs 
b/datafusion/physical-plan/src/aggregates/group_values/mod.rs
index f419328d11..2f3b1a19e7 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs
@@ -22,7 +22,7 @@ use arrow::array::types::{
     Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType,
     TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
 };
-use arrow::array::{ArrayRef, RecordBatch, downcast_primitive};
+use arrow::array::{ArrayRef, downcast_primitive};
 use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
 use datafusion_common::Result;
 
@@ -112,7 +112,7 @@ pub trait GroupValues: Send {
     fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
 
     /// Clear the contents and shrink the capacity to the size of the batch 
(free up memory usage)
-    fn clear_shrink(&mut self, batch: &RecordBatch);
+    fn clear_shrink(&mut self, num_rows: usize);
 }
 
 /// Return a specialized implementation of [`GroupValues`] for the given 
schema.
diff --git 
a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs 
b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs
index b62bc11aff..4c9e376fc4 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs
@@ -30,7 +30,7 @@ use crate::aggregates::group_values::multi_group_by::{
     bytes_view::ByteViewGroupValueBuilder, 
primitive::PrimitiveGroupValueBuilder,
 };
 use ahash::RandomState;
-use arrow::array::{Array, ArrayRef, RecordBatch};
+use arrow::array::{Array, ArrayRef};
 use arrow::compute::cast;
 use arrow::datatypes::{
     BinaryViewType, DataType, Date32Type, Date64Type, Decimal128Type, 
Float32Type,
@@ -1181,14 +1181,13 @@ impl<const STREAMING: bool> GroupValues for 
GroupValuesColumn<STREAMING> {
         Ok(output)
     }
 
-    fn clear_shrink(&mut self, batch: &RecordBatch) {
-        let count = batch.num_rows();
+    fn clear_shrink(&mut self, num_rows: usize) {
         self.group_values.clear();
         self.map.clear();
-        self.map.shrink_to(count, |_| 0); // hasher does not matter since the 
map is cleared
+        self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since 
the map is cleared
         self.map_size = self.map.capacity() * size_of::<(u64, usize)>();
         self.hashes_buffer.clear();
-        self.hashes_buffer.shrink_to(count);
+        self.hashes_buffer.shrink_to(num_rows);
 
         // Such structures are only used in `non-streaming` case
         if !STREAMING {
diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs 
b/datafusion/physical-plan/src/aggregates/group_values/row.rs
index a5e5c16006..dd794c9573 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/row.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs
@@ -17,7 +17,7 @@
 
 use crate::aggregates::group_values::GroupValues;
 use ahash::RandomState;
-use arrow::array::{Array, ArrayRef, ListArray, RecordBatch, StructArray};
+use arrow::array::{Array, ArrayRef, ListArray, StructArray};
 use arrow::compute::cast;
 use arrow::datatypes::{DataType, SchemaRef};
 use arrow::row::{RowConverter, Rows, SortField};
@@ -243,17 +243,16 @@ impl GroupValues for GroupValuesRows {
         Ok(output)
     }
 
-    fn clear_shrink(&mut self, batch: &RecordBatch) {
-        let count = batch.num_rows();
+    fn clear_shrink(&mut self, num_rows: usize) {
         self.group_values = self.group_values.take().map(|mut rows| {
             rows.clear();
             rows
         });
         self.map.clear();
-        self.map.shrink_to(count, |_| 0); // hasher does not matter since the 
map is cleared
+        self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since 
the map is cleared
         self.map_size = self.map.capacity() * size_of::<(u64, usize)>();
         self.hashes_buffer.clear();
-        self.hashes_buffer.shrink_to(count);
+        self.hashes_buffer.shrink_to(num_rows);
     }
 }
 
diff --git 
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs
 
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs
index 44b763a91f..e993c0c53d 100644
--- 
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs
+++ 
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs
@@ -19,7 +19,6 @@ use crate::aggregates::group_values::GroupValues;
 
 use arrow::array::{
     ArrayRef, AsArray as _, BooleanArray, BooleanBufferBuilder, 
NullBufferBuilder,
-    RecordBatch,
 };
 use datafusion_common::Result;
 use datafusion_expr::EmitTo;
@@ -146,7 +145,7 @@ impl GroupValues for GroupValuesBoolean {
         Ok(vec![Arc::new(BooleanArray::new(values, nulls)) as _])
     }
 
-    fn clear_shrink(&mut self, _batch: &RecordBatch) {
+    fn clear_shrink(&mut self, _num_rows: usize) {
         self.false_group = None;
         self.true_group = None;
         self.null_group = None;
diff --git 
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs 
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs
index b901aee313..b881a51b25 100644
--- 
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs
+++ 
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs
@@ -19,7 +19,7 @@ use std::mem::size_of;
 
 use crate::aggregates::group_values::GroupValues;
 
-use arrow::array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch};
+use arrow::array::{Array, ArrayRef, OffsetSizeTrait};
 use datafusion_common::Result;
 use datafusion_expr::EmitTo;
 use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType};
@@ -120,7 +120,7 @@ impl<O: OffsetSizeTrait> GroupValues for 
GroupValuesBytes<O> {
         Ok(vec![group_values])
     }
 
-    fn clear_shrink(&mut self, _batch: &RecordBatch) {
+    fn clear_shrink(&mut self, _num_rows: usize) {
         // in theory we could potentially avoid this reallocation and clear the
         // contents of the maps, but for now we just reset the map from the 
beginning
         self.map.take();
diff --git 
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs
 
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs
index be9a0334e3..7a56f7c52c 100644
--- 
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs
+++ 
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs
@@ -16,7 +16,7 @@
 // under the License.
 
 use crate::aggregates::group_values::GroupValues;
-use arrow::array::{Array, ArrayRef, RecordBatch};
+use arrow::array::{Array, ArrayRef};
 use datafusion_expr::EmitTo;
 use datafusion_physical_expr::binary_map::OutputType;
 use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap;
@@ -122,7 +122,7 @@ impl GroupValues for GroupValuesBytesView {
         Ok(vec![group_values])
     }
 
-    fn clear_shrink(&mut self, _batch: &RecordBatch) {
+    fn clear_shrink(&mut self, _num_rows: usize) {
         // in theory we could potentially avoid this reallocation and clear the
         // contents of the maps, but for now we just reset the map from the 
beginning
         self.map.take();
diff --git 
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs
 
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs
index 41d34218f6..c46cde8786 100644
--- 
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs
+++ 
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs
@@ -23,7 +23,6 @@ use arrow::array::{
     cast::AsArray,
 };
 use arrow::datatypes::{DataType, i256};
-use arrow::record_batch::RecordBatch;
 use datafusion_common::Result;
 use datafusion_execution::memory_pool::proxy::VecAllocExt;
 use datafusion_expr::EmitTo;
@@ -213,11 +212,10 @@ where
         Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))])
     }
 
-    fn clear_shrink(&mut self, batch: &RecordBatch) {
-        let count = batch.num_rows();
+    fn clear_shrink(&mut self, num_rows: usize) {
         self.values.clear();
-        self.values.shrink_to(count);
+        self.values.shrink_to(num_rows);
         self.map.clear();
-        self.map.shrink_to(count, |_| 0); // hasher does not matter since the 
map is cleared
+        self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since 
the map is cleared
     }
 }
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs 
b/datafusion/physical-plan/src/aggregates/mod.rs
index b0d432a9de..06f12a9019 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -31,7 +31,6 @@ use crate::filter_pushdown::{
     FilterPushdownPropagation, PushedDownPredicate,
 };
 use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
-use crate::windows::get_ordered_partition_by_indices;
 use crate::{
     DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
     SendableRecordBatchStream, Statistics,
@@ -613,12 +612,13 @@ impl AggregateExec {
         // If existing ordering satisfies a prefix of the GROUP BY expressions,
         // prefix requirements with this section. In this case, aggregation 
will
         // work more efficiently.
-        let indices = get_ordered_partition_by_indices(&groupby_exprs, 
&input)?;
-        let mut new_requirements = indices
-            .iter()
-            .map(|&idx| {
-                PhysicalSortRequirement::new(Arc::clone(&groupby_exprs[idx]), 
None)
-            })
+        // Copy the `PhysicalSortExpr`s to retain the sort options.
+        let (new_sort_exprs, indices) =
+            input_eq_properties.find_longest_permutation(&groupby_exprs)?;
+
+        let mut new_requirements = new_sort_exprs
+            .into_iter()
+            .map(PhysicalSortRequirement::from)
             .collect::<Vec<_>>();
 
         let req = get_finer_aggregate_exprs_requirement(
@@ -1815,7 +1815,7 @@ mod tests {
     use crate::test::exec::{BlockingExec, 
assert_strong_count_converges_to_zero};
 
     use arrow::array::{
-        DictionaryArray, Float32Array, Float64Array, Int32Array, StructArray,
+        DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, 
StructArray,
         UInt32Array, UInt64Array,
     };
     use arrow::compute::{SortOptions, concat_batches};
@@ -1837,6 +1837,8 @@ mod tests {
     use datafusion_physical_expr::expressions::Literal;
     use datafusion_physical_expr::expressions::lit;
 
+    use crate::projection::ProjectionExec;
+    use datafusion_physical_expr::projection::ProjectionExpr;
     use futures::{FutureExt, Stream};
     use insta::{allow_duplicates, assert_snapshot};
 
@@ -2484,7 +2486,7 @@ mod tests {
         ] {
             let n_aggr = aggregates.len();
             let partial_aggregate = Arc::new(AggregateExec::try_new(
-                AggregateMode::Partial,
+                AggregateMode::Single,
                 groups,
                 aggregates,
                 vec![None; n_aggr],
@@ -3420,6 +3422,117 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn test_grouped_aggregation_respects_memory_limit() -> Result<()> {
+        // test with spill
+        fn create_record_batch(
+            schema: &Arc<Schema>,
+            data: (Vec<u32>, Vec<f64>),
+        ) -> Result<RecordBatch> {
+            Ok(RecordBatch::try_new(
+                Arc::clone(schema),
+                vec![
+                    Arc::new(UInt32Array::from(data.0)),
+                    Arc::new(Float64Array::from(data.1)),
+                ],
+            )?)
+        }
+
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::UInt32, false),
+            Field::new("b", DataType::Float64, false),
+        ]));
+
+        let batches = vec![
+            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 
3.0, 4.0]))?,
+            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 
3.0, 4.0]))?,
+        ];
+        let plan: Arc<dyn ExecutionPlan> =
+            TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), 
None)?;
+        let proj = ProjectionExec::try_new(
+            vec![
+                ProjectionExpr::new(lit("0"), "l".to_string()),
+                ProjectionExpr::new_from_expression(col("a", &schema)?, 
&schema)?,
+                ProjectionExpr::new_from_expression(col("b", &schema)?, 
&schema)?,
+            ],
+            plan,
+        )?;
+        let plan: Arc<dyn ExecutionPlan> = Arc::new(proj);
+        let schema = plan.schema();
+
+        let grouping_set = PhysicalGroupBy::new(
+            vec![
+                (col("l", &schema)?, "l".to_string()),
+                (col("a", &schema)?, "a".to_string()),
+            ],
+            vec![],
+            vec![vec![false, false]],
+            false,
+        );
+
+        // Test with MIN for simple intermediate state (min) and AVG for 
multiple intermediate states (partial sum, partial count).
+        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
+            Arc::new(
+                AggregateExprBuilder::new(
+                    datafusion_functions_aggregate::min_max::min_udaf(),
+                    vec![col("b", &schema)?],
+                )
+                .schema(Arc::clone(&schema))
+                .alias("MIN(b)")
+                .build()?,
+            ),
+            Arc::new(
+                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
+                    .schema(Arc::clone(&schema))
+                    .alias("AVG(b)")
+                    .build()?,
+            ),
+        ];
+
+        let single_aggregate = Arc::new(AggregateExec::try_new(
+            AggregateMode::Single,
+            grouping_set,
+            aggregates,
+            vec![None, None],
+            plan,
+            Arc::clone(&schema),
+        )?);
+
+        let batch_size = 2;
+        let memory_pool = Arc::new(FairSpillPool::new(2000));
+        let task_ctx = Arc::new(
+            TaskContext::default()
+                
.with_session_config(SessionConfig::new().with_batch_size(batch_size))
+                .with_runtime(Arc::new(
+                    RuntimeEnvBuilder::new()
+                        .with_memory_pool(memory_pool)
+                        .build()?,
+                )),
+        );
+
+        let result = collect(single_aggregate.execute(0, 
Arc::clone(&task_ctx))?).await;
+        match result {
+            Ok(result) => {
+                assert_spill_count_metric(true, single_aggregate);
+
+                allow_duplicates! {
+                    assert_snapshot!(batches_to_string(&result), @r"
+                +---+---+--------+--------+
+                | l | a | MIN(b) | AVG(b) |
+                +---+---+--------+--------+
+                | 0 | 2 | 1.0    | 1.0    |
+                | 0 | 3 | 2.0    | 2.0    |
+                | 0 | 4 | 3.0    | 3.5    |
+                +---+---+--------+--------+
+            ");
+                }
+            }
+            Err(e) => assert!(matches!(e, 
DataFusionError::ResourcesExhausted(_))),
+        }
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn test_aggregate_statistics_edge_cases() -> Result<()> {
         use crate::test::exec::StatisticsExec;
@@ -3492,4 +3605,87 @@ mod tests {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn test_order_is_retained_when_spilling() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int64, false),
+            Field::new("b", DataType::Int64, false),
+            Field::new("c", DataType::Int64, false),
+        ]));
+
+        let batches = vec![vec![
+            RecordBatch::try_new(
+                Arc::clone(&schema),
+                vec![
+                    Arc::new(Int64Array::from(vec![2])),
+                    Arc::new(Int64Array::from(vec![2])),
+                    Arc::new(Int64Array::from(vec![1])),
+                ],
+            )?,
+            RecordBatch::try_new(
+                Arc::clone(&schema),
+                vec![
+                    Arc::new(Int64Array::from(vec![1])),
+                    Arc::new(Int64Array::from(vec![1])),
+                    Arc::new(Int64Array::from(vec![1])),
+                ],
+            )?,
+            RecordBatch::try_new(
+                Arc::clone(&schema),
+                vec![
+                    Arc::new(Int64Array::from(vec![0])),
+                    Arc::new(Int64Array::from(vec![0])),
+                    Arc::new(Int64Array::from(vec![1])),
+                ],
+            )?,
+        ]];
+        let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), 
None)?;
+        let scan = scan.try_with_sort_information(vec![
+            LexOrdering::new([PhysicalSortExpr::new(
+                col("b", schema.as_ref())?,
+                SortOptions::default().desc(),
+            )])
+            .unwrap(),
+        ])?;
+
+        let aggr = Arc::new(AggregateExec::try_new(
+            AggregateMode::Single,
+            PhysicalGroupBy::new(
+                vec![
+                    (col("b", schema.as_ref())?, "b".to_string()),
+                    (col("c", schema.as_ref())?, "c".to_string()),
+                ],
+                vec![],
+                vec![vec![false, false]],
+                false,
+            ),
+            vec![Arc::new(
+                AggregateExprBuilder::new(sum_udaf(), vec![col("c", 
schema.as_ref())?])
+                    .schema(Arc::clone(&schema))
+                    .alias("SUM(c)")
+                    .build()?,
+            )],
+            vec![None],
+            Arc::new(scan) as Arc<dyn ExecutionPlan>,
+            Arc::clone(&schema),
+        )?);
+
+        let task_ctx = new_spill_ctx(1, 600);
+        let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await?;
+        assert_spill_count_metric(true, aggr);
+
+        allow_duplicates! {
+            assert_snapshot!(batches_to_string(&result), @r"
+            +---+---+--------+
+            | b | c | SUM(c) |
+            +---+---+--------+
+            | 2 | 1 | 1      |
+            | 1 | 1 | 1      |
+            | 0 | 1 | 1      |
+            +---+---+--------+
+        ");
+        }
+        Ok(())
+    }
 }
diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs 
b/datafusion/physical-plan/src/aggregates/row_hash.rs
index 1e7757de4a..cb22fbf9a0 100644
--- a/datafusion/physical-plan/src/aggregates/row_hash.rs
+++ b/datafusion/physical-plan/src/aggregates/row_hash.rs
@@ -33,7 +33,6 @@ use crate::metrics::{BaselineMetrics, MetricBuilder, 
RecordOutput};
 use crate::sorts::sort::sort_batch;
 use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
 use crate::spill::spill_manager::SpillManager;
-use crate::stream::RecordBatchStreamAdapter;
 use crate::{PhysicalExpr, aggregates, metrics};
 use crate::{RecordBatchStream, SendableRecordBatchStream};
 
@@ -210,6 +209,17 @@ impl SkipAggregationProbe {
     }
 }
 
+/// Controls the behavior when an out-of-memory condition occurs.
+#[derive(PartialEq, Debug)]
+enum OutOfMemoryMode {
+    /// When out of memory occurs, spill state to disk
+    Spill,
+    /// When out of memory occurs, attempt to emit group values early
+    EmitEarly,
+    /// When out of memory occurs, immediately report the error
+    ReportError,
+}
+
 /// HashTable based Grouping Aggregator
 ///
 /// # Design Goals
@@ -433,6 +443,9 @@ pub(crate) struct GroupedHashAggregateStream {
     /// The memory reservation for this grouping
     reservation: MemoryReservation,
 
+    /// The behavior to trigger when out of memory occurs
+    oom_mode: OutOfMemoryMode,
+
     /// Execution metrics
     baseline_metrics: BaselineMetrics,
 
@@ -510,12 +523,12 @@ impl GroupedHashAggregateStream {
         // Therefore, when we spill these intermediate states or pass them to 
another
         // aggregation operator, we must use a schema that includes both the 
group
         // columns **and** the partial-state columns.
-        let partial_agg_schema = create_schema(
+        let spill_schema = Arc::new(create_schema(
             &agg.input().schema(),
             &agg_group_by,
             &aggregate_exprs,
             AggregateMode::Partial,
-        )?;
+        )?);
 
         // Need to update the GROUP BY expressions to point to the correct 
column after schema change
         let merging_group_by_expr = agg_group_by
@@ -527,20 +540,25 @@ impl GroupedHashAggregateStream {
             })
             .collect();
 
-        let partial_agg_schema = Arc::new(partial_agg_schema);
+        let output_ordering = agg.cache.output_ordering();
 
-        let spill_expr =
+        let spill_sort_exprs =
             group_schema
                 .fields
                 .into_iter()
                 .enumerate()
                 .map(|(idx, field)| {
-                    PhysicalSortExpr::new_default(Arc::new(Column::new(
-                        field.name().as_str(),
-                        idx,
-                    )) as _)
+                    let output_expr = Column::new(field.name().as_str(), idx);
+
+                    // Try to use the sort options from the output ordering, 
if available.
+                    // This ensures that spilled state is sorted in the 
required order as well.
+                    let sort_options = output_ordering
+                        .and_then(|o| o.get_sort_options(&output_expr))
+                        .unwrap_or_default();
+
+                    PhysicalSortExpr::new(Arc::new(output_expr), sort_options)
                 });
-        let Some(spill_expr) = LexOrdering::new(spill_expr) else {
+        let Some(spill_ordering) = LexOrdering::new(spill_sort_exprs) else {
             return internal_err!("Spill expression is empty");
         };
 
@@ -550,11 +568,35 @@ impl GroupedHashAggregateStream {
             .collect::<Vec<_>>()
             .join(", ");
         let name = format!("GroupedHashAggregateStream[{partition}] 
({agg_fn_names})");
-        let reservation = MemoryConsumer::new(name)
-            .with_can_spill(true)
-            .register(context.memory_pool());
         let group_ordering = GroupOrdering::try_new(&agg.input_order_mode)?;
+        let oom_mode = match (agg.mode, &group_ordering) {
+            // In partial aggregation mode, always prefer to emit incomplete 
results early.
+            (AggregateMode::Partial, _) => OutOfMemoryMode::EmitEarly,
+            // For non-partial aggregation modes, emitting incomplete results 
is not an option.
+            // Instead, use disk spilling to store sorted, incomplete results, 
and merge them
+            // afterwards.
+            (_, GroupOrdering::None | GroupOrdering::Partial(_))
+                if context.runtime_env().disk_manager.tmp_files_enabled() =>
+            {
+                OutOfMemoryMode::Spill
+            }
+            // For `GroupOrdering::Full`, the incoming stream is already 
sorted. This ensures the
+            // number of incomplete groups can be kept small at all times. If 
we still hit
+            // an out-of-memory condition, spilling to disk would not be 
beneficial since the same
+            // situation is likely to reoccur when reading back the spilled 
data.
+            // Therefore, we fall back to simply reporting the error 
immediately.
+            // This mode will also be used if the `DiskManager` is not 
configured to allow spilling
+            // to disk.
+            _ => OutOfMemoryMode::ReportError,
+        };
+
         let group_values = new_group_values(group_schema, &group_ordering)?;
+        let reservation = MemoryConsumer::new(name)
+            // We interpret 'can spill' as 'can handle memory back pressure'.
+            // This value needs to be set to true for the default memory pool 
implementations
+            // to ensure fair application of back pressure amongst the memory 
consumers.
+            .with_can_spill(oom_mode != OutOfMemoryMode::ReportError)
+            .register(context.memory_pool());
         timer.done();
 
         let exec_state = ExecutionState::ReadingInput;
@@ -562,14 +604,14 @@ impl GroupedHashAggregateStream {
         let spill_manager = SpillManager::new(
             context.runtime_env(),
             metrics::SpillMetrics::new(&agg.metrics, partition),
-            Arc::clone(&partial_agg_schema),
+            Arc::clone(&spill_schema),
         )
         .with_compression_type(context.session_config().spill_compression());
 
         let spill_state = SpillState {
             spills: vec![],
-            spill_expr,
-            spill_schema: partial_agg_schema,
+            spill_expr: spill_ordering,
+            spill_schema,
             is_stream_merging: false,
             merging_aggregate_arguments,
             merging_group_by: 
PhysicalGroupBy::new_single(merging_group_by_expr),
@@ -627,6 +669,7 @@ impl GroupedHashAggregateStream {
             filter_expressions,
             group_by: agg_group_by,
             reservation,
+            oom_mode,
             group_values,
             current_group_indices: Default::default(),
             exec_state,
@@ -676,21 +719,24 @@ impl Stream for GroupedHashAggregateStream {
             match &self.exec_state {
                 ExecutionState::ReadingInput => 'reading_input: {
                     match ready!(self.input.poll_next_unpin(cx)) {
-                        // New batch to aggregate in partial aggregation 
operator
-                        Some(Ok(batch)) if self.mode == AggregateMode::Partial 
=> {
+                        // New batch to aggregate
+                        Some(Ok(batch)) => {
                             let timer = elapsed_compute.timer();
                             let input_rows = batch.num_rows();
 
-                            if let Some(reduction_factor) = 
self.reduction_factor.as_ref()
+                            if self.mode == AggregateMode::Partial
+                                && let Some(reduction_factor) =
+                                    self.reduction_factor.as_ref()
                             {
                                 reduction_factor.add_total(input_rows);
                             }
 
-                            // Do the grouping
+                            // Do the grouping.
+                            // `group_aggregate_batch` will _not_ have updated 
the memory reservation yet.
+                            // The rest of the code will first try to reduce 
memory usage by
+                            // already emitting results.
                             self.group_aggregate_batch(&batch)?;
 
-                            // If we can begin emitting rows, do so,
-                            // otherwise keep consuming input
                             assert!(!self.input_done);
 
                             // If the number of group values equals or exceeds 
the soft limit,
@@ -702,7 +748,13 @@ impl Stream for GroupedHashAggregateStream {
                                 break 'reading_input;
                             }
 
-                            if let Some(to_emit) = 
self.group_ordering.emit_to() {
+                            // Try to emit completed groups if possible.
+                            // If we already started spilling, we can no 
longer emit since
+                            // this might lead to incorrect output ordering
+                            if (self.spill_state.spills.is_empty()
+                                || self.spill_state.is_stream_merging)
+                                && let Some(to_emit) = 
self.group_ordering.emit_to()
+                            {
                                 timer.done();
                                 if let Some(batch) = self.emit(to_emit, 
false)? {
                                     self.exec_state =
@@ -712,18 +764,28 @@ impl Stream for GroupedHashAggregateStream {
                                 break 'reading_input;
                             }
 
-                            // Check if we should switch to skip aggregation 
mode
-                            // It's important that we do this before we early 
emit since we've
-                            // already updated the probe.
-                            self.update_skip_aggregation_probe(input_rows);
-                            if let Some(new_state) = 
self.switch_to_skip_aggregation()? {
-                                timer.done();
-                                self.exec_state = new_state;
-                                break 'reading_input;
+                            if self.mode == AggregateMode::Partial {
+                                // Spilling should never be activated in 
partial aggregation mode.
+                                assert!(!self.spill_state.is_stream_merging);
+
+                                // Check if we should switch to skip 
aggregation mode
+                                // It's important that we do this before we 
early emit since we've
+                                // already updated the probe.
+                                self.update_skip_aggregation_probe(input_rows);
+                                if let Some(new_state) =
+                                    self.switch_to_skip_aggregation()?
+                                {
+                                    timer.done();
+                                    self.exec_state = new_state;
+                                    break 'reading_input;
+                                }
                             }
 
-                            // Check if we need to emit early due to memory 
pressure
-                            if let Some(new_state) = 
self.emit_early_if_necessary()? {
+                            // If we reach this point, try to update the 
memory reservation
+                            // handling out-of-memory conditions as determined 
by the OOM mode.
+                            if let Some(new_state) =
+                                self.try_update_memory_reservation()?
+                            {
                                 timer.done();
                                 self.exec_state = new_state;
                                 break 'reading_input;
@@ -732,43 +794,6 @@ impl Stream for GroupedHashAggregateStream {
                             timer.done();
                         }
 
-                        // New batch to aggregate in terminal aggregation 
operator
-                        // (Final/FinalPartitioned/Single/SinglePartitioned)
-                        Some(Ok(batch)) => {
-                            let timer = elapsed_compute.timer();
-
-                            // Make sure we have enough capacity for `batch`, 
otherwise spill
-                            self.spill_previous_if_necessary(&batch)?;
-
-                            // Do the grouping
-                            self.group_aggregate_batch(&batch)?;
-
-                            // If we can begin emitting rows, do so,
-                            // otherwise keep consuming input
-                            assert!(!self.input_done);
-
-                            // If the number of group values equals or exceeds 
the soft limit,
-                            // emit all groups and switch to producing output
-                            if self.hit_soft_group_limit() {
-                                timer.done();
-                                self.set_input_done_and_produce_output()?;
-                                // make sure the exec_state just set is not 
overwritten below
-                                break 'reading_input;
-                            }
-
-                            if let Some(to_emit) = 
self.group_ordering.emit_to() {
-                                timer.done();
-                                if let Some(batch) = self.emit(to_emit, 
false)? {
-                                    self.exec_state =
-                                        ExecutionState::ProducingOutput(batch);
-                                };
-                                // make sure the exec_state just set is not 
overwritten below
-                                break 'reading_input;
-                            }
-
-                            timer.done();
-                        }
-
                         // Found error from input stream
                         Some(Err(e)) => {
                             // inner had error, return to caller
@@ -987,25 +1012,56 @@ impl GroupedHashAggregateStream {
             }
         }
 
-        match self.update_memory_reservation() {
-            // Here we can ignore `insufficient_capacity_err` because we will 
spill later,
-            // but at least one batch should fit in the memory
-            Err(DataFusionError::ResourcesExhausted(_))
-                if self.group_values.len() >= self.batch_size =>
-            {
-                Ok(())
+        Ok(())
+    }
+
+    /// Attempts to update the memory reservation. If that fails due to a
+    /// [DataFusionError::ResourcesExhausted] error, an attempt will be made 
to resolve
+    /// the out-of-memory condition based on the [out-of-memory handling 
mode](OutOfMemoryMode).
+    ///
+    /// If the out-of-memory condition can not be resolved, an `Err` value 
will be returned
+    ///
+    /// Returns `Ok(Some(ExecutionState))` if the state should be changed, 
`Ok(None)` otherwise.
+    fn try_update_memory_reservation(&mut self) -> 
Result<Option<ExecutionState>> {
+        let oom = match self.update_memory_reservation() {
+            Err(e @ DataFusionError::ResourcesExhausted(_)) => e,
+            Err(e) => return Err(e),
+            Ok(_) => return Ok(None),
+        };
+
+        match self.oom_mode {
+            OutOfMemoryMode::Spill if !self.group_values.is_empty() => {
+                self.spill()?;
+                self.clear_shrink(self.batch_size);
+                self.update_memory_reservation()?;
+                Ok(None)
             }
-            other => other,
+            OutOfMemoryMode::EmitEarly if self.group_values.len() > 1 => {
+                let n = if self.group_values.len() >= self.batch_size {
+                    // Try to emit an integer multiple of batch size if 
possible
+                    self.group_values.len() / self.batch_size * self.batch_size
+                } else {
+                    // Otherwise emit whatever we can
+                    self.group_values.len()
+                };
+
+                if let Some(batch) = self.emit(EmitTo::First(n), false)? {
+                    Ok(Some(ExecutionState::ProducingOutput(batch)))
+                } else {
+                    Err(oom)
+                }
+            }
+            _ => Err(oom),
         }
     }
 
     fn update_memory_reservation(&mut self) -> Result<()> {
         let acc = self.accumulators.iter().map(|x| x.size()).sum::<usize>();
-        let reservation_result = self.reservation.try_resize(
-            acc + self.group_values.size()
-                + self.group_ordering.size()
-                + self.current_group_indices.allocated_size(),
-        );
+        let new_size = acc
+            + self.group_values.size()
+            + self.group_ordering.size()
+            + self.current_group_indices.allocated_size();
+        let reservation_result = self.reservation.try_resize(new_size);
 
         if reservation_result.is_ok() {
             self.spill_state
@@ -1060,24 +1116,6 @@ impl GroupedHashAggregateStream {
         Ok(Some(batch))
     }
 
-    /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the 
memory target slightly
-    /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to 
disk and clear the
-    /// memory. Currently only [`GroupOrdering::None`] is supported for 
spilling.
-    fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) -> 
Result<()> {
-        // TODO: support group_ordering for spilling
-        if !self.group_values.is_empty()
-            && batch.num_rows() > 0
-            && matches!(self.group_ordering, GroupOrdering::None)
-            && !self.spill_state.is_stream_merging
-            && self.update_memory_reservation().is_err()
-        {
-            assert_ne!(self.mode, AggregateMode::Partial);
-            self.spill()?;
-            self.clear_shrink(batch);
-        }
-        Ok(())
-    }
-
     /// Emit all intermediate aggregation states, sort them, and store them on 
disk.
     /// This process helps in reducing memory pressure by allowing the data to 
be
     /// read back with streaming merge.
@@ -1115,72 +1153,15 @@ impl GroupedHashAggregateStream {
     }
 
     /// Clear memory and shirk capacities to the size of the batch.
-    fn clear_shrink(&mut self, batch: &RecordBatch) {
-        self.group_values.clear_shrink(batch);
+    fn clear_shrink(&mut self, num_rows: usize) {
+        self.group_values.clear_shrink(num_rows);
         self.current_group_indices.clear();
-        self.current_group_indices.shrink_to(batch.num_rows());
+        self.current_group_indices.shrink_to(num_rows);
     }
 
     /// Clear memory and shirk capacities to zero.
     fn clear_all(&mut self) {
-        let s = self.schema();
-        self.clear_shrink(&RecordBatch::new_empty(s));
-    }
-
-    /// Emit if the used memory exceeds the target for partial aggregation.
-    /// Currently only [`GroupOrdering::None`] is supported for early emitting.
-    /// TODO: support group_ordering for early emitting
-    ///
-    /// Returns `Some(ExecutionState)` if the state should be changed, None 
otherwise.
-    fn emit_early_if_necessary(&mut self) -> Result<Option<ExecutionState>> {
-        if self.group_values.len() >= self.batch_size
-            && matches!(self.group_ordering, GroupOrdering::None)
-            && self.update_memory_reservation().is_err()
-        {
-            assert_eq!(self.mode, AggregateMode::Partial);
-            let n = self.group_values.len() / self.batch_size * 
self.batch_size;
-            if let Some(batch) = self.emit(EmitTo::First(n), false)? {
-                return Ok(Some(ExecutionState::ProducingOutput(batch)));
-            };
-        }
-        Ok(None)
-    }
-
-    /// At this point, all the inputs are read and there are some spills.
-    /// Emit the remaining rows and create a batch.
-    /// Conduct a streaming merge sort between the batch and spilled data. 
Since the stream is fully
-    /// sorted, set `self.group_ordering` to Full, then later we can read with 
[`EmitTo::First`].
-    fn update_merged_stream(&mut self) -> Result<()> {
-        let Some(batch) = self.emit(EmitTo::All, true)? else {
-            return Ok(());
-        };
-        // clear up memory for streaming_merge
-        self.clear_all();
-        self.update_memory_reservation()?;
-        let mut streams: Vec<SendableRecordBatchStream> = vec![];
-        let expr = self.spill_state.spill_expr.clone();
-        let schema = batch.schema();
-        streams.push(Box::pin(RecordBatchStreamAdapter::new(
-            Arc::clone(&schema),
-            futures::stream::once(futures::future::lazy(move |_| {
-                sort_batch(&batch, &expr, None)
-            })),
-        )));
-
-        self.spill_state.is_stream_merging = true;
-        self.input = StreamingMergeBuilder::new()
-            .with_streams(streams)
-            .with_schema(schema)
-            .with_spill_manager(self.spill_state.spill_manager.clone())
-            .with_sorted_spill_files(std::mem::take(&mut 
self.spill_state.spills))
-            .with_expressions(&self.spill_state.spill_expr)
-            .with_metrics(self.baseline_metrics.clone())
-            .with_batch_size(self.batch_size)
-            .with_reservation(self.reservation.new_empty())
-            .build()?;
-        self.input_done = false;
-        self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new());
-        Ok(())
+        self.clear_shrink(0);
     }
 
     /// returns true if there is a soft groups limit and the number of distinct
@@ -1192,18 +1173,60 @@ impl GroupedHashAggregateStream {
         group_values_soft_limit <= self.group_values.len()
     }
 
-    /// common function for signalling end of processing of the input stream
+    /// Finalizes reading of the input stream and prepares for producing 
output values.
+    ///
+    /// This method is called both when the original input stream and,
+    /// in case of disk spilling, the SPM stream have been drained.
     fn set_input_done_and_produce_output(&mut self) -> Result<()> {
         self.input_done = true;
         self.group_ordering.input_done();
         let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
         let timer = elapsed_compute.timer();
         self.exec_state = if self.spill_state.spills.is_empty() {
+            // Input has been entirely processed without spilling to disk.
+
+            // Flush any remaining group values.
             let batch = self.emit(EmitTo::All, false)?;
+
+            // If there are none, we're done; otherwise switch to emitting them
             batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput)
         } else {
-            // If spill files exist, stream-merge them.
-            self.update_merged_stream()?;
+            // Spill any remaining data to disk. There is some performance 
overhead in
+            // writing out this last chunk of data and reading it back. The 
benefit of
+            // doing this is that memory usage for this stream is reduced, and 
the more
+            // sophisticated memory handling in `MultiLevelMergeBuilder` can 
take over
+            // instead.
+            // Spilling to disk and reading back also ensures batch size is 
consistent
+            // rather than potentially having one significantly larger last 
batch.
+            self.spill()?;
+
+            // Mark that we're switching to stream merging mode.
+            self.spill_state.is_stream_merging = true;
+
+            self.input = StreamingMergeBuilder::new()
+                .with_schema(Arc::clone(&self.spill_state.spill_schema))
+                .with_spill_manager(self.spill_state.spill_manager.clone())
+                .with_sorted_spill_files(std::mem::take(&mut 
self.spill_state.spills))
+                .with_expressions(&self.spill_state.spill_expr)
+                .with_metrics(self.baseline_metrics.clone())
+                .with_batch_size(self.batch_size)
+                .with_reservation(self.reservation.new_empty())
+                .build()?;
+            self.input_done = false;
+
+            // Reset the group values collectors.
+            self.clear_all();
+
+            // We can now use `GroupOrdering::Full` since the spill files are 
sorted
+            // on the grouping columns.
+            self.group_ordering = 
GroupOrdering::Full(GroupOrderingFull::new());
+
+            // Use `OutOfMemoryMode::ReportError` from this point on
+            // to ensure we don't spill the spilled data to disk again.
+            self.oom_mode = OutOfMemoryMode::ReportError;
+
+            self.update_memory_reservation()?;
+
             ExecutionState::ReadingInput
         };
         timer.done();
diff --git a/datafusion/physical-plan/src/test.rs 
b/datafusion/physical-plan/src/test.rs
index f2336920b3..c94b5a4131 100644
--- a/datafusion/physical-plan/src/test.rs
+++ b/datafusion/physical-plan/src/test.rs
@@ -342,6 +342,7 @@ impl TestMemoryExec {
         }
 
         self.sort_information = sort_information;
+        self.cache = self.compute_properties();
         Ok(self)
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to