This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new eb30c19b30 Implement disk spilling for all grouping ordering modes in
GroupedHashAggregateStream (#19287)
eb30c19b30 is described below
commit eb30c19b303e22ab2ab104e0fa67ff8f511261f2
Author: Pepijn Van Eeckhoudt <[email protected]>
AuthorDate: Sat Dec 20 12:22:25 2025 +0100
Implement disk spilling for all grouping ordering modes in
GroupedHashAggregateStream (#19287)
## Which issue does this PR close?
- Closes #19286.
- Related to #13123
## Rationale for this change
GroupedHashAggregateStream currently always reports that it can spill to
the memory tracking subsystem even though this is dependent on the
aggregation mode and the grouping order.
The optimistic logic in `group_aggregate_batch` does not correctly take
the spilling preconditions into account which can lead to excessive
memory use.
In order to to resolve this, this PR implements disk spilling for all
grouping modes.
## What changes are included in this PR?
- Correctly set `MemoryConsumer::can_spill` to reflect actual spilling
behaviour
- Ensure optimistic out-of-memory tolerance in `group_aggregate_batch`
is aligned with disk spilling or early emission logic
- Implement output order respecting disk spilling for partially and
fully sorted inputs.
## Are these changes tested?
Added additional test case to demonstrate problem.
Added test case to check that output order is respected after spilling.
## Are there any user-facing changes?
Yes, memory exhaustion may be reported much earlier in the query
pipeline than is currently the case. In my local tests with a per
consumer memory limit of 32MiB, grouped aggregation would consume 480MiB
in practice. This was then reported by ExternalSortExec which choked on
trying to reserve that much memory.
---
datafusion/physical-expr-common/src/sort_expr.rs | 11 +
.../src/aggregates/group_values/mod.rs | 4 +-
.../aggregates/group_values/multi_group_by/mod.rs | 9 +-
.../src/aggregates/group_values/row.rs | 9 +-
.../group_values/single_group_by/boolean.rs | 3 +-
.../group_values/single_group_by/bytes.rs | 4 +-
.../group_values/single_group_by/bytes_view.rs | 4 +-
.../group_values/single_group_by/primitive.rs | 8 +-
datafusion/physical-plan/src/aggregates/mod.rs | 214 ++++++++++++-
.../physical-plan/src/aggregates/row_hash.rs | 353 +++++++++++----------
datafusion/physical-plan/src/test.rs | 1 +
11 files changed, 423 insertions(+), 197 deletions(-)
diff --git a/datafusion/physical-expr-common/src/sort_expr.rs
b/datafusion/physical-expr-common/src/sort_expr.rs
index e8558c7643..db30dd6ed2 100644
--- a/datafusion/physical-expr-common/src/sort_expr.rs
+++ b/datafusion/physical-expr-common/src/sort_expr.rs
@@ -457,6 +457,17 @@ impl LexOrdering {
req.expr.eq(&cur.expr) && is_reversed_sort_options(&req.options,
&cur.options)
})
}
+
+ /// Returns the sort options for the given expression if one is defined in
this `LexOrdering`.
+ pub fn get_sort_options(&self, expr: &dyn PhysicalExpr) ->
Option<SortOptions> {
+ for e in self {
+ if e.expr.as_ref().dyn_eq(expr) {
+ return Some(e.options);
+ }
+ }
+
+ None
+ }
}
/// Check if two SortOptions represent reversed orderings.
diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs
b/datafusion/physical-plan/src/aggregates/group_values/mod.rs
index f419328d11..2f3b1a19e7 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs
@@ -22,7 +22,7 @@ use arrow::array::types::{
Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
};
-use arrow::array::{ArrayRef, RecordBatch, downcast_primitive};
+use arrow::array::{ArrayRef, downcast_primitive};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use datafusion_common::Result;
@@ -112,7 +112,7 @@ pub trait GroupValues: Send {
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
/// Clear the contents and shrink the capacity to the size of the batch
(free up memory usage)
- fn clear_shrink(&mut self, batch: &RecordBatch);
+ fn clear_shrink(&mut self, num_rows: usize);
}
/// Return a specialized implementation of [`GroupValues`] for the given
schema.
diff --git
a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs
b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs
index b62bc11aff..4c9e376fc4 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs
@@ -30,7 +30,7 @@ use crate::aggregates::group_values::multi_group_by::{
bytes_view::ByteViewGroupValueBuilder,
primitive::PrimitiveGroupValueBuilder,
};
use ahash::RandomState;
-use arrow::array::{Array, ArrayRef, RecordBatch};
+use arrow::array::{Array, ArrayRef};
use arrow::compute::cast;
use arrow::datatypes::{
BinaryViewType, DataType, Date32Type, Date64Type, Decimal128Type,
Float32Type,
@@ -1181,14 +1181,13 @@ impl<const STREAMING: bool> GroupValues for
GroupValuesColumn<STREAMING> {
Ok(output)
}
- fn clear_shrink(&mut self, batch: &RecordBatch) {
- let count = batch.num_rows();
+ fn clear_shrink(&mut self, num_rows: usize) {
self.group_values.clear();
self.map.clear();
- self.map.shrink_to(count, |_| 0); // hasher does not matter since the
map is cleared
+ self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since
the map is cleared
self.map_size = self.map.capacity() * size_of::<(u64, usize)>();
self.hashes_buffer.clear();
- self.hashes_buffer.shrink_to(count);
+ self.hashes_buffer.shrink_to(num_rows);
// Such structures are only used in `non-streaming` case
if !STREAMING {
diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs
b/datafusion/physical-plan/src/aggregates/group_values/row.rs
index a5e5c16006..dd794c9573 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/row.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs
@@ -17,7 +17,7 @@
use crate::aggregates::group_values::GroupValues;
use ahash::RandomState;
-use arrow::array::{Array, ArrayRef, ListArray, RecordBatch, StructArray};
+use arrow::array::{Array, ArrayRef, ListArray, StructArray};
use arrow::compute::cast;
use arrow::datatypes::{DataType, SchemaRef};
use arrow::row::{RowConverter, Rows, SortField};
@@ -243,17 +243,16 @@ impl GroupValues for GroupValuesRows {
Ok(output)
}
- fn clear_shrink(&mut self, batch: &RecordBatch) {
- let count = batch.num_rows();
+ fn clear_shrink(&mut self, num_rows: usize) {
self.group_values = self.group_values.take().map(|mut rows| {
rows.clear();
rows
});
self.map.clear();
- self.map.shrink_to(count, |_| 0); // hasher does not matter since the
map is cleared
+ self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since
the map is cleared
self.map_size = self.map.capacity() * size_of::<(u64, usize)>();
self.hashes_buffer.clear();
- self.hashes_buffer.shrink_to(count);
+ self.hashes_buffer.shrink_to(num_rows);
}
}
diff --git
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs
index 44b763a91f..e993c0c53d 100644
---
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs
+++
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs
@@ -19,7 +19,6 @@ use crate::aggregates::group_values::GroupValues;
use arrow::array::{
ArrayRef, AsArray as _, BooleanArray, BooleanBufferBuilder,
NullBufferBuilder,
- RecordBatch,
};
use datafusion_common::Result;
use datafusion_expr::EmitTo;
@@ -146,7 +145,7 @@ impl GroupValues for GroupValuesBoolean {
Ok(vec![Arc::new(BooleanArray::new(values, nulls)) as _])
}
- fn clear_shrink(&mut self, _batch: &RecordBatch) {
+ fn clear_shrink(&mut self, _num_rows: usize) {
self.false_group = None;
self.true_group = None;
self.null_group = None;
diff --git
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs
index b901aee313..b881a51b25 100644
---
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs
+++
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs
@@ -19,7 +19,7 @@ use std::mem::size_of;
use crate::aggregates::group_values::GroupValues;
-use arrow::array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch};
+use arrow::array::{Array, ArrayRef, OffsetSizeTrait};
use datafusion_common::Result;
use datafusion_expr::EmitTo;
use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType};
@@ -120,7 +120,7 @@ impl<O: OffsetSizeTrait> GroupValues for
GroupValuesBytes<O> {
Ok(vec![group_values])
}
- fn clear_shrink(&mut self, _batch: &RecordBatch) {
+ fn clear_shrink(&mut self, _num_rows: usize) {
// in theory we could potentially avoid this reallocation and clear the
// contents of the maps, but for now we just reset the map from the
beginning
self.map.take();
diff --git
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs
index be9a0334e3..7a56f7c52c 100644
---
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs
+++
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs
@@ -16,7 +16,7 @@
// under the License.
use crate::aggregates::group_values::GroupValues;
-use arrow::array::{Array, ArrayRef, RecordBatch};
+use arrow::array::{Array, ArrayRef};
use datafusion_expr::EmitTo;
use datafusion_physical_expr::binary_map::OutputType;
use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap;
@@ -122,7 +122,7 @@ impl GroupValues for GroupValuesBytesView {
Ok(vec![group_values])
}
- fn clear_shrink(&mut self, _batch: &RecordBatch) {
+ fn clear_shrink(&mut self, _num_rows: usize) {
// in theory we could potentially avoid this reallocation and clear the
// contents of the maps, but for now we just reset the map from the
beginning
self.map.take();
diff --git
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs
index 41d34218f6..c46cde8786 100644
---
a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs
+++
b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs
@@ -23,7 +23,6 @@ use arrow::array::{
cast::AsArray,
};
use arrow::datatypes::{DataType, i256};
-use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use datafusion_expr::EmitTo;
@@ -213,11 +212,10 @@ where
Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))])
}
- fn clear_shrink(&mut self, batch: &RecordBatch) {
- let count = batch.num_rows();
+ fn clear_shrink(&mut self, num_rows: usize) {
self.values.clear();
- self.values.shrink_to(count);
+ self.values.shrink_to(num_rows);
self.map.clear();
- self.map.shrink_to(count, |_| 0); // hasher does not matter since the
map is cleared
+ self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since
the map is cleared
}
}
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs
b/datafusion/physical-plan/src/aggregates/mod.rs
index b0d432a9de..06f12a9019 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -31,7 +31,6 @@ use crate::filter_pushdown::{
FilterPushdownPropagation, PushedDownPredicate,
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
-use crate::windows::get_ordered_partition_by_indices;
use crate::{
DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
SendableRecordBatchStream, Statistics,
@@ -613,12 +612,13 @@ impl AggregateExec {
// If existing ordering satisfies a prefix of the GROUP BY expressions,
// prefix requirements with this section. In this case, aggregation
will
// work more efficiently.
- let indices = get_ordered_partition_by_indices(&groupby_exprs,
&input)?;
- let mut new_requirements = indices
- .iter()
- .map(|&idx| {
- PhysicalSortRequirement::new(Arc::clone(&groupby_exprs[idx]),
None)
- })
+ // Copy the `PhysicalSortExpr`s to retain the sort options.
+ let (new_sort_exprs, indices) =
+ input_eq_properties.find_longest_permutation(&groupby_exprs)?;
+
+ let mut new_requirements = new_sort_exprs
+ .into_iter()
+ .map(PhysicalSortRequirement::from)
.collect::<Vec<_>>();
let req = get_finer_aggregate_exprs_requirement(
@@ -1815,7 +1815,7 @@ mod tests {
use crate::test::exec::{BlockingExec,
assert_strong_count_converges_to_zero};
use arrow::array::{
- DictionaryArray, Float32Array, Float64Array, Int32Array, StructArray,
+ DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array,
StructArray,
UInt32Array, UInt64Array,
};
use arrow::compute::{SortOptions, concat_batches};
@@ -1837,6 +1837,8 @@ mod tests {
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr::expressions::lit;
+ use crate::projection::ProjectionExec;
+ use datafusion_physical_expr::projection::ProjectionExpr;
use futures::{FutureExt, Stream};
use insta::{allow_duplicates, assert_snapshot};
@@ -2484,7 +2486,7 @@ mod tests {
] {
let n_aggr = aggregates.len();
let partial_aggregate = Arc::new(AggregateExec::try_new(
- AggregateMode::Partial,
+ AggregateMode::Single,
groups,
aggregates,
vec![None; n_aggr],
@@ -3420,6 +3422,117 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn test_grouped_aggregation_respects_memory_limit() -> Result<()> {
+ // test with spill
+ fn create_record_batch(
+ schema: &Arc<Schema>,
+ data: (Vec<u32>, Vec<f64>),
+ ) -> Result<RecordBatch> {
+ Ok(RecordBatch::try_new(
+ Arc::clone(schema),
+ vec![
+ Arc::new(UInt32Array::from(data.0)),
+ Arc::new(Float64Array::from(data.1)),
+ ],
+ )?)
+ }
+
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::UInt32, false),
+ Field::new("b", DataType::Float64, false),
+ ]));
+
+ let batches = vec![
+ create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0,
3.0, 4.0]))?,
+ create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0,
3.0, 4.0]))?,
+ ];
+ let plan: Arc<dyn ExecutionPlan> =
+ TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema),
None)?;
+ let proj = ProjectionExec::try_new(
+ vec![
+ ProjectionExpr::new(lit("0"), "l".to_string()),
+ ProjectionExpr::new_from_expression(col("a", &schema)?,
&schema)?,
+ ProjectionExpr::new_from_expression(col("b", &schema)?,
&schema)?,
+ ],
+ plan,
+ )?;
+ let plan: Arc<dyn ExecutionPlan> = Arc::new(proj);
+ let schema = plan.schema();
+
+ let grouping_set = PhysicalGroupBy::new(
+ vec![
+ (col("l", &schema)?, "l".to_string()),
+ (col("a", &schema)?, "a".to_string()),
+ ],
+ vec![],
+ vec![vec![false, false]],
+ false,
+ );
+
+ // Test with MIN for simple intermediate state (min) and AVG for
multiple intermediate states (partial sum, partial count).
+ let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
+ Arc::new(
+ AggregateExprBuilder::new(
+ datafusion_functions_aggregate::min_max::min_udaf(),
+ vec![col("b", &schema)?],
+ )
+ .schema(Arc::clone(&schema))
+ .alias("MIN(b)")
+ .build()?,
+ ),
+ Arc::new(
+ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
+ .schema(Arc::clone(&schema))
+ .alias("AVG(b)")
+ .build()?,
+ ),
+ ];
+
+ let single_aggregate = Arc::new(AggregateExec::try_new(
+ AggregateMode::Single,
+ grouping_set,
+ aggregates,
+ vec![None, None],
+ plan,
+ Arc::clone(&schema),
+ )?);
+
+ let batch_size = 2;
+ let memory_pool = Arc::new(FairSpillPool::new(2000));
+ let task_ctx = Arc::new(
+ TaskContext::default()
+
.with_session_config(SessionConfig::new().with_batch_size(batch_size))
+ .with_runtime(Arc::new(
+ RuntimeEnvBuilder::new()
+ .with_memory_pool(memory_pool)
+ .build()?,
+ )),
+ );
+
+ let result = collect(single_aggregate.execute(0,
Arc::clone(&task_ctx))?).await;
+ match result {
+ Ok(result) => {
+ assert_spill_count_metric(true, single_aggregate);
+
+ allow_duplicates! {
+ assert_snapshot!(batches_to_string(&result), @r"
+ +---+---+--------+--------+
+ | l | a | MIN(b) | AVG(b) |
+ +---+---+--------+--------+
+ | 0 | 2 | 1.0 | 1.0 |
+ | 0 | 3 | 2.0 | 2.0 |
+ | 0 | 4 | 3.0 | 3.5 |
+ +---+---+--------+--------+
+ ");
+ }
+ }
+ Err(e) => assert!(matches!(e,
DataFusionError::ResourcesExhausted(_))),
+ }
+
+ Ok(())
+ }
+
#[tokio::test]
async fn test_aggregate_statistics_edge_cases() -> Result<()> {
use crate::test::exec::StatisticsExec;
@@ -3492,4 +3605,87 @@ mod tests {
Ok(())
}
+
+ #[tokio::test]
+ async fn test_order_is_retained_when_spilling() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, false),
+ Field::new("b", DataType::Int64, false),
+ Field::new("c", DataType::Int64, false),
+ ]));
+
+ let batches = vec![vec![
+ RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![
+ Arc::new(Int64Array::from(vec![2])),
+ Arc::new(Int64Array::from(vec![2])),
+ Arc::new(Int64Array::from(vec![1])),
+ ],
+ )?,
+ RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![
+ Arc::new(Int64Array::from(vec![1])),
+ Arc::new(Int64Array::from(vec![1])),
+ Arc::new(Int64Array::from(vec![1])),
+ ],
+ )?,
+ RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![
+ Arc::new(Int64Array::from(vec![0])),
+ Arc::new(Int64Array::from(vec![0])),
+ Arc::new(Int64Array::from(vec![1])),
+ ],
+ )?,
+ ]];
+ let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema),
None)?;
+ let scan = scan.try_with_sort_information(vec![
+ LexOrdering::new([PhysicalSortExpr::new(
+ col("b", schema.as_ref())?,
+ SortOptions::default().desc(),
+ )])
+ .unwrap(),
+ ])?;
+
+ let aggr = Arc::new(AggregateExec::try_new(
+ AggregateMode::Single,
+ PhysicalGroupBy::new(
+ vec![
+ (col("b", schema.as_ref())?, "b".to_string()),
+ (col("c", schema.as_ref())?, "c".to_string()),
+ ],
+ vec![],
+ vec![vec![false, false]],
+ false,
+ ),
+ vec![Arc::new(
+ AggregateExprBuilder::new(sum_udaf(), vec![col("c",
schema.as_ref())?])
+ .schema(Arc::clone(&schema))
+ .alias("SUM(c)")
+ .build()?,
+ )],
+ vec![None],
+ Arc::new(scan) as Arc<dyn ExecutionPlan>,
+ Arc::clone(&schema),
+ )?);
+
+ let task_ctx = new_spill_ctx(1, 600);
+ let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await?;
+ assert_spill_count_metric(true, aggr);
+
+ allow_duplicates! {
+ assert_snapshot!(batches_to_string(&result), @r"
+ +---+---+--------+
+ | b | c | SUM(c) |
+ +---+---+--------+
+ | 2 | 1 | 1 |
+ | 1 | 1 | 1 |
+ | 0 | 1 | 1 |
+ +---+---+--------+
+ ");
+ }
+ Ok(())
+ }
}
diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs
b/datafusion/physical-plan/src/aggregates/row_hash.rs
index 1e7757de4a..cb22fbf9a0 100644
--- a/datafusion/physical-plan/src/aggregates/row_hash.rs
+++ b/datafusion/physical-plan/src/aggregates/row_hash.rs
@@ -33,7 +33,6 @@ use crate::metrics::{BaselineMetrics, MetricBuilder,
RecordOutput};
use crate::sorts::sort::sort_batch;
use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
use crate::spill::spill_manager::SpillManager;
-use crate::stream::RecordBatchStreamAdapter;
use crate::{PhysicalExpr, aggregates, metrics};
use crate::{RecordBatchStream, SendableRecordBatchStream};
@@ -210,6 +209,17 @@ impl SkipAggregationProbe {
}
}
+/// Controls the behavior when an out-of-memory condition occurs.
+#[derive(PartialEq, Debug)]
+enum OutOfMemoryMode {
+ /// When out of memory occurs, spill state to disk
+ Spill,
+ /// When out of memory occurs, attempt to emit group values early
+ EmitEarly,
+ /// When out of memory occurs, immediately report the error
+ ReportError,
+}
+
/// HashTable based Grouping Aggregator
///
/// # Design Goals
@@ -433,6 +443,9 @@ pub(crate) struct GroupedHashAggregateStream {
/// The memory reservation for this grouping
reservation: MemoryReservation,
+ /// The behavior to trigger when out of memory occurs
+ oom_mode: OutOfMemoryMode,
+
/// Execution metrics
baseline_metrics: BaselineMetrics,
@@ -510,12 +523,12 @@ impl GroupedHashAggregateStream {
// Therefore, when we spill these intermediate states or pass them to
another
// aggregation operator, we must use a schema that includes both the
group
// columns **and** the partial-state columns.
- let partial_agg_schema = create_schema(
+ let spill_schema = Arc::new(create_schema(
&agg.input().schema(),
&agg_group_by,
&aggregate_exprs,
AggregateMode::Partial,
- )?;
+ )?);
// Need to update the GROUP BY expressions to point to the correct
column after schema change
let merging_group_by_expr = agg_group_by
@@ -527,20 +540,25 @@ impl GroupedHashAggregateStream {
})
.collect();
- let partial_agg_schema = Arc::new(partial_agg_schema);
+ let output_ordering = agg.cache.output_ordering();
- let spill_expr =
+ let spill_sort_exprs =
group_schema
.fields
.into_iter()
.enumerate()
.map(|(idx, field)| {
- PhysicalSortExpr::new_default(Arc::new(Column::new(
- field.name().as_str(),
- idx,
- )) as _)
+ let output_expr = Column::new(field.name().as_str(), idx);
+
+ // Try to use the sort options from the output ordering,
if available.
+ // This ensures that spilled state is sorted in the
required order as well.
+ let sort_options = output_ordering
+ .and_then(|o| o.get_sort_options(&output_expr))
+ .unwrap_or_default();
+
+ PhysicalSortExpr::new(Arc::new(output_expr), sort_options)
});
- let Some(spill_expr) = LexOrdering::new(spill_expr) else {
+ let Some(spill_ordering) = LexOrdering::new(spill_sort_exprs) else {
return internal_err!("Spill expression is empty");
};
@@ -550,11 +568,35 @@ impl GroupedHashAggregateStream {
.collect::<Vec<_>>()
.join(", ");
let name = format!("GroupedHashAggregateStream[{partition}]
({agg_fn_names})");
- let reservation = MemoryConsumer::new(name)
- .with_can_spill(true)
- .register(context.memory_pool());
let group_ordering = GroupOrdering::try_new(&agg.input_order_mode)?;
+ let oom_mode = match (agg.mode, &group_ordering) {
+ // In partial aggregation mode, always prefer to emit incomplete
results early.
+ (AggregateMode::Partial, _) => OutOfMemoryMode::EmitEarly,
+ // For non-partial aggregation modes, emitting incomplete results
is not an option.
+ // Instead, use disk spilling to store sorted, incomplete results,
and merge them
+ // afterwards.
+ (_, GroupOrdering::None | GroupOrdering::Partial(_))
+ if context.runtime_env().disk_manager.tmp_files_enabled() =>
+ {
+ OutOfMemoryMode::Spill
+ }
+ // For `GroupOrdering::Full`, the incoming stream is already
sorted. This ensures the
+ // number of incomplete groups can be kept small at all times. If
we still hit
+ // an out-of-memory condition, spilling to disk would not be
beneficial since the same
+ // situation is likely to reoccur when reading back the spilled
data.
+ // Therefore, we fall back to simply reporting the error
immediately.
+ // This mode will also be used if the `DiskManager` is not
configured to allow spilling
+ // to disk.
+ _ => OutOfMemoryMode::ReportError,
+ };
+
let group_values = new_group_values(group_schema, &group_ordering)?;
+ let reservation = MemoryConsumer::new(name)
+ // We interpret 'can spill' as 'can handle memory back pressure'.
+ // This value needs to be set to true for the default memory pool
implementations
+ // to ensure fair application of back pressure amongst the memory
consumers.
+ .with_can_spill(oom_mode != OutOfMemoryMode::ReportError)
+ .register(context.memory_pool());
timer.done();
let exec_state = ExecutionState::ReadingInput;
@@ -562,14 +604,14 @@ impl GroupedHashAggregateStream {
let spill_manager = SpillManager::new(
context.runtime_env(),
metrics::SpillMetrics::new(&agg.metrics, partition),
- Arc::clone(&partial_agg_schema),
+ Arc::clone(&spill_schema),
)
.with_compression_type(context.session_config().spill_compression());
let spill_state = SpillState {
spills: vec![],
- spill_expr,
- spill_schema: partial_agg_schema,
+ spill_expr: spill_ordering,
+ spill_schema,
is_stream_merging: false,
merging_aggregate_arguments,
merging_group_by:
PhysicalGroupBy::new_single(merging_group_by_expr),
@@ -627,6 +669,7 @@ impl GroupedHashAggregateStream {
filter_expressions,
group_by: agg_group_by,
reservation,
+ oom_mode,
group_values,
current_group_indices: Default::default(),
exec_state,
@@ -676,21 +719,24 @@ impl Stream for GroupedHashAggregateStream {
match &self.exec_state {
ExecutionState::ReadingInput => 'reading_input: {
match ready!(self.input.poll_next_unpin(cx)) {
- // New batch to aggregate in partial aggregation
operator
- Some(Ok(batch)) if self.mode == AggregateMode::Partial
=> {
+ // New batch to aggregate
+ Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
let input_rows = batch.num_rows();
- if let Some(reduction_factor) =
self.reduction_factor.as_ref()
+ if self.mode == AggregateMode::Partial
+ && let Some(reduction_factor) =
+ self.reduction_factor.as_ref()
{
reduction_factor.add_total(input_rows);
}
- // Do the grouping
+ // Do the grouping.
+ // `group_aggregate_batch` will _not_ have updated
the memory reservation yet.
+ // The rest of the code will first try to reduce
memory usage by
+ // already emitting results.
self.group_aggregate_batch(&batch)?;
- // If we can begin emitting rows, do so,
- // otherwise keep consuming input
assert!(!self.input_done);
// If the number of group values equals or exceeds
the soft limit,
@@ -702,7 +748,13 @@ impl Stream for GroupedHashAggregateStream {
break 'reading_input;
}
- if let Some(to_emit) =
self.group_ordering.emit_to() {
+ // Try to emit completed groups if possible.
+ // If we already started spilling, we can no
longer emit since
+ // this might lead to incorrect output ordering
+ if (self.spill_state.spills.is_empty()
+ || self.spill_state.is_stream_merging)
+ && let Some(to_emit) =
self.group_ordering.emit_to()
+ {
timer.done();
if let Some(batch) = self.emit(to_emit,
false)? {
self.exec_state =
@@ -712,18 +764,28 @@ impl Stream for GroupedHashAggregateStream {
break 'reading_input;
}
- // Check if we should switch to skip aggregation
mode
- // It's important that we do this before we early
emit since we've
- // already updated the probe.
- self.update_skip_aggregation_probe(input_rows);
- if let Some(new_state) =
self.switch_to_skip_aggregation()? {
- timer.done();
- self.exec_state = new_state;
- break 'reading_input;
+ if self.mode == AggregateMode::Partial {
+ // Spilling should never be activated in
partial aggregation mode.
+ assert!(!self.spill_state.is_stream_merging);
+
+ // Check if we should switch to skip
aggregation mode
+ // It's important that we do this before we
early emit since we've
+ // already updated the probe.
+ self.update_skip_aggregation_probe(input_rows);
+ if let Some(new_state) =
+ self.switch_to_skip_aggregation()?
+ {
+ timer.done();
+ self.exec_state = new_state;
+ break 'reading_input;
+ }
}
- // Check if we need to emit early due to memory
pressure
- if let Some(new_state) =
self.emit_early_if_necessary()? {
+ // If we reach this point, try to update the
memory reservation
+ // handling out-of-memory conditions as determined
by the OOM mode.
+ if let Some(new_state) =
+ self.try_update_memory_reservation()?
+ {
timer.done();
self.exec_state = new_state;
break 'reading_input;
@@ -732,43 +794,6 @@ impl Stream for GroupedHashAggregateStream {
timer.done();
}
- // New batch to aggregate in terminal aggregation
operator
- // (Final/FinalPartitioned/Single/SinglePartitioned)
- Some(Ok(batch)) => {
- let timer = elapsed_compute.timer();
-
- // Make sure we have enough capacity for `batch`,
otherwise spill
- self.spill_previous_if_necessary(&batch)?;
-
- // Do the grouping
- self.group_aggregate_batch(&batch)?;
-
- // If we can begin emitting rows, do so,
- // otherwise keep consuming input
- assert!(!self.input_done);
-
- // If the number of group values equals or exceeds
the soft limit,
- // emit all groups and switch to producing output
- if self.hit_soft_group_limit() {
- timer.done();
- self.set_input_done_and_produce_output()?;
- // make sure the exec_state just set is not
overwritten below
- break 'reading_input;
- }
-
- if let Some(to_emit) =
self.group_ordering.emit_to() {
- timer.done();
- if let Some(batch) = self.emit(to_emit,
false)? {
- self.exec_state =
- ExecutionState::ProducingOutput(batch);
- };
- // make sure the exec_state just set is not
overwritten below
- break 'reading_input;
- }
-
- timer.done();
- }
-
// Found error from input stream
Some(Err(e)) => {
// inner had error, return to caller
@@ -987,25 +1012,56 @@ impl GroupedHashAggregateStream {
}
}
- match self.update_memory_reservation() {
- // Here we can ignore `insufficient_capacity_err` because we will
spill later,
- // but at least one batch should fit in the memory
- Err(DataFusionError::ResourcesExhausted(_))
- if self.group_values.len() >= self.batch_size =>
- {
- Ok(())
+ Ok(())
+ }
+
+ /// Attempts to update the memory reservation. If that fails due to a
+ /// [DataFusionError::ResourcesExhausted] error, an attempt will be made
to resolve
+ /// the out-of-memory condition based on the [out-of-memory handling
mode](OutOfMemoryMode).
+ ///
+ /// If the out-of-memory condition can not be resolved, an `Err` value
will be returned
+ ///
+ /// Returns `Ok(Some(ExecutionState))` if the state should be changed,
`Ok(None)` otherwise.
+ fn try_update_memory_reservation(&mut self) ->
Result<Option<ExecutionState>> {
+ let oom = match self.update_memory_reservation() {
+ Err(e @ DataFusionError::ResourcesExhausted(_)) => e,
+ Err(e) => return Err(e),
+ Ok(_) => return Ok(None),
+ };
+
+ match self.oom_mode {
+ OutOfMemoryMode::Spill if !self.group_values.is_empty() => {
+ self.spill()?;
+ self.clear_shrink(self.batch_size);
+ self.update_memory_reservation()?;
+ Ok(None)
}
- other => other,
+ OutOfMemoryMode::EmitEarly if self.group_values.len() > 1 => {
+ let n = if self.group_values.len() >= self.batch_size {
+ // Try to emit an integer multiple of batch size if
possible
+ self.group_values.len() / self.batch_size * self.batch_size
+ } else {
+ // Otherwise emit whatever we can
+ self.group_values.len()
+ };
+
+ if let Some(batch) = self.emit(EmitTo::First(n), false)? {
+ Ok(Some(ExecutionState::ProducingOutput(batch)))
+ } else {
+ Err(oom)
+ }
+ }
+ _ => Err(oom),
}
}
fn update_memory_reservation(&mut self) -> Result<()> {
let acc = self.accumulators.iter().map(|x| x.size()).sum::<usize>();
- let reservation_result = self.reservation.try_resize(
- acc + self.group_values.size()
- + self.group_ordering.size()
- + self.current_group_indices.allocated_size(),
- );
+ let new_size = acc
+ + self.group_values.size()
+ + self.group_ordering.size()
+ + self.current_group_indices.allocated_size();
+ let reservation_result = self.reservation.try_resize(new_size);
if reservation_result.is_ok() {
self.spill_state
@@ -1060,24 +1116,6 @@ impl GroupedHashAggregateStream {
Ok(Some(batch))
}
- /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the
memory target slightly
- /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to
disk and clear the
- /// memory. Currently only [`GroupOrdering::None`] is supported for
spilling.
- fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) ->
Result<()> {
- // TODO: support group_ordering for spilling
- if !self.group_values.is_empty()
- && batch.num_rows() > 0
- && matches!(self.group_ordering, GroupOrdering::None)
- && !self.spill_state.is_stream_merging
- && self.update_memory_reservation().is_err()
- {
- assert_ne!(self.mode, AggregateMode::Partial);
- self.spill()?;
- self.clear_shrink(batch);
- }
- Ok(())
- }
-
/// Emit all intermediate aggregation states, sort them, and store them on
disk.
/// This process helps in reducing memory pressure by allowing the data to
be
/// read back with streaming merge.
@@ -1115,72 +1153,15 @@ impl GroupedHashAggregateStream {
}
/// Clear memory and shirk capacities to the size of the batch.
- fn clear_shrink(&mut self, batch: &RecordBatch) {
- self.group_values.clear_shrink(batch);
+ fn clear_shrink(&mut self, num_rows: usize) {
+ self.group_values.clear_shrink(num_rows);
self.current_group_indices.clear();
- self.current_group_indices.shrink_to(batch.num_rows());
+ self.current_group_indices.shrink_to(num_rows);
}
/// Clear memory and shirk capacities to zero.
fn clear_all(&mut self) {
- let s = self.schema();
- self.clear_shrink(&RecordBatch::new_empty(s));
- }
-
- /// Emit if the used memory exceeds the target for partial aggregation.
- /// Currently only [`GroupOrdering::None`] is supported for early emitting.
- /// TODO: support group_ordering for early emitting
- ///
- /// Returns `Some(ExecutionState)` if the state should be changed, None
otherwise.
- fn emit_early_if_necessary(&mut self) -> Result<Option<ExecutionState>> {
- if self.group_values.len() >= self.batch_size
- && matches!(self.group_ordering, GroupOrdering::None)
- && self.update_memory_reservation().is_err()
- {
- assert_eq!(self.mode, AggregateMode::Partial);
- let n = self.group_values.len() / self.batch_size *
self.batch_size;
- if let Some(batch) = self.emit(EmitTo::First(n), false)? {
- return Ok(Some(ExecutionState::ProducingOutput(batch)));
- };
- }
- Ok(None)
- }
-
- /// At this point, all the inputs are read and there are some spills.
- /// Emit the remaining rows and create a batch.
- /// Conduct a streaming merge sort between the batch and spilled data.
Since the stream is fully
- /// sorted, set `self.group_ordering` to Full, then later we can read with
[`EmitTo::First`].
- fn update_merged_stream(&mut self) -> Result<()> {
- let Some(batch) = self.emit(EmitTo::All, true)? else {
- return Ok(());
- };
- // clear up memory for streaming_merge
- self.clear_all();
- self.update_memory_reservation()?;
- let mut streams: Vec<SendableRecordBatchStream> = vec![];
- let expr = self.spill_state.spill_expr.clone();
- let schema = batch.schema();
- streams.push(Box::pin(RecordBatchStreamAdapter::new(
- Arc::clone(&schema),
- futures::stream::once(futures::future::lazy(move |_| {
- sort_batch(&batch, &expr, None)
- })),
- )));
-
- self.spill_state.is_stream_merging = true;
- self.input = StreamingMergeBuilder::new()
- .with_streams(streams)
- .with_schema(schema)
- .with_spill_manager(self.spill_state.spill_manager.clone())
- .with_sorted_spill_files(std::mem::take(&mut
self.spill_state.spills))
- .with_expressions(&self.spill_state.spill_expr)
- .with_metrics(self.baseline_metrics.clone())
- .with_batch_size(self.batch_size)
- .with_reservation(self.reservation.new_empty())
- .build()?;
- self.input_done = false;
- self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new());
- Ok(())
+ self.clear_shrink(0);
}
/// returns true if there is a soft groups limit and the number of distinct
@@ -1192,18 +1173,60 @@ impl GroupedHashAggregateStream {
group_values_soft_limit <= self.group_values.len()
}
- /// common function for signalling end of processing of the input stream
+ /// Finalizes reading of the input stream and prepares for producing
output values.
+ ///
+ /// This method is called both when the original input stream and,
+ /// in case of disk spilling, the SPM stream have been drained.
fn set_input_done_and_produce_output(&mut self) -> Result<()> {
self.input_done = true;
self.group_ordering.input_done();
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
let timer = elapsed_compute.timer();
self.exec_state = if self.spill_state.spills.is_empty() {
+ // Input has been entirely processed without spilling to disk.
+
+ // Flush any remaining group values.
let batch = self.emit(EmitTo::All, false)?;
+
+ // If there are none, we're done; otherwise switch to emitting them
batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput)
} else {
- // If spill files exist, stream-merge them.
- self.update_merged_stream()?;
+ // Spill any remaining data to disk. There is some performance
overhead in
+ // writing out this last chunk of data and reading it back. The
benefit of
+ // doing this is that memory usage for this stream is reduced, and
the more
+ // sophisticated memory handling in `MultiLevelMergeBuilder` can
take over
+ // instead.
+ // Spilling to disk and reading back also ensures batch size is
consistent
+ // rather than potentially having one significantly larger last
batch.
+ self.spill()?;
+
+ // Mark that we're switching to stream merging mode.
+ self.spill_state.is_stream_merging = true;
+
+ self.input = StreamingMergeBuilder::new()
+ .with_schema(Arc::clone(&self.spill_state.spill_schema))
+ .with_spill_manager(self.spill_state.spill_manager.clone())
+ .with_sorted_spill_files(std::mem::take(&mut
self.spill_state.spills))
+ .with_expressions(&self.spill_state.spill_expr)
+ .with_metrics(self.baseline_metrics.clone())
+ .with_batch_size(self.batch_size)
+ .with_reservation(self.reservation.new_empty())
+ .build()?;
+ self.input_done = false;
+
+ // Reset the group values collectors.
+ self.clear_all();
+
+ // We can now use `GroupOrdering::Full` since the spill files are
sorted
+ // on the grouping columns.
+ self.group_ordering =
GroupOrdering::Full(GroupOrderingFull::new());
+
+ // Use `OutOfMemoryMode::ReportError` from this point on
+ // to ensure we don't spill the spilled data to disk again.
+ self.oom_mode = OutOfMemoryMode::ReportError;
+
+ self.update_memory_reservation()?;
+
ExecutionState::ReadingInput
};
timer.done();
diff --git a/datafusion/physical-plan/src/test.rs
b/datafusion/physical-plan/src/test.rs
index f2336920b3..c94b5a4131 100644
--- a/datafusion/physical-plan/src/test.rs
+++ b/datafusion/physical-plan/src/test.rs
@@ -342,6 +342,7 @@ impl TestMemoryExec {
}
self.sort_information = sort_information;
+ self.cache = self.compute_properties();
Ok(self)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]