This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b16ad9badc fix: SortMergeJoin don't wait for all input before emitting 
(#20482)
b16ad9badc is described below

commit b16ad9badc45ef7b08b4044d5d238f5575a2dd82
Author: Raz Luvaton <[email protected]>
AuthorDate: Tue Feb 24 21:12:42 2026 +0200

    fix: SortMergeJoin don't wait for all input before emitting (#20482)
    
    ## Which issue does this PR close?
    
    N/A
    
    ## Rationale for this change
    
    I noticed while playing around with local tests and debugging memory
    issue, that `SortMergeJoinStream` wait for all input before start
    emitting, which shouldn't be the case as we can emit early when we have
    enough data.
    
    also, this cause huge memory pressure
    
    ## What changes are included in this PR?
    
    Trying to fix the issue, not sure yet
    
    ## Are these changes tested?
    
    Yes
    
    ## Are there any user-facing changes?
    
    
    -----
    
    
    ## TODO:
    - [x] update docs
    - [x] finish fix
---
 .../src/joins/sort_merge_join/stream.rs            |  36 +-
 .../src/joins/sort_merge_join/tests.rs             | 449 ++++++++++++++++++++-
 datafusion/physical-plan/src/test/exec.rs          | 111 ++++-
 3 files changed, 562 insertions(+), 34 deletions(-)

diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
index 37213401fd..11e4a903ac 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
@@ -70,6 +70,8 @@ pub(super) enum SortMergeJoinState {
     Polling,
     /// Joining polled data and making output
     JoinOutput,
+    /// Emit ready data if have any and then go back to [`Self::Init`] state
+    EmitReadyThenInit,
     /// No more output
     Exhausted,
 }
@@ -598,13 +600,45 @@ impl Stream for SortMergeJoinStream {
                     self.current_ordering = self.compare_streamed_buffered()?;
                     self.state = SortMergeJoinState::JoinOutput;
                 }
+                SortMergeJoinState::EmitReadyThenInit => {
+                    // If have data to emit, emit it and if no more, change to 
next
+
+                    // Verify metadata alignment before checking if we have 
batches to output
+                    self.joined_record_batches
+                        .filter_metadata
+                        .debug_assert_metadata_aligned();
+
+                    // For filtered joins, skip output and let Init state 
handle it
+                    if needs_deferred_filtering(&self.filter, self.join_type) {
+                        self.state = SortMergeJoinState::Init;
+                        continue;
+                    }
+
+                    // For non-filtered joins, only output if we have a 
completed batch
+                    // (opportunistic output when target batch size is reached)
+                    if self
+                        .joined_record_batches
+                        .joined_batches
+                        .has_completed_batch()
+                    {
+                        let record_batch = self
+                            .joined_record_batches
+                            .joined_batches
+                            .next_completed_batch()
+                            .expect("has_completed_batch was true");
+                        (&record_batch)
+                            
.record_output(&self.join_metrics.baseline_metrics());
+                        return Poll::Ready(Some(Ok(record_batch)));
+                    }
+                    self.state = SortMergeJoinState::Init;
+                }
                 SortMergeJoinState::JoinOutput => {
                     self.join_partial()?;
 
                     if self.num_unfrozen_pairs() < self.batch_size {
                         if self.buffered_data.scanning_finished() {
                             self.buffered_data.scanning_reset();
-                            self.state = SortMergeJoinState::Init;
+                            self.state = SortMergeJoinState::EmitReadyThenInit;
                         }
                     } else {
                         self.freeze_all()?;
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
index 5163eb44ee..b16ad59abc 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
@@ -24,41 +24,44 @@
 //!
 //! Add relevant tests under the specified sections.
 
-use std::sync::Arc;
-
+use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn};
+use crate::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec};
+use crate::test::TestMemoryExec;
+use crate::test::exec::BarrierExec;
+use crate::test::{build_table_i32, build_table_i32_two_cols};
+use crate::{ExecutionPlan, common};
+use crate::{
+    expressions::Column, 
joins::sort_merge_join::filter::get_corrected_filter_mask,
+    joins::sort_merge_join::stream::JoinedRecordBatches,
+};
 use arrow::array::{
     BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray,
     Int32Array, RecordBatch, UInt64Array,
 };
 use arrow::compute::{BatchCoalescer, SortOptions, filter_record_batch};
 use arrow::datatypes::{DataType, Field, Schema};
-
+use arrow_ord::sort::SortColumn;
+use arrow_schema::SchemaRef;
 use datafusion_common::JoinType::*;
 use datafusion_common::{
-    JoinSide,
+    JoinSide, internal_err,
     test_util::{batches_to_sort_string, batches_to_string},
 };
 use datafusion_common::{
     JoinType, NullEquality, Result, assert_batches_eq, assert_contains,
 };
-use datafusion_execution::TaskContext;
+use datafusion_common_runtime::JoinSet;
 use datafusion_execution::config::SessionConfig;
 use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
+use datafusion_execution::{SendableRecordBatchStream, TaskContext};
 use datafusion_expr::Operator;
 use datafusion_physical_expr::expressions::BinaryExpr;
+use futures::StreamExt;
 use insta::{allow_duplicates, assert_snapshot};
-
-use crate::{
-    expressions::Column, 
joins::sort_merge_join::filter::get_corrected_filter_mask,
-    joins::sort_merge_join::stream::JoinedRecordBatches,
-};
-
-use crate::joins::SortMergeJoinExec;
-use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn};
-use crate::test::TestMemoryExec;
-use crate::test::{build_table_i32, build_table_i32_two_cols};
-use crate::{ExecutionPlan, common};
+use itertools::Itertools;
+use std::sync::Arc;
+use std::task::Poll;
 
 fn build_table(
     a: (&str, &Vec<i32>),
@@ -3130,6 +3133,420 @@ fn test_partition_statistics() -> Result<()> {
     Ok(())
 }
 
+fn build_batches(
+    a: (&str, &[Vec<bool>]),
+    b: (&str, &[Vec<i32>]),
+    c: (&str, &[Vec<i32>]),
+) -> (Vec<RecordBatch>, SchemaRef) {
+    assert_eq!(a.1.len(), b.1.len());
+    let mut batches = vec![];
+
+    let schema = Arc::new(Schema::new(vec![
+        Field::new(a.0, DataType::Boolean, false),
+        Field::new(b.0, DataType::Int32, false),
+        Field::new(c.0, DataType::Int32, false),
+    ]));
+
+    for i in 0..a.1.len() {
+        batches.push(
+            RecordBatch::try_new(
+                Arc::clone(&schema),
+                vec![
+                    Arc::new(BooleanArray::from(a.1[i].clone())),
+                    Arc::new(Int32Array::from(b.1[i].clone())),
+                    Arc::new(Int32Array::from(c.1[i].clone())),
+                ],
+            )
+            .unwrap(),
+        );
+    }
+    let schema = batches[0].schema();
+    (batches, schema)
+}
+
+fn build_batched_finish_barrier_table(
+    a: (&str, &[Vec<bool>]),
+    b: (&str, &[Vec<i32>]),
+    c: (&str, &[Vec<i32>]),
+) -> (Arc<BarrierExec>, Arc<TestMemoryExec>) {
+    let (batches, schema) = build_batches(a, b, c);
+
+    let memory_exec = TestMemoryExec::try_new_exec(
+        std::slice::from_ref(&batches),
+        Arc::clone(&schema),
+        None,
+    )
+    .unwrap();
+
+    let barrier_exec = Arc::new(
+        BarrierExec::new(vec![batches], schema)
+            .with_log(false)
+            .without_start_barrier()
+            .with_finish_barrier(),
+    );
+
+    (barrier_exec, memory_exec)
+}
+
+/// Concat and sort batches by all the columns to make sure we can compare 
them with different join
+fn prepare_record_batches_for_cmp(output: Vec<RecordBatch>) -> RecordBatch {
+    let output_batch = arrow::compute::concat_batches(output[0].schema_ref(), 
&output)
+        .expect("failed to concat batches");
+
+    // Sort on all columns to make sure we have a deterministic order for the 
assertion
+    let sort_columns = output_batch
+        .columns()
+        .iter()
+        .map(|c| SortColumn {
+            values: Arc::clone(c),
+            options: None,
+        })
+        .collect::<Vec<_>>();
+
+    let sorted_columns =
+        arrow::compute::lexsort(&sort_columns, None).expect("failed to sort");
+
+    RecordBatch::try_new(output_batch.schema(), sorted_columns)
+        .expect("failed to create batch")
+}
+
+#[expect(clippy::too_many_arguments)]
+async fn join_get_stream_and_get_expected(
+    left: Arc<dyn ExecutionPlan>,
+    right: Arc<dyn ExecutionPlan>,
+    oracle_left: Arc<dyn ExecutionPlan>,
+    oracle_right: Arc<dyn ExecutionPlan>,
+    on: JoinOn,
+    join_type: JoinType,
+    filter: Option<JoinFilter>,
+    batch_size: usize,
+) -> Result<(SendableRecordBatchStream, RecordBatch)> {
+    let sort_options = vec![SortOptions::default(); on.len()];
+    let null_equality = NullEquality::NullEqualsNothing;
+    let task_ctx = Arc::new(
+        TaskContext::default()
+            
.with_session_config(SessionConfig::default().with_batch_size(batch_size)),
+    );
+
+    let expected_output = {
+        let oracle = HashJoinExec::try_new(
+            oracle_left,
+            oracle_right,
+            on.clone(),
+            filter.clone(),
+            &join_type,
+            None,
+            PartitionMode::Partitioned,
+            null_equality,
+            false,
+        )?;
+
+        let stream = oracle.execute(0, Arc::clone(&task_ctx))?;
+
+        let batches = common::collect(stream).await?;
+
+        prepare_record_batches_for_cmp(batches)
+    };
+
+    let join = SortMergeJoinExec::try_new(
+        left,
+        right,
+        on,
+        filter,
+        join_type,
+        sort_options,
+        null_equality,
+    )?;
+
+    let stream = join.execute(0, task_ctx)?;
+
+    Ok((stream, expected_output))
+}
+
+fn generate_data_for_emit_early_test(
+    batch_size: usize,
+    number_of_batches: usize,
+    join_type: JoinType,
+) -> (
+    Arc<BarrierExec>,
+    Arc<BarrierExec>,
+    Arc<TestMemoryExec>,
+    Arc<TestMemoryExec>,
+) {
+    let number_of_rows_per_batch = number_of_batches * batch_size;
+    // Prepare data
+    let left_a1 = (0..number_of_rows_per_batch as i32)
+        .chunks(batch_size)
+        .into_iter()
+        .map(|chunk| chunk.collect::<Vec<_>>())
+        .collect::<Vec<_>>();
+    let left_b1 = (0..1000000)
+        .filter(|item| {
+            match join_type {
+                LeftAnti | RightAnti => {
+                    let remainder = item % (batch_size as i32);
+
+                    // Make sure to have one that match and one that don't
+                    remainder == 0 || remainder == 1
+                }
+                // Have at least 1 that is not matching
+                _ => item % batch_size as i32 != 0,
+            }
+        })
+        .take(number_of_rows_per_batch)
+        .chunks(batch_size)
+        .into_iter()
+        .map(|chunk| chunk.collect::<Vec<_>>())
+        .collect::<Vec<_>>();
+
+    let left_bool_col1 = left_a1
+        .clone()
+        .into_iter()
+        .map(|b| {
+            b.into_iter()
+                // Mostly true but have some false that not overlap with the 
right column
+                .map(|a| a % (batch_size as i32) != (batch_size as i32) - 2)
+                .collect::<Vec<_>>()
+        })
+        .collect::<Vec<_>>();
+
+    let (left, left_memory) = build_batched_finish_barrier_table(
+        ("bool_col1", left_bool_col1.as_slice()),
+        ("b1", left_b1.as_slice()),
+        ("a1", left_a1.as_slice()),
+    );
+
+    let right_a2 = (0..number_of_rows_per_batch as i32)
+        .map(|item| item * 11)
+        .chunks(batch_size)
+        .into_iter()
+        .map(|chunk| chunk.collect::<Vec<_>>())
+        .collect::<Vec<_>>();
+    let right_b1 = (0..1000000)
+        .filter(|item| {
+            match join_type {
+                LeftAnti | RightAnti => {
+                    let remainder = item % (batch_size as i32);
+
+                    // Make sure to have one that match and one that don't
+                    remainder == 1 || remainder == 2
+                }
+                // Have at least 1 that is not matching
+                _ => item % batch_size as i32 != 1,
+            }
+        })
+        .take(number_of_rows_per_batch)
+        .chunks(batch_size)
+        .into_iter()
+        .map(|chunk| chunk.collect::<Vec<_>>())
+        .collect::<Vec<_>>();
+    let right_bool_col2 = right_a2
+        .clone()
+        .into_iter()
+        .map(|b| {
+            b.into_iter()
+                // Mostly true but have some false that not overlap with the 
left column
+                .map(|a| a % (batch_size as i32) != (batch_size as i32) - 1)
+                .collect::<Vec<_>>()
+        })
+        .collect::<Vec<_>>();
+
+    let (right, right_memory) = build_batched_finish_barrier_table(
+        ("bool_col2", right_bool_col2.as_slice()),
+        ("b1", right_b1.as_slice()),
+        ("a2", right_a2.as_slice()),
+    );
+
+    (left, right, left_memory, right_memory)
+}
+
+#[tokio::test]
+async fn test_should_emit_early_when_have_enough_data_to_emit() -> Result<()> {
+    for with_filtering in [false, true] {
+        let join_types = vec![
+            Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, 
RightMark,
+        ];
+        const BATCH_SIZE: usize = 10;
+        for join_type in join_types {
+            for output_batch_size in [
+                BATCH_SIZE / 3,
+                BATCH_SIZE / 2,
+                BATCH_SIZE,
+                BATCH_SIZE * 2,
+                BATCH_SIZE * 3,
+            ] {
+                // Make sure the number of batches is enough for all join type 
to emit some output
+                let number_of_batches = if output_batch_size <= BATCH_SIZE {
+                    100
+                } else {
+                    // Have enough batches
+                    (output_batch_size * 100) / BATCH_SIZE
+                };
+
+                let (left, right, left_memory, right_memory) =
+                    generate_data_for_emit_early_test(
+                        BATCH_SIZE,
+                        number_of_batches,
+                        join_type,
+                    );
+
+                let on = vec![(
+                    Arc::new(Column::new_with_schema("b1", &left.schema())?) 
as _,
+                    Arc::new(Column::new_with_schema("b1", &right.schema())?) 
as _,
+                )];
+
+                let join_filter = if with_filtering {
+                    let filter = JoinFilter::new(
+                        Arc::new(BinaryExpr::new(
+                            Arc::new(Column::new("bool_col1", 0)),
+                            Operator::And,
+                            Arc::new(Column::new("bool_col2", 1)),
+                        )),
+                        vec![
+                            ColumnIndex {
+                                index: 0,
+                                side: JoinSide::Left,
+                            },
+                            ColumnIndex {
+                                index: 0,
+                                side: JoinSide::Right,
+                            },
+                        ],
+                        Arc::new(Schema::new(vec![
+                            Field::new("bool_col1", DataType::Boolean, true),
+                            Field::new("bool_col2", DataType::Boolean, true),
+                        ])),
+                    );
+                    Some(filter)
+                } else {
+                    None
+                };
+
+                // select *
+                // from t1
+                // right join t2 on t1.b1 = t2.b1 and t1.bool_col1 AND 
t2.bool_col2
+                let (mut output_stream, expected) = 
join_get_stream_and_get_expected(
+                    Arc::clone(&left) as Arc<dyn ExecutionPlan>,
+                    Arc::clone(&right) as Arc<dyn ExecutionPlan>,
+                    left_memory as Arc<dyn ExecutionPlan>,
+                    right_memory as Arc<dyn ExecutionPlan>,
+                    on,
+                    join_type,
+                    join_filter,
+                    output_batch_size,
+                )
+                .await?;
+
+                let (output_batched, output_batches_after_finish) =
+                  consume_stream_until_finish_barrier_reached(left, right, 
&mut output_stream).await.unwrap_or_else(|e| panic!("Failed to consume stream 
for join type: '{join_type}' and with filtering '{with_filtering}': {e:?}"));
+
+                // It should emit more than that, but we are being generous
+                // and to make sure the test pass for all
+                const MINIMUM_OUTPUT_BATCHES: usize = 5;
+                assert!(
+                    MINIMUM_OUTPUT_BATCHES <= number_of_batches / 5,
+                    "Make sure that the minimum output batches is realistic"
+                );
+                // Test to make sure that we are not waiting for input to be 
fully consumed to emit some output
+                assert!(
+                    output_batched.len() >= MINIMUM_OUTPUT_BATCHES,
+                    "[Sort Merge Join {join_type}] Stream must have at least 
emit {} batches, but only got {} batches",
+                    MINIMUM_OUTPUT_BATCHES,
+                    output_batched.len()
+                );
+
+                // Just sanity test to make sure we are still producing valid 
output
+                {
+                    let output = [output_batched, 
output_batches_after_finish].concat();
+                    let actual_prepared = 
prepare_record_batches_for_cmp(output);
+
+                    assert_eq!(actual_prepared.columns(), expected.columns());
+                }
+            }
+        }
+    }
+    Ok(())
+}
+
+/// Polls the stream until both barriers are reached,
+/// collecting the emitted batches along the way.
+///
+/// If the stream is pending for too long (5s) without emitting any batches,
+/// it panics to avoid hanging the test indefinitely.
+///
+/// Note: The left and right BarrierExec might be the input of the output 
stream
+async fn consume_stream_until_finish_barrier_reached(
+    left: Arc<BarrierExec>,
+    right: Arc<BarrierExec>,
+    output_stream: &mut SendableRecordBatchStream,
+) -> Result<(Vec<RecordBatch>, Vec<RecordBatch>)> {
+    let mut switch_to_finish_barrier = false;
+    let mut output_batched = vec![];
+    let mut after_finish_barrier_reached = vec![];
+    let mut background_task = JoinSet::new();
+
+    let mut start_time_since_last_ready = 
datafusion_common::instant::Instant::now();
+    loop {
+        let next_item = output_stream.next();
+
+        // Manual polling
+        let poll_output = futures::poll!(next_item);
+
+        // Wake up the stream to make sure it makes progress
+        tokio::task::yield_now().await;
+
+        match poll_output {
+            Poll::Ready(Some(Ok(batch))) => {
+                if batch.num_rows() == 0 {
+                    return internal_err!("join stream should not emit empty 
batch");
+                }
+                if switch_to_finish_barrier {
+                    after_finish_barrier_reached.push(batch);
+                } else {
+                    output_batched.push(batch);
+                }
+                start_time_since_last_ready = 
datafusion_common::instant::Instant::now();
+            }
+            Poll::Ready(Some(Err(e))) => return Err(e),
+            Poll::Ready(None) if !switch_to_finish_barrier => {
+                unreachable!("Stream should not end before manually finishing 
it")
+            }
+            Poll::Ready(None) => {
+                break;
+            }
+            Poll::Pending => {
+                if right.is_finish_barrier_reached()
+                    && left.is_finish_barrier_reached()
+                    && !switch_to_finish_barrier
+                {
+                    switch_to_finish_barrier = true;
+
+                    let right = Arc::clone(&right);
+                    background_task.spawn(async move {
+                        right.wait_finish().await;
+                    });
+                    let left = Arc::clone(&left);
+                    background_task.spawn(async move {
+                        left.wait_finish().await;
+                    });
+                }
+
+                // Make sure the test doesn't run forever
+                if start_time_since_last_ready.elapsed()
+                    > std::time::Duration::from_secs(5)
+                {
+                    return internal_err!(
+                        "Stream should have emitted data by now, but it's 
still pending. Output batches so far: {}",
+                        output_batched.len()
+                    );
+                }
+            }
+        }
+    }
+
+    Ok((output_batched, after_finish_barrier_reached))
+}
+
 /// Returns the column names on the schema
 fn columns(schema: &Schema) -> Vec<String> {
     schema.fields().iter().map(|f| f.name().clone()).collect()
diff --git a/datafusion/physical-plan/src/test/exec.rs 
b/datafusion/physical-plan/src/test/exec.rs
index ebed84477a..df5093226e 100644
--- a/datafusion/physical-plan/src/test/exec.rs
+++ b/datafusion/physical-plan/src/test/exec.rs
@@ -17,13 +17,6 @@
 
 //! Simple iterator over batches for use in testing
 
-use std::{
-    any::Any,
-    pin::Pin,
-    sync::{Arc, Weak},
-    task::{Context, Poll},
-};
-
 use crate::{
     DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
     RecordBatchStream, SendableRecordBatchStream, Statistics, common,
@@ -33,6 +26,13 @@ use crate::{
     execution_plan::EmissionType,
     stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter},
 };
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::{
+    any::Any,
+    pin::Pin,
+    sync::{Arc, Weak},
+    task::{Context, Poll},
+};
 
 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
@@ -294,29 +294,91 @@ pub struct BarrierExec {
     schema: SchemaRef,
 
     /// all streams wait on this barrier to produce
-    barrier: Arc<Barrier>,
+    start_data_barrier: Option<Arc<Barrier>>,
+
+    /// the stream wait for this to return Poll::Ready(None)
+    finish_barrier: Option<Arc<(Barrier, AtomicUsize)>>,
+
     cache: PlanProperties,
+
+    log: bool,
 }
 
 impl BarrierExec {
     /// Create a new exec with some number of partitions.
     pub fn new(data: Vec<Vec<RecordBatch>>, schema: SchemaRef) -> Self {
         // wait for all streams and the input
-        let barrier = Arc::new(Barrier::new(data.len() + 1));
+        let barrier = Some(Arc::new(Barrier::new(data.len() + 1)));
         let cache = Self::compute_properties(Arc::clone(&schema), &data);
         Self {
             data,
             schema,
-            barrier,
+            start_data_barrier: barrier,
             cache,
+            finish_barrier: None,
+            log: true,
         }
     }
 
+    pub fn with_log(mut self, log: bool) -> Self {
+        self.log = log;
+        self
+    }
+
+    pub fn without_start_barrier(mut self) -> Self {
+        self.start_data_barrier = None;
+        self
+    }
+
+    pub fn with_finish_barrier(mut self) -> Self {
+        let barrier = Arc::new((
+            // wait for all streams and the input
+            Barrier::new(self.data.len() + 1),
+            AtomicUsize::new(0),
+        ));
+
+        self.finish_barrier = Some(barrier);
+        self
+    }
+
     /// wait until all the input streams and this function is ready
     pub async fn wait(&self) {
-        println!("BarrierExec::wait waiting on barrier");
-        self.barrier.wait().await;
-        println!("BarrierExec::wait done waiting");
+        let barrier = &self
+            .start_data_barrier
+            .as_ref()
+            .expect("Must only be called when having a start barrier");
+        if self.log {
+            println!("BarrierExec::wait waiting on barrier");
+        }
+        barrier.wait().await;
+        if self.log {
+            println!("BarrierExec::wait done waiting");
+        }
+    }
+
+    pub async fn wait_finish(&self) {
+        let (barrier, _) = &self
+            .finish_barrier
+            .as_deref()
+            .expect("Must only be called when having a finish barrier");
+
+        if self.log {
+            println!("BarrierExec::wait_finish waiting on barrier");
+        }
+        barrier.wait().await;
+        if self.log {
+            println!("BarrierExec::wait_finish done waiting");
+        }
+    }
+
+    /// Return true if the finish barrier has been reached in all partitions
+    pub fn is_finish_barrier_reached(&self) -> bool {
+        let (_, reached_finish) = self
+            .finish_barrier
+            .as_deref()
+            .expect("Must only be called when having finish barrier");
+
+        reached_finish.load(Ordering::Relaxed) == self.data.len()
     }
 
     /// This function creates the cache object that stores the plan properties 
such as schema, equivalence properties, ordering, partitioning, etc.
@@ -387,17 +449,32 @@ impl ExecutionPlan for BarrierExec {
 
         // task simply sends data in order after barrier is reached
         let data = self.data[partition].clone();
-        let b = Arc::clone(&self.barrier);
+        let start_barrier = self.start_data_barrier.as_ref().map(Arc::clone);
+        let finish_barrier = self.finish_barrier.as_ref().map(Arc::clone);
+        let log = self.log;
         let tx = builder.tx();
         builder.spawn(async move {
-            println!("Partition {partition} waiting on barrier");
-            b.wait().await;
+            if let Some(barrier) = start_barrier {
+                if log {
+                    println!("Partition {partition} waiting on barrier");
+                }
+                barrier.wait().await;
+            }
             for batch in data {
-                println!("Partition {partition} sending batch");
+                if log {
+                    println!("Partition {partition} sending batch");
+                }
                 if let Err(e) = tx.send(Ok(batch)).await {
                     println!("ERROR batch via barrier stream stream: {e}");
                 }
             }
+            if let Some((barrier, reached_finish)) = finish_barrier.as_deref() 
{
+                if log {
+                    println!("Partition {partition} waiting on finish 
barrier");
+                }
+                reached_finish.fetch_add(1, Ordering::Relaxed);
+                barrier.wait().await;
+            }
 
             Ok(())
         });


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to