This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 94b5511321 Fix: Sort Merge Join LeftSemi issues when JoinFilter is set 
(#10304)
94b5511321 is described below

commit 94b55113218bad0e41522b0c090526e1e822565d
Author: Oleks V <[email protected]>
AuthorDate: Mon May 20 14:56:58 2024 -0700

    Fix: Sort Merge Join LeftSemi issues when JoinFilter is set (#10304)
    
    
    * Fix: Sort Merge Join Left Semi crashes
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../physical-plan/src/joins/sort_merge_join.rs     | 230 ++++++++++++++++++---
 .../sqllogictest/test_files/sort_merge_join.slt    | 151 ++++++++++++++
 2 files changed, 352 insertions(+), 29 deletions(-)

diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index d4cf6864d7..1cc7bf4700 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -30,22 +30,13 @@ use std::pin::Pin;
 use std::sync::Arc;
 use std::task::{Context, Poll};
 
-use crate::expressions::PhysicalSortExpr;
-use crate::joins::utils::{
-    build_join_schema, check_join_is_valid, estimate_join_statistics,
-    partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
-};
-use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
-use crate::{
-    execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, 
Distribution,
-    ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
-    RecordBatchStream, SendableRecordBatchStream, Statistics,
-};
-
 use arrow::array::*;
 use arrow::compute::{self, concat_batches, take, SortOptions};
 use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
 use arrow::error::ArrowError;
+use futures::{Stream, StreamExt};
+use hashbrown::HashSet;
+
 use datafusion_common::{
     internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, 
Result,
 };
@@ -54,7 +45,17 @@ use datafusion_execution::TaskContext;
 use datafusion_physical_expr::equivalence::join_equivalence_properties;
 use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};
 
-use futures::{Stream, StreamExt};
+use crate::expressions::PhysicalSortExpr;
+use crate::joins::utils::{
+    build_join_schema, check_join_is_valid, estimate_join_statistics,
+    partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
+};
+use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
+use crate::{
+    execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, 
Distribution,
+    ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
+    RecordBatchStream, SendableRecordBatchStream, Statistics,
+};
 
 /// join execution plan executes partitions in parallel and combines them into 
a set of
 /// partitions.
@@ -491,6 +492,10 @@ struct StreamedBatch {
     pub output_indices: Vec<StreamedJoinedChunk>,
     /// Index of currently scanned batch from buffered data
     pub buffered_batch_idx: Option<usize>,
+    /// Indices that found a match for the given join filter
+    /// Used for semi joins to keep track the streaming index which got a join 
filter match
+    /// and already emitted to the output.
+    pub join_filter_matched_idxs: HashSet<u64>,
 }
 
 impl StreamedBatch {
@@ -502,6 +507,7 @@ impl StreamedBatch {
             join_arrays,
             output_indices: vec![],
             buffered_batch_idx: None,
+            join_filter_matched_idxs: HashSet::new(),
         }
     }
 
@@ -512,6 +518,7 @@ impl StreamedBatch {
             join_arrays: vec![],
             output_indices: vec![],
             buffered_batch_idx: None,
+            join_filter_matched_idxs: HashSet::new(),
         }
     }
 
@@ -990,7 +997,22 @@ impl SMJStream {
             }
             Ordering::Equal => {
                 if matches!(self.join_type, JoinType::LeftSemi) {
-                    join_streamed = !self.streamed_joined;
+                    // if the join filter is specified then its needed to 
output the streamed index
+                    // only if it has not been emitted before
+                    // the `join_filter_matched_idxs` keeps track on if 
streamed index has a successful
+                    // filter match and prevents the same index to go into 
output more than once
+                    if self.filter.is_some() {
+                        join_streamed = !self
+                            .streamed_batch
+                            .join_filter_matched_idxs
+                            .contains(&(self.streamed_batch.idx as u64))
+                            && !self.streamed_joined;
+                        // if the join filter specified there can be 
references to buffered columns
+                        // so buffered columns are needed to access them
+                        join_buffered = join_streamed;
+                    } else {
+                        join_streamed = !self.streamed_joined;
+                    }
                 }
                 if matches!(
                     self.join_type,
@@ -1134,17 +1156,15 @@ impl SMJStream {
                 .collect::<Result<Vec<_>, ArrowError>>()?;
 
             let buffered_indices: UInt64Array = 
chunk.buffered_indices.finish();
-
             let mut buffered_columns =
                 if matches!(self.join_type, JoinType::LeftSemi | 
JoinType::LeftAnti) {
                     vec![]
                 } else if let Some(buffered_idx) = chunk.buffered_batch_idx {
-                    self.buffered_data.batches[buffered_idx]
-                        .batch
-                        .columns()
-                        .iter()
-                        .map(|column| take(column, &buffered_indices, None))
-                        .collect::<Result<Vec<_>, ArrowError>>()?
+                    get_buffered_columns(
+                        &self.buffered_data,
+                        buffered_idx,
+                        &buffered_indices,
+                    )?
                 } else {
                     self.buffered_schema
                         .fields()
@@ -1161,6 +1181,15 @@ impl SMJStream {
             let filter_columns = if chunk.buffered_batch_idx.is_some() {
                 if matches!(self.join_type, JoinType::Right) {
                     get_filter_column(&self.filter, &buffered_columns, 
&streamed_columns)
+                } else if matches!(self.join_type, JoinType::LeftSemi) {
+                    // unwrap is safe here as we check is_some on top of if 
statement
+                    let buffered_columns = get_buffered_columns(
+                        &self.buffered_data,
+                        chunk.buffered_batch_idx.unwrap(),
+                        &buffered_indices,
+                    )?;
+
+                    get_filter_column(&self.filter, &streamed_columns, 
&buffered_columns)
                 } else {
                     get_filter_column(&self.filter, &streamed_columns, 
&buffered_columns)
                 }
@@ -1195,7 +1224,17 @@ impl SMJStream {
                         .into_array(filter_batch.num_rows())?;
 
                     // The selection mask of the filter
-                    let mask = 
datafusion_common::cast::as_boolean_array(&filter_result)?;
+                    let mut mask =
+                        
datafusion_common::cast::as_boolean_array(&filter_result)?;
+
+                    let maybe_filtered_join_mask: Option<(BooleanArray, 
Vec<u64>)> =
+                        get_filtered_join_mask(self.join_type, 
streamed_indices, mask);
+                    if let Some(ref filtered_join_mask) = 
maybe_filtered_join_mask {
+                        mask = &filtered_join_mask.0;
+                        self.streamed_batch
+                            .join_filter_matched_idxs
+                            .extend(&filtered_join_mask.1);
+                    }
 
                     // Push the filtered batch to the output
                     let filtered_batch =
@@ -1365,6 +1404,69 @@ fn get_filter_column(
     filter_columns
 }
 
+/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]`
+#[inline(always)]
+fn get_buffered_columns(
+    buffered_data: &BufferedData,
+    buffered_batch_idx: usize,
+    buffered_indices: &UInt64Array,
+) -> Result<Vec<ArrayRef>, ArrowError> {
+    buffered_data.batches[buffered_batch_idx]
+        .batch
+        .columns()
+        .iter()
+        .map(|column| take(column, &buffered_indices, None))
+        .collect::<Result<Vec<_>, ArrowError>>()
+}
+
+// Calculate join filter bit mask considering join type specifics
+// `streamed_indices` - array of streamed datasource JOINED row indices
+// `mask` - array booleans representing computed join filter expression eval 
result:
+//      true = the row index matches the join filter
+//      false = the row index doesn't match the join filter
+// `streamed_indices` have the same length as `mask`
+fn get_filtered_join_mask(
+    join_type: JoinType,
+    streamed_indices: UInt64Array,
+    mask: &BooleanArray,
+) -> Option<(BooleanArray, Vec<u64>)> {
+    // for LeftSemi Join the filter mask should be calculated in its own way:
+    // if we find at least one matching row for specific streaming index
+    // we don't need to check any others for the same index
+    if matches!(join_type, JoinType::LeftSemi) {
+        // have we seen a filter match for a streaming index before
+        let mut seen_as_true: bool = false;
+        let streamed_indices_length = streamed_indices.len();
+        let mut corrected_mask: BooleanBuilder =
+            BooleanBuilder::with_capacity(streamed_indices_length);
+
+        let mut filter_matched_indices: Vec<u64> = vec![];
+
+        #[allow(clippy::needless_range_loop)]
+        for i in 0..streamed_indices_length {
+            // LeftSemi respects only first true values for specific streaming 
index,
+            // others true values for the same index must be false
+            if mask.value(i) && !seen_as_true {
+                seen_as_true = true;
+                corrected_mask.append_value(true);
+                filter_matched_indices.push(streamed_indices.value(i));
+            } else {
+                corrected_mask.append_value(false);
+            }
+
+            // if switched to next streaming index(e.g. from 0 to 1, or from 1 
to 2), we reset seen_as_true flag
+            if i < streamed_indices_length - 1
+                && streamed_indices.value(i) != streamed_indices.value(i + 1)
+            {
+                seen_as_true = false;
+            }
+        }
+        Some((corrected_mask.finish(), filter_matched_indices))
+    } else {
+        None
+    }
+}
+
 /// Buffered data contains all buffered batches with one unique join key
 #[derive(Debug, Default)]
 struct BufferedData {
@@ -1604,17 +1706,13 @@ fn is_join_arrays_equal(
 mod tests {
     use std::sync::Arc;
 
-    use crate::expressions::Column;
-    use crate::joins::utils::JoinOn;
-    use crate::joins::SortMergeJoinExec;
-    use crate::memory::MemoryExec;
-    use crate::test::build_table_i32;
-    use crate::{common, ExecutionPlan};
-
     use arrow::array::{Date32Array, Date64Array, Int32Array};
     use arrow::compute::SortOptions;
     use arrow::datatypes::{DataType, Field, Schema};
     use arrow::record_batch::RecordBatch;
+    use arrow_array::{BooleanArray, UInt64Array};
+
+    use datafusion_common::JoinType::LeftSemi;
     use datafusion_common::{
         assert_batches_eq, assert_batches_sorted_eq, assert_contains, 
JoinType, Result,
     };
@@ -1622,6 +1720,14 @@ mod tests {
     use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
     use datafusion_execution::TaskContext;
 
+    use crate::expressions::Column;
+    use crate::joins::sort_merge_join::get_filtered_join_mask;
+    use crate::joins::utils::JoinOn;
+    use crate::joins::SortMergeJoinExec;
+    use crate::memory::MemoryExec;
+    use crate::test::build_table_i32;
+    use crate::{common, ExecutionPlan};
+
     fn build_table(
         a: (&str, &Vec<i32>),
         b: (&str, &Vec<i32>),
@@ -2641,6 +2747,72 @@ mod tests {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn left_semi_join_filtered_mask() -> Result<()> {
+        assert_eq!(
+            get_filtered_join_mask(
+                LeftSemi,
+                UInt64Array::from(vec![0, 0, 1, 1]),
+                &BooleanArray::from(vec![true, true, false, false])
+            ),
+            Some((BooleanArray::from(vec![true, false, false, false]), 
vec![0]))
+        );
+
+        assert_eq!(
+            get_filtered_join_mask(
+                LeftSemi,
+                UInt64Array::from(vec![0, 1]),
+                &BooleanArray::from(vec![true, true])
+            ),
+            Some((BooleanArray::from(vec![true, true]), vec![0, 1]))
+        );
+
+        assert_eq!(
+            get_filtered_join_mask(
+                LeftSemi,
+                UInt64Array::from(vec![0, 1]),
+                &BooleanArray::from(vec![false, true])
+            ),
+            Some((BooleanArray::from(vec![false, true]), vec![1]))
+        );
+
+        assert_eq!(
+            get_filtered_join_mask(
+                LeftSemi,
+                UInt64Array::from(vec![0, 1]),
+                &BooleanArray::from(vec![true, false])
+            ),
+            Some((BooleanArray::from(vec![true, false]), vec![0]))
+        );
+
+        assert_eq!(
+            get_filtered_join_mask(
+                LeftSemi,
+                UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
+                &BooleanArray::from(vec![false, true, true, true, true, true])
+            ),
+            Some((
+                BooleanArray::from(vec![false, true, false, true, false, 
false]),
+                vec![0, 1]
+            ))
+        );
+
+        assert_eq!(
+            get_filtered_join_mask(
+                LeftSemi,
+                UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
+                &BooleanArray::from(vec![false, false, false, false, false, 
true])
+            ),
+            Some((
+                BooleanArray::from(vec![false, false, false, false, false, 
true]),
+                vec![1]
+            ))
+        );
+
+        Ok(())
+    }
+
     /// Returns the column names on the schema
     fn columns(schema: &Schema) -> Vec<String> {
         schema.fields().iter().map(|f| f.name().clone()).collect()
diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt 
b/datafusion/sqllogictest/test_files/sort_merge_join.slt
index 7b7e355fa2..3a27d9693d 100644
--- a/datafusion/sqllogictest/test_files/sort_merge_join.slt
+++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt
@@ -263,6 +263,139 @@ DROP TABLE t1;
 statement ok
 DROP TABLE t2;
 
+# LEFTSEMI join tests
+
+query II
+select * from (
+with
+t1 as (
+    select 11 a, 12 b union all
+    select 11 a, 13 b),
+t2 as (
+    select 11 a, 12 b union all
+    select 11 a, 13 b
+    )
+    select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and 
t2.b = t1.b)
+) order by 1, 2
+----
+11 12
+11 13
+
+query II
+select * from (
+with
+t1 as (
+    select 11 a, 12 b union all
+    select 11 a, 13 b),
+t2 as (
+    select 11 a, 12 b union all
+    select 11 a, 13 b
+    )
+    select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and 
t2.b != t1.b)
+) order by 1, 2;
+----
+11 12
+11 13
+
+query II
+select * from (
+with
+t1 as (
+    select null a, 12 b union all
+    select 11 a, 13 b),
+t2 as (
+    select 11 a, 12 b union all
+    select 11 a, 13 b
+    )
+    select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and 
t2.b != t1.b)
+) order by 1, 2;
+----
+11 13
+
+query II
+select * from (
+with
+t1 as (
+    select 11 a, 12 b union all
+    select 11 a, 13 b)
+    select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a 
and t2.b = t1.b)
+) order by 1, 2;
+----
+11 12
+11 13
+
+query II
+select * from (
+with
+t1 as (
+    select 11 a, 12 b union all
+    select 11 a, 13 b)
+    select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a 
and t2.b != t1.b)
+) order by 1, 2;
+----
+11 12
+11 13
+
+query II
+select * from (
+with
+t1 as (
+    select null a, 12 b union all
+    select 11 a, 13 b)
+    select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a 
and t2.b != t1.b)
+) order by 1, 2;
+
+query II
+select * from (
+with
+t1 as (
+    select 11 a, 12 b union all
+    select 11 a, 13 b),
+t2 as (
+    select 11 a, 12 b union all
+    select 11 a, 14 b
+    )
+select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b 
!= t1.b)
+) order by 1, 2;
+----
+11 12
+11 13
+
+query II
+select * from (
+with
+t1 as (
+    select 11 a, 12 b union all
+    select 11 a, 13 b),
+t2 as (
+    select 11 a, 12 b union all
+    select 11 a, 12 b union all
+    select 11 a, 14 b
+    )
+select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b 
!= t1.b)
+) order by 1, 2;
+----
+11 12
+11 13
+
+#LEFTANTI tests
+# returns no rows instead of correct result
+#query III
+#select * from (
+#with
+#t1 as (
+#    select 11 a, 12 b, 1 c union all
+#    select 11 a, 13 b, 2 c),
+#t2 as (
+#    select 11 a, 12 b, 3 c union all
+#    select 11 a, 14 b, 4 c
+#    )
+#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and 
t2.b != t1.b and t1.c > t2.c)
+#) order by 1, 2;
+#----
+#11 12 1
+#11 13 2
+
 # Set batch size to 1 for sort merge join to test scenario when data spread 
across multiple batches
 statement ok
 set datafusion.execution.batch_size = 1;
@@ -280,5 +413,23 @@ SELECT * FROM (
 ) ORDER BY 1, 2;
 ----
 
+
+query II
+select * from (
+with
+t1 as (
+    select 11 a, 12 b union all
+    select 11 a, 13 b),
+t2 as (
+    select 11 a, 12 b union all
+    select 11 a, 12 b union all
+    select 11 a, 14 b
+    )
+select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b 
!= t1.b)
+) order by 1, 2;
+----
+11 12
+11 13
+
 statement ok
 set datafusion.optimizer.prefer_hash_join = true;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to