This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 36991aca9f limit intermediate batch size in nested_loop_join (#16443)
36991aca9f is described below
commit 36991aca9fabac8fe010a2b27d49b64d96658d2e
Author: UBarney <[email protected]>
AuthorDate: Tue Jul 15 03:49:04 2025 +0800
limit intermediate batch size in nested_loop_join (#16443)
* limit intermediate_batch_size in nested_loop_join
* address some todo
* address comment
* address comments
* address comment
* fix typo
* address comment
---
datafusion/physical-plan/src/joins/hash_join.rs | 1 +
.../physical-plan/src/joins/nested_loop_join.rs | 505 ++++++++++++++-------
.../physical-plan/src/joins/symmetric_hash_join.rs | 1 +
datafusion/physical-plan/src/joins/utils.rs | 60 ++-
4 files changed, 382 insertions(+), 185 deletions(-)
diff --git a/datafusion/physical-plan/src/joins/hash_join.rs
b/datafusion/physical-plan/src/joins/hash_join.rs
index 40ab9f2a0b..a7f28ede44 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -1533,6 +1533,7 @@ impl HashJoinStream {
right_indices,
filter,
JoinSide::Left,
+ None,
)?
} else {
(left_indices, right_indices)
diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs
b/datafusion/physical-plan/src/joins/nested_loop_join.rs
index c84b3a9d40..343edbb0d4 100644
--- a/datafusion/physical-plan/src/joins/nested_loop_join.rs
+++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs
@@ -18,6 +18,7 @@
//! [`NestedLoopJoinExec`]: joins without equijoin (equality predicates).
use std::any::Any;
+use std::cmp::min;
use std::fmt::Formatter;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
@@ -26,7 +27,7 @@ use std::task::Poll;
use super::utils::{
asymmetric_join_output_partitioning, get_final_indices_from_shared_bitmap,
need_produce_result_in_final, reorder_output_after_swap,
swap_join_projection,
- BatchSplitter, BatchTransformer, NoopBatchTransformer,
StatefulStreamResult,
+ StatefulStreamResult,
};
use crate::common::can_project;
use crate::execution_plan::{boundedness_from_children, EmissionType};
@@ -47,12 +48,13 @@ use crate::{
SendableRecordBatchStream,
};
-use arrow::array::{BooleanBufferBuilder, UInt32Array, UInt64Array};
+use arrow::array::{BooleanBufferBuilder, PrimitiveArray, UInt32Array,
UInt64Array};
use arrow::compute::concat_batches;
-use arrow::datatypes::{Schema, SchemaRef};
+use arrow::datatypes::{Schema, SchemaRef, UInt32Type, UInt64Type};
use arrow::record_batch::RecordBatch;
use datafusion_common::{
- exec_datafusion_err, internal_err, project_schema, JoinSide, Result,
Statistics,
+ exec_datafusion_err, internal_datafusion_err, internal_err,
project_schema, JoinSide,
+ Result, Statistics,
};
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_execution::TaskContext;
@@ -510,8 +512,6 @@ impl ExecutionPlan for NestedLoopJoinExec {
})?;
let batch_size = context.session_config().batch_size();
- let enforce_batch_size_in_joins =
- context.session_config().enforce_batch_size_in_joins();
let outer_table = self.right.execute(partition, context)?;
@@ -530,37 +530,21 @@ impl ExecutionPlan for NestedLoopJoinExec {
None => self.column_indices.clone(),
};
- if enforce_batch_size_in_joins {
- Ok(Box::pin(NestedLoopJoinStream {
- schema: self.schema(),
- filter: self.filter.clone(),
- join_type: self.join_type,
- outer_table,
- inner_table,
- column_indices: column_indices_after_projection,
- join_metrics,
- indices_cache,
- right_side_ordered,
- state: NestedLoopJoinStreamState::WaitBuildSide,
- batch_transformer: BatchSplitter::new(batch_size),
- left_data: None,
- }))
- } else {
- Ok(Box::pin(NestedLoopJoinStream {
- schema: self.schema(),
- filter: self.filter.clone(),
- join_type: self.join_type,
- outer_table,
- inner_table,
- column_indices: column_indices_after_projection,
- join_metrics,
- indices_cache,
- right_side_ordered,
- state: NestedLoopJoinStreamState::WaitBuildSide,
- batch_transformer: NoopBatchTransformer::new(),
- left_data: None,
- }))
- }
+ Ok(Box::pin(NestedLoopJoinStream {
+ schema: self.schema(),
+ filter: self.filter.clone(),
+ join_type: self.join_type,
+ outer_table,
+ inner_table,
+ column_indices: column_indices_after_projection,
+ join_metrics,
+ indices_cache,
+ right_side_ordered,
+ state: NestedLoopJoinStreamState::WaitBuildSide,
+ left_data: None,
+ join_result_status: None,
+ intermediate_batch_size: batch_size,
+ }))
}
fn metrics(&self) -> Option<MetricsSet> {
@@ -687,8 +671,15 @@ enum NestedLoopJoinStreamState {
/// Indicates that a non-empty batch has been fetched from probe-side, and
/// is ready to be processed
ProcessProbeBatch(RecordBatch),
- /// Indicates that probe-side has been fully processed
- ExhaustedProbeSide,
+ /// Preparation phase: Gathers the indices of unmatched rows from the
build-side.
+ /// This state is entered for join types that emit unmatched build-side
rows
+ /// (e.g., LEFT and FULL joins) after the entire probe-side input has been
consumed.
+ PrepareUnmatchedBuildRows,
+ /// Output unmatched build-side rows.
+ /// The indices for rows to output has already been calculated in the
previous
+ /// `PrepareUnmatchedBuildRows` state. In this state the final batch will
be materialized incrementally.
+ // The inner `RecordBatch` is an empty dummy batch used to get right
schema.
+ OutputUnmatchedBuildRows(RecordBatch),
/// Indicates that NestedLoopJoinStream execution is completed
Completed,
}
@@ -705,8 +696,29 @@ impl NestedLoopJoinStreamState {
}
}
+/// Tracks incremental output of join result batches.
+///
+/// Initialized with all matching pairs that satisfy the join predicate.
+/// Pairs are stored as indices in `build_indices` and `probe_indices`
+/// Each poll outputs a batch within the configured size limit and updates
+/// processed_count until all pairs are consumed.
+///
+/// Example: 5000 matches, batch size limit is 100
+/// - Poll 1: output batch[0..100], processed_count = 100
+/// - Poll 2: output batch[100..200], processed_count = 200
+/// - ...continues until processed_count = 5000
+struct JoinResultProgress {
+ /// Row indices from build-side table (left table).
+ build_indices: PrimitiveArray<UInt64Type>,
+ /// Row indices from probe-side table (right table).
+ probe_indices: PrimitiveArray<UInt32Type>,
+ /// Number of index pairs already processed into output batches.
+ /// We have completed join result for indices [0..processed_count).
+ processed_count: usize,
+}
+
/// A stream that issues [RecordBatch]es as they arrive from the right of the
join.
-struct NestedLoopJoinStream<T> {
+struct NestedLoopJoinStream {
/// Input schema
schema: Arc<Schema>,
/// join filter
@@ -729,10 +741,13 @@ struct NestedLoopJoinStream<T> {
right_side_ordered: bool,
/// Current state of the stream
state: NestedLoopJoinStreamState,
- /// Transforms the output batch before returning.
- batch_transformer: T,
/// Result of the left data future
left_data: Option<Arc<JoinLeftData>>,
+
+ /// Tracks progress when building join result batches incrementally.
+ join_result_status: Option<JoinResultProgress>,
+
+ intermediate_batch_size: usize,
}
/// Creates a Cartesian product of two input batches, preserving the order of
the right batch,
@@ -755,6 +770,7 @@ fn build_join_indices(
right_batch: &RecordBatch,
filter: Option<&JoinFilter>,
indices_cache: &mut (UInt64Array, UInt32Array),
+ max_intermediate_batch_size: usize,
) -> Result<(UInt64Array, UInt32Array)> {
let left_row_count = left_batch.num_rows();
let right_row_count = right_batch.num_rows();
@@ -805,13 +821,14 @@ fn build_join_indices(
right_indices,
filter,
JoinSide::Left,
+ Some(max_intermediate_batch_size),
)
} else {
Ok((left_indices, right_indices))
}
}
-impl<T: BatchTransformer> NestedLoopJoinStream<T> {
+impl NestedLoopJoinStream {
fn poll_next_impl(
&mut self,
cx: &mut std::task::Context<'_>,
@@ -828,8 +845,11 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
let poll = handle_state!(self.process_probe_batch());
self.join_metrics.baseline.record_poll(poll)
}
- NestedLoopJoinStreamState::ExhaustedProbeSide => {
- let poll =
handle_state!(self.process_unmatched_build_batch());
+ NestedLoopJoinStreamState::PrepareUnmatchedBuildRows => {
+ handle_state!(self.prepare_unmatched_output_indices())
+ }
+ NestedLoopJoinStreamState::OutputUnmatchedBuildRows(_) => {
+ let poll = handle_state!(self.build_unmatched_output());
self.join_metrics.baseline.record_poll(poll)
}
NestedLoopJoinStreamState::Completed => Poll::Ready(None),
@@ -837,6 +857,116 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
}
}
+ // This function's main job is to construct an output `RecordBatch` based
on pre-calculated join indices.
+ // It operates in a chunk-based manner, meaning it processes a portion of
the results in each call,
+ // making it suitable for streaming large datasets without high memory
consumption.
+ // This function behaves like an iterator. It returns `Ok(None)`
+ // to signal that the result stream is exhausted and there is no more data.
+ fn get_next_join_result(&mut self) -> Result<Option<RecordBatch>> {
+ let status = self.join_result_status.as_mut().ok_or_else(|| {
+ internal_datafusion_err!(
+ "get_next_join_result called without initializing
join_result_status"
+ )
+ })?;
+
+ let (left_indices, right_indices, current_start) = (
+ &status.build_indices,
+ &status.probe_indices,
+ status.processed_count,
+ );
+
+ let left_batch = self
+ .left_data
+ .as_ref()
+ .ok_or_else(|| internal_datafusion_err!("should have left_batch"))?
+ .batch();
+
+ let right_batch = match &self.state {
+ NestedLoopJoinStreamState::ProcessProbeBatch(record_batch)
+ |
NestedLoopJoinStreamState::OutputUnmatchedBuildRows(record_batch) => {
+ record_batch
+ }
+ _ => {
+ return internal_err!(
+ "State should be ProcessProbeBatch or
OutputUnmatchedBuildRows"
+ )
+ }
+ };
+
+ if left_indices.is_empty() && right_indices.is_empty() &&
current_start == 0 {
+ // To match the behavior of the previous implementation, return an
empty RecordBatch.
+ let res = RecordBatch::new_empty(Arc::clone(&self.schema));
+ status.processed_count = 1;
+ return Ok(Some(res));
+ }
+
+ if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti)
{
+ // in this case left_indices.num_rows() == 0
+ let end = min(
+ current_start + self.intermediate_batch_size,
+ right_indices.len(),
+ );
+
+ if current_start >= end {
+ return Ok(None);
+ }
+
+ let res = Some(build_batch_from_indices(
+ &self.schema,
+ left_batch,
+ right_batch,
+ left_indices,
+ &right_indices.slice(current_start, end - current_start),
+ &self.column_indices,
+ JoinSide::Left,
+ )?);
+
+ status.processed_count = end;
+ return Ok(res);
+ }
+
+ if current_start >= left_indices.len() {
+ return Ok(None);
+ }
+
+ let end = min(
+ current_start + self.intermediate_batch_size,
+ left_indices.len(),
+ );
+
+ let left_indices = &left_indices.slice(current_start, end -
current_start);
+ let right_indices = &right_indices.slice(current_start, end -
current_start);
+
+ // Switch around the build side and probe side for
`JoinType::RightMark`
+ // because in a RightMark join, we want to mark rows on the right table
+ // by looking for matches in the left.
+ let res = if self.join_type == JoinType::RightMark {
+ build_batch_from_indices(
+ &self.schema,
+ right_batch,
+ left_batch,
+ left_indices,
+ right_indices,
+ &self.column_indices,
+ JoinSide::Right,
+ )
+ } else {
+ build_batch_from_indices(
+ &self.schema,
+ left_batch,
+ right_batch,
+ left_indices,
+ right_indices,
+ &self.column_indices,
+ JoinSide::Left,
+ )
+ }?;
+
+ status.processed_count = end;
+
+ Ok(Some(res))
+ }
+
fn collect_build_side(
&mut self,
cx: &mut std::task::Context<'_>,
@@ -861,9 +991,12 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
match ready!(self.outer_table.poll_next_unpin(cx)) {
None => {
- self.state = NestedLoopJoinStreamState::ExhaustedProbeSide;
+ self.state =
NestedLoopJoinStreamState::PrepareUnmatchedBuildRows;
}
Some(Ok(right_batch)) => {
+ self.join_metrics.input_batches.add(1);
+ self.join_metrics.input_rows.add(right_batch.num_rows());
+
self.state =
NestedLoopJoinStreamState::ProcessProbeBatch(right_batch);
}
Some(Err(err)) => return Poll::Ready(Err(err)),
@@ -885,43 +1018,64 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
let visited_left_side = left_data.bitmap();
let batch = self.state.try_as_process_probe_batch()?;
- match self.batch_transformer.next() {
- None => {
- // Setting up timer & updating input metrics
- self.join_metrics.input_batches.add(1);
- self.join_metrics.input_rows.add(batch.num_rows());
- let timer = self.join_metrics.join_time.timer();
-
- let result = join_left_and_right_batch(
- left_data.batch(),
- batch,
- self.join_type,
- self.filter.as_ref(),
- &self.column_indices,
- &self.schema,
- visited_left_side,
- &mut self.indices_cache,
- self.right_side_ordered,
- );
- timer.done();
+ let binding = self.join_metrics.join_time.clone();
+ let _timer = binding.timer();
+
+ if self.join_result_status.is_none() {
+ let (left_side_indices, right_side_indices) =
join_left_and_right_batch(
+ left_data.batch(),
+ batch,
+ self.join_type,
+ self.filter.as_ref(),
+ visited_left_side,
+ &mut self.indices_cache,
+ self.right_side_ordered,
+ self.intermediate_batch_size,
+ )?;
+ self.join_result_status = Some(JoinResultProgress {
+ build_indices: left_side_indices,
+ probe_indices: right_side_indices,
+ processed_count: 0,
+ })
+ }
+
+ let join_result = self.get_next_join_result()?;
- self.batch_transformer.set_batch(result?);
+ match join_result {
+ Some(res) => {
+ self.join_metrics.output_batches.add(1);
+ Ok(StatefulStreamResult::Ready(Some(res)))
+ }
+ None => {
+ self.state = NestedLoopJoinStreamState::FetchProbeBatch;
+ self.join_result_status = None;
Ok(StatefulStreamResult::Continue)
}
- Some((batch, last)) => {
- if last {
- self.state = NestedLoopJoinStreamState::FetchProbeBatch;
- }
+ }
+ }
+ fn build_unmatched_output(
+ &mut self,
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
+ let binding = self.join_metrics.join_time.clone();
+ let _timer = binding.timer();
+
+ let res = self.get_next_join_result()?;
+ match res {
+ Some(res) => {
self.join_metrics.output_batches.add(1);
- Ok(StatefulStreamResult::Ready(Some(batch)))
+ Ok(StatefulStreamResult::Ready(Some(res)))
+ }
+ None => {
+ self.state = NestedLoopJoinStreamState::Completed;
+ Ok(StatefulStreamResult::Ready(None))
}
}
}
- /// Processes unmatched build-side rows for certain join types and produces
- /// output batch, updates state to `Completed`.
- fn process_unmatched_build_batch(
+ /// This function's primary purpose is to handle the final output stage
required by specific join types after all right-side (probe) data has been
exhausted.
+ /// It is critically important for LEFT*/FULL joins, which must emit
left-side (build) rows that found no match. For these cases, it identifies the
unmatched rows and prepares the necessary state to output them.
+ fn prepare_unmatched_output_indices(
&mut self,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
let Some(left_data) = self.left_data.clone() else {
@@ -942,31 +1096,21 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
};
// Only setting up timer, input is exhausted
- let timer = self.join_metrics.join_time.timer();
+ let _timer = self.join_metrics.join_time.timer();
// use the global left bitmap to produce the left indices and
right indices
let (left_side, right_side) =
get_final_indices_from_shared_bitmap(visited_left_side,
self.join_type);
- let empty_right_batch =
RecordBatch::new_empty(self.outer_table.schema());
- // use the left and right indices to produce the batch result
- let result = build_batch_from_indices(
- &self.schema,
- left_data.batch(),
- &empty_right_batch,
- &left_side,
- &right_side,
- &self.column_indices,
- JoinSide::Left,
- );
- self.state = NestedLoopJoinStreamState::Completed;
-
- // Recording time
- if result.is_ok() {
- timer.done();
- }
- self.join_metrics.output_batches.add(1);
+ self.join_result_status = Some(JoinResultProgress {
+ build_indices: left_side,
+ probe_indices: right_side,
+ processed_count: 0,
+ });
+ self.state = NestedLoopJoinStreamState::OutputUnmatchedBuildRows(
+ RecordBatch::new_empty(self.outer_table.schema()),
+ );
- Ok(StatefulStreamResult::Ready(Some(result?)))
+ Ok(StatefulStreamResult::Continue)
} else {
// end of the join loop
self.state = NestedLoopJoinStreamState::Completed;
@@ -981,20 +1125,23 @@ fn join_left_and_right_batch(
right_batch: &RecordBatch,
join_type: JoinType,
filter: Option<&JoinFilter>,
- column_indices: &[ColumnIndex],
- schema: &Schema,
visited_left_side: &SharedBitmapBuilder,
indices_cache: &mut (UInt64Array, UInt32Array),
right_side_ordered: bool,
-) -> Result<RecordBatch> {
- let (left_side, right_side) =
- build_join_indices(left_batch, right_batch, filter,
indices_cache).map_err(
- |e| {
- exec_datafusion_err!(
- "Fail to build join indices in NestedLoopJoinExec, error:
{e}"
- )
- },
- )?;
+ max_intermediate_batch_size: usize,
+) -> Result<(PrimitiveArray<UInt64Type>, PrimitiveArray<UInt32Type>)> {
+ let (left_side, right_side) = build_join_indices(
+ left_batch,
+ right_batch,
+ filter,
+ indices_cache,
+ max_intermediate_batch_size,
+ )
+ .map_err(|e| {
+ exec_datafusion_err!(
+ "Fail to build join indices in NestedLoopJoinExec, error: {e}"
+ )
+ })?;
// set the left bitmap
// and only full join need the left bitmap
@@ -1013,33 +1160,10 @@ fn join_left_and_right_batch(
right_side_ordered,
)?;
- // Switch around the build side and probe side for `JoinType::RightMark`
- // because in a RightMark join, we want to mark rows on the right table
- // by looking for matches in the left.
- if join_type == JoinType::RightMark {
- build_batch_from_indices(
- schema,
- right_batch,
- left_batch,
- &left_side,
- &right_side,
- column_indices,
- JoinSide::Right,
- )
- } else {
- build_batch_from_indices(
- schema,
- left_batch,
- right_batch,
- &left_side,
- &right_side,
- column_indices,
- JoinSide::Left,
- )
- }
+ Ok((left_side, right_side))
}
-impl<T: BatchTransformer + Unpin + Send> Stream for NestedLoopJoinStream<T> {
+impl Stream for NestedLoopJoinStream {
type Item = Result<RecordBatch>;
fn poll_next(
@@ -1050,7 +1174,7 @@ impl<T: BatchTransformer + Unpin + Send> Stream for
NestedLoopJoinStream<T> {
}
}
-impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for
NestedLoopJoinStream<T> {
+impl RecordBatchStream for NestedLoopJoinStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
@@ -1081,6 +1205,7 @@ pub(crate) mod tests {
use datafusion_physical_expr::{Partitioning, PhysicalExpr};
use datafusion_physical_expr_common::sort_expr::{LexOrdering,
PhysicalSortExpr};
+ use insta::allow_duplicates;
use insta::assert_snapshot;
use rstest::rstest;
@@ -1218,6 +1343,9 @@ pub(crate) mod tests {
batches.extend(
more_batches
.into_iter()
+ .inspect(|b| {
+ assert!(b.num_rows() <=
context.session_config().batch_size())
+ })
.filter(|b| b.num_rows() > 0)
.collect::<Vec<_>>(),
);
@@ -1228,9 +1356,18 @@ pub(crate) mod tests {
Ok((columns, batches, metrics))
}
+ fn new_task_ctx(batch_size: usize) -> Arc<TaskContext> {
+ let base = TaskContext::default();
+ // limit max size of intermediate batch used in nlj to 1
+ let cfg = base.session_config().clone().with_batch_size(batch_size);
+ Arc::new(base.with_session_config(cfg))
+ }
+
+ #[rstest]
#[tokio::test]
- async fn join_inner_with_filter() -> Result<()> {
- let task_ctx = Arc::new(TaskContext::default());
+ async fn join_inner_with_filter(#[values(1, 2, 16)] batch_size: usize) ->
Result<()> {
+ let task_ctx = new_task_ctx(batch_size);
+ dbg!(&batch_size);
let left = build_left_table();
let right = build_right_table();
let filter = prepare_join_filter();
@@ -1242,23 +1379,25 @@ pub(crate) mod tests {
task_ctx,
)
.await?;
+
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
- assert_snapshot!(batches_to_sort_string(&batches), @r#"
+ allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches),
@r#"
+----+----+----+----+----+----+
| a1 | b1 | c1 | a2 | b2 | c2 |
+----+----+----+----+----+----+
| 5 | 5 | 50 | 2 | 2 | 80 |
+----+----+----+----+----+----+
- "#);
+ "#));
assert_join_metrics!(metrics, 1);
Ok(())
}
+ #[rstest]
#[tokio::test]
- async fn join_left_with_filter() -> Result<()> {
- let task_ctx = Arc::new(TaskContext::default());
+ async fn join_left_with_filter(#[values(1, 2, 16)] batch_size: usize) ->
Result<()> {
+ let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
@@ -1272,7 +1411,7 @@ pub(crate) mod tests {
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
- assert_snapshot!(batches_to_sort_string(&batches), @r#"
+ allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches),
@r#"
+----+----+-----+----+----+----+
| a1 | b1 | c1 | a2 | b2 | c2 |
+----+----+-----+----+----+----+
@@ -1280,16 +1419,17 @@ pub(crate) mod tests {
| 5 | 5 | 50 | 2 | 2 | 80 |
| 9 | 8 | 90 | | | |
+----+----+-----+----+----+----+
- "#);
+ "#));
assert_join_metrics!(metrics, 3);
Ok(())
}
+ #[rstest]
#[tokio::test]
- async fn join_right_with_filter() -> Result<()> {
- let task_ctx = Arc::new(TaskContext::default());
+ async fn join_right_with_filter(#[values(1, 2, 16)] batch_size: usize) ->
Result<()> {
+ let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
@@ -1303,7 +1443,7 @@ pub(crate) mod tests {
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
- assert_snapshot!(batches_to_sort_string(&batches), @r#"
+ allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches),
@r#"
+----+----+----+----+----+-----+
| a1 | b1 | c1 | a2 | b2 | c2 |
+----+----+----+----+----+-----+
@@ -1311,16 +1451,17 @@ pub(crate) mod tests {
| | | | 12 | 10 | 40 |
| 5 | 5 | 50 | 2 | 2 | 80 |
+----+----+----+----+----+-----+
- "#);
+ "#));
assert_join_metrics!(metrics, 3);
Ok(())
}
+ #[rstest]
#[tokio::test]
- async fn join_full_with_filter() -> Result<()> {
- let task_ctx = Arc::new(TaskContext::default());
+ async fn join_full_with_filter(#[values(1, 2, 16)] batch_size: usize) ->
Result<()> {
+ let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
@@ -1334,7 +1475,7 @@ pub(crate) mod tests {
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
- assert_snapshot!(batches_to_sort_string(&batches), @r#"
+ allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches),
@r#"
+----+----+-----+----+----+-----+
| a1 | b1 | c1 | a2 | b2 | c2 |
+----+----+-----+----+----+-----+
@@ -1344,16 +1485,19 @@ pub(crate) mod tests {
| 5 | 5 | 50 | 2 | 2 | 80 |
| 9 | 8 | 90 | | | |
+----+----+-----+----+----+-----+
- "#);
+ "#));
assert_join_metrics!(metrics, 5);
Ok(())
}
+ #[rstest]
#[tokio::test]
- async fn join_left_semi_with_filter() -> Result<()> {
- let task_ctx = Arc::new(TaskContext::default());
+ async fn join_left_semi_with_filter(
+ #[values(1, 2, 16)] batch_size: usize,
+ ) -> Result<()> {
+ let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
@@ -1367,22 +1511,25 @@ pub(crate) mod tests {
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1"]);
- assert_snapshot!(batches_to_sort_string(&batches), @r#"
+ allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches),
@r#"
+----+----+----+
| a1 | b1 | c1 |
+----+----+----+
| 5 | 5 | 50 |
+----+----+----+
- "#);
+ "#));
assert_join_metrics!(metrics, 1);
Ok(())
}
+ #[rstest]
#[tokio::test]
- async fn join_left_anti_with_filter() -> Result<()> {
- let task_ctx = Arc::new(TaskContext::default());
+ async fn join_left_anti_with_filter(
+ #[values(1, 2, 16)] batch_size: usize,
+ ) -> Result<()> {
+ let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
@@ -1396,23 +1543,26 @@ pub(crate) mod tests {
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1"]);
- assert_snapshot!(batches_to_sort_string(&batches), @r#"
+ allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches),
@r#"
+----+----+-----+
| a1 | b1 | c1 |
+----+----+-----+
| 11 | 8 | 110 |
| 9 | 8 | 90 |
+----+----+-----+
- "#);
+ "#));
assert_join_metrics!(metrics, 2);
Ok(())
}
+ #[rstest]
#[tokio::test]
- async fn join_right_semi_with_filter() -> Result<()> {
- let task_ctx = Arc::new(TaskContext::default());
+ async fn join_right_semi_with_filter(
+ #[values(1, 2, 16)] batch_size: usize,
+ ) -> Result<()> {
+ let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
@@ -1426,22 +1576,25 @@ pub(crate) mod tests {
)
.await?;
assert_eq!(columns, vec!["a2", "b2", "c2"]);
- assert_snapshot!(batches_to_sort_string(&batches), @r#"
+ allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches),
@r#"
+----+----+----+
| a2 | b2 | c2 |
+----+----+----+
| 2 | 2 | 80 |
+----+----+----+
- "#);
+ "#));
assert_join_metrics!(metrics, 1);
Ok(())
}
+ #[rstest]
#[tokio::test]
- async fn join_right_anti_with_filter() -> Result<()> {
- let task_ctx = Arc::new(TaskContext::default());
+ async fn join_right_anti_with_filter(
+ #[values(1, 2, 16)] batch_size: usize,
+ ) -> Result<()> {
+ let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
@@ -1455,23 +1608,26 @@ pub(crate) mod tests {
)
.await?;
assert_eq!(columns, vec!["a2", "b2", "c2"]);
- assert_snapshot!(batches_to_sort_string(&batches), @r#"
+ allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches),
@r#"
+----+----+-----+
| a2 | b2 | c2 |
+----+----+-----+
| 10 | 10 | 100 |
| 12 | 10 | 40 |
+----+----+-----+
- "#);
+ "#));
assert_join_metrics!(metrics, 2);
Ok(())
}
+ #[rstest]
#[tokio::test]
- async fn join_left_mark_with_filter() -> Result<()> {
- let task_ctx = Arc::new(TaskContext::default());
+ async fn join_left_mark_with_filter(
+ #[values(1, 2, 16)] batch_size: usize,
+ ) -> Result<()> {
+ let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
@@ -1485,7 +1641,7 @@ pub(crate) mod tests {
)
.await?;
assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
- assert_snapshot!(batches_to_sort_string(&batches), @r#"
+ allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches),
@r#"
+----+----+-----+-------+
| a1 | b1 | c1 | mark |
+----+----+-----+-------+
@@ -1493,16 +1649,19 @@ pub(crate) mod tests {
| 5 | 5 | 50 | true |
| 9 | 8 | 90 | false |
+----+----+-----+-------+
- "#);
+ "#));
assert_join_metrics!(metrics, 3);
Ok(())
}
+ #[rstest]
#[tokio::test]
- async fn join_right_mark_with_filter() -> Result<()> {
- let task_ctx = Arc::new(TaskContext::default());
+ async fn join_right_mark_with_filter(
+ #[values(1, 2, 16)] batch_size: usize,
+ ) -> Result<()> {
+ let task_ctx = new_task_ctx(batch_size);
let left = build_left_table();
let right = build_right_table();
@@ -1517,7 +1676,7 @@ pub(crate) mod tests {
.await?;
assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]);
- assert_snapshot!(batches_to_sort_string(&batches), @r#"
+ allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches),
@r#"
+----+----+-----+-------+
| a2 | b2 | c2 | mark |
+----+----+-----+-------+
@@ -1525,7 +1684,7 @@ pub(crate) mod tests {
| 12 | 10 | 40 | false |
| 2 | 2 | 80 | true |
+----+----+-----+-------+
- "#);
+ "#));
assert_join_metrics!(metrics, 3);
@@ -1659,6 +1818,7 @@ pub(crate) mod tests {
join_type: JoinType,
#[values(1, 100, 1000)] left_batch_size: usize,
#[values(1, 100, 1000)] right_batch_size: usize,
+ #[values(2, 10000)] batch_size: usize,
) -> Result<()> {
let left_columns = generate_columns(3, 1000);
let left = build_table(
@@ -1708,8 +1868,9 @@ pub(crate) mod tests {
assert_eq!(right.options, join.options);
}
+ let task_ctx = new_task_ctx(batch_size);
let batches = nested_loop_join
- .execute(0, Arc::new(TaskContext::default()))?
+ .execute(0, task_ctx)?
.try_collect::<Vec<_>>()
.await?;
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index 013d4497a8..9a8d4cbb66 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -992,6 +992,7 @@ pub(crate) fn join_with_probe_batch(
probe_indices,
filter,
build_hash_joiner.build_side,
+ None,
)?
} else {
(build_indices, probe_indices)
diff --git a/datafusion/physical-plan/src/joins/utils.rs
b/datafusion/physical-plan/src/joins/utils.rs
index 5248eafef0..35827d4fcd 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -17,6 +17,7 @@
//! Join related functionality used both on logical and physical plans
+use std::cmp::min;
use std::collections::HashSet;
use std::fmt::{self, Debug};
use std::future::Future;
@@ -844,24 +845,56 @@ pub(crate) fn apply_join_filter_to_indices(
probe_indices: UInt32Array,
filter: &JoinFilter,
build_side: JoinSide,
+ max_intermediate_size: Option<usize>,
) -> Result<(UInt64Array, UInt32Array)> {
if build_indices.is_empty() && probe_indices.is_empty() {
return Ok((build_indices, probe_indices));
};
- let intermediate_batch = build_batch_from_indices(
- filter.schema(),
- build_input_buffer,
- probe_batch,
- &build_indices,
- &probe_indices,
- filter.column_indices(),
- build_side,
- )?;
- let filter_result = filter
- .expression()
- .evaluate(&intermediate_batch)?
- .into_array(intermediate_batch.num_rows())?;
+ let filter_result = if let Some(max_size) = max_intermediate_size {
+ let mut filter_results =
+ Vec::with_capacity(build_indices.len().div_ceil(max_size));
+
+ for i in (0..build_indices.len()).step_by(max_size) {
+ let end = min(build_indices.len(), i + max_size);
+ let len = end - i;
+ let intermediate_batch = build_batch_from_indices(
+ filter.schema(),
+ build_input_buffer,
+ probe_batch,
+ &build_indices.slice(i, len),
+ &probe_indices.slice(i, len),
+ filter.column_indices(),
+ build_side,
+ )?;
+ let filter_result = filter
+ .expression()
+ .evaluate(&intermediate_batch)?
+ .into_array(intermediate_batch.num_rows())?;
+ filter_results.push(filter_result);
+ }
+
+ let filter_refs: Vec<&dyn Array> =
+ filter_results.iter().map(|a| a.as_ref()).collect();
+
+ compute::concat(&filter_refs)?
+ } else {
+ let intermediate_batch = build_batch_from_indices(
+ filter.schema(),
+ build_input_buffer,
+ probe_batch,
+ &build_indices,
+ &probe_indices,
+ filter.column_indices(),
+ build_side,
+ )?;
+
+ filter
+ .expression()
+ .evaluate(&intermediate_batch)?
+ .into_array(intermediate_batch.num_rows())?
+ };
+
let mask = as_boolean_array(&filter_result)?;
let left_filtered = compute::filter(&build_indices, mask)?;
@@ -924,6 +957,7 @@ pub(crate) fn build_batch_from_indices(
compute::take(array.as_ref(), probe_indices, None)?
}
};
+
columns.push(array);
}
Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]