This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 3f0b3425cb feat: Improve sort memory resilience (#19494)
3f0b3425cb is described below
commit 3f0b3425cbe0af8a48db73b8b0fb9a647011171a
Author: Emily Matheys <[email protected]>
AuthorDate: Mon Dec 29 19:49:01 2025 +0200
feat: Improve sort memory resilience (#19494)
## Which issue does this PR close?
<!--
We generally require a GitHub issue to be filed for all bug fixes and
enhancements and this helps us generate change logs for our releases.
You can link an issue to this PR using the GitHub syntax. For example
`Closes #123` indicates that this PR will close issue #123.
-->
Closes #19493 .
## Rationale for this change
<!--
Why are you proposing this change? If this is already explained clearly
in the issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.
-->
Greatly reduces the memory requested by ExternalSorter to perform sorts,
adds much more granularity to the reservations, and Tries to do this
with minimal overhead by merging the splitting and sorting processes.
## What changes are included in this PR?
<!--
There is no need to duplicate the description in the issue here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->
The sort stream will calculate the indices once, but the take will be
done in batches, so we create batch_size sized RecordBatches, whose
get_record_batch_size results return info that is very close to their
sliced sizes(if not completely the same), this means there is no need
for the precaution of reserving a huge amount of memory in order to do
the merge sort, meaning we can merge more streams at the same time, and
so on and so forth.
## Are these changes tested?
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code
If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?
-->
Yes
## Are there any user-facing changes?
<!--
If there are user-facing changes then we may require documentation to be
updated before approving the PR.
-->
<!--
If there are any breaking changes to public APIs, please add the `api
change` label.
-->
There is a new sort_batch_chunked function, which returns a Vec of
RecordBatch, based on the provided batch_size.
Some docs are updated.
---------
Co-authored-by: Raz Luvaton <[email protected]>
---
.../physical-plan/src/aggregates/row_hash.rs | 2 +-
.../physical-plan/src/sorts/multi_level_merge.rs | 12 +-
datafusion/physical-plan/src/sorts/sort.rs | 465 +++++++++++++++++----
.../physical-plan/src/spill/spill_manager.rs | 5 +-
datafusion/physical-plan/src/stream.rs | 187 ++++++++-
5 files changed, 586 insertions(+), 85 deletions(-)
diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs
b/datafusion/physical-plan/src/aggregates/row_hash.rs
index cb22fbf9a0..1ae7202711 100644
--- a/datafusion/physical-plan/src/aggregates/row_hash.rs
+++ b/datafusion/physical-plan/src/aggregates/row_hash.rs
@@ -1198,7 +1198,7 @@ impl GroupedHashAggregateStream {
// instead.
// Spilling to disk and reading back also ensures batch size is
consistent
// rather than potentially having one significantly larger last
batch.
- self.spill()?;
+ self.spill()?; // TODO: use sort_batch_chunked instead?
// Mark that we're switching to stream merging mode.
self.spill_state.is_stream_merging = true;
diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs
b/datafusion/physical-plan/src/sorts/multi_level_merge.rs
index 3540f1de3e..2e0d668a29 100644
--- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs
+++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs
@@ -30,7 +30,7 @@ use arrow::datatypes::SchemaRef;
use datafusion_common::Result;
use datafusion_execution::memory_pool::MemoryReservation;
-use crate::sorts::sort::get_reserved_byte_for_record_batch_size;
+use crate::sorts::sort::get_reserved_bytes_for_record_batch_size;
use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
use crate::stream::RecordBatchStreamAdapter;
use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
@@ -360,9 +360,13 @@ impl MultiLevelMergeBuilder {
for spill in &self.sorted_spill_files {
// For memory pools that are not shared this is good, for other
this is not
// and there should be some upper limit to memory reservation so
we won't starve the system
- match reservation.try_grow(get_reserved_byte_for_record_batch_size(
- spill.max_record_batch_memory * buffer_len,
- )) {
+ match reservation.try_grow(
+ get_reserved_bytes_for_record_batch_size(
+ spill.max_record_batch_memory,
+ // Size will be the same as the sliced size, bc it is a
spilled batch.
+ spill.max_record_batch_memory,
+ ) * buffer_len,
+ ) {
Ok(_) => {
number_of_spills_to_read_for_current_phase += 1;
}
diff --git a/datafusion/physical-plan/src/sorts/sort.rs
b/datafusion/physical-plan/src/sorts/sort.rs
index 18cdcbe9de..3e8fdf1f3e 100644
--- a/datafusion/physical-plan/src/sorts/sort.rs
+++ b/datafusion/physical-plan/src/sorts/sort.rs
@@ -34,16 +34,15 @@ use crate::filter_pushdown::{
};
use crate::limit::LimitStream;
use crate::metrics::{
- BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
SpillMetrics,
- SplitMetrics,
+ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics,
};
use crate::projection::{ProjectionExec, make_with_child, update_ordering};
use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
use crate::spill::get_record_batch_memory_size;
use crate::spill::in_progress_spill_file::InProgressSpillFile;
use crate::spill::spill_manager::{GetSlicedSize, SpillManager};
-use crate::stream::BatchSplitStream;
use crate::stream::RecordBatchStreamAdapter;
+use crate::stream::ReservationStream;
use crate::topk::TopK;
use crate::topk::TopKDynamicFilters;
use crate::{
@@ -75,8 +74,6 @@ struct ExternalSorterMetrics {
baseline: BaselineMetrics,
spill_metrics: SpillMetrics,
-
- split_metrics: SplitMetrics,
}
impl ExternalSorterMetrics {
@@ -84,7 +81,6 @@ impl ExternalSorterMetrics {
Self {
baseline: BaselineMetrics::new(metrics, partition),
spill_metrics: SpillMetrics::new(metrics, partition),
- split_metrics: SplitMetrics::new(metrics, partition),
}
}
}
@@ -545,7 +541,7 @@ impl ExternalSorter {
while let Some(batch) = sorted_stream.next().await {
let batch = batch?;
- let sorted_size = get_reserved_byte_for_record_batch(&batch);
+ let sorted_size = get_reserved_bytes_for_record_batch(&batch)?;
if self.reservation.try_grow(sorted_size).is_err() {
// Although the reservation is not enough, the batch is
// already in memory, so it's okay to combine it with
previously
@@ -662,7 +658,7 @@ impl ExternalSorter {
if self.in_mem_batches.len() == 1 {
let batch = self.in_mem_batches.swap_remove(0);
let reservation = self.reservation.take();
- return self.sort_batch_stream(batch, metrics, reservation, true);
+ return self.sort_batch_stream(batch, &metrics, reservation);
}
// If less than sort_in_place_threshold_bytes, concatenate and sort in
place
@@ -671,10 +667,10 @@ impl ExternalSorter {
let batch = concat_batches(&self.schema, &self.in_mem_batches)?;
self.in_mem_batches.clear();
self.reservation
- .try_resize(get_reserved_byte_for_record_batch(&batch))
+ .try_resize(get_reserved_bytes_for_record_batch(&batch)?)
.map_err(Self::err_with_oom_context)?;
let reservation = self.reservation.take();
- return self.sort_batch_stream(batch, metrics, reservation, true);
+ return self.sort_batch_stream(batch, &metrics, reservation);
}
let streams = std::mem::take(&mut self.in_mem_batches)
@@ -683,15 +679,8 @@ impl ExternalSorter {
let metrics = self.metrics.baseline.intermediate();
let reservation = self
.reservation
- .split(get_reserved_byte_for_record_batch(&batch));
- let input = self.sort_batch_stream(
- batch,
- metrics,
- reservation,
- // Passing false as `StreamingMergeBuilder` will split the
- // stream into batches of `self.batch_size` rows.
- false,
- )?;
+ .split(get_reserved_bytes_for_record_batch(&batch)?);
+ let input = self.sort_batch_stream(batch, &metrics,
reservation)?;
Ok(spawn_buffered(input, 1))
})
.collect::<Result<_>>()?;
@@ -709,52 +698,78 @@ impl ExternalSorter {
/// Sorts a single `RecordBatch` into a single stream.
///
- /// `reservation` accounts for the memory used by this batch and
- /// is released when the sort is complete
- ///
- /// passing `split` true will return a [`BatchSplitStream`] where each
batch maximum row count
- /// will be `self.batch_size`.
- /// If `split` is false, the stream will return a single batch
+ /// This may output multiple batches depending on the size of the
+ /// sorted data and the target batch size.
+ /// For single-batch output cases, `reservation` will be freed immediately
after sorting,
+ /// as the batch will be output and is expected to be reserved by the
consumer of the stream.
+ /// For multi-batch output cases, `reservation` will be grown to match the
actual
+ /// size of sorted output, and as each batch is output, its memory will be
freed from the reservation.
+ /// (This leads to the same behaviour, as futures are only evaluated when
polled by the consumer.)
fn sort_batch_stream(
&self,
batch: RecordBatch,
- metrics: BaselineMetrics,
- reservation: MemoryReservation,
- mut split: bool,
+ metrics: &BaselineMetrics,
+ mut reservation: MemoryReservation,
) -> Result<SendableRecordBatchStream> {
assert_eq!(
- get_reserved_byte_for_record_batch(&batch),
+ get_reserved_bytes_for_record_batch(&batch)?,
reservation.size()
);
- split = split && batch.num_rows() > self.batch_size;
-
let schema = batch.schema();
-
let expressions = self.expr.clone();
- let stream = futures::stream::once(async move {
- let _timer = metrics.elapsed_compute().timer();
+ let batch_size = self.batch_size;
+ let output_row_metrics = metrics.output_rows().clone();
- let sorted = sort_batch(&batch, &expressions, None)?;
+ let stream = futures::stream::once(async move {
+ let schema = batch.schema();
- (&sorted).record_output(&metrics);
+ // Sort the batch immediately and get all output batches
+ let sorted_batches = sort_batch_chunked(&batch, &expressions,
batch_size)?;
drop(batch);
- drop(reservation);
- Ok(sorted)
- });
- let mut output: SendableRecordBatchStream =
- Box::pin(RecordBatchStreamAdapter::new(schema, stream));
+ // Free the old reservation and grow it to match the actual sorted
output size
+ reservation.free();
- if split {
- output = Box::pin(BatchSplitStream::new(
- output,
- self.batch_size,
- self.metrics.split_metrics.clone(),
- ));
- }
+ Result::<_, DataFusionError>::Ok((schema, sorted_batches,
reservation))
+ })
+ .then({
+ move |batches| async move {
+ match batches {
+ Ok((schema, sorted_batches, mut reservation)) => {
+ // Calculate the total size of sorted batches
+ let total_sorted_size: usize = sorted_batches
+ .iter()
+ .map(get_record_batch_memory_size)
+ .sum();
+ reservation
+ .try_grow(total_sorted_size)
+ .map_err(Self::err_with_oom_context)?;
+
+ // Wrap in ReservationStream to hold the reservation
+ Ok(Box::pin(ReservationStream::new(
+ Arc::clone(&schema),
+ Box::pin(RecordBatchStreamAdapter::new(
+ schema,
+
futures::stream::iter(sorted_batches.into_iter().map(Ok)),
+ )),
+ reservation,
+ )) as SendableRecordBatchStream)
+ }
+ Err(e) => Err(e),
+ }
+ }
+ })
+ .try_flatten()
+ .map(move |batch| match batch {
+ Ok(batch) => {
+ output_row_metrics.add(batch.num_rows());
+ Ok(batch)
+ }
+ Err(e) => Err(e),
+ });
- Ok(output)
+ Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
}
/// If this sort may spill, pre-allocates
@@ -780,7 +795,7 @@ impl ExternalSorter {
&mut self,
input: &RecordBatch,
) -> Result<()> {
- let size = get_reserved_byte_for_record_batch(input);
+ let size = get_reserved_bytes_for_record_batch(input)?;
match self.reservation.try_grow(size) {
Ok(_) => Ok(()),
@@ -819,16 +834,27 @@ impl ExternalSorter {
/// in sorting and merging. The sorted copies are in either row format or
array format.
/// Please refer to cursor.rs and stream.rs for more details. No matter what
format the
/// sorted copies are, they will use more memory than the original record
batch.
-pub(crate) fn get_reserved_byte_for_record_batch_size(record_batch_size:
usize) -> usize {
- // 2x may not be enough for some cases, but it's a good start.
+///
+/// This can basically be calculated as the sum of the actual space it takes in
+/// memory (which would be larger for a sliced batch), and the size of the
actual data.
+pub(crate) fn get_reserved_bytes_for_record_batch_size(
+ record_batch_size: usize,
+ sliced_size: usize,
+) -> usize {
+ // Even 2x may not be enough for some cases, but it's a good enough
estimation as a baseline.
// If 2x is not enough, user can set a larger value for
`sort_spill_reservation_bytes`
// to compensate for the extra memory needed.
- record_batch_size * 2
+ record_batch_size + sliced_size
}
/// Estimate how much memory is needed to sort a `RecordBatch`.
-fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize {
-
get_reserved_byte_for_record_batch_size(get_record_batch_memory_size(batch))
+/// This will just call `get_reserved_bytes_for_record_batch_size` with the
+/// memory size of the record batch and its sliced size.
+pub(super) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) ->
Result<usize> {
+ Ok(get_reserved_bytes_for_record_batch_size(
+ get_record_batch_memory_size(batch),
+ batch.get_sliced_size()?,
+ ))
}
impl Debug for ExternalSorter {
@@ -853,15 +879,7 @@ pub fn sort_batch(
.collect::<Result<Vec<_>>>()?;
let indices = lexsort_to_indices(&sort_columns, fetch)?;
- let mut columns = take_arrays(batch.columns(), &indices, None)?;
-
- // The columns may be larger than the unsorted columns in `batch`
especially for variable length
- // data types due to exponential growth when building the sort columns. We
shrink the columns
- // to prevent memory reservation failures, as well as excessive memory
allocation when running
- // merges in `SortPreservingMergeStream`.
- columns.iter_mut().for_each(|c| {
- c.shrink_to_fit();
- });
+ let columns = take_arrays(batch.columns(), &indices, None)?;
let options =
RecordBatchOptions::new().with_row_count(Some(indices.len()));
Ok(RecordBatch::try_new_with_options(
@@ -871,6 +889,48 @@ pub fn sort_batch(
)?)
}
+/// Sort a batch and return the result as multiple batches of size
`batch_size`.
+/// This is useful when you want to avoid creating one large sorted batch in
memory,
+/// and instead want to process the sorted data in smaller chunks.
+pub fn sort_batch_chunked(
+ batch: &RecordBatch,
+ expressions: &LexOrdering,
+ batch_size: usize,
+) -> Result<Vec<RecordBatch>> {
+ let sort_columns = expressions
+ .iter()
+ .map(|expr| expr.evaluate_to_sort_column(batch))
+ .collect::<Result<Vec<_>>>()?;
+
+ let indices = lexsort_to_indices(&sort_columns, None)?;
+
+ // Split indices into chunks of batch_size
+ let num_rows = indices.len();
+ let num_chunks = num_rows.div_ceil(batch_size);
+
+ let result_batches = (0..num_chunks)
+ .map(|chunk_idx| {
+ let start = chunk_idx * batch_size;
+ let end = (start + batch_size).min(num_rows);
+ let chunk_len = end - start;
+
+ // Create a slice of indices for this chunk
+ let chunk_indices = indices.slice(start, chunk_len);
+
+ // Take the columns using this chunk of indices
+ let columns = take_arrays(batch.columns(), &chunk_indices, None)?;
+
+ let options =
RecordBatchOptions::new().with_row_count(Some(chunk_len));
+ let chunk_batch =
+ RecordBatch::try_new_with_options(batch.schema(), columns,
&options)?;
+
+ Ok(chunk_batch)
+ })
+ .collect::<Result<Vec<RecordBatch>>>()?;
+
+ Ok(result_batches)
+}
+
/// Sort execution plan.
///
/// Support sorting datasets that are larger than the memory allotted
@@ -1173,10 +1233,7 @@ impl ExecutionPlan for SortExec {
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut new_sort = self.cloned();
- assert!(
- children.len() == 1,
- "SortExec should have exactly one child"
- );
+ assert_eq!(children.len(), 1, "SortExec should have exactly one
child");
new_sort.input = Arc::clone(&children[0]);
// Recompute the properties based on the new input since they may have
changed
let (cache, sort_prefix) = Self::compute_properties(
@@ -1623,13 +1680,24 @@ mod tests {
#[tokio::test]
async fn test_batch_reservation_error() -> Result<()> {
// Pick a memory limit and sort_spill_reservation that make the first
batch reservation fail.
- // These values assume that the ExternalSorter will reserve 800 bytes
for the first batch.
- let expected_batch_reservation = 800;
let merge_reservation: usize = 0; // Set to 0 for simplicity
- let memory_limit: usize = expected_batch_reservation +
merge_reservation - 1; // Just short of what we need
let session_config =
SessionConfig::new().with_sort_spill_reservation_bytes(merge_reservation);
+
+ let plan = test::scan_partitioned(1);
+
+ // Read the first record batch to determine the actual memory
requirement
+ let expected_batch_reservation = {
+ let temp_ctx = Arc::new(TaskContext::default());
+ let mut stream = plan.execute(0, Arc::clone(&temp_ctx))?;
+ let first_batch = stream.next().await.unwrap()?;
+ get_reserved_bytes_for_record_batch(&first_batch)?
+ };
+
+ // Set memory limit just short of what we need
+ let memory_limit: usize = expected_batch_reservation +
merge_reservation - 1;
+
let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(memory_limit, 1.0)
.build_arc()?;
@@ -1639,14 +1707,11 @@ mod tests {
.with_runtime(runtime),
);
- let plan = test::scan_partitioned(1);
-
- // Read the first record batch to assert that our memory limit and
sort_spill_reservation
- // settings trigger the test scenario.
+ // Verify that our memory limit is insufficient
{
let mut stream = plan.execute(0, Arc::clone(&task_ctx))?;
let first_batch = stream.next().await.unwrap()?;
- let batch_reservation =
get_reserved_byte_for_record_batch(&first_batch);
+ let batch_reservation =
get_reserved_bytes_for_record_batch(&first_batch)?;
assert_eq!(batch_reservation, expected_batch_reservation);
assert!(memory_limit < (merge_reservation + batch_reservation));
@@ -1814,6 +1879,93 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn test_sort_memory_reduction_per_batch() -> Result<()> {
+ // This test verifies that memory reservation is reduced for every
batch emitted
+ // during the sort process. This is important to ensure we don't hold
onto
+ // memory longer than necessary.
+
+ // Create a large enough batch that will be split into multiple output
batches
+ let batch_size = 50; // Small batch size to force multiple output
batches
+ let num_rows = 1000; // Create enough data for multiple batches
+
+ let task_ctx = Arc::new(
+ TaskContext::default().with_session_config(
+ SessionConfig::new()
+ .with_batch_size(batch_size)
+ .with_sort_in_place_threshold_bytes(usize::MAX), // Ensure
we don't concat batches
+ ),
+ );
+
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create unsorted data
+ let mut values: Vec<i32> = (0..num_rows).collect();
+ values.reverse();
+
+ let input_batch = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(Int32Array::from(values))],
+ )?;
+
+ let batches = vec![input_batch];
+
+ let sort_exec = Arc::new(SortExec::new(
+ [PhysicalSortExpr {
+ expr: Arc::new(Column::new("a", 0)),
+ options: SortOptions::default(),
+ }]
+ .into(),
+ TestMemoryExec::try_new_exec(
+ std::slice::from_ref(&batches),
+ Arc::clone(&schema),
+ None,
+ )?,
+ ));
+
+ let mut stream = sort_exec.execute(0, Arc::clone(&task_ctx))?;
+
+ let mut previous_reserved =
task_ctx.runtime_env().memory_pool.reserved();
+ let mut batch_count = 0;
+
+ // Collect batches and verify memory is reduced with each batch
+ while let Some(result) = stream.next().await {
+ let batch = result?;
+ batch_count += 1;
+
+ // Verify we got a non-empty batch
+ assert!(batch.num_rows() > 0, "Batch should not be empty");
+
+ let current_reserved =
task_ctx.runtime_env().memory_pool.reserved();
+
+ // After the first batch, memory should be reducing or staying the
same
+ // (it should not increase as we emit batches)
+ if batch_count > 1 {
+ assert!(
+ current_reserved <= previous_reserved,
+ "Memory reservation should decrease or stay same as
batches are emitted. \
+ Batch {batch_count}: previous={previous_reserved},
current={current_reserved}"
+ );
+ }
+
+ previous_reserved = current_reserved;
+ }
+
+ assert!(
+ batch_count > 1,
+ "Expected multiple batches to be emitted, got {batch_count}"
+ );
+
+ // Verify all memory is returned at the end
+ assert_eq!(
+ task_ctx.runtime_env().memory_pool.reserved(),
+ 0,
+ "All memory should be returned after consuming all batches"
+ );
+
+ Ok(())
+ }
+
#[tokio::test]
async fn test_sort_metadata() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
@@ -2402,4 +2554,165 @@ mod tests {
Ok((sorted_batches, metrics))
}
+
+ #[tokio::test]
+ async fn test_sort_batch_chunked_basic() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create a batch with 1000 rows
+ let mut values: Vec<i32> = (0..1000).collect();
+ // Shuffle to make it unsorted
+ values.reverse();
+
+ let batch = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(Int32Array::from(values))],
+ )?;
+
+ let expressions: LexOrdering =
+ [PhysicalSortExpr::new_default(Arc::new(Column::new("a",
0)))].into();
+
+ // Sort with batch_size = 250
+ let result_batches = sort_batch_chunked(&batch, &expressions, 250)?;
+
+ // Verify 4 batches are returned
+ assert_eq!(result_batches.len(), 4);
+
+ // Verify each batch has <= 250 rows
+ let mut total_rows = 0;
+ for (i, batch) in result_batches.iter().enumerate() {
+ assert!(
+ batch.num_rows() <= 250,
+ "Batch {} has {} rows, expected <= 250",
+ i,
+ batch.num_rows()
+ );
+ total_rows += batch.num_rows();
+ }
+
+ // Verify total row count matches input
+ assert_eq!(total_rows, 1000);
+
+ // Verify data is correctly sorted across all chunks
+ let concatenated = concat_batches(&schema, &result_batches)?;
+ let array = as_primitive_array::<Int32Type>(concatenated.column(0))?;
+ for i in 0..array.len() - 1 {
+ assert!(
+ array.value(i) <= array.value(i + 1),
+ "Array not sorted at position {}: {} > {}",
+ i,
+ array.value(i),
+ array.value(i + 1)
+ );
+ }
+ assert_eq!(array.value(0), 0);
+ assert_eq!(array.value(array.len() - 1), 999);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_sort_batch_chunked_smaller_than_batch_size() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create a batch with 50 rows
+ let values: Vec<i32> = (0..50).rev().collect();
+ let batch = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(Int32Array::from(values))],
+ )?;
+
+ let expressions: LexOrdering =
+ [PhysicalSortExpr::new_default(Arc::new(Column::new("a",
0)))].into();
+
+ // Sort with batch_size = 100
+ let result_batches = sort_batch_chunked(&batch, &expressions, 100)?;
+
+ // Should return exactly 1 batch
+ assert_eq!(result_batches.len(), 1);
+ assert_eq!(result_batches[0].num_rows(), 50);
+
+ // Verify it's correctly sorted
+ let array =
as_primitive_array::<Int32Type>(result_batches[0].column(0))?;
+ for i in 0..array.len() - 1 {
+ assert!(array.value(i) <= array.value(i + 1));
+ }
+ assert_eq!(array.value(0), 0);
+ assert_eq!(array.value(49), 49);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_sort_batch_chunked_exact_multiple() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create a batch with 1000 rows
+ let values: Vec<i32> = (0..1000).rev().collect();
+ let batch = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(Int32Array::from(values))],
+ )?;
+
+ let expressions: LexOrdering =
+ [PhysicalSortExpr::new_default(Arc::new(Column::new("a",
0)))].into();
+
+ // Sort with batch_size = 100
+ let result_batches = sort_batch_chunked(&batch, &expressions, 100)?;
+
+ // Should return exactly 10 batches of 100 rows each
+ assert_eq!(result_batches.len(), 10);
+ for batch in &result_batches {
+ assert_eq!(batch.num_rows(), 100);
+ }
+
+ // Verify sorted correctly across all batches
+ let concatenated = concat_batches(&schema, &result_batches)?;
+ let array = as_primitive_array::<Int32Type>(concatenated.column(0))?;
+ for i in 0..array.len() - 1 {
+ assert!(array.value(i) <= array.value(i + 1));
+ }
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_sort_batch_chunked_empty_batch() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ let batch = RecordBatch::new_empty(Arc::clone(&schema));
+
+ let expressions: LexOrdering =
+ [PhysicalSortExpr::new_default(Arc::new(Column::new("a",
0)))].into();
+
+ let result_batches = sort_batch_chunked(&batch, &expressions, 100)?;
+
+ // Empty input produces no output batches (0 chunks)
+ assert_eq!(result_batches.len(), 0);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_get_reserved_bytes_for_record_batch_with_sliced_batches() ->
Result<()>
+ {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create a larger batch then slice it
+ let large_array = Int32Array::from((0..1000).collect::<Vec<i32>>());
+ let sliced_array = large_array.slice(100, 50); // Take 50 elements
starting at 100
+
+ let sliced_batch =
+ RecordBatch::try_new(Arc::clone(&schema),
vec![Arc::new(sliced_array)])?;
+ let batch =
+ RecordBatch::try_new(Arc::clone(&schema),
vec![Arc::new(large_array)])?;
+
+ let sliced_reserved =
get_reserved_bytes_for_record_batch(&sliced_batch)?;
+ let reserved = get_reserved_bytes_for_record_batch(&batch)?;
+
+ // The reserved memory for the sliced batch should be less than that
of the full batch
+ assert!(reserved > sliced_reserved);
+
+ Ok(())
+ }
}
diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs
b/datafusion/physical-plan/src/spill/spill_manager.rs
index d460067339..89b0276206 100644
--- a/datafusion/physical-plan/src/spill/spill_manager.rs
+++ b/datafusion/physical-plan/src/spill/spill_manager.rs
@@ -20,12 +20,11 @@
use arrow::array::StringViewArray;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
-use datafusion_execution::runtime_env::RuntimeEnv;
-use std::sync::Arc;
-
use datafusion_common::{Result, config::SpillCompression};
use datafusion_execution::SendableRecordBatchStream;
use datafusion_execution::disk_manager::RefCountedTempFile;
+use datafusion_execution::runtime_env::RuntimeEnv;
+use std::sync::Arc;
use super::{SpillReaderStream, in_progress_spill_file::InProgressSpillFile};
use crate::coop::cooperative;
diff --git a/datafusion/physical-plan/src/stream.rs
b/datafusion/physical-plan/src/stream.rs
index 8b2ea10068..80c2233d05 100644
--- a/datafusion/physical-plan/src/stream.rs
+++ b/datafusion/physical-plan/src/stream.rs
@@ -27,11 +27,13 @@ use super::metrics::ExecutionPlanMetricsSet;
use super::metrics::{BaselineMetrics, SplitMetrics};
use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
use crate::displayable;
+use crate::spill::get_record_batch_memory_size;
use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
use datafusion_common::{Result, exec_err};
use datafusion_common_runtime::JoinSet;
use datafusion_execution::TaskContext;
+use datafusion_execution::memory_pool::MemoryReservation;
use futures::ready;
use futures::stream::BoxStream;
@@ -699,6 +701,70 @@ impl RecordBatchStream for BatchSplitStream {
}
}
+/// A stream that holds a memory reservation for its lifetime,
+/// shrinking the reservation as batches are consumed.
+/// The original reservation must have its batch sizes calculated using
[`get_record_batch_memory_size`]
+/// On error, the reservation is *NOT* freed, until the stream is dropped.
+pub(crate) struct ReservationStream {
+ schema: SchemaRef,
+ inner: SendableRecordBatchStream,
+ reservation: MemoryReservation,
+}
+
+impl ReservationStream {
+ pub(crate) fn new(
+ schema: SchemaRef,
+ inner: SendableRecordBatchStream,
+ reservation: MemoryReservation,
+ ) -> Self {
+ Self {
+ schema,
+ inner,
+ reservation,
+ }
+ }
+}
+
+impl Stream for ReservationStream {
+ type Item = Result<RecordBatch>;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ let res = self.inner.poll_next_unpin(cx);
+
+ match res {
+ Poll::Ready(res) => {
+ match res {
+ Some(Ok(batch)) => {
+ self.reservation
+ .shrink(get_record_batch_memory_size(&batch));
+ Poll::Ready(Some(Ok(batch)))
+ }
+ Some(Err(err)) => Poll::Ready(Some(Err(err))),
+ None => {
+ // Stream is done so free the reservation completely
+ self.reservation.free();
+ Poll::Ready(None)
+ }
+ }
+ }
+ Poll::Pending => Poll::Pending,
+ }
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ self.inner.size_hint()
+ }
+}
+
+impl RecordBatchStream for ReservationStream {
+ fn schema(&self) -> SchemaRef {
+ Arc::clone(&self.schema)
+ }
+}
+
#[cfg(test)]
mod test {
use super::*;
@@ -924,7 +990,126 @@ mod test {
assert_eq!(
number_of_batches, 2,
- "Should have received exactly one empty batch"
+ "Should have received exactly two empty batches"
+ );
+ }
+
+ #[tokio::test]
+ async fn test_reservation_stream_shrinks_on_poll() {
+ use arrow::array::Int32Array;
+ use datafusion_execution::memory_pool::MemoryConsumer;
+ use datafusion_execution::runtime_env::RuntimeEnvBuilder;
+
+ let runtime = RuntimeEnvBuilder::new()
+ .with_memory_limit(10 * 1024 * 1024, 1.0)
+ .build_arc()
+ .unwrap();
+
+ let mut reservation =
MemoryConsumer::new("test").register(&runtime.memory_pool);
+
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ // Create batches
+ let batch1 = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
+ )
+ .unwrap();
+ let batch2 = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![Arc::new(Int32Array::from(vec![6, 7, 8, 9, 10]))],
+ )
+ .unwrap();
+
+ let batch1_size = get_record_batch_memory_size(&batch1);
+ let batch2_size = get_record_batch_memory_size(&batch2);
+
+ // Reserve memory upfront
+ reservation.try_grow(batch1_size + batch2_size).unwrap();
+ let initial_reserved = runtime.memory_pool.reserved();
+ assert_eq!(initial_reserved, batch1_size + batch2_size);
+
+ // Create stream with batches
+ let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);
+ let inner =
Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream))
+ as SendableRecordBatchStream;
+
+ let mut res_stream =
+ ReservationStream::new(Arc::clone(&schema), inner, reservation);
+
+ // Poll first batch
+ let result1 = res_stream.next().await;
+ assert!(result1.is_some());
+
+ // Memory should be reduced by batch1_size
+ let after_first = runtime.memory_pool.reserved();
+ assert_eq!(after_first, batch2_size);
+
+ // Poll second batch
+ let result2 = res_stream.next().await;
+ assert!(result2.is_some());
+
+ // Memory should be reduced by batch2_size
+ let after_second = runtime.memory_pool.reserved();
+ assert_eq!(after_second, 0);
+
+ // Poll None (end of stream)
+ let result3 = res_stream.next().await;
+ assert!(result3.is_none());
+
+ // Memory should still be 0
+ assert_eq!(runtime.memory_pool.reserved(), 0);
+ }
+
+ #[tokio::test]
+ async fn test_reservation_stream_error_handling() {
+ use datafusion_execution::memory_pool::MemoryConsumer;
+ use datafusion_execution::runtime_env::RuntimeEnvBuilder;
+
+ let runtime = RuntimeEnvBuilder::new()
+ .with_memory_limit(10 * 1024 * 1024, 1.0)
+ .build_arc()
+ .unwrap();
+
+ let mut reservation =
MemoryConsumer::new("test").register(&runtime.memory_pool);
+
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+
+ reservation.try_grow(1000).unwrap();
+ let initial = runtime.memory_pool.reserved();
+ assert_eq!(initial, 1000);
+
+ // Create a stream that errors
+ let stream = futures::stream::iter(vec![exec_err!("Test error")]);
+ let inner =
Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream))
+ as SendableRecordBatchStream;
+
+ let mut res_stream =
+ ReservationStream::new(Arc::clone(&schema), inner, reservation);
+
+ // Get the error
+ let result = res_stream.next().await;
+ assert!(result.is_some());
+ assert!(result.unwrap().is_err());
+
+ // Verify reservation is NOT automatically freed on error
+ // The reservation is only freed when poll_next returns
Poll::Ready(None)
+ // After an error, the stream may continue to hold the reservation
+ // until it's explicitly dropped or polled to None
+ let after_error = runtime.memory_pool.reserved();
+ assert_eq!(
+ after_error, 1000,
+ "Reservation should still be held after error"
+ );
+
+ // Drop the stream to free the reservation
+ drop(res_stream);
+
+ // Now memory should be freed
+ assert_eq!(
+ runtime.memory_pool.reserved(),
+ 0,
+ "Memory should be freed when stream is dropped"
);
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]