This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 63efaee255 Support SortMergeJoin spilling (#11218)
63efaee255 is described below

commit 63efaee2555ddd1381b4885867860621ec791f82
Author: Oleks V <[email protected]>
AuthorDate: Sun Jul 21 17:09:54 2024 -0700

    Support SortMergeJoin spilling (#11218)
    
    * Support SortMerge spilling
---
 datafusion/core/tests/memory_limit/mod.rs          |  27 +-
 datafusion/execution/src/memory_pool/mod.rs        |  19 +-
 .../physical-plan/src/joins/sort_merge_join.rs     | 457 +++++++++++++++++----
 datafusion/physical-plan/src/sorts/sort.rs         |   7 +-
 datafusion/physical-plan/src/spill.rs              | 103 ++++-
 5 files changed, 529 insertions(+), 84 deletions(-)

diff --git a/datafusion/core/tests/memory_limit/mod.rs 
b/datafusion/core/tests/memory_limit/mod.rs
index f4f4f8cd89..bc2c3315da 100644
--- a/datafusion/core/tests/memory_limit/mod.rs
+++ b/datafusion/core/tests/memory_limit/mod.rs
@@ -164,7 +164,7 @@ async fn cross_join() {
 }
 
 #[tokio::test]
-async fn merge_join() {
+async fn sort_merge_join_no_spill() {
     // Planner chooses MergeJoin only if number of partitions > 1
     let config = SessionConfig::new()
         .with_target_partitions(2)
@@ -175,11 +175,32 @@ async fn merge_join() {
             "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = 
t2.time",
         )
         .with_expected_errors(vec![
-            "Resources exhausted: Failed to allocate additional",
+            "Failed to allocate additional",
             "SMJStream",
+            "Disk spilling disabled",
         ])
         .with_memory_limit(1_000)
         .with_config(config)
+        .with_scenario(Scenario::AccessLogStreaming)
+        .run()
+        .await
+}
+
+#[tokio::test]
+async fn sort_merge_join_spill() {
+    // Planner chooses MergeJoin only if number of partitions > 1
+    let config = SessionConfig::new()
+        .with_target_partitions(2)
+        .set_bool("datafusion.optimizer.prefer_hash_join", false);
+
+    TestCase::new()
+        .with_query(
+            "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = 
t2.time",
+        )
+        .with_memory_limit(1_000)
+        .with_config(config)
+        .with_disk_manager_config(DiskManagerConfig::NewOs)
+        .with_scenario(Scenario::AccessLogStreaming)
         .run()
         .await
 }
@@ -453,7 +474,7 @@ impl TestCase {
         let table = scenario.table();
 
         let rt_config = RuntimeConfig::new()
-            // do not allow spilling
+            // disk manager setting controls the spilling
             .with_disk_manager(disk_manager_config)
             .with_memory_limit(memory_limit, MEMORY_FRACTION);
 
diff --git a/datafusion/execution/src/memory_pool/mod.rs 
b/datafusion/execution/src/memory_pool/mod.rs
index 3f66a304dc..92ed1b2918 100644
--- a/datafusion/execution/src/memory_pool/mod.rs
+++ b/datafusion/execution/src/memory_pool/mod.rs
@@ -18,7 +18,7 @@
 //! [`MemoryPool`] for memory management during query execution, [`proxy]` for
 //! help with allocation accounting.
 
-use datafusion_common::Result;
+use datafusion_common::{internal_err, Result};
 use std::{cmp::Ordering, sync::Arc};
 
 mod pool;
@@ -220,6 +220,23 @@ impl MemoryReservation {
         self.size = new_size
     }
 
+    /// Tries to free `capacity` bytes from this reservation
+    /// if `capacity` does not exceed [`Self::size`]
+    /// Returns new reservation size
+    /// or error if shrinking capacity is more than allocated size
+    pub fn try_shrink(&mut self, capacity: usize) -> Result<usize> {
+        if let Some(new_size) = self.size.checked_sub(capacity) {
+            self.registration.pool.shrink(self, capacity);
+            self.size = new_size;
+            Ok(new_size)
+        } else {
+            internal_err!(
+                "Cannot free the capacity {capacity} out of allocated size {}",
+                self.size
+            )
+        }
+    }
+
     /// Sets the size of this reservation to `capacity`
     pub fn resize(&mut self, capacity: usize) {
         match capacity.cmp(&self.size) {
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index a03e4a83fd..5fde028c7f 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -24,40 +24,46 @@ use std::any::Any;
 use std::cmp::Ordering;
 use std::collections::VecDeque;
 use std::fmt::Formatter;
+use std::fs::File;
+use std::io::BufReader;
 use std::mem;
 use std::ops::Range;
 use std::pin::Pin;
 use std::sync::Arc;
 use std::task::{Context, Poll};
 
-use crate::expressions::PhysicalSortExpr;
-use crate::joins::utils::{
-    build_join_schema, check_join_is_valid, estimate_join_statistics,
-    symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
-};
-use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
-use crate::{
-    execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, 
Distribution,
-    ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
-    RecordBatchStream, SendableRecordBatchStream, Statistics,
-};
-
 use arrow::array::*;
 use arrow::compute::{self, concat_batches, take, SortOptions};
 use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
 use arrow::error::ArrowError;
+use arrow::ipc::reader::FileReader;
 use arrow_array::types::UInt64Type;
+use futures::{Stream, StreamExt};
+use hashbrown::HashSet;
 
 use datafusion_common::{
-    internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, 
Result,
+    exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, 
JoinType,
+    Result,
 };
+use datafusion_execution::disk_manager::RefCountedTempFile;
 use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
+use datafusion_execution::runtime_env::RuntimeEnv;
 use datafusion_execution::TaskContext;
 use datafusion_physical_expr::equivalence::join_equivalence_properties;
 use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};
 
-use futures::{Stream, StreamExt};
-use hashbrown::HashSet;
+use crate::expressions::PhysicalSortExpr;
+use crate::joins::utils::{
+    build_join_schema, check_join_is_valid, estimate_join_statistics,
+    symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
+};
+use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, 
MetricsSet};
+use crate::spill::spill_record_batches;
+use crate::{
+    execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, 
Distribution,
+    ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
+    RecordBatchStream, SendableRecordBatchStream, Statistics,
+};
 
 /// join execution plan executes partitions in parallel and combines them into 
a set of
 /// partitions.
@@ -234,11 +240,6 @@ impl SortMergeJoinExec {
 
 impl DisplayAs for SortMergeJoinExec {
     fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> 
std::fmt::Result {
-        let display_filter = self.filter.as_ref().map_or_else(
-            || "".to_string(),
-            |f| format!(", filter={}", f.expression()),
-        );
-
         match t {
             DisplayFormatType::Default | DisplayFormatType::Verbose => {
                 let on = self
@@ -250,7 +251,12 @@ impl DisplayAs for SortMergeJoinExec {
                 write!(
                     f,
                     "SortMergeJoin: join_type={:?}, on=[{}]{}",
-                    self.join_type, on, display_filter
+                    self.join_type,
+                    on,
+                    self.filter.as_ref().map_or("".to_string(), |f| format!(
+                        ", filter={}",
+                        f.expression()
+                    ))
                 )
             }
         }
@@ -375,6 +381,7 @@ impl ExecutionPlan for SortMergeJoinExec {
             batch_size,
             SortMergeJoinMetrics::new(partition, &self.metrics),
             reservation,
+            context.runtime_env(),
         )?))
     }
 
@@ -412,6 +419,12 @@ struct SortMergeJoinMetrics {
     /// Peak memory used for buffered data.
     /// Calculated as sum of peak memory values across partitions
     peak_mem_used: metrics::Gauge,
+    /// count of spills during the execution of the operator
+    spill_count: Count,
+    /// total spilled bytes during the execution of the operator
+    spilled_bytes: Count,
+    /// total spilled rows during the execution of the operator
+    spilled_rows: Count,
 }
 
 impl SortMergeJoinMetrics {
@@ -425,6 +438,9 @@ impl SortMergeJoinMetrics {
             MetricBuilder::new(metrics).counter("output_batches", partition);
         let output_rows = MetricBuilder::new(metrics).output_rows(partition);
         let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", 
partition);
+        let spill_count = MetricBuilder::new(metrics).spill_count(partition);
+        let spilled_bytes = 
MetricBuilder::new(metrics).spilled_bytes(partition);
+        let spilled_rows = MetricBuilder::new(metrics).spilled_rows(partition);
 
         Self {
             join_time,
@@ -433,6 +449,9 @@ impl SortMergeJoinMetrics {
             output_batches,
             output_rows,
             peak_mem_used,
+            spill_count,
+            spilled_bytes,
+            spilled_rows,
         }
     }
 }
@@ -565,7 +584,8 @@ impl StreamedBatch {
 #[derive(Debug)]
 struct BufferedBatch {
     /// The buffered record batch
-    pub batch: RecordBatch,
+    /// None if the batch spilled to disk th
+    pub batch: Option<RecordBatch>,
     /// The range in which the rows share the same join key
     pub range: Range<usize>,
     /// Array refs of the join key
@@ -577,6 +597,14 @@ struct BufferedBatch {
     /// The indices of buffered batch that failed the join filter.
     /// When dequeuing the buffered batch, we need to produce null joined rows 
for these indices.
     pub join_filter_failed_idxs: HashSet<u64>,
+    /// Current buffered batch number of rows. Equal to batch.num_rows()
+    /// but if batch is spilled to disk this property is preferable
+    /// and less expensive
+    pub num_rows: usize,
+    /// An optional temp spill file name on the disk if the batch spilled
+    /// None by default
+    /// Some(fileName) if the batch spilled to the disk
+    pub spill_file: Option<RefCountedTempFile>,
 }
 
 impl BufferedBatch {
@@ -602,13 +630,16 @@ impl BufferedBatch {
             + mem::size_of::<Range<usize>>()
             + mem::size_of::<usize>();
 
+        let num_rows = batch.num_rows();
         BufferedBatch {
-            batch,
+            batch: Some(batch),
             range,
             join_arrays,
             null_joined: vec![],
             size_estimation,
             join_filter_failed_idxs: HashSet::new(),
+            num_rows,
+            spill_file: None,
         }
     }
 }
@@ -666,6 +697,8 @@ struct SMJStream {
     pub join_metrics: SortMergeJoinMetrics,
     /// Memory reservation
     pub reservation: MemoryReservation,
+    /// Runtime env
+    pub runtime_env: Arc<RuntimeEnv>,
 }
 
 impl RecordBatchStream for SMJStream {
@@ -785,6 +818,7 @@ impl SMJStream {
         batch_size: usize,
         join_metrics: SortMergeJoinMetrics,
         reservation: MemoryReservation,
+        runtime_env: Arc<RuntimeEnv>,
     ) -> Result<Self> {
         let streamed_schema = streamed.schema();
         let buffered_schema = buffered.schema();
@@ -813,6 +847,7 @@ impl SMJStream {
             join_type,
             join_metrics,
             reservation,
+            runtime_env,
         })
     }
 
@@ -858,6 +893,58 @@ impl SMJStream {
         }
     }
 
+    fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> 
Result<()> {
+        // Shrink memory usage for in-memory batches only
+        if buffered_batch.spill_file.is_none() && 
buffered_batch.batch.is_some() {
+            self.reservation
+                .try_shrink(buffered_batch.size_estimation)?;
+        }
+
+        Ok(())
+    }
+
+    fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> 
Result<()> {
+        match self.reservation.try_grow(buffered_batch.size_estimation) {
+            Ok(_) => {
+                self.join_metrics
+                    .peak_mem_used
+                    .set_max(self.reservation.size());
+                Ok(())
+            }
+            Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => {
+                // spill buffered batch to disk
+                let spill_file = self
+                    .runtime_env
+                    .disk_manager
+                    .create_tmp_file("sort_merge_join_buffered_spill")?;
+
+                if let Some(batch) = buffered_batch.batch {
+                    spill_record_batches(
+                        vec![batch],
+                        spill_file.path().into(),
+                        Arc::clone(&self.buffered_schema),
+                    )?;
+                    buffered_batch.spill_file = Some(spill_file);
+                    buffered_batch.batch = None;
+
+                    // update metrics to register spill
+                    self.join_metrics.spill_count.add(1);
+                    self.join_metrics
+                        .spilled_bytes
+                        .add(buffered_batch.size_estimation);
+                    
self.join_metrics.spilled_rows.add(buffered_batch.num_rows);
+                    Ok(())
+                } else {
+                    internal_err!("Buffered batch has empty body")
+                }
+            }
+            Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()),
+        }?;
+
+        self.buffered_data.batches.push_back(buffered_batch);
+        Ok(())
+    }
+
     /// Poll next buffered batches
     fn poll_buffered_batches(&mut self, cx: &mut Context) -> 
Poll<Option<Result<()>>> {
         loop {
@@ -867,12 +954,12 @@ impl SMJStream {
                     while !self.buffered_data.batches.is_empty() {
                         let head_batch = self.buffered_data.head_batch();
                         // If the head batch is fully processed, dequeue it 
and produce output of it.
-                        if head_batch.range.end == head_batch.batch.num_rows() 
{
+                        if head_batch.range.end == head_batch.num_rows {
                             self.freeze_dequeuing_buffered()?;
                             if let Some(buffered_batch) =
                                 self.buffered_data.batches.pop_front()
                             {
-                                
self.reservation.shrink(buffered_batch.size_estimation);
+                                self.free_reservation(buffered_batch)?;
                             }
                         } else {
                             // If the head batch is not fully processed, break 
the loop.
@@ -900,25 +987,22 @@ impl SMJStream {
                     Poll::Ready(Some(batch)) => {
                         self.join_metrics.input_batches.add(1);
                         self.join_metrics.input_rows.add(batch.num_rows());
+
                         if batch.num_rows() > 0 {
                             let buffered_batch =
                                 BufferedBatch::new(batch, 0..1, 
&self.on_buffered);
-                            
self.reservation.try_grow(buffered_batch.size_estimation)?;
-                            self.join_metrics
-                                .peak_mem_used
-                                .set_max(self.reservation.size());
 
-                            
self.buffered_data.batches.push_back(buffered_batch);
+                            self.allocate_reservation(buffered_batch)?;
                             self.buffered_state = BufferedState::PollingRest;
                         }
                     }
                 },
                 BufferedState::PollingRest => {
                     if self.buffered_data.tail_batch().range.end
-                        < self.buffered_data.tail_batch().batch.num_rows()
+                        < self.buffered_data.tail_batch().num_rows
                     {
                         while self.buffered_data.tail_batch().range.end
-                            < self.buffered_data.tail_batch().batch.num_rows()
+                            < self.buffered_data.tail_batch().num_rows
                         {
                             if is_join_arrays_equal(
                                 &self.buffered_data.head_batch().join_arrays,
@@ -941,6 +1025,7 @@ impl SMJStream {
                                 self.buffered_state = BufferedState::Ready;
                             }
                             Poll::Ready(Some(batch)) => {
+                                // Polling batches coming concurrently as 
multiple partitions
                                 self.join_metrics.input_batches.add(1);
                                 
self.join_metrics.input_rows.add(batch.num_rows());
                                 if batch.num_rows() > 0 {
@@ -949,12 +1034,7 @@ impl SMJStream {
                                         0..0,
                                         &self.on_buffered,
                                     );
-                                    self.reservation
-                                        
.try_grow(buffered_batch.size_estimation)?;
-                                    self.join_metrics
-                                        .peak_mem_used
-                                        .set_max(self.reservation.size());
-                                    
self.buffered_data.batches.push_back(buffered_batch);
+                                    self.allocate_reservation(buffered_batch)?;
                                 }
                             }
                         }
@@ -1473,13 +1553,8 @@ fn produce_buffered_null_batch(
     }
 
     // Take buffered (right) columns
-    let buffered_columns = buffered_batch
-        .batch
-        .columns()
-        .iter()
-        .map(|column| take(column, &buffered_indices, None))
-        .collect::<Result<Vec<_>, ArrowError>>()
-        .map_err(Into::<DataFusionError>::into)?;
+    let buffered_columns =
+        get_buffered_columns_from_batch(buffered_batch, buffered_indices)?;
 
     // Create null streamed (left) columns
     let mut streamed_columns = streamed_schema
@@ -1502,13 +1577,45 @@ fn get_buffered_columns(
     buffered_data: &BufferedData,
     buffered_batch_idx: usize,
     buffered_indices: &UInt64Array,
-) -> Result<Vec<ArrayRef>, ArrowError> {
-    buffered_data.batches[buffered_batch_idx]
-        .batch
-        .columns()
-        .iter()
-        .map(|column| take(column, &buffered_indices, None))
-        .collect::<Result<Vec<_>, ArrowError>>()
+) -> Result<Vec<ArrayRef>> {
+    get_buffered_columns_from_batch(
+        &buffered_data.batches[buffered_batch_idx],
+        buffered_indices,
+    )
+}
+
+#[inline(always)]
+fn get_buffered_columns_from_batch(
+    buffered_batch: &BufferedBatch,
+    buffered_indices: &UInt64Array,
+) -> Result<Vec<ArrayRef>> {
+    match (&buffered_batch.spill_file, &buffered_batch.batch) {
+        // In memory batch
+        (None, Some(batch)) => Ok(batch
+            .columns()
+            .iter()
+            .map(|column| take(column, &buffered_indices, None))
+            .collect::<Result<Vec<_>, ArrowError>>()
+            .map_err(Into::<DataFusionError>::into)?),
+        // If the batch was spilled to disk, less likely
+        (Some(spill_file), None) => {
+            let mut buffered_cols: Vec<ArrayRef> =
+                Vec::with_capacity(buffered_indices.len());
+
+            let file = BufReader::new(File::open(spill_file.path())?);
+            let reader = FileReader::try_new(file, None)?;
+
+            for batch in reader {
+                batch?.columns().iter().for_each(|column| {
+                    buffered_cols.extend(take(column, &buffered_indices, None))
+                });
+            }
+
+            Ok(buffered_cols)
+        }
+        // Invalid combination
+        (spill, batch) => internal_err!("Unexpected buffered batch spill 
status. Spill exists: {}. In-memory exists: {}", spill.is_some(), 
batch.is_some()),
+    }
 }
 
 /// Calculate join filter bit mask considering join type specifics
@@ -1854,6 +1961,7 @@ mod tests {
         assert_batches_eq, assert_batches_sorted_eq, assert_contains, 
JoinType, Result,
     };
     use datafusion_execution::config::SessionConfig;
+    use datafusion_execution::disk_manager::DiskManagerConfig;
     use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
     use datafusion_execution::TaskContext;
 
@@ -2749,7 +2857,7 @@ mod tests {
     }
 
     #[tokio::test]
-    async fn overallocation_single_batch() -> Result<()> {
+    async fn overallocation_single_batch_no_spill() -> Result<()> {
         let left = build_table(
             ("a1", &vec![0, 1, 2, 3, 4, 5]),
             ("b1", &vec![1, 2, 3, 4, 5, 6]),
@@ -2775,14 +2883,17 @@ mod tests {
             JoinType::LeftAnti,
         ];
 
-        for join_type in join_types {
-            let runtime_config = RuntimeConfig::new().with_memory_limit(100, 
1.0);
-            let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
-            let session_config = SessionConfig::default().with_batch_size(50);
+        // Disable DiskManager to prevent spilling
+        let runtime_config = RuntimeConfig::new()
+            .with_memory_limit(100, 1.0)
+            .with_disk_manager(DiskManagerConfig::Disabled);
+        let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
+        let session_config = SessionConfig::default().with_batch_size(50);
 
+        for join_type in join_types {
             let task_ctx = TaskContext::default()
-                .with_session_config(session_config)
-                .with_runtime(runtime);
+                .with_session_config(session_config.clone())
+                .with_runtime(Arc::clone(&runtime));
             let task_ctx = Arc::new(task_ctx);
 
             let join = join_with_options(
@@ -2797,18 +2908,20 @@ mod tests {
             let stream = join.execute(0, task_ctx)?;
             let err = common::collect(stream).await.unwrap_err();
 
-            assert_contains!(
-                err.to_string(),
-                "Resources exhausted: Failed to allocate additional"
-            );
+            assert_contains!(err.to_string(), "Failed to allocate additional");
             assert_contains!(err.to_string(), "SMJStream[0]");
+            assert_contains!(err.to_string(), "Disk spilling disabled");
+            assert!(join.metrics().is_some());
+            assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
+            assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
+            assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
         }
 
         Ok(())
     }
 
     #[tokio::test]
-    async fn overallocation_multi_batch() -> Result<()> {
+    async fn overallocation_multi_batch_no_spill() -> Result<()> {
         let left_batch_1 = build_table_i32(
             ("a1", &vec![0, 1]),
             ("b1", &vec![1, 1]),
@@ -2855,13 +2968,17 @@ mod tests {
             JoinType::LeftAnti,
         ];
 
+        // Disable DiskManager to prevent spilling
+        let runtime_config = RuntimeConfig::new()
+            .with_memory_limit(100, 1.0)
+            .with_disk_manager(DiskManagerConfig::Disabled);
+        let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
+        let session_config = SessionConfig::default().with_batch_size(50);
+
         for join_type in join_types {
-            let runtime_config = RuntimeConfig::new().with_memory_limit(100, 
1.0);
-            let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
-            let session_config = SessionConfig::default().with_batch_size(50);
             let task_ctx = TaskContext::default()
-                .with_session_config(session_config)
-                .with_runtime(runtime);
+                .with_session_config(session_config.clone())
+                .with_runtime(Arc::clone(&runtime));
             let task_ctx = Arc::new(task_ctx);
             let join = join_with_options(
                 Arc::clone(&left),
@@ -2875,11 +2992,205 @@ mod tests {
             let stream = join.execute(0, task_ctx)?;
             let err = common::collect(stream).await.unwrap_err();
 
-            assert_contains!(
-                err.to_string(),
-                "Resources exhausted: Failed to allocate additional"
-            );
+            assert_contains!(err.to_string(), "Failed to allocate additional");
             assert_contains!(err.to_string(), "SMJStream[0]");
+            assert_contains!(err.to_string(), "Disk spilling disabled");
+            assert!(join.metrics().is_some());
+            assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
+            assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
+            assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
+        }
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn overallocation_single_batch_spill() -> Result<()> {
+        let left = build_table(
+            ("a1", &vec![0, 1, 2, 3, 4, 5]),
+            ("b1", &vec![1, 2, 3, 4, 5, 6]),
+            ("c1", &vec![4, 5, 6, 7, 8, 9]),
+        );
+        let right = build_table(
+            ("a2", &vec![0, 10, 20, 30, 40]),
+            ("b2", &vec![1, 3, 4, 6, 8]),
+            ("c2", &vec![50, 60, 70, 80, 90]),
+        );
+        let on = vec![(
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
+        )];
+        let sort_options = vec![SortOptions::default(); on.len()];
+
+        let join_types = [
+            JoinType::Inner,
+            JoinType::Left,
+            JoinType::Right,
+            JoinType::Full,
+            JoinType::LeftSemi,
+            JoinType::LeftAnti,
+        ];
+
+        // Enable DiskManager to allow spilling
+        let runtime_config = RuntimeConfig::new()
+            .with_memory_limit(100, 1.0)
+            .with_disk_manager(DiskManagerConfig::NewOs);
+        let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
+
+        for batch_size in [1, 50] {
+            let session_config = 
SessionConfig::default().with_batch_size(batch_size);
+
+            for join_type in &join_types {
+                let task_ctx = TaskContext::default()
+                    .with_session_config(session_config.clone())
+                    .with_runtime(Arc::clone(&runtime));
+                let task_ctx = Arc::new(task_ctx);
+
+                let join = join_with_options(
+                    Arc::clone(&left),
+                    Arc::clone(&right),
+                    on.clone(),
+                    *join_type,
+                    sort_options.clone(),
+                    false,
+                )?;
+
+                let stream = join.execute(0, task_ctx)?;
+                let spilled_join_result = 
common::collect(stream).await.unwrap();
+
+                assert!(join.metrics().is_some());
+                assert!(join.metrics().unwrap().spill_count().unwrap() > 0);
+                assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0);
+                assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0);
+
+                // Run the test with no spill configuration as
+                let task_ctx_no_spill =
+                    
TaskContext::default().with_session_config(session_config.clone());
+                let task_ctx_no_spill = Arc::new(task_ctx_no_spill);
+
+                let join = join_with_options(
+                    Arc::clone(&left),
+                    Arc::clone(&right),
+                    on.clone(),
+                    *join_type,
+                    sort_options.clone(),
+                    false,
+                )?;
+                let stream = join.execute(0, task_ctx_no_spill)?;
+                let no_spilled_join_result = 
common::collect(stream).await.unwrap();
+
+                assert!(join.metrics().is_some());
+                assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
+                assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
+                assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
+                // Compare spilled and non spilled data to check spill logic 
doesn't corrupt the data
+                assert_eq!(spilled_join_result, no_spilled_join_result);
+            }
+        }
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn overallocation_multi_batch_spill() -> Result<()> {
+        let left_batch_1 = build_table_i32(
+            ("a1", &vec![0, 1]),
+            ("b1", &vec![1, 1]),
+            ("c1", &vec![4, 5]),
+        );
+        let left_batch_2 = build_table_i32(
+            ("a1", &vec![2, 3]),
+            ("b1", &vec![1, 1]),
+            ("c1", &vec![6, 7]),
+        );
+        let left_batch_3 = build_table_i32(
+            ("a1", &vec![4, 5]),
+            ("b1", &vec![1, 1]),
+            ("c1", &vec![8, 9]),
+        );
+        let right_batch_1 = build_table_i32(
+            ("a2", &vec![0, 10]),
+            ("b2", &vec![1, 1]),
+            ("c2", &vec![50, 60]),
+        );
+        let right_batch_2 = build_table_i32(
+            ("a2", &vec![20, 30]),
+            ("b2", &vec![1, 1]),
+            ("c2", &vec![70, 80]),
+        );
+        let right_batch_3 =
+            build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", 
&vec![90]));
+        let left =
+            build_table_from_batches(vec![left_batch_1, left_batch_2, 
left_batch_3]);
+        let right =
+            build_table_from_batches(vec![right_batch_1, right_batch_2, 
right_batch_3]);
+        let on = vec![(
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
+        )];
+        let sort_options = vec![SortOptions::default(); on.len()];
+
+        let join_types = [
+            JoinType::Inner,
+            JoinType::Left,
+            JoinType::Right,
+            JoinType::Full,
+            JoinType::LeftSemi,
+            JoinType::LeftAnti,
+        ];
+
+        // Enable DiskManager to allow spilling
+        let runtime_config = RuntimeConfig::new()
+            .with_memory_limit(500, 1.0)
+            .with_disk_manager(DiskManagerConfig::NewOs);
+        let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
+        for batch_size in [1, 50] {
+            let session_config = 
SessionConfig::default().with_batch_size(batch_size);
+
+            for join_type in &join_types {
+                let task_ctx = TaskContext::default()
+                    .with_session_config(session_config.clone())
+                    .with_runtime(Arc::clone(&runtime));
+                let task_ctx = Arc::new(task_ctx);
+                let join = join_with_options(
+                    Arc::clone(&left),
+                    Arc::clone(&right),
+                    on.clone(),
+                    *join_type,
+                    sort_options.clone(),
+                    false,
+                )?;
+
+                let stream = join.execute(0, task_ctx)?;
+                let spilled_join_result = 
common::collect(stream).await.unwrap();
+                assert!(join.metrics().is_some());
+                assert!(join.metrics().unwrap().spill_count().unwrap() > 0);
+                assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0);
+                assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0);
+
+                // Run the test with no spill configuration as
+                let task_ctx_no_spill =
+                    
TaskContext::default().with_session_config(session_config.clone());
+                let task_ctx_no_spill = Arc::new(task_ctx_no_spill);
+
+                let join = join_with_options(
+                    Arc::clone(&left),
+                    Arc::clone(&right),
+                    on.clone(),
+                    *join_type,
+                    sort_options.clone(),
+                    false,
+                )?;
+                let stream = join.execute(0, task_ctx_no_spill)?;
+                let no_spilled_join_result = 
common::collect(stream).await.unwrap();
+
+                assert!(join.metrics().is_some());
+                assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
+                assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
+                assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
+                // Compare spilled and non spilled data to check spill logic 
doesn't corrupt the data
+                assert_eq!(spilled_join_result, no_spilled_join_result);
+            }
         }
 
         Ok(())
diff --git a/datafusion/physical-plan/src/sorts/sort.rs 
b/datafusion/physical-plan/src/sorts/sort.rs
index d576f77d9f..13ff63c174 100644
--- a/datafusion/physical-plan/src/sorts/sort.rs
+++ b/datafusion/physical-plan/src/sorts/sort.rs
@@ -45,7 +45,7 @@ use arrow::record_batch::RecordBatch;
 use arrow::row::{RowConverter, SortField};
 use arrow_array::{Array, RecordBatchOptions, UInt32Array};
 use arrow_schema::DataType;
-use datafusion_common::{DataFusionError, Result};
+use datafusion_common::{internal_err, Result};
 use datafusion_execution::disk_manager::RefCountedTempFile;
 use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
 use datafusion_execution::runtime_env::RuntimeEnv;
@@ -333,10 +333,7 @@ impl ExternalSorter {
 
             for spill in self.spills.drain(..) {
                 if !spill.path().exists() {
-                    return Err(DataFusionError::Internal(format!(
-                        "Spill file {:?} does not exist",
-                        spill.path()
-                    )));
+                    return internal_err!("Spill file {:?} does not exist", 
spill.path());
                 }
                 let stream = read_spill_as_stream(spill, 
Arc::clone(&self.schema), 2)?;
                 streams.push(stream);
diff --git a/datafusion/physical-plan/src/spill.rs 
b/datafusion/physical-plan/src/spill.rs
index 0018a27bd2..21ca58fa0a 100644
--- a/datafusion/physical-plan/src/spill.rs
+++ b/datafusion/physical-plan/src/spill.rs
@@ -40,7 +40,7 @@ use crate::stream::RecordBatchReceiverStream;
 /// `path` - temp file
 /// `schema` - batches schema, should be the same across batches
 /// `buffer` - internal buffer of capacity batches
-pub fn read_spill_as_stream(
+pub(crate) fn read_spill_as_stream(
     path: RefCountedTempFile,
     schema: SchemaRef,
     buffer: usize,
@@ -56,7 +56,7 @@ pub fn read_spill_as_stream(
 /// Spills in-memory `batches` to disk.
 ///
 /// Returns total number of the rows spilled to disk.
-pub fn spill_record_batches(
+pub(crate) fn spill_record_batches(
     batches: Vec<RecordBatch>,
     path: PathBuf,
     schema: SchemaRef,
@@ -85,3 +85,102 @@ fn read_spill(sender: Sender<Result<RecordBatch>>, path: 
&Path) -> Result<()> {
     }
     Ok(())
 }
+
+/// Spill the `RecordBatch` to disk as smaller batches
+/// split by `batch_size_rows`
+/// Return `total_rows` what is spilled
+pub fn spill_record_batch_by_size(
+    batch: &RecordBatch,
+    path: PathBuf,
+    schema: SchemaRef,
+    batch_size_rows: usize,
+) -> Result<()> {
+    let mut offset = 0;
+    let total_rows = batch.num_rows();
+    let mut writer = IPCWriter::new(&path, schema.as_ref())?;
+
+    while offset < total_rows {
+        let length = std::cmp::min(total_rows - offset, batch_size_rows);
+        let batch = batch.slice(offset, length);
+        offset += batch.num_rows();
+        writer.write(&batch)?;
+    }
+    writer.finish()?;
+
+    Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::spill::{spill_record_batch_by_size, spill_record_batches};
+    use crate::test::build_table_i32;
+    use datafusion_common::Result;
+    use datafusion_execution::disk_manager::DiskManagerConfig;
+    use datafusion_execution::DiskManager;
+    use std::fs::File;
+    use std::io::BufReader;
+    use std::sync::Arc;
+
+    #[test]
+    fn test_batch_spill_and_read() -> Result<()> {
+        let batch1 = build_table_i32(
+            ("a2", &vec![0, 1, 2]),
+            ("b2", &vec![3, 4, 5]),
+            ("c2", &vec![4, 5, 6]),
+        );
+
+        let batch2 = build_table_i32(
+            ("a2", &vec![10, 11, 12]),
+            ("b2", &vec![13, 14, 15]),
+            ("c2", &vec![14, 15, 16]),
+        );
+
+        let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?;
+
+        let spill_file = disk_manager.create_tmp_file("Test Spill")?;
+        let schema = batch1.schema();
+        let num_rows = batch1.num_rows() + batch2.num_rows();
+        let cnt = spill_record_batches(
+            vec![batch1, batch2],
+            spill_file.path().into(),
+            Arc::clone(&schema),
+        );
+        assert_eq!(cnt.unwrap(), num_rows);
+
+        let file = BufReader::new(File::open(spill_file.path())?);
+        let reader = arrow::ipc::reader::FileReader::try_new(file, None)?;
+
+        assert_eq!(reader.num_batches(), 2);
+        assert_eq!(reader.schema(), schema);
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_batch_spill_by_size() -> Result<()> {
+        let batch1 = build_table_i32(
+            ("a2", &vec![0, 1, 2, 3]),
+            ("b2", &vec![3, 4, 5, 6]),
+            ("c2", &vec![4, 5, 6, 7]),
+        );
+
+        let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?;
+
+        let spill_file = disk_manager.create_tmp_file("Test Spill")?;
+        let schema = batch1.schema();
+        spill_record_batch_by_size(
+            &batch1,
+            spill_file.path().into(),
+            Arc::clone(&schema),
+            1,
+        )?;
+
+        let file = BufReader::new(File::open(spill_file.path())?);
+        let reader = arrow::ipc::reader::FileReader::try_new(file, None)?;
+
+        assert_eq!(reader.num_batches(), 4);
+        assert_eq!(reader.schema(), schema);
+
+        Ok(())
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to