This is an automated email from the ASF dual-hosted git repository.
berkay pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new f514e12ec4 Preserve the order of right table in NestedLoopJoinExec
(#12504)
f514e12ec4 is described below
commit f514e12ec4b73a4e6a417c5756152dd1ddceac10
Author: Alihan Çelikcan <[email protected]>
AuthorDate: Wed Sep 18 16:17:51 2024 +0300
Preserve the order of right table in NestedLoopJoinExec (#12504)
* Maintain right child's order in NestedLoopJoinExec
* Format
* Refactor monotonicity check
* Update sqllogictest according to new behavior
* Check output ordering properties
* Parameterize only batch sizes for left and right tables
* Document maintains_input_order
---
.../physical-plan/src/joins/nested_loop_join.rs | 260 +++++++++++++++++++--
datafusion/sqllogictest/test_files/join.slt | 12 +-
datafusion/sqllogictest/test_files/joins.slt | 2 +-
3 files changed, 254 insertions(+), 20 deletions(-)
diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs
b/datafusion/physical-plan/src/joins/nested_loop_join.rs
index c6f1833c13..b30e5184f0 100644
--- a/datafusion/physical-plan/src/joins/nested_loop_join.rs
+++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs
@@ -221,7 +221,7 @@ impl NestedLoopJoinExec {
right.equivalence_properties().clone(),
&join_type,
schema,
- &[false, false],
+ &Self::maintains_input_order(join_type),
None,
// No on columns in nested loop join
&[],
@@ -238,6 +238,31 @@ impl NestedLoopJoinExec {
PlanProperties::new(eq_properties, output_partitioning, mode)
}
+
+ /// Returns a vector indicating whether the left and right inputs maintain
their order.
+ /// The first element corresponds to the left input, and the second to the
right.
+ ///
+ /// The left (build-side) input's order may change, but the right
(probe-side) input's
+ /// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins.
+ ///
+ /// Maintaining the right input's order helps optimize the nodes down the
pipeline
+ /// (See [`ExecutionPlan::maintains_input_order`]).
+ ///
+ /// This is a separate method because it is also called when computing
properties, before
+ /// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an
argument, as
+ /// opposed to `Self`, for the same reason.
+ fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
+ vec![
+ false,
+ matches!(
+ join_type,
+ JoinType::Inner
+ | JoinType::Right
+ | JoinType::RightAnti
+ | JoinType::RightSemi
+ ),
+ ]
+ }
}
impl DisplayAs for NestedLoopJoinExec {
@@ -278,6 +303,10 @@ impl ExecutionPlan for NestedLoopJoinExec {
]
}
+ fn maintains_input_order(&self) -> Vec<bool> {
+ Self::maintains_input_order(self.join_type)
+ }
+
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.left, &self.right]
}
@@ -430,17 +459,17 @@ struct NestedLoopJoinStream {
}
fn build_join_indices(
- left_row_index: usize,
- right_batch: &RecordBatch,
+ right_row_index: usize,
left_batch: &RecordBatch,
+ right_batch: &RecordBatch,
filter: Option<&JoinFilter>,
) -> Result<(UInt64Array, UInt32Array)> {
- // left indices: [left_index, left_index, ...., left_index]
- // right indices: [0, 1, 2, 3, 4,....,right_row_count]
+ // left indices: [0, 1, 2, 3, 4, ..., left_row_count]
+ // right indices: [right_index, right_index, ..., right_index]
- let right_row_count = right_batch.num_rows();
- let left_indices = UInt64Array::from(vec![left_row_index as u64;
right_row_count]);
- let right_indices = UInt32Array::from_iter_values(0..(right_row_count as
u32));
+ let left_row_count = left_batch.num_rows();
+ let left_indices = UInt64Array::from_iter_values(0..(left_row_count as
u64));
+ let right_indices = UInt32Array::from(vec![right_row_index as u32;
left_row_count]);
// in the nested loop join, the filter can contain non-equal and equal
condition.
if let Some(filter) = filter {
apply_join_filter_to_indices(
@@ -567,9 +596,9 @@ fn join_left_and_right_batch(
schema: &Schema,
visited_left_side: &SharedBitmapBuilder,
) -> Result<RecordBatch> {
- let indices = (0..left_batch.num_rows())
- .map(|left_row_index| {
- build_join_indices(left_row_index, right_batch, left_batch, filter)
+ let indices = (0..right_batch.num_rows())
+ .map(|right_row_index| {
+ build_join_indices(right_row_index, left_batch, right_batch,
filter)
})
.collect::<Result<Vec<(UInt64Array, UInt32Array)>>>()
.map_err(|e| {
@@ -601,7 +630,7 @@ fn join_left_and_right_batch(
right_side,
0..right_batch.num_rows(),
join_type,
- false,
+ true,
);
build_batch_from_indices(
@@ -649,20 +678,59 @@ mod tests {
};
use arrow::datatypes::{DataType, Field};
+ use arrow_array::Int32Array;
+ use arrow_schema::SortOptions;
use datafusion_common::{assert_batches_sorted_eq, assert_contains,
ScalarValue};
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
use datafusion_physical_expr::{Partitioning, PhysicalExpr};
+ use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
+
+ use rstest::rstest;
fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
c: (&str, &Vec<i32>),
+ batch_size: Option<usize>,
+ sorted_column_names: Vec<&str>,
) -> Arc<dyn ExecutionPlan> {
let batch = build_table_i32(a, b, c);
let schema = batch.schema();
- Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
+
+ let batches = if let Some(batch_size) = batch_size {
+ let num_batches = batch.num_rows().div_ceil(batch_size);
+ (0..num_batches)
+ .map(|i| {
+ let start = i * batch_size;
+ let remaining_rows = batch.num_rows() - start;
+ batch.slice(start, batch_size.min(remaining_rows))
+ })
+ .collect::<Vec<_>>()
+ } else {
+ vec![batch]
+ };
+
+ let mut exec =
+ MemoryExec::try_new(&[batches], Arc::clone(&schema),
None).unwrap();
+ if !sorted_column_names.is_empty() {
+ let mut sort_info = Vec::new();
+ for name in sorted_column_names {
+ let index = schema.index_of(name).unwrap();
+ let sort_expr = PhysicalSortExpr {
+ expr: Arc::new(Column::new(name, index)),
+ options: SortOptions {
+ descending: false,
+ nulls_first: false,
+ },
+ };
+ sort_info.push(sort_expr);
+ }
+ exec = exec.with_sort_information(vec![sort_info]);
+ }
+
+ Arc::new(exec)
}
fn build_left_table() -> Arc<dyn ExecutionPlan> {
@@ -670,6 +738,8 @@ mod tests {
("a1", &vec![5, 9, 11]),
("b1", &vec![5, 8, 8]),
("c1", &vec![50, 90, 110]),
+ None,
+ Vec::new(),
)
}
@@ -678,6 +748,8 @@ mod tests {
("a2", &vec![12, 2, 10]),
("b2", &vec![10, 2, 10]),
("c2", &vec![40, 80, 100]),
+ None,
+ Vec::new(),
)
}
@@ -1005,11 +1077,15 @@ mod tests {
("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
+ None,
+ Vec::new(),
);
let right = build_table(
("a2", &vec![10, 11]),
("b2", &vec![12, 13]),
("c2", &vec![14, 15]),
+ None,
+ Vec::new(),
);
let filter = prepare_join_filter();
@@ -1050,6 +1126,164 @@ mod tests {
Ok(())
}
+ fn prepare_mod_join_filter() -> JoinFilter {
+ let column_indices = vec![
+ ColumnIndex {
+ index: 1,
+ side: JoinSide::Left,
+ },
+ ColumnIndex {
+ index: 1,
+ side: JoinSide::Right,
+ },
+ ];
+ let intermediate_schema = Schema::new(vec![
+ Field::new("x", DataType::Int32, true),
+ Field::new("x", DataType::Int32, true),
+ ]);
+
+ // left.b1 % 3
+ let left_mod = Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("x", 0)),
+ Operator::Modulo,
+ Arc::new(Literal::new(ScalarValue::Int32(Some(3)))),
+ )) as Arc<dyn PhysicalExpr>;
+ // left.b1 % 3 != 0
+ let left_filter = Arc::new(BinaryExpr::new(
+ left_mod,
+ Operator::NotEq,
+ Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
+ )) as Arc<dyn PhysicalExpr>;
+
+ // right.b2 % 5
+ let right_mod = Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("x", 1)),
+ Operator::Modulo,
+ Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
+ )) as Arc<dyn PhysicalExpr>;
+ // right.b2 % 5 != 0
+ let right_filter = Arc::new(BinaryExpr::new(
+ right_mod,
+ Operator::NotEq,
+ Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
+ )) as Arc<dyn PhysicalExpr>;
+ // filter = left.b1 % 3 != 0 and right.b2 % 5 != 0
+ let filter_expression =
+ Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
+ as Arc<dyn PhysicalExpr>;
+
+ JoinFilter::new(filter_expression, column_indices, intermediate_schema)
+ }
+
+ fn generate_columns(num_columns: usize, num_rows: usize) -> Vec<Vec<i32>> {
+ let column = (1..=num_rows).map(|x| x as i32).collect();
+ vec![column; num_columns]
+ }
+
+ #[rstest]
+ #[tokio::test]
+ async fn join_maintains_right_order(
+ #[values(
+ JoinType::Inner,
+ JoinType::Right,
+ JoinType::RightAnti,
+ JoinType::RightSemi
+ )]
+ join_type: JoinType,
+ #[values(1, 100, 1000)] left_batch_size: usize,
+ #[values(1, 100, 1000)] right_batch_size: usize,
+ ) -> Result<()> {
+ let left_columns = generate_columns(3, 1000);
+ let left = build_table(
+ ("a1", &left_columns[0]),
+ ("b1", &left_columns[1]),
+ ("c1", &left_columns[2]),
+ Some(left_batch_size),
+ Vec::new(),
+ );
+
+ let right_columns = generate_columns(3, 1000);
+ let right = build_table(
+ ("a2", &right_columns[0]),
+ ("b2", &right_columns[1]),
+ ("c2", &right_columns[2]),
+ Some(right_batch_size),
+ vec!["a2", "b2", "c2"],
+ );
+
+ let filter = prepare_mod_join_filter();
+
+ let nested_loop_join = Arc::new(NestedLoopJoinExec::try_new(
+ left,
+ Arc::clone(&right),
+ Some(filter),
+ &join_type,
+ )?) as Arc<dyn ExecutionPlan>;
+ assert_eq!(nested_loop_join.maintains_input_order(), vec![false,
true]);
+
+ let right_column_indices = match join_type {
+ JoinType::Inner | JoinType::Right => vec![3, 4, 5],
+ JoinType::RightAnti | JoinType::RightSemi => vec![0, 1, 2],
+ _ => unreachable!(),
+ };
+
+ let right_ordering = right.output_ordering().unwrap();
+ let join_ordering = nested_loop_join.output_ordering().unwrap();
+ for (right, join) in right_ordering.iter().zip(join_ordering.iter()) {
+ let right_column =
right.expr.as_any().downcast_ref::<Column>().unwrap();
+ let join_column =
join.expr.as_any().downcast_ref::<Column>().unwrap();
+ assert_eq!(join_column.name(), join_column.name());
+ assert_eq!(
+ right_column_indices[right_column.index()],
+ join_column.index()
+ );
+ assert_eq!(right.options, join.options);
+ }
+
+ let batches = nested_loop_join
+ .execute(0, Arc::new(TaskContext::default()))?
+ .try_collect::<Vec<_>>()
+ .await?;
+
+ // Make sure that the order of the right side is maintained
+ let mut prev_values = [i32::MIN, i32::MIN, i32::MIN];
+
+ for (batch_index, batch) in batches.iter().enumerate() {
+ let columns: Vec<_> = right_column_indices
+ .iter()
+ .map(|&i| {
+ batch
+ .column(i)
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .unwrap()
+ })
+ .collect();
+
+ for row in 0..batch.num_rows() {
+ let current_values = [
+ columns[0].value(row),
+ columns[1].value(row),
+ columns[2].value(row),
+ ];
+ assert!(
+ current_values
+ .into_iter()
+ .zip(prev_values)
+ .all(|(current, prev)| current >= prev),
+ "batch_index: {} row: {} current: {:?}, prev: {:?}",
+ batch_index,
+ row,
+ current_values,
+ prev_values
+ );
+ prev_values = current_values;
+ }
+ }
+
+ Ok(())
+ }
+
/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
diff --git a/datafusion/sqllogictest/test_files/join.slt
b/datafusion/sqllogictest/test_files/join.slt
index 3e7a08981e..2f505c9fc7 100644
--- a/datafusion/sqllogictest/test_files/join.slt
+++ b/datafusion/sqllogictest/test_files/join.slt
@@ -838,10 +838,10 @@ LEFT JOIN department AS d
ON (e.name = 'Alice' OR e.name = 'Bob');
----
1 Alice HR
-2 Bob HR
1 Alice Engineering
-2 Bob Engineering
1 Alice Sales
+2 Bob HR
+2 Bob Engineering
2 Bob Sales
3 Carol NULL
@@ -853,10 +853,10 @@ RIGHT JOIN employees AS e
ON (e.name = 'Alice' OR e.name = 'Bob');
----
1 Alice HR
-2 Bob HR
1 Alice Engineering
-2 Bob Engineering
1 Alice Sales
+2 Bob HR
+2 Bob Engineering
2 Bob Sales
3 Carol NULL
@@ -868,10 +868,10 @@ FULL JOIN employees AS e
ON (e.name = 'Alice' OR e.name = 'Bob');
----
1 Alice HR
-2 Bob HR
1 Alice Engineering
-2 Bob Engineering
1 Alice Sales
+2 Bob HR
+2 Bob Engineering
2 Bob Sales
3 Carol NULL
diff --git a/datafusion/sqllogictest/test_files/joins.slt
b/datafusion/sqllogictest/test_files/joins.slt
index 7d0262952b..679c2eee10 100644
--- a/datafusion/sqllogictest/test_files/joins.slt
+++ b/datafusion/sqllogictest/test_files/joins.slt
@@ -2136,10 +2136,10 @@ FROM (select t1_id from join_t1 where join_t1.t1_id >
22) as join_t1
RIGHT JOIN (select t2_id from join_t2 where join_t2.t2_id > 11) as join_t2
ON join_t1.t1_id < join_t2.t2_id
----
+NULL 22
33 44
33 55
44 55
-NULL 22
#####
# Configuration teardown
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]