This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 63efaee255 Support SortMergeJoin spilling (#11218)
63efaee255 is described below
commit 63efaee2555ddd1381b4885867860621ec791f82
Author: Oleks V <[email protected]>
AuthorDate: Sun Jul 21 17:09:54 2024 -0700
Support SortMergeJoin spilling (#11218)
* Support SortMerge spilling
---
datafusion/core/tests/memory_limit/mod.rs | 27 +-
datafusion/execution/src/memory_pool/mod.rs | 19 +-
.../physical-plan/src/joins/sort_merge_join.rs | 457 +++++++++++++++++----
datafusion/physical-plan/src/sorts/sort.rs | 7 +-
datafusion/physical-plan/src/spill.rs | 103 ++++-
5 files changed, 529 insertions(+), 84 deletions(-)
diff --git a/datafusion/core/tests/memory_limit/mod.rs
b/datafusion/core/tests/memory_limit/mod.rs
index f4f4f8cd89..bc2c3315da 100644
--- a/datafusion/core/tests/memory_limit/mod.rs
+++ b/datafusion/core/tests/memory_limit/mod.rs
@@ -164,7 +164,7 @@ async fn cross_join() {
}
#[tokio::test]
-async fn merge_join() {
+async fn sort_merge_join_no_spill() {
// Planner chooses MergeJoin only if number of partitions > 1
let config = SessionConfig::new()
.with_target_partitions(2)
@@ -175,11 +175,32 @@ async fn merge_join() {
"select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time =
t2.time",
)
.with_expected_errors(vec![
- "Resources exhausted: Failed to allocate additional",
+ "Failed to allocate additional",
"SMJStream",
+ "Disk spilling disabled",
])
.with_memory_limit(1_000)
.with_config(config)
+ .with_scenario(Scenario::AccessLogStreaming)
+ .run()
+ .await
+}
+
+#[tokio::test]
+async fn sort_merge_join_spill() {
+ // Planner chooses MergeJoin only if number of partitions > 1
+ let config = SessionConfig::new()
+ .with_target_partitions(2)
+ .set_bool("datafusion.optimizer.prefer_hash_join", false);
+
+ TestCase::new()
+ .with_query(
+ "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time =
t2.time",
+ )
+ .with_memory_limit(1_000)
+ .with_config(config)
+ .with_disk_manager_config(DiskManagerConfig::NewOs)
+ .with_scenario(Scenario::AccessLogStreaming)
.run()
.await
}
@@ -453,7 +474,7 @@ impl TestCase {
let table = scenario.table();
let rt_config = RuntimeConfig::new()
- // do not allow spilling
+ // disk manager setting controls the spilling
.with_disk_manager(disk_manager_config)
.with_memory_limit(memory_limit, MEMORY_FRACTION);
diff --git a/datafusion/execution/src/memory_pool/mod.rs
b/datafusion/execution/src/memory_pool/mod.rs
index 3f66a304dc..92ed1b2918 100644
--- a/datafusion/execution/src/memory_pool/mod.rs
+++ b/datafusion/execution/src/memory_pool/mod.rs
@@ -18,7 +18,7 @@
//! [`MemoryPool`] for memory management during query execution, [`proxy]` for
//! help with allocation accounting.
-use datafusion_common::Result;
+use datafusion_common::{internal_err, Result};
use std::{cmp::Ordering, sync::Arc};
mod pool;
@@ -220,6 +220,23 @@ impl MemoryReservation {
self.size = new_size
}
+ /// Tries to free `capacity` bytes from this reservation
+ /// if `capacity` does not exceed [`Self::size`]
+ /// Returns new reservation size
+ /// or error if shrinking capacity is more than allocated size
+ pub fn try_shrink(&mut self, capacity: usize) -> Result<usize> {
+ if let Some(new_size) = self.size.checked_sub(capacity) {
+ self.registration.pool.shrink(self, capacity);
+ self.size = new_size;
+ Ok(new_size)
+ } else {
+ internal_err!(
+ "Cannot free the capacity {capacity} out of allocated size {}",
+ self.size
+ )
+ }
+ }
+
/// Sets the size of this reservation to `capacity`
pub fn resize(&mut self, capacity: usize) {
match capacity.cmp(&self.size) {
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index a03e4a83fd..5fde028c7f 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -24,40 +24,46 @@ use std::any::Any;
use std::cmp::Ordering;
use std::collections::VecDeque;
use std::fmt::Formatter;
+use std::fs::File;
+use std::io::BufReader;
use std::mem;
use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
-use crate::expressions::PhysicalSortExpr;
-use crate::joins::utils::{
- build_join_schema, check_join_is_valid, estimate_join_statistics,
- symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
-};
-use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
-use crate::{
- execution_mode_from_children, metrics, DisplayAs, DisplayFormatType,
Distribution,
- ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
- RecordBatchStream, SendableRecordBatchStream, Statistics,
-};
-
use arrow::array::*;
use arrow::compute::{self, concat_batches, take, SortOptions};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use arrow::error::ArrowError;
+use arrow::ipc::reader::FileReader;
use arrow_array::types::UInt64Type;
+use futures::{Stream, StreamExt};
+use hashbrown::HashSet;
use datafusion_common::{
- internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType,
Result,
+ exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide,
JoinType,
+ Result,
};
+use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
+use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};
-use futures::{Stream, StreamExt};
-use hashbrown::HashSet;
+use crate::expressions::PhysicalSortExpr;
+use crate::joins::utils::{
+ build_join_schema, check_join_is_valid, estimate_join_statistics,
+ symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
+};
+use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder,
MetricsSet};
+use crate::spill::spill_record_batches;
+use crate::{
+ execution_mode_from_children, metrics, DisplayAs, DisplayFormatType,
Distribution,
+ ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
+ RecordBatchStream, SendableRecordBatchStream, Statistics,
+};
/// join execution plan executes partitions in parallel and combines them into
a set of
/// partitions.
@@ -234,11 +240,6 @@ impl SortMergeJoinExec {
impl DisplayAs for SortMergeJoinExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) ->
std::fmt::Result {
- let display_filter = self.filter.as_ref().map_or_else(
- || "".to_string(),
- |f| format!(", filter={}", f.expression()),
- );
-
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let on = self
@@ -250,7 +251,12 @@ impl DisplayAs for SortMergeJoinExec {
write!(
f,
"SortMergeJoin: join_type={:?}, on=[{}]{}",
- self.join_type, on, display_filter
+ self.join_type,
+ on,
+ self.filter.as_ref().map_or("".to_string(), |f| format!(
+ ", filter={}",
+ f.expression()
+ ))
)
}
}
@@ -375,6 +381,7 @@ impl ExecutionPlan for SortMergeJoinExec {
batch_size,
SortMergeJoinMetrics::new(partition, &self.metrics),
reservation,
+ context.runtime_env(),
)?))
}
@@ -412,6 +419,12 @@ struct SortMergeJoinMetrics {
/// Peak memory used for buffered data.
/// Calculated as sum of peak memory values across partitions
peak_mem_used: metrics::Gauge,
+ /// count of spills during the execution of the operator
+ spill_count: Count,
+ /// total spilled bytes during the execution of the operator
+ spilled_bytes: Count,
+ /// total spilled rows during the execution of the operator
+ spilled_rows: Count,
}
impl SortMergeJoinMetrics {
@@ -425,6 +438,9 @@ impl SortMergeJoinMetrics {
MetricBuilder::new(metrics).counter("output_batches", partition);
let output_rows = MetricBuilder::new(metrics).output_rows(partition);
let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used",
partition);
+ let spill_count = MetricBuilder::new(metrics).spill_count(partition);
+ let spilled_bytes =
MetricBuilder::new(metrics).spilled_bytes(partition);
+ let spilled_rows = MetricBuilder::new(metrics).spilled_rows(partition);
Self {
join_time,
@@ -433,6 +449,9 @@ impl SortMergeJoinMetrics {
output_batches,
output_rows,
peak_mem_used,
+ spill_count,
+ spilled_bytes,
+ spilled_rows,
}
}
}
@@ -565,7 +584,8 @@ impl StreamedBatch {
#[derive(Debug)]
struct BufferedBatch {
/// The buffered record batch
- pub batch: RecordBatch,
+ /// None if the batch spilled to disk th
+ pub batch: Option<RecordBatch>,
/// The range in which the rows share the same join key
pub range: Range<usize>,
/// Array refs of the join key
@@ -577,6 +597,14 @@ struct BufferedBatch {
/// The indices of buffered batch that failed the join filter.
/// When dequeuing the buffered batch, we need to produce null joined rows
for these indices.
pub join_filter_failed_idxs: HashSet<u64>,
+ /// Current buffered batch number of rows. Equal to batch.num_rows()
+ /// but if batch is spilled to disk this property is preferable
+ /// and less expensive
+ pub num_rows: usize,
+ /// An optional temp spill file name on the disk if the batch spilled
+ /// None by default
+ /// Some(fileName) if the batch spilled to the disk
+ pub spill_file: Option<RefCountedTempFile>,
}
impl BufferedBatch {
@@ -602,13 +630,16 @@ impl BufferedBatch {
+ mem::size_of::<Range<usize>>()
+ mem::size_of::<usize>();
+ let num_rows = batch.num_rows();
BufferedBatch {
- batch,
+ batch: Some(batch),
range,
join_arrays,
null_joined: vec![],
size_estimation,
join_filter_failed_idxs: HashSet::new(),
+ num_rows,
+ spill_file: None,
}
}
}
@@ -666,6 +697,8 @@ struct SMJStream {
pub join_metrics: SortMergeJoinMetrics,
/// Memory reservation
pub reservation: MemoryReservation,
+ /// Runtime env
+ pub runtime_env: Arc<RuntimeEnv>,
}
impl RecordBatchStream for SMJStream {
@@ -785,6 +818,7 @@ impl SMJStream {
batch_size: usize,
join_metrics: SortMergeJoinMetrics,
reservation: MemoryReservation,
+ runtime_env: Arc<RuntimeEnv>,
) -> Result<Self> {
let streamed_schema = streamed.schema();
let buffered_schema = buffered.schema();
@@ -813,6 +847,7 @@ impl SMJStream {
join_type,
join_metrics,
reservation,
+ runtime_env,
})
}
@@ -858,6 +893,58 @@ impl SMJStream {
}
}
+ fn free_reservation(&mut self, buffered_batch: BufferedBatch) ->
Result<()> {
+ // Shrink memory usage for in-memory batches only
+ if buffered_batch.spill_file.is_none() &&
buffered_batch.batch.is_some() {
+ self.reservation
+ .try_shrink(buffered_batch.size_estimation)?;
+ }
+
+ Ok(())
+ }
+
+ fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) ->
Result<()> {
+ match self.reservation.try_grow(buffered_batch.size_estimation) {
+ Ok(_) => {
+ self.join_metrics
+ .peak_mem_used
+ .set_max(self.reservation.size());
+ Ok(())
+ }
+ Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => {
+ // spill buffered batch to disk
+ let spill_file = self
+ .runtime_env
+ .disk_manager
+ .create_tmp_file("sort_merge_join_buffered_spill")?;
+
+ if let Some(batch) = buffered_batch.batch {
+ spill_record_batches(
+ vec![batch],
+ spill_file.path().into(),
+ Arc::clone(&self.buffered_schema),
+ )?;
+ buffered_batch.spill_file = Some(spill_file);
+ buffered_batch.batch = None;
+
+ // update metrics to register spill
+ self.join_metrics.spill_count.add(1);
+ self.join_metrics
+ .spilled_bytes
+ .add(buffered_batch.size_estimation);
+
self.join_metrics.spilled_rows.add(buffered_batch.num_rows);
+ Ok(())
+ } else {
+ internal_err!("Buffered batch has empty body")
+ }
+ }
+ Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()),
+ }?;
+
+ self.buffered_data.batches.push_back(buffered_batch);
+ Ok(())
+ }
+
/// Poll next buffered batches
fn poll_buffered_batches(&mut self, cx: &mut Context) ->
Poll<Option<Result<()>>> {
loop {
@@ -867,12 +954,12 @@ impl SMJStream {
while !self.buffered_data.batches.is_empty() {
let head_batch = self.buffered_data.head_batch();
// If the head batch is fully processed, dequeue it
and produce output of it.
- if head_batch.range.end == head_batch.batch.num_rows()
{
+ if head_batch.range.end == head_batch.num_rows {
self.freeze_dequeuing_buffered()?;
if let Some(buffered_batch) =
self.buffered_data.batches.pop_front()
{
-
self.reservation.shrink(buffered_batch.size_estimation);
+ self.free_reservation(buffered_batch)?;
}
} else {
// If the head batch is not fully processed, break
the loop.
@@ -900,25 +987,22 @@ impl SMJStream {
Poll::Ready(Some(batch)) => {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
+
if batch.num_rows() > 0 {
let buffered_batch =
BufferedBatch::new(batch, 0..1,
&self.on_buffered);
-
self.reservation.try_grow(buffered_batch.size_estimation)?;
- self.join_metrics
- .peak_mem_used
- .set_max(self.reservation.size());
-
self.buffered_data.batches.push_back(buffered_batch);
+ self.allocate_reservation(buffered_batch)?;
self.buffered_state = BufferedState::PollingRest;
}
}
},
BufferedState::PollingRest => {
if self.buffered_data.tail_batch().range.end
- < self.buffered_data.tail_batch().batch.num_rows()
+ < self.buffered_data.tail_batch().num_rows
{
while self.buffered_data.tail_batch().range.end
- < self.buffered_data.tail_batch().batch.num_rows()
+ < self.buffered_data.tail_batch().num_rows
{
if is_join_arrays_equal(
&self.buffered_data.head_batch().join_arrays,
@@ -941,6 +1025,7 @@ impl SMJStream {
self.buffered_state = BufferedState::Ready;
}
Poll::Ready(Some(batch)) => {
+ // Polling batches coming concurrently as
multiple partitions
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
if batch.num_rows() > 0 {
@@ -949,12 +1034,7 @@ impl SMJStream {
0..0,
&self.on_buffered,
);
- self.reservation
-
.try_grow(buffered_batch.size_estimation)?;
- self.join_metrics
- .peak_mem_used
- .set_max(self.reservation.size());
-
self.buffered_data.batches.push_back(buffered_batch);
+ self.allocate_reservation(buffered_batch)?;
}
}
}
@@ -1473,13 +1553,8 @@ fn produce_buffered_null_batch(
}
// Take buffered (right) columns
- let buffered_columns = buffered_batch
- .batch
- .columns()
- .iter()
- .map(|column| take(column, &buffered_indices, None))
- .collect::<Result<Vec<_>, ArrowError>>()
- .map_err(Into::<DataFusionError>::into)?;
+ let buffered_columns =
+ get_buffered_columns_from_batch(buffered_batch, buffered_indices)?;
// Create null streamed (left) columns
let mut streamed_columns = streamed_schema
@@ -1502,13 +1577,45 @@ fn get_buffered_columns(
buffered_data: &BufferedData,
buffered_batch_idx: usize,
buffered_indices: &UInt64Array,
-) -> Result<Vec<ArrayRef>, ArrowError> {
- buffered_data.batches[buffered_batch_idx]
- .batch
- .columns()
- .iter()
- .map(|column| take(column, &buffered_indices, None))
- .collect::<Result<Vec<_>, ArrowError>>()
+) -> Result<Vec<ArrayRef>> {
+ get_buffered_columns_from_batch(
+ &buffered_data.batches[buffered_batch_idx],
+ buffered_indices,
+ )
+}
+
+#[inline(always)]
+fn get_buffered_columns_from_batch(
+ buffered_batch: &BufferedBatch,
+ buffered_indices: &UInt64Array,
+) -> Result<Vec<ArrayRef>> {
+ match (&buffered_batch.spill_file, &buffered_batch.batch) {
+ // In memory batch
+ (None, Some(batch)) => Ok(batch
+ .columns()
+ .iter()
+ .map(|column| take(column, &buffered_indices, None))
+ .collect::<Result<Vec<_>, ArrowError>>()
+ .map_err(Into::<DataFusionError>::into)?),
+ // If the batch was spilled to disk, less likely
+ (Some(spill_file), None) => {
+ let mut buffered_cols: Vec<ArrayRef> =
+ Vec::with_capacity(buffered_indices.len());
+
+ let file = BufReader::new(File::open(spill_file.path())?);
+ let reader = FileReader::try_new(file, None)?;
+
+ for batch in reader {
+ batch?.columns().iter().for_each(|column| {
+ buffered_cols.extend(take(column, &buffered_indices, None))
+ });
+ }
+
+ Ok(buffered_cols)
+ }
+ // Invalid combination
+ (spill, batch) => internal_err!("Unexpected buffered batch spill
status. Spill exists: {}. In-memory exists: {}", spill.is_some(),
batch.is_some()),
+ }
}
/// Calculate join filter bit mask considering join type specifics
@@ -1854,6 +1961,7 @@ mod tests {
assert_batches_eq, assert_batches_sorted_eq, assert_contains,
JoinType, Result,
};
use datafusion_execution::config::SessionConfig;
+ use datafusion_execution::disk_manager::DiskManagerConfig;
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_execution::TaskContext;
@@ -2749,7 +2857,7 @@ mod tests {
}
#[tokio::test]
- async fn overallocation_single_batch() -> Result<()> {
+ async fn overallocation_single_batch_no_spill() -> Result<()> {
let left = build_table(
("a1", &vec![0, 1, 2, 3, 4, 5]),
("b1", &vec![1, 2, 3, 4, 5, 6]),
@@ -2775,14 +2883,17 @@ mod tests {
JoinType::LeftAnti,
];
- for join_type in join_types {
- let runtime_config = RuntimeConfig::new().with_memory_limit(100,
1.0);
- let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
- let session_config = SessionConfig::default().with_batch_size(50);
+ // Disable DiskManager to prevent spilling
+ let runtime_config = RuntimeConfig::new()
+ .with_memory_limit(100, 1.0)
+ .with_disk_manager(DiskManagerConfig::Disabled);
+ let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
+ let session_config = SessionConfig::default().with_batch_size(50);
+ for join_type in join_types {
let task_ctx = TaskContext::default()
- .with_session_config(session_config)
- .with_runtime(runtime);
+ .with_session_config(session_config.clone())
+ .with_runtime(Arc::clone(&runtime));
let task_ctx = Arc::new(task_ctx);
let join = join_with_options(
@@ -2797,18 +2908,20 @@ mod tests {
let stream = join.execute(0, task_ctx)?;
let err = common::collect(stream).await.unwrap_err();
- assert_contains!(
- err.to_string(),
- "Resources exhausted: Failed to allocate additional"
- );
+ assert_contains!(err.to_string(), "Failed to allocate additional");
assert_contains!(err.to_string(), "SMJStream[0]");
+ assert_contains!(err.to_string(), "Disk spilling disabled");
+ assert!(join.metrics().is_some());
+ assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
+ assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
+ assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
}
Ok(())
}
#[tokio::test]
- async fn overallocation_multi_batch() -> Result<()> {
+ async fn overallocation_multi_batch_no_spill() -> Result<()> {
let left_batch_1 = build_table_i32(
("a1", &vec![0, 1]),
("b1", &vec![1, 1]),
@@ -2855,13 +2968,17 @@ mod tests {
JoinType::LeftAnti,
];
+ // Disable DiskManager to prevent spilling
+ let runtime_config = RuntimeConfig::new()
+ .with_memory_limit(100, 1.0)
+ .with_disk_manager(DiskManagerConfig::Disabled);
+ let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
+ let session_config = SessionConfig::default().with_batch_size(50);
+
for join_type in join_types {
- let runtime_config = RuntimeConfig::new().with_memory_limit(100,
1.0);
- let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
- let session_config = SessionConfig::default().with_batch_size(50);
let task_ctx = TaskContext::default()
- .with_session_config(session_config)
- .with_runtime(runtime);
+ .with_session_config(session_config.clone())
+ .with_runtime(Arc::clone(&runtime));
let task_ctx = Arc::new(task_ctx);
let join = join_with_options(
Arc::clone(&left),
@@ -2875,11 +2992,205 @@ mod tests {
let stream = join.execute(0, task_ctx)?;
let err = common::collect(stream).await.unwrap_err();
- assert_contains!(
- err.to_string(),
- "Resources exhausted: Failed to allocate additional"
- );
+ assert_contains!(err.to_string(), "Failed to allocate additional");
assert_contains!(err.to_string(), "SMJStream[0]");
+ assert_contains!(err.to_string(), "Disk spilling disabled");
+ assert!(join.metrics().is_some());
+ assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
+ assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
+ assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
+ }
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn overallocation_single_batch_spill() -> Result<()> {
+ let left = build_table(
+ ("a1", &vec![0, 1, 2, 3, 4, 5]),
+ ("b1", &vec![1, 2, 3, 4, 5, 6]),
+ ("c1", &vec![4, 5, 6, 7, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![0, 10, 20, 30, 40]),
+ ("b2", &vec![1, 3, 4, 6, 8]),
+ ("c2", &vec![50, 60, 70, 80, 90]),
+ );
+ let on = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
+ )];
+ let sort_options = vec![SortOptions::default(); on.len()];
+
+ let join_types = [
+ JoinType::Inner,
+ JoinType::Left,
+ JoinType::Right,
+ JoinType::Full,
+ JoinType::LeftSemi,
+ JoinType::LeftAnti,
+ ];
+
+ // Enable DiskManager to allow spilling
+ let runtime_config = RuntimeConfig::new()
+ .with_memory_limit(100, 1.0)
+ .with_disk_manager(DiskManagerConfig::NewOs);
+ let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
+
+ for batch_size in [1, 50] {
+ let session_config =
SessionConfig::default().with_batch_size(batch_size);
+
+ for join_type in &join_types {
+ let task_ctx = TaskContext::default()
+ .with_session_config(session_config.clone())
+ .with_runtime(Arc::clone(&runtime));
+ let task_ctx = Arc::new(task_ctx);
+
+ let join = join_with_options(
+ Arc::clone(&left),
+ Arc::clone(&right),
+ on.clone(),
+ *join_type,
+ sort_options.clone(),
+ false,
+ )?;
+
+ let stream = join.execute(0, task_ctx)?;
+ let spilled_join_result =
common::collect(stream).await.unwrap();
+
+ assert!(join.metrics().is_some());
+ assert!(join.metrics().unwrap().spill_count().unwrap() > 0);
+ assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0);
+ assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0);
+
+ // Run the test with no spill configuration as
+ let task_ctx_no_spill =
+
TaskContext::default().with_session_config(session_config.clone());
+ let task_ctx_no_spill = Arc::new(task_ctx_no_spill);
+
+ let join = join_with_options(
+ Arc::clone(&left),
+ Arc::clone(&right),
+ on.clone(),
+ *join_type,
+ sort_options.clone(),
+ false,
+ )?;
+ let stream = join.execute(0, task_ctx_no_spill)?;
+ let no_spilled_join_result =
common::collect(stream).await.unwrap();
+
+ assert!(join.metrics().is_some());
+ assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
+ assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
+ assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
+ // Compare spilled and non spilled data to check spill logic
doesn't corrupt the data
+ assert_eq!(spilled_join_result, no_spilled_join_result);
+ }
+ }
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn overallocation_multi_batch_spill() -> Result<()> {
+ let left_batch_1 = build_table_i32(
+ ("a1", &vec![0, 1]),
+ ("b1", &vec![1, 1]),
+ ("c1", &vec![4, 5]),
+ );
+ let left_batch_2 = build_table_i32(
+ ("a1", &vec![2, 3]),
+ ("b1", &vec![1, 1]),
+ ("c1", &vec![6, 7]),
+ );
+ let left_batch_3 = build_table_i32(
+ ("a1", &vec![4, 5]),
+ ("b1", &vec![1, 1]),
+ ("c1", &vec![8, 9]),
+ );
+ let right_batch_1 = build_table_i32(
+ ("a2", &vec![0, 10]),
+ ("b2", &vec![1, 1]),
+ ("c2", &vec![50, 60]),
+ );
+ let right_batch_2 = build_table_i32(
+ ("a2", &vec![20, 30]),
+ ("b2", &vec![1, 1]),
+ ("c2", &vec![70, 80]),
+ );
+ let right_batch_3 =
+ build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2",
&vec![90]));
+ let left =
+ build_table_from_batches(vec![left_batch_1, left_batch_2,
left_batch_3]);
+ let right =
+ build_table_from_batches(vec![right_batch_1, right_batch_2,
right_batch_3]);
+ let on = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
+ )];
+ let sort_options = vec![SortOptions::default(); on.len()];
+
+ let join_types = [
+ JoinType::Inner,
+ JoinType::Left,
+ JoinType::Right,
+ JoinType::Full,
+ JoinType::LeftSemi,
+ JoinType::LeftAnti,
+ ];
+
+ // Enable DiskManager to allow spilling
+ let runtime_config = RuntimeConfig::new()
+ .with_memory_limit(500, 1.0)
+ .with_disk_manager(DiskManagerConfig::NewOs);
+ let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
+ for batch_size in [1, 50] {
+ let session_config =
SessionConfig::default().with_batch_size(batch_size);
+
+ for join_type in &join_types {
+ let task_ctx = TaskContext::default()
+ .with_session_config(session_config.clone())
+ .with_runtime(Arc::clone(&runtime));
+ let task_ctx = Arc::new(task_ctx);
+ let join = join_with_options(
+ Arc::clone(&left),
+ Arc::clone(&right),
+ on.clone(),
+ *join_type,
+ sort_options.clone(),
+ false,
+ )?;
+
+ let stream = join.execute(0, task_ctx)?;
+ let spilled_join_result =
common::collect(stream).await.unwrap();
+ assert!(join.metrics().is_some());
+ assert!(join.metrics().unwrap().spill_count().unwrap() > 0);
+ assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0);
+ assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0);
+
+ // Run the test with no spill configuration as
+ let task_ctx_no_spill =
+
TaskContext::default().with_session_config(session_config.clone());
+ let task_ctx_no_spill = Arc::new(task_ctx_no_spill);
+
+ let join = join_with_options(
+ Arc::clone(&left),
+ Arc::clone(&right),
+ on.clone(),
+ *join_type,
+ sort_options.clone(),
+ false,
+ )?;
+ let stream = join.execute(0, task_ctx_no_spill)?;
+ let no_spilled_join_result =
common::collect(stream).await.unwrap();
+
+ assert!(join.metrics().is_some());
+ assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
+ assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
+ assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
+ // Compare spilled and non spilled data to check spill logic
doesn't corrupt the data
+ assert_eq!(spilled_join_result, no_spilled_join_result);
+ }
}
Ok(())
diff --git a/datafusion/physical-plan/src/sorts/sort.rs
b/datafusion/physical-plan/src/sorts/sort.rs
index d576f77d9f..13ff63c174 100644
--- a/datafusion/physical-plan/src/sorts/sort.rs
+++ b/datafusion/physical-plan/src/sorts/sort.rs
@@ -45,7 +45,7 @@ use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, SortField};
use arrow_array::{Array, RecordBatchOptions, UInt32Array};
use arrow_schema::DataType;
-use datafusion_common::{DataFusionError, Result};
+use datafusion_common::{internal_err, Result};
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_execution::runtime_env::RuntimeEnv;
@@ -333,10 +333,7 @@ impl ExternalSorter {
for spill in self.spills.drain(..) {
if !spill.path().exists() {
- return Err(DataFusionError::Internal(format!(
- "Spill file {:?} does not exist",
- spill.path()
- )));
+ return internal_err!("Spill file {:?} does not exist",
spill.path());
}
let stream = read_spill_as_stream(spill,
Arc::clone(&self.schema), 2)?;
streams.push(stream);
diff --git a/datafusion/physical-plan/src/spill.rs
b/datafusion/physical-plan/src/spill.rs
index 0018a27bd2..21ca58fa0a 100644
--- a/datafusion/physical-plan/src/spill.rs
+++ b/datafusion/physical-plan/src/spill.rs
@@ -40,7 +40,7 @@ use crate::stream::RecordBatchReceiverStream;
/// `path` - temp file
/// `schema` - batches schema, should be the same across batches
/// `buffer` - internal buffer of capacity batches
-pub fn read_spill_as_stream(
+pub(crate) fn read_spill_as_stream(
path: RefCountedTempFile,
schema: SchemaRef,
buffer: usize,
@@ -56,7 +56,7 @@ pub fn read_spill_as_stream(
/// Spills in-memory `batches` to disk.
///
/// Returns total number of the rows spilled to disk.
-pub fn spill_record_batches(
+pub(crate) fn spill_record_batches(
batches: Vec<RecordBatch>,
path: PathBuf,
schema: SchemaRef,
@@ -85,3 +85,102 @@ fn read_spill(sender: Sender<Result<RecordBatch>>, path:
&Path) -> Result<()> {
}
Ok(())
}
+
+/// Spill the `RecordBatch` to disk as smaller batches
+/// split by `batch_size_rows`
+/// Return `total_rows` what is spilled
+pub fn spill_record_batch_by_size(
+ batch: &RecordBatch,
+ path: PathBuf,
+ schema: SchemaRef,
+ batch_size_rows: usize,
+) -> Result<()> {
+ let mut offset = 0;
+ let total_rows = batch.num_rows();
+ let mut writer = IPCWriter::new(&path, schema.as_ref())?;
+
+ while offset < total_rows {
+ let length = std::cmp::min(total_rows - offset, batch_size_rows);
+ let batch = batch.slice(offset, length);
+ offset += batch.num_rows();
+ writer.write(&batch)?;
+ }
+ writer.finish()?;
+
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::spill::{spill_record_batch_by_size, spill_record_batches};
+ use crate::test::build_table_i32;
+ use datafusion_common::Result;
+ use datafusion_execution::disk_manager::DiskManagerConfig;
+ use datafusion_execution::DiskManager;
+ use std::fs::File;
+ use std::io::BufReader;
+ use std::sync::Arc;
+
+ #[test]
+ fn test_batch_spill_and_read() -> Result<()> {
+ let batch1 = build_table_i32(
+ ("a2", &vec![0, 1, 2]),
+ ("b2", &vec![3, 4, 5]),
+ ("c2", &vec![4, 5, 6]),
+ );
+
+ let batch2 = build_table_i32(
+ ("a2", &vec![10, 11, 12]),
+ ("b2", &vec![13, 14, 15]),
+ ("c2", &vec![14, 15, 16]),
+ );
+
+ let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?;
+
+ let spill_file = disk_manager.create_tmp_file("Test Spill")?;
+ let schema = batch1.schema();
+ let num_rows = batch1.num_rows() + batch2.num_rows();
+ let cnt = spill_record_batches(
+ vec![batch1, batch2],
+ spill_file.path().into(),
+ Arc::clone(&schema),
+ );
+ assert_eq!(cnt.unwrap(), num_rows);
+
+ let file = BufReader::new(File::open(spill_file.path())?);
+ let reader = arrow::ipc::reader::FileReader::try_new(file, None)?;
+
+ assert_eq!(reader.num_batches(), 2);
+ assert_eq!(reader.schema(), schema);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_batch_spill_by_size() -> Result<()> {
+ let batch1 = build_table_i32(
+ ("a2", &vec![0, 1, 2, 3]),
+ ("b2", &vec![3, 4, 5, 6]),
+ ("c2", &vec![4, 5, 6, 7]),
+ );
+
+ let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?;
+
+ let spill_file = disk_manager.create_tmp_file("Test Spill")?;
+ let schema = batch1.schema();
+ spill_record_batch_by_size(
+ &batch1,
+ spill_file.path().into(),
+ Arc::clone(&schema),
+ 1,
+ )?;
+
+ let file = BufReader::new(File::open(spill_file.path())?);
+ let reader = arrow::ipc::reader::FileReader::try_new(file, None)?;
+
+ assert_eq!(reader.num_batches(), 4);
+ assert_eq!(reader.schema(), schema);
+
+ Ok(())
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]