This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 36991aca9f limit intermediate batch size in nested_loop_join (#16443)
36991aca9f is described below

commit 36991aca9fabac8fe010a2b27d49b64d96658d2e
Author: UBarney <[email protected]>
AuthorDate: Tue Jul 15 03:49:04 2025 +0800

    limit intermediate batch size in nested_loop_join (#16443)
    
    * limit intermediate_batch_size in nested_loop_join
    
    * address some todo
    
    * address comment
    
    * address comments
    
    * address comment
    
    * fix typo
    
    * address comment
---
 datafusion/physical-plan/src/joins/hash_join.rs    |   1 +
 .../physical-plan/src/joins/nested_loop_join.rs    | 505 ++++++++++++++-------
 .../physical-plan/src/joins/symmetric_hash_join.rs |   1 +
 datafusion/physical-plan/src/joins/utils.rs        |  60 ++-
 4 files changed, 382 insertions(+), 185 deletions(-)

diff --git a/datafusion/physical-plan/src/joins/hash_join.rs 
b/datafusion/physical-plan/src/joins/hash_join.rs
index 40ab9f2a0b..a7f28ede44 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -1533,6 +1533,7 @@ impl HashJoinStream {
                 right_indices,
                 filter,
                 JoinSide::Left,
+                None,
             )?
         } else {
             (left_indices, right_indices)
diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs 
b/datafusion/physical-plan/src/joins/nested_loop_join.rs
index c84b3a9d40..343edbb0d4 100644
--- a/datafusion/physical-plan/src/joins/nested_loop_join.rs
+++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs
@@ -18,6 +18,7 @@
 //! [`NestedLoopJoinExec`]: joins without equijoin (equality predicates).
 
 use std::any::Any;
+use std::cmp::min;
 use std::fmt::Formatter;
 use std::sync::atomic::{AtomicUsize, Ordering};
 use std::sync::Arc;
@@ -26,7 +27,7 @@ use std::task::Poll;
 use super::utils::{
     asymmetric_join_output_partitioning, get_final_indices_from_shared_bitmap,
     need_produce_result_in_final, reorder_output_after_swap, 
swap_join_projection,
-    BatchSplitter, BatchTransformer, NoopBatchTransformer, 
StatefulStreamResult,
+    StatefulStreamResult,
 };
 use crate::common::can_project;
 use crate::execution_plan::{boundedness_from_children, EmissionType};
@@ -47,12 +48,13 @@ use crate::{
     SendableRecordBatchStream,
 };
 
-use arrow::array::{BooleanBufferBuilder, UInt32Array, UInt64Array};
+use arrow::array::{BooleanBufferBuilder, PrimitiveArray, UInt32Array, 
UInt64Array};
 use arrow::compute::concat_batches;
-use arrow::datatypes::{Schema, SchemaRef};
+use arrow::datatypes::{Schema, SchemaRef, UInt32Type, UInt64Type};
 use arrow::record_batch::RecordBatch;
 use datafusion_common::{
-    exec_datafusion_err, internal_err, project_schema, JoinSide, Result, 
Statistics,
+    exec_datafusion_err, internal_datafusion_err, internal_err, 
project_schema, JoinSide,
+    Result, Statistics,
 };
 use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
 use datafusion_execution::TaskContext;
@@ -510,8 +512,6 @@ impl ExecutionPlan for NestedLoopJoinExec {
         })?;
 
         let batch_size = context.session_config().batch_size();
-        let enforce_batch_size_in_joins =
-            context.session_config().enforce_batch_size_in_joins();
 
         let outer_table = self.right.execute(partition, context)?;
 
@@ -530,37 +530,21 @@ impl ExecutionPlan for NestedLoopJoinExec {
             None => self.column_indices.clone(),
         };
 
-        if enforce_batch_size_in_joins {
-            Ok(Box::pin(NestedLoopJoinStream {
-                schema: self.schema(),
-                filter: self.filter.clone(),
-                join_type: self.join_type,
-                outer_table,
-                inner_table,
-                column_indices: column_indices_after_projection,
-                join_metrics,
-                indices_cache,
-                right_side_ordered,
-                state: NestedLoopJoinStreamState::WaitBuildSide,
-                batch_transformer: BatchSplitter::new(batch_size),
-                left_data: None,
-            }))
-        } else {
-            Ok(Box::pin(NestedLoopJoinStream {
-                schema: self.schema(),
-                filter: self.filter.clone(),
-                join_type: self.join_type,
-                outer_table,
-                inner_table,
-                column_indices: column_indices_after_projection,
-                join_metrics,
-                indices_cache,
-                right_side_ordered,
-                state: NestedLoopJoinStreamState::WaitBuildSide,
-                batch_transformer: NoopBatchTransformer::new(),
-                left_data: None,
-            }))
-        }
+        Ok(Box::pin(NestedLoopJoinStream {
+            schema: self.schema(),
+            filter: self.filter.clone(),
+            join_type: self.join_type,
+            outer_table,
+            inner_table,
+            column_indices: column_indices_after_projection,
+            join_metrics,
+            indices_cache,
+            right_side_ordered,
+            state: NestedLoopJoinStreamState::WaitBuildSide,
+            left_data: None,
+            join_result_status: None,
+            intermediate_batch_size: batch_size,
+        }))
     }
 
     fn metrics(&self) -> Option<MetricsSet> {
@@ -687,8 +671,15 @@ enum NestedLoopJoinStreamState {
     /// Indicates that a non-empty batch has been fetched from probe-side, and
     /// is ready to be processed
     ProcessProbeBatch(RecordBatch),
-    /// Indicates that probe-side has been fully processed
-    ExhaustedProbeSide,
+    /// Preparation phase: Gathers the indices of unmatched rows from the 
build-side.
+    /// This state is entered for join types that emit unmatched build-side 
rows
+    /// (e.g., LEFT and FULL joins) after the entire probe-side input has been 
consumed.
+    PrepareUnmatchedBuildRows,
+    /// Output unmatched build-side rows.
+    /// The indices for rows to output has already been calculated in the 
previous
+    /// `PrepareUnmatchedBuildRows` state. In this state the final batch will 
be materialized incrementally.
+    // The inner `RecordBatch` is an empty dummy batch used to get right 
schema.
+    OutputUnmatchedBuildRows(RecordBatch),
     /// Indicates that NestedLoopJoinStream execution is completed
     Completed,
 }
@@ -705,8 +696,29 @@ impl NestedLoopJoinStreamState {
     }
 }
 
+/// Tracks incremental output of join result batches.
+///
+/// Initialized with all matching pairs that satisfy the join predicate.
+/// Pairs are stored as indices in `build_indices` and `probe_indices`
+/// Each poll outputs a batch within the configured size limit and updates
+/// processed_count until all pairs are consumed.
+///
+/// Example: 5000 matches, batch size limit is 100
+/// - Poll 1: output batch[0..100], processed_count = 100  
+/// - Poll 2: output batch[100..200], processed_count = 200
+/// - ...continues until processed_count = 5000
+struct JoinResultProgress {
+    /// Row indices from build-side table (left table).
+    build_indices: PrimitiveArray<UInt64Type>,
+    /// Row indices from probe-side table (right table).
+    probe_indices: PrimitiveArray<UInt32Type>,
+    /// Number of index pairs already processed into output batches.
+    /// We have completed join result for indices [0..processed_count).
+    processed_count: usize,
+}
+
 /// A stream that issues [RecordBatch]es as they arrive from the right  of the 
join.
-struct NestedLoopJoinStream<T> {
+struct NestedLoopJoinStream {
     /// Input schema
     schema: Arc<Schema>,
     /// join filter
@@ -729,10 +741,13 @@ struct NestedLoopJoinStream<T> {
     right_side_ordered: bool,
     /// Current state of the stream
     state: NestedLoopJoinStreamState,
-    /// Transforms the output batch before returning.
-    batch_transformer: T,
     /// Result of the left data future
     left_data: Option<Arc<JoinLeftData>>,
+
+    /// Tracks progress when building join result batches incrementally.
+    join_result_status: Option<JoinResultProgress>,
+
+    intermediate_batch_size: usize,
 }
 
 /// Creates a Cartesian product of two input batches, preserving the order of 
the right batch,
@@ -755,6 +770,7 @@ fn build_join_indices(
     right_batch: &RecordBatch,
     filter: Option<&JoinFilter>,
     indices_cache: &mut (UInt64Array, UInt32Array),
+    max_intermediate_batch_size: usize,
 ) -> Result<(UInt64Array, UInt32Array)> {
     let left_row_count = left_batch.num_rows();
     let right_row_count = right_batch.num_rows();
@@ -805,13 +821,14 @@ fn build_join_indices(
             right_indices,
             filter,
             JoinSide::Left,
+            Some(max_intermediate_batch_size),
         )
     } else {
         Ok((left_indices, right_indices))
     }
 }
 
-impl<T: BatchTransformer> NestedLoopJoinStream<T> {
+impl NestedLoopJoinStream {
     fn poll_next_impl(
         &mut self,
         cx: &mut std::task::Context<'_>,
@@ -828,8 +845,11 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
                     let poll = handle_state!(self.process_probe_batch());
                     self.join_metrics.baseline.record_poll(poll)
                 }
-                NestedLoopJoinStreamState::ExhaustedProbeSide => {
-                    let poll = 
handle_state!(self.process_unmatched_build_batch());
+                NestedLoopJoinStreamState::PrepareUnmatchedBuildRows => {
+                    handle_state!(self.prepare_unmatched_output_indices())
+                }
+                NestedLoopJoinStreamState::OutputUnmatchedBuildRows(_) => {
+                    let poll = handle_state!(self.build_unmatched_output());
                     self.join_metrics.baseline.record_poll(poll)
                 }
                 NestedLoopJoinStreamState::Completed => Poll::Ready(None),
@@ -837,6 +857,116 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
         }
     }
 
+    // This function's main job is to construct an output `RecordBatch` based 
on pre-calculated join indices.
+    // It operates in a chunk-based manner, meaning it processes a portion of 
the results in each call,
+    // making it suitable for streaming large datasets without high memory 
consumption.
+    // This function behaves like an iterator. It returns `Ok(None)`
+    // to signal that the result stream is exhausted and there is no more data.
+    fn get_next_join_result(&mut self) -> Result<Option<RecordBatch>> {
+        let status = self.join_result_status.as_mut().ok_or_else(|| {
+            internal_datafusion_err!(
+                "get_next_join_result called without initializing 
join_result_status"
+            )
+        })?;
+
+        let (left_indices, right_indices, current_start) = (
+            &status.build_indices,
+            &status.probe_indices,
+            status.processed_count,
+        );
+
+        let left_batch = self
+            .left_data
+            .as_ref()
+            .ok_or_else(|| internal_datafusion_err!("should have left_batch"))?
+            .batch();
+
+        let right_batch = match &self.state {
+            NestedLoopJoinStreamState::ProcessProbeBatch(record_batch)
+            | 
NestedLoopJoinStreamState::OutputUnmatchedBuildRows(record_batch) => {
+                record_batch
+            }
+            _ => {
+                return internal_err!(
+                    "State should be ProcessProbeBatch or 
OutputUnmatchedBuildRows"
+                )
+            }
+        };
+
+        if left_indices.is_empty() && right_indices.is_empty() && 
current_start == 0 {
+            // To match the behavior of the previous implementation, return an 
empty RecordBatch.
+            let res = RecordBatch::new_empty(Arc::clone(&self.schema));
+            status.processed_count = 1;
+            return Ok(Some(res));
+        }
+
+        if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) 
{
+            // in this case left_indices.num_rows() == 0
+            let end = min(
+                current_start + self.intermediate_batch_size,
+                right_indices.len(),
+            );
+
+            if current_start >= end {
+                return Ok(None);
+            }
+
+            let res = Some(build_batch_from_indices(
+                &self.schema,
+                left_batch,
+                right_batch,
+                left_indices,
+                &right_indices.slice(current_start, end - current_start),
+                &self.column_indices,
+                JoinSide::Left,
+            )?);
+
+            status.processed_count = end;
+            return Ok(res);
+        }
+
+        if current_start >= left_indices.len() {
+            return Ok(None);
+        }
+
+        let end = min(
+            current_start + self.intermediate_batch_size,
+            left_indices.len(),
+        );
+
+        let left_indices = &left_indices.slice(current_start, end - 
current_start);
+        let right_indices = &right_indices.slice(current_start, end - 
current_start);
+
+        // Switch around the build side and probe side for 
`JoinType::RightMark`
+        // because in a RightMark join, we want to mark rows on the right table
+        // by looking for matches in the left.
+        let res = if self.join_type == JoinType::RightMark {
+            build_batch_from_indices(
+                &self.schema,
+                right_batch,
+                left_batch,
+                left_indices,
+                right_indices,
+                &self.column_indices,
+                JoinSide::Right,
+            )
+        } else {
+            build_batch_from_indices(
+                &self.schema,
+                left_batch,
+                right_batch,
+                left_indices,
+                right_indices,
+                &self.column_indices,
+                JoinSide::Left,
+            )
+        }?;
+
+        status.processed_count = end;
+
+        Ok(Some(res))
+    }
+
     fn collect_build_side(
         &mut self,
         cx: &mut std::task::Context<'_>,
@@ -861,9 +991,12 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
     ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
         match ready!(self.outer_table.poll_next_unpin(cx)) {
             None => {
-                self.state = NestedLoopJoinStreamState::ExhaustedProbeSide;
+                self.state = 
NestedLoopJoinStreamState::PrepareUnmatchedBuildRows;
             }
             Some(Ok(right_batch)) => {
+                self.join_metrics.input_batches.add(1);
+                self.join_metrics.input_rows.add(right_batch.num_rows());
+
                 self.state = 
NestedLoopJoinStreamState::ProcessProbeBatch(right_batch);
             }
             Some(Err(err)) => return Poll::Ready(Err(err)),
@@ -885,43 +1018,64 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
         let visited_left_side = left_data.bitmap();
         let batch = self.state.try_as_process_probe_batch()?;
 
-        match self.batch_transformer.next() {
-            None => {
-                // Setting up timer & updating input metrics
-                self.join_metrics.input_batches.add(1);
-                self.join_metrics.input_rows.add(batch.num_rows());
-                let timer = self.join_metrics.join_time.timer();
-
-                let result = join_left_and_right_batch(
-                    left_data.batch(),
-                    batch,
-                    self.join_type,
-                    self.filter.as_ref(),
-                    &self.column_indices,
-                    &self.schema,
-                    visited_left_side,
-                    &mut self.indices_cache,
-                    self.right_side_ordered,
-                );
-                timer.done();
+        let binding = self.join_metrics.join_time.clone();
+        let _timer = binding.timer();
+
+        if self.join_result_status.is_none() {
+            let (left_side_indices, right_side_indices) = 
join_left_and_right_batch(
+                left_data.batch(),
+                batch,
+                self.join_type,
+                self.filter.as_ref(),
+                visited_left_side,
+                &mut self.indices_cache,
+                self.right_side_ordered,
+                self.intermediate_batch_size,
+            )?;
+            self.join_result_status = Some(JoinResultProgress {
+                build_indices: left_side_indices,
+                probe_indices: right_side_indices,
+                processed_count: 0,
+            })
+        }
+
+        let join_result = self.get_next_join_result()?;
 
-                self.batch_transformer.set_batch(result?);
+        match join_result {
+            Some(res) => {
+                self.join_metrics.output_batches.add(1);
+                Ok(StatefulStreamResult::Ready(Some(res)))
+            }
+            None => {
+                self.state = NestedLoopJoinStreamState::FetchProbeBatch;
+                self.join_result_status = None;
                 Ok(StatefulStreamResult::Continue)
             }
-            Some((batch, last)) => {
-                if last {
-                    self.state = NestedLoopJoinStreamState::FetchProbeBatch;
-                }
+        }
+    }
 
+    fn build_unmatched_output(
+        &mut self,
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
+        let binding = self.join_metrics.join_time.clone();
+        let _timer = binding.timer();
+
+        let res = self.get_next_join_result()?;
+        match res {
+            Some(res) => {
                 self.join_metrics.output_batches.add(1);
-                Ok(StatefulStreamResult::Ready(Some(batch)))
+                Ok(StatefulStreamResult::Ready(Some(res)))
+            }
+            None => {
+                self.state = NestedLoopJoinStreamState::Completed;
+                Ok(StatefulStreamResult::Ready(None))
             }
         }
     }
 
-    /// Processes unmatched build-side rows for certain join types and produces
-    /// output batch, updates state to `Completed`.
-    fn process_unmatched_build_batch(
+    /// This function's primary purpose is to handle the final output stage 
required by specific join types after all right-side (probe) data has been 
exhausted.
+    /// It is critically important for LEFT*/FULL joins, which must emit 
left-side (build) rows that found no match. For these cases, it identifies the 
unmatched rows and prepares the necessary state to output them.
+    fn prepare_unmatched_output_indices(
         &mut self,
     ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
         let Some(left_data) = self.left_data.clone() else {
@@ -942,31 +1096,21 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
             };
 
             // Only setting up timer, input is exhausted
-            let timer = self.join_metrics.join_time.timer();
+            let _timer = self.join_metrics.join_time.timer();
             // use the global left bitmap to produce the left indices and 
right indices
             let (left_side, right_side) =
                 get_final_indices_from_shared_bitmap(visited_left_side, 
self.join_type);
-            let empty_right_batch = 
RecordBatch::new_empty(self.outer_table.schema());
-            // use the left and right indices to produce the batch result
-            let result = build_batch_from_indices(
-                &self.schema,
-                left_data.batch(),
-                &empty_right_batch,
-                &left_side,
-                &right_side,
-                &self.column_indices,
-                JoinSide::Left,
-            );
-            self.state = NestedLoopJoinStreamState::Completed;
-
-            // Recording time
-            if result.is_ok() {
-                timer.done();
-            }
 
-            self.join_metrics.output_batches.add(1);
+            self.join_result_status = Some(JoinResultProgress {
+                build_indices: left_side,
+                probe_indices: right_side,
+                processed_count: 0,
+            });
+            self.state = NestedLoopJoinStreamState::OutputUnmatchedBuildRows(
+                RecordBatch::new_empty(self.outer_table.schema()),
+            );
 
-            Ok(StatefulStreamResult::Ready(Some(result?)))
+            Ok(StatefulStreamResult::Continue)
         } else {
             // end of the join loop
             self.state = NestedLoopJoinStreamState::Completed;
@@ -981,20 +1125,23 @@ fn join_left_and_right_batch(
     right_batch: &RecordBatch,
     join_type: JoinType,
     filter: Option<&JoinFilter>,
-    column_indices: &[ColumnIndex],
-    schema: &Schema,
     visited_left_side: &SharedBitmapBuilder,
     indices_cache: &mut (UInt64Array, UInt32Array),
     right_side_ordered: bool,
-) -> Result<RecordBatch> {
-    let (left_side, right_side) =
-        build_join_indices(left_batch, right_batch, filter, 
indices_cache).map_err(
-            |e| {
-                exec_datafusion_err!(
-                    "Fail to build join indices in NestedLoopJoinExec, error: 
{e}"
-                )
-            },
-        )?;
+    max_intermediate_batch_size: usize,
+) -> Result<(PrimitiveArray<UInt64Type>, PrimitiveArray<UInt32Type>)> {
+    let (left_side, right_side) = build_join_indices(
+        left_batch,
+        right_batch,
+        filter,
+        indices_cache,
+        max_intermediate_batch_size,
+    )
+    .map_err(|e| {
+        exec_datafusion_err!(
+            "Fail to build join indices in NestedLoopJoinExec, error: {e}"
+        )
+    })?;
 
     // set the left bitmap
     // and only full join need the left bitmap
@@ -1013,33 +1160,10 @@ fn join_left_and_right_batch(
         right_side_ordered,
     )?;
 
-    // Switch around the build side and probe side for `JoinType::RightMark`
-    // because in a RightMark join, we want to mark rows on the right table
-    // by looking for matches in the left.
-    if join_type == JoinType::RightMark {
-        build_batch_from_indices(
-            schema,
-            right_batch,
-            left_batch,
-            &left_side,
-            &right_side,
-            column_indices,
-            JoinSide::Right,
-        )
-    } else {
-        build_batch_from_indices(
-            schema,
-            left_batch,
-            right_batch,
-            &left_side,
-            &right_side,
-            column_indices,
-            JoinSide::Left,
-        )
-    }
+    Ok((left_side, right_side))
 }
 
-impl<T: BatchTransformer + Unpin + Send> Stream for NestedLoopJoinStream<T> {
+impl Stream for NestedLoopJoinStream {
     type Item = Result<RecordBatch>;
 
     fn poll_next(
@@ -1050,7 +1174,7 @@ impl<T: BatchTransformer + Unpin + Send> Stream for 
NestedLoopJoinStream<T> {
     }
 }
 
-impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for 
NestedLoopJoinStream<T> {
+impl RecordBatchStream for NestedLoopJoinStream {
     fn schema(&self) -> SchemaRef {
         Arc::clone(&self.schema)
     }
@@ -1081,6 +1205,7 @@ pub(crate) mod tests {
     use datafusion_physical_expr::{Partitioning, PhysicalExpr};
     use datafusion_physical_expr_common::sort_expr::{LexOrdering, 
PhysicalSortExpr};
 
+    use insta::allow_duplicates;
     use insta::assert_snapshot;
     use rstest::rstest;
 
@@ -1218,6 +1343,9 @@ pub(crate) mod tests {
             batches.extend(
                 more_batches
                     .into_iter()
+                    .inspect(|b| {
+                        assert!(b.num_rows() <= 
context.session_config().batch_size())
+                    })
                     .filter(|b| b.num_rows() > 0)
                     .collect::<Vec<_>>(),
             );
@@ -1228,9 +1356,18 @@ pub(crate) mod tests {
         Ok((columns, batches, metrics))
     }
 
+    fn new_task_ctx(batch_size: usize) -> Arc<TaskContext> {
+        let base = TaskContext::default();
+        // limit max size of intermediate batch used in nlj to 1
+        let cfg = base.session_config().clone().with_batch_size(batch_size);
+        Arc::new(base.with_session_config(cfg))
+    }
+
+    #[rstest]
     #[tokio::test]
-    async fn join_inner_with_filter() -> Result<()> {
-        let task_ctx = Arc::new(TaskContext::default());
+    async fn join_inner_with_filter(#[values(1, 2, 16)] batch_size: usize) -> 
Result<()> {
+        let task_ctx = new_task_ctx(batch_size);
+        dbg!(&batch_size);
         let left = build_left_table();
         let right = build_right_table();
         let filter = prepare_join_filter();
@@ -1242,23 +1379,25 @@ pub(crate) mod tests {
             task_ctx,
         )
         .await?;
+
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
-        assert_snapshot!(batches_to_sort_string(&batches), @r#"
+        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), 
@r#"
             +----+----+----+----+----+----+
             | a1 | b1 | c1 | a2 | b2 | c2 |
             +----+----+----+----+----+----+
             | 5  | 5  | 50 | 2  | 2  | 80 |
             +----+----+----+----+----+----+
-            "#);
+            "#));
 
         assert_join_metrics!(metrics, 1);
 
         Ok(())
     }
 
+    #[rstest]
     #[tokio::test]
-    async fn join_left_with_filter() -> Result<()> {
-        let task_ctx = Arc::new(TaskContext::default());
+    async fn join_left_with_filter(#[values(1, 2, 16)] batch_size: usize) -> 
Result<()> {
+        let task_ctx = new_task_ctx(batch_size);
         let left = build_left_table();
         let right = build_right_table();
 
@@ -1272,7 +1411,7 @@ pub(crate) mod tests {
         )
         .await?;
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
-        assert_snapshot!(batches_to_sort_string(&batches), @r#"
+        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), 
@r#"
             +----+----+-----+----+----+----+
             | a1 | b1 | c1  | a2 | b2 | c2 |
             +----+----+-----+----+----+----+
@@ -1280,16 +1419,17 @@ pub(crate) mod tests {
             | 5  | 5  | 50  | 2  | 2  | 80 |
             | 9  | 8  | 90  |    |    |    |
             +----+----+-----+----+----+----+
-            "#);
+            "#));
 
         assert_join_metrics!(metrics, 3);
 
         Ok(())
     }
 
+    #[rstest]
     #[tokio::test]
-    async fn join_right_with_filter() -> Result<()> {
-        let task_ctx = Arc::new(TaskContext::default());
+    async fn join_right_with_filter(#[values(1, 2, 16)] batch_size: usize) -> 
Result<()> {
+        let task_ctx = new_task_ctx(batch_size);
         let left = build_left_table();
         let right = build_right_table();
 
@@ -1303,7 +1443,7 @@ pub(crate) mod tests {
         )
         .await?;
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
-        assert_snapshot!(batches_to_sort_string(&batches), @r#"
+        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), 
@r#"
             +----+----+----+----+----+-----+
             | a1 | b1 | c1 | a2 | b2 | c2  |
             +----+----+----+----+----+-----+
@@ -1311,16 +1451,17 @@ pub(crate) mod tests {
             |    |    |    | 12 | 10 | 40  |
             | 5  | 5  | 50 | 2  | 2  | 80  |
             +----+----+----+----+----+-----+
-            "#);
+            "#));
 
         assert_join_metrics!(metrics, 3);
 
         Ok(())
     }
 
+    #[rstest]
     #[tokio::test]
-    async fn join_full_with_filter() -> Result<()> {
-        let task_ctx = Arc::new(TaskContext::default());
+    async fn join_full_with_filter(#[values(1, 2, 16)] batch_size: usize) -> 
Result<()> {
+        let task_ctx = new_task_ctx(batch_size);
         let left = build_left_table();
         let right = build_right_table();
 
@@ -1334,7 +1475,7 @@ pub(crate) mod tests {
         )
         .await?;
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
-        assert_snapshot!(batches_to_sort_string(&batches), @r#"
+        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), 
@r#"
             +----+----+-----+----+----+-----+
             | a1 | b1 | c1  | a2 | b2 | c2  |
             +----+----+-----+----+----+-----+
@@ -1344,16 +1485,19 @@ pub(crate) mod tests {
             | 5  | 5  | 50  | 2  | 2  | 80  |
             | 9  | 8  | 90  |    |    |     |
             +----+----+-----+----+----+-----+
-            "#);
+            "#));
 
         assert_join_metrics!(metrics, 5);
 
         Ok(())
     }
 
+    #[rstest]
     #[tokio::test]
-    async fn join_left_semi_with_filter() -> Result<()> {
-        let task_ctx = Arc::new(TaskContext::default());
+    async fn join_left_semi_with_filter(
+        #[values(1, 2, 16)] batch_size: usize,
+    ) -> Result<()> {
+        let task_ctx = new_task_ctx(batch_size);
         let left = build_left_table();
         let right = build_right_table();
 
@@ -1367,22 +1511,25 @@ pub(crate) mod tests {
         )
         .await?;
         assert_eq!(columns, vec!["a1", "b1", "c1"]);
-        assert_snapshot!(batches_to_sort_string(&batches), @r#"
+        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), 
@r#"
             +----+----+----+
             | a1 | b1 | c1 |
             +----+----+----+
             | 5  | 5  | 50 |
             +----+----+----+
-            "#);
+            "#));
 
         assert_join_metrics!(metrics, 1);
 
         Ok(())
     }
 
+    #[rstest]
     #[tokio::test]
-    async fn join_left_anti_with_filter() -> Result<()> {
-        let task_ctx = Arc::new(TaskContext::default());
+    async fn join_left_anti_with_filter(
+        #[values(1, 2, 16)] batch_size: usize,
+    ) -> Result<()> {
+        let task_ctx = new_task_ctx(batch_size);
         let left = build_left_table();
         let right = build_right_table();
 
@@ -1396,23 +1543,26 @@ pub(crate) mod tests {
         )
         .await?;
         assert_eq!(columns, vec!["a1", "b1", "c1"]);
-        assert_snapshot!(batches_to_sort_string(&batches), @r#"
+        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), 
@r#"
             +----+----+-----+
             | a1 | b1 | c1  |
             +----+----+-----+
             | 11 | 8  | 110 |
             | 9  | 8  | 90  |
             +----+----+-----+
-            "#);
+            "#));
 
         assert_join_metrics!(metrics, 2);
 
         Ok(())
     }
 
+    #[rstest]
     #[tokio::test]
-    async fn join_right_semi_with_filter() -> Result<()> {
-        let task_ctx = Arc::new(TaskContext::default());
+    async fn join_right_semi_with_filter(
+        #[values(1, 2, 16)] batch_size: usize,
+    ) -> Result<()> {
+        let task_ctx = new_task_ctx(batch_size);
         let left = build_left_table();
         let right = build_right_table();
 
@@ -1426,22 +1576,25 @@ pub(crate) mod tests {
         )
         .await?;
         assert_eq!(columns, vec!["a2", "b2", "c2"]);
-        assert_snapshot!(batches_to_sort_string(&batches), @r#"
+        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), 
@r#"
             +----+----+----+
             | a2 | b2 | c2 |
             +----+----+----+
             | 2  | 2  | 80 |
             +----+----+----+
-            "#);
+            "#));
 
         assert_join_metrics!(metrics, 1);
 
         Ok(())
     }
 
+    #[rstest]
     #[tokio::test]
-    async fn join_right_anti_with_filter() -> Result<()> {
-        let task_ctx = Arc::new(TaskContext::default());
+    async fn join_right_anti_with_filter(
+        #[values(1, 2, 16)] batch_size: usize,
+    ) -> Result<()> {
+        let task_ctx = new_task_ctx(batch_size);
         let left = build_left_table();
         let right = build_right_table();
 
@@ -1455,23 +1608,26 @@ pub(crate) mod tests {
         )
         .await?;
         assert_eq!(columns, vec!["a2", "b2", "c2"]);
-        assert_snapshot!(batches_to_sort_string(&batches), @r#"
+        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), 
@r#"
             +----+----+-----+
             | a2 | b2 | c2  |
             +----+----+-----+
             | 10 | 10 | 100 |
             | 12 | 10 | 40  |
             +----+----+-----+
-            "#);
+            "#));
 
         assert_join_metrics!(metrics, 2);
 
         Ok(())
     }
 
+    #[rstest]
     #[tokio::test]
-    async fn join_left_mark_with_filter() -> Result<()> {
-        let task_ctx = Arc::new(TaskContext::default());
+    async fn join_left_mark_with_filter(
+        #[values(1, 2, 16)] batch_size: usize,
+    ) -> Result<()> {
+        let task_ctx = new_task_ctx(batch_size);
         let left = build_left_table();
         let right = build_right_table();
 
@@ -1485,7 +1641,7 @@ pub(crate) mod tests {
         )
         .await?;
         assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
-        assert_snapshot!(batches_to_sort_string(&batches), @r#"
+        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), 
@r#"
             +----+----+-----+-------+
             | a1 | b1 | c1  | mark  |
             +----+----+-----+-------+
@@ -1493,16 +1649,19 @@ pub(crate) mod tests {
             | 5  | 5  | 50  | true  |
             | 9  | 8  | 90  | false |
             +----+----+-----+-------+
-            "#);
+            "#));
 
         assert_join_metrics!(metrics, 3);
 
         Ok(())
     }
 
+    #[rstest]
     #[tokio::test]
-    async fn join_right_mark_with_filter() -> Result<()> {
-        let task_ctx = Arc::new(TaskContext::default());
+    async fn join_right_mark_with_filter(
+        #[values(1, 2, 16)] batch_size: usize,
+    ) -> Result<()> {
+        let task_ctx = new_task_ctx(batch_size);
         let left = build_left_table();
         let right = build_right_table();
 
@@ -1517,7 +1676,7 @@ pub(crate) mod tests {
         .await?;
         assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]);
 
-        assert_snapshot!(batches_to_sort_string(&batches), @r#"
+        allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), 
@r#"
             +----+----+-----+-------+
             | a2 | b2 | c2  | mark  |
             +----+----+-----+-------+
@@ -1525,7 +1684,7 @@ pub(crate) mod tests {
             | 12 | 10 | 40  | false |
             | 2  | 2  | 80  | true  |
             +----+----+-----+-------+
-            "#);
+            "#));
 
         assert_join_metrics!(metrics, 3);
 
@@ -1659,6 +1818,7 @@ pub(crate) mod tests {
         join_type: JoinType,
         #[values(1, 100, 1000)] left_batch_size: usize,
         #[values(1, 100, 1000)] right_batch_size: usize,
+        #[values(2, 10000)] batch_size: usize,
     ) -> Result<()> {
         let left_columns = generate_columns(3, 1000);
         let left = build_table(
@@ -1708,8 +1868,9 @@ pub(crate) mod tests {
             assert_eq!(right.options, join.options);
         }
 
+        let task_ctx = new_task_ctx(batch_size);
         let batches = nested_loop_join
-            .execute(0, Arc::new(TaskContext::default()))?
+            .execute(0, task_ctx)?
             .try_collect::<Vec<_>>()
             .await?;
 
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs 
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index 013d4497a8..9a8d4cbb66 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -992,6 +992,7 @@ pub(crate) fn join_with_probe_batch(
             probe_indices,
             filter,
             build_hash_joiner.build_side,
+            None,
         )?
     } else {
         (build_indices, probe_indices)
diff --git a/datafusion/physical-plan/src/joins/utils.rs 
b/datafusion/physical-plan/src/joins/utils.rs
index 5248eafef0..35827d4fcd 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -17,6 +17,7 @@
 
 //! Join related functionality used both on logical and physical plans
 
+use std::cmp::min;
 use std::collections::HashSet;
 use std::fmt::{self, Debug};
 use std::future::Future;
@@ -844,24 +845,56 @@ pub(crate) fn apply_join_filter_to_indices(
     probe_indices: UInt32Array,
     filter: &JoinFilter,
     build_side: JoinSide,
+    max_intermediate_size: Option<usize>,
 ) -> Result<(UInt64Array, UInt32Array)> {
     if build_indices.is_empty() && probe_indices.is_empty() {
         return Ok((build_indices, probe_indices));
     };
 
-    let intermediate_batch = build_batch_from_indices(
-        filter.schema(),
-        build_input_buffer,
-        probe_batch,
-        &build_indices,
-        &probe_indices,
-        filter.column_indices(),
-        build_side,
-    )?;
-    let filter_result = filter
-        .expression()
-        .evaluate(&intermediate_batch)?
-        .into_array(intermediate_batch.num_rows())?;
+    let filter_result = if let Some(max_size) = max_intermediate_size {
+        let mut filter_results =
+            Vec::with_capacity(build_indices.len().div_ceil(max_size));
+
+        for i in (0..build_indices.len()).step_by(max_size) {
+            let end = min(build_indices.len(), i + max_size);
+            let len = end - i;
+            let intermediate_batch = build_batch_from_indices(
+                filter.schema(),
+                build_input_buffer,
+                probe_batch,
+                &build_indices.slice(i, len),
+                &probe_indices.slice(i, len),
+                filter.column_indices(),
+                build_side,
+            )?;
+            let filter_result = filter
+                .expression()
+                .evaluate(&intermediate_batch)?
+                .into_array(intermediate_batch.num_rows())?;
+            filter_results.push(filter_result);
+        }
+
+        let filter_refs: Vec<&dyn Array> =
+            filter_results.iter().map(|a| a.as_ref()).collect();
+
+        compute::concat(&filter_refs)?
+    } else {
+        let intermediate_batch = build_batch_from_indices(
+            filter.schema(),
+            build_input_buffer,
+            probe_batch,
+            &build_indices,
+            &probe_indices,
+            filter.column_indices(),
+            build_side,
+        )?;
+
+        filter
+            .expression()
+            .evaluate(&intermediate_batch)?
+            .into_array(intermediate_batch.num_rows())?
+    };
+
     let mask = as_boolean_array(&filter_result)?;
 
     let left_filtered = compute::filter(&build_indices, mask)?;
@@ -924,6 +957,7 @@ pub(crate) fn build_batch_from_indices(
                 compute::take(array.as_ref(), probe_indices, None)?
             }
         };
+
         columns.push(array);
     }
     Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to