This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 94b5511321 Fix: Sort Merge Join LeftSemi issues when JoinFilter is set
(#10304)
94b5511321 is described below
commit 94b55113218bad0e41522b0c090526e1e822565d
Author: Oleks V <[email protected]>
AuthorDate: Mon May 20 14:56:58 2024 -0700
Fix: Sort Merge Join LeftSemi issues when JoinFilter is set (#10304)
* Fix: Sort Merge Join Left Semi crashes
Co-authored-by: Andrew Lamb <[email protected]>
---
.../physical-plan/src/joins/sort_merge_join.rs | 230 ++++++++++++++++++---
.../sqllogictest/test_files/sort_merge_join.slt | 151 ++++++++++++++
2 files changed, 352 insertions(+), 29 deletions(-)
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index d4cf6864d7..1cc7bf4700 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -30,22 +30,13 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
-use crate::expressions::PhysicalSortExpr;
-use crate::joins::utils::{
- build_join_schema, check_join_is_valid, estimate_join_statistics,
- partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
-};
-use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
-use crate::{
- execution_mode_from_children, metrics, DisplayAs, DisplayFormatType,
Distribution,
- ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
- RecordBatchStream, SendableRecordBatchStream, Statistics,
-};
-
use arrow::array::*;
use arrow::compute::{self, concat_batches, take, SortOptions};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use arrow::error::ArrowError;
+use futures::{Stream, StreamExt};
+use hashbrown::HashSet;
+
use datafusion_common::{
internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType,
Result,
};
@@ -54,7 +45,17 @@ use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};
-use futures::{Stream, StreamExt};
+use crate::expressions::PhysicalSortExpr;
+use crate::joins::utils::{
+ build_join_schema, check_join_is_valid, estimate_join_statistics,
+ partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
+};
+use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
+use crate::{
+ execution_mode_from_children, metrics, DisplayAs, DisplayFormatType,
Distribution,
+ ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
+ RecordBatchStream, SendableRecordBatchStream, Statistics,
+};
/// join execution plan executes partitions in parallel and combines them into
a set of
/// partitions.
@@ -491,6 +492,10 @@ struct StreamedBatch {
pub output_indices: Vec<StreamedJoinedChunk>,
/// Index of currently scanned batch from buffered data
pub buffered_batch_idx: Option<usize>,
+ /// Indices that found a match for the given join filter
+ /// Used for semi joins to keep track the streaming index which got a join
filter match
+ /// and already emitted to the output.
+ pub join_filter_matched_idxs: HashSet<u64>,
}
impl StreamedBatch {
@@ -502,6 +507,7 @@ impl StreamedBatch {
join_arrays,
output_indices: vec![],
buffered_batch_idx: None,
+ join_filter_matched_idxs: HashSet::new(),
}
}
@@ -512,6 +518,7 @@ impl StreamedBatch {
join_arrays: vec![],
output_indices: vec![],
buffered_batch_idx: None,
+ join_filter_matched_idxs: HashSet::new(),
}
}
@@ -990,7 +997,22 @@ impl SMJStream {
}
Ordering::Equal => {
if matches!(self.join_type, JoinType::LeftSemi) {
- join_streamed = !self.streamed_joined;
+ // if the join filter is specified then its needed to
output the streamed index
+ // only if it has not been emitted before
+ // the `join_filter_matched_idxs` keeps track on if
streamed index has a successful
+ // filter match and prevents the same index to go into
output more than once
+ if self.filter.is_some() {
+ join_streamed = !self
+ .streamed_batch
+ .join_filter_matched_idxs
+ .contains(&(self.streamed_batch.idx as u64))
+ && !self.streamed_joined;
+ // if the join filter specified there can be
references to buffered columns
+ // so buffered columns are needed to access them
+ join_buffered = join_streamed;
+ } else {
+ join_streamed = !self.streamed_joined;
+ }
}
if matches!(
self.join_type,
@@ -1134,17 +1156,15 @@ impl SMJStream {
.collect::<Result<Vec<_>, ArrowError>>()?;
let buffered_indices: UInt64Array =
chunk.buffered_indices.finish();
-
let mut buffered_columns =
if matches!(self.join_type, JoinType::LeftSemi |
JoinType::LeftAnti) {
vec![]
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
- self.buffered_data.batches[buffered_idx]
- .batch
- .columns()
- .iter()
- .map(|column| take(column, &buffered_indices, None))
- .collect::<Result<Vec<_>, ArrowError>>()?
+ get_buffered_columns(
+ &self.buffered_data,
+ buffered_idx,
+ &buffered_indices,
+ )?
} else {
self.buffered_schema
.fields()
@@ -1161,6 +1181,15 @@ impl SMJStream {
let filter_columns = if chunk.buffered_batch_idx.is_some() {
if matches!(self.join_type, JoinType::Right) {
get_filter_column(&self.filter, &buffered_columns,
&streamed_columns)
+ } else if matches!(self.join_type, JoinType::LeftSemi) {
+ // unwrap is safe here as we check is_some on top of if
statement
+ let buffered_columns = get_buffered_columns(
+ &self.buffered_data,
+ chunk.buffered_batch_idx.unwrap(),
+ &buffered_indices,
+ )?;
+
+ get_filter_column(&self.filter, &streamed_columns,
&buffered_columns)
} else {
get_filter_column(&self.filter, &streamed_columns,
&buffered_columns)
}
@@ -1195,7 +1224,17 @@ impl SMJStream {
.into_array(filter_batch.num_rows())?;
// The selection mask of the filter
- let mask =
datafusion_common::cast::as_boolean_array(&filter_result)?;
+ let mut mask =
+
datafusion_common::cast::as_boolean_array(&filter_result)?;
+
+ let maybe_filtered_join_mask: Option<(BooleanArray,
Vec<u64>)> =
+ get_filtered_join_mask(self.join_type,
streamed_indices, mask);
+ if let Some(ref filtered_join_mask) =
maybe_filtered_join_mask {
+ mask = &filtered_join_mask.0;
+ self.streamed_batch
+ .join_filter_matched_idxs
+ .extend(&filtered_join_mask.1);
+ }
// Push the filtered batch to the output
let filtered_batch =
@@ -1365,6 +1404,69 @@ fn get_filter_column(
filter_columns
}
+/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]`
+#[inline(always)]
+fn get_buffered_columns(
+ buffered_data: &BufferedData,
+ buffered_batch_idx: usize,
+ buffered_indices: &UInt64Array,
+) -> Result<Vec<ArrayRef>, ArrowError> {
+ buffered_data.batches[buffered_batch_idx]
+ .batch
+ .columns()
+ .iter()
+ .map(|column| take(column, &buffered_indices, None))
+ .collect::<Result<Vec<_>, ArrowError>>()
+}
+
+// Calculate join filter bit mask considering join type specifics
+// `streamed_indices` - array of streamed datasource JOINED row indices
+// `mask` - array booleans representing computed join filter expression eval
result:
+// true = the row index matches the join filter
+// false = the row index doesn't match the join filter
+// `streamed_indices` have the same length as `mask`
+fn get_filtered_join_mask(
+ join_type: JoinType,
+ streamed_indices: UInt64Array,
+ mask: &BooleanArray,
+) -> Option<(BooleanArray, Vec<u64>)> {
+ // for LeftSemi Join the filter mask should be calculated in its own way:
+ // if we find at least one matching row for specific streaming index
+ // we don't need to check any others for the same index
+ if matches!(join_type, JoinType::LeftSemi) {
+ // have we seen a filter match for a streaming index before
+ let mut seen_as_true: bool = false;
+ let streamed_indices_length = streamed_indices.len();
+ let mut corrected_mask: BooleanBuilder =
+ BooleanBuilder::with_capacity(streamed_indices_length);
+
+ let mut filter_matched_indices: Vec<u64> = vec![];
+
+ #[allow(clippy::needless_range_loop)]
+ for i in 0..streamed_indices_length {
+ // LeftSemi respects only first true values for specific streaming
index,
+ // others true values for the same index must be false
+ if mask.value(i) && !seen_as_true {
+ seen_as_true = true;
+ corrected_mask.append_value(true);
+ filter_matched_indices.push(streamed_indices.value(i));
+ } else {
+ corrected_mask.append_value(false);
+ }
+
+ // if switched to next streaming index(e.g. from 0 to 1, or from 1
to 2), we reset seen_as_true flag
+ if i < streamed_indices_length - 1
+ && streamed_indices.value(i) != streamed_indices.value(i + 1)
+ {
+ seen_as_true = false;
+ }
+ }
+ Some((corrected_mask.finish(), filter_matched_indices))
+ } else {
+ None
+ }
+}
+
/// Buffered data contains all buffered batches with one unique join key
#[derive(Debug, Default)]
struct BufferedData {
@@ -1604,17 +1706,13 @@ fn is_join_arrays_equal(
mod tests {
use std::sync::Arc;
- use crate::expressions::Column;
- use crate::joins::utils::JoinOn;
- use crate::joins::SortMergeJoinExec;
- use crate::memory::MemoryExec;
- use crate::test::build_table_i32;
- use crate::{common, ExecutionPlan};
-
use arrow::array::{Date32Array, Date64Array, Int32Array};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
+ use arrow_array::{BooleanArray, UInt64Array};
+
+ use datafusion_common::JoinType::LeftSemi;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains,
JoinType, Result,
};
@@ -1622,6 +1720,14 @@ mod tests {
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_execution::TaskContext;
+ use crate::expressions::Column;
+ use crate::joins::sort_merge_join::get_filtered_join_mask;
+ use crate::joins::utils::JoinOn;
+ use crate::joins::SortMergeJoinExec;
+ use crate::memory::MemoryExec;
+ use crate::test::build_table_i32;
+ use crate::{common, ExecutionPlan};
+
fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
@@ -2641,6 +2747,72 @@ mod tests {
Ok(())
}
+
+ #[tokio::test]
+ async fn left_semi_join_filtered_mask() -> Result<()> {
+ assert_eq!(
+ get_filtered_join_mask(
+ LeftSemi,
+ UInt64Array::from(vec![0, 0, 1, 1]),
+ &BooleanArray::from(vec![true, true, false, false])
+ ),
+ Some((BooleanArray::from(vec![true, false, false, false]),
vec![0]))
+ );
+
+ assert_eq!(
+ get_filtered_join_mask(
+ LeftSemi,
+ UInt64Array::from(vec![0, 1]),
+ &BooleanArray::from(vec![true, true])
+ ),
+ Some((BooleanArray::from(vec![true, true]), vec![0, 1]))
+ );
+
+ assert_eq!(
+ get_filtered_join_mask(
+ LeftSemi,
+ UInt64Array::from(vec![0, 1]),
+ &BooleanArray::from(vec![false, true])
+ ),
+ Some((BooleanArray::from(vec![false, true]), vec![1]))
+ );
+
+ assert_eq!(
+ get_filtered_join_mask(
+ LeftSemi,
+ UInt64Array::from(vec![0, 1]),
+ &BooleanArray::from(vec![true, false])
+ ),
+ Some((BooleanArray::from(vec![true, false]), vec![0]))
+ );
+
+ assert_eq!(
+ get_filtered_join_mask(
+ LeftSemi,
+ UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
+ &BooleanArray::from(vec![false, true, true, true, true, true])
+ ),
+ Some((
+ BooleanArray::from(vec![false, true, false, true, false,
false]),
+ vec![0, 1]
+ ))
+ );
+
+ assert_eq!(
+ get_filtered_join_mask(
+ LeftSemi,
+ UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
+ &BooleanArray::from(vec![false, false, false, false, false,
true])
+ ),
+ Some((
+ BooleanArray::from(vec![false, false, false, false, false,
true]),
+ vec![1]
+ ))
+ );
+
+ Ok(())
+ }
+
/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt
b/datafusion/sqllogictest/test_files/sort_merge_join.slt
index 7b7e355fa2..3a27d9693d 100644
--- a/datafusion/sqllogictest/test_files/sort_merge_join.slt
+++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt
@@ -263,6 +263,139 @@ DROP TABLE t1;
statement ok
DROP TABLE t2;
+# LEFTSEMI join tests
+
+query II
+select * from (
+with
+t1 as (
+ select 11 a, 12 b union all
+ select 11 a, 13 b),
+t2 as (
+ select 11 a, 12 b union all
+ select 11 a, 13 b
+ )
+ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and
t2.b = t1.b)
+) order by 1, 2
+----
+11 12
+11 13
+
+query II
+select * from (
+with
+t1 as (
+ select 11 a, 12 b union all
+ select 11 a, 13 b),
+t2 as (
+ select 11 a, 12 b union all
+ select 11 a, 13 b
+ )
+ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and
t2.b != t1.b)
+) order by 1, 2;
+----
+11 12
+11 13
+
+query II
+select * from (
+with
+t1 as (
+ select null a, 12 b union all
+ select 11 a, 13 b),
+t2 as (
+ select 11 a, 12 b union all
+ select 11 a, 13 b
+ )
+ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and
t2.b != t1.b)
+) order by 1, 2;
+----
+11 13
+
+query II
+select * from (
+with
+t1 as (
+ select 11 a, 12 b union all
+ select 11 a, 13 b)
+ select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a
and t2.b = t1.b)
+) order by 1, 2;
+----
+11 12
+11 13
+
+query II
+select * from (
+with
+t1 as (
+ select 11 a, 12 b union all
+ select 11 a, 13 b)
+ select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a
and t2.b != t1.b)
+) order by 1, 2;
+----
+11 12
+11 13
+
+query II
+select * from (
+with
+t1 as (
+ select null a, 12 b union all
+ select 11 a, 13 b)
+ select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a
and t2.b != t1.b)
+) order by 1, 2;
+
+query II
+select * from (
+with
+t1 as (
+ select 11 a, 12 b union all
+ select 11 a, 13 b),
+t2 as (
+ select 11 a, 12 b union all
+ select 11 a, 14 b
+ )
+select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b
!= t1.b)
+) order by 1, 2;
+----
+11 12
+11 13
+
+query II
+select * from (
+with
+t1 as (
+ select 11 a, 12 b union all
+ select 11 a, 13 b),
+t2 as (
+ select 11 a, 12 b union all
+ select 11 a, 12 b union all
+ select 11 a, 14 b
+ )
+select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b
!= t1.b)
+) order by 1, 2;
+----
+11 12
+11 13
+
+#LEFTANTI tests
+# returns no rows instead of correct result
+#query III
+#select * from (
+#with
+#t1 as (
+# select 11 a, 12 b, 1 c union all
+# select 11 a, 13 b, 2 c),
+#t2 as (
+# select 11 a, 12 b, 3 c union all
+# select 11 a, 14 b, 4 c
+# )
+#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and
t2.b != t1.b and t1.c > t2.c)
+#) order by 1, 2;
+#----
+#11 12 1
+#11 13 2
+
# Set batch size to 1 for sort merge join to test scenario when data spread
across multiple batches
statement ok
set datafusion.execution.batch_size = 1;
@@ -280,5 +413,23 @@ SELECT * FROM (
) ORDER BY 1, 2;
----
+
+query II
+select * from (
+with
+t1 as (
+ select 11 a, 12 b union all
+ select 11 a, 13 b),
+t2 as (
+ select 11 a, 12 b union all
+ select 11 a, 12 b union all
+ select 11 a, 14 b
+ )
+select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b
!= t1.b)
+) order by 1, 2;
+----
+11 12
+11 13
+
statement ok
set datafusion.optimizer.prefer_hash_join = true;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]