This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 97ea05c0f6 Extending join fuzz tests to support join filtering (#10728)
97ea05c0f6 is described below

commit 97ea05c0f60aa11a270420968d3fefc859c0d346
Author: Edmondo Porcu <[email protected]>
AuthorDate: Tue Jun 11 22:19:55 2024 -0400

    Extending join fuzz tests to support join filtering (#10728)
    
    * Extending join fuzz tests to support join filtering
    
    
    ---------
    
    Co-authored-by: Oleks V <[email protected]>
---
 datafusion/core/tests/fuzz_cases/join_fuzz.rs | 407 +++++++++++++++++++-------
 1 file changed, 296 insertions(+), 111 deletions(-)

diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
index 824f1eec4a..8c2e24de56 100644
--- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
@@ -22,6 +22,11 @@ use arrow::compute::SortOptions;
 use arrow::record_batch::RecordBatch;
 use arrow::util::pretty::pretty_format_batches;
 use arrow_schema::Schema;
+
+use datafusion_common::ScalarValue;
+use datafusion_physical_expr::expressions::Literal;
+use datafusion_physical_expr::PhysicalExprRef;
+
 use rand::Rng;
 
 use datafusion::common::JoinSide;
@@ -40,92 +45,207 @@ use test_utils::stagger_batch_with_seed;
 
 #[tokio::test]
 async fn test_inner_join_1k() {
-    run_join_test(
+    JoinFuzzTestCase::new(
+        make_staggered_batches(1000),
+        make_staggered_batches(1000),
+        JoinType::Inner,
+        None,
+    )
+    .run_test()
+    .await
+}
+
+fn less_than_10_join_filter(schema1: Arc<Schema>, _schema2: Arc<Schema>) -> 
JoinFilter {
+    let less_than_100 = Arc::new(BinaryExpr::new(
+        Arc::new(Column::new("a", 0)),
+        Operator::Lt,
+        Arc::new(Literal::new(ScalarValue::from(100))),
+    )) as _;
+    let column_indices = vec![ColumnIndex {
+        index: 0,
+        side: JoinSide::Left,
+    }];
+    let intermediate_schema =
+        Schema::new(vec![schema1.field_with_name("a").unwrap().to_owned()]);
+
+    JoinFilter::new(less_than_100, column_indices, intermediate_schema)
+}
+
+#[tokio::test]
+async fn test_inner_join_1k_filtered() {
+    JoinFuzzTestCase::new(
+        make_staggered_batches(1000),
+        make_staggered_batches(1000),
+        JoinType::Inner,
+        Some(Box::new(less_than_10_join_filter)),
+    )
+    .run_test()
+    .await
+}
+
+#[tokio::test]
+async fn test_inner_join_1k_smjoin() {
+    JoinFuzzTestCase::new(
         make_staggered_batches(1000),
         make_staggered_batches(1000),
         JoinType::Inner,
+        None,
     )
+    .run_test()
     .await
 }
 
 #[tokio::test]
 async fn test_left_join_1k() {
-    run_join_test(
+    JoinFuzzTestCase::new(
+        make_staggered_batches(1000),
+        make_staggered_batches(1000),
+        JoinType::Left,
+        None,
+    )
+    .run_test()
+    .await
+}
+
+#[tokio::test]
+async fn test_left_join_1k_filtered() {
+    JoinFuzzTestCase::new(
         make_staggered_batches(1000),
         make_staggered_batches(1000),
         JoinType::Left,
+        Some(Box::new(less_than_10_join_filter)),
     )
+    .run_test()
     .await
 }
 
 #[tokio::test]
 async fn test_right_join_1k() {
-    run_join_test(
+    JoinFuzzTestCase::new(
+        make_staggered_batches(1000),
+        make_staggered_batches(1000),
+        JoinType::Right,
+        None,
+    )
+    .run_test()
+    .await
+}
+// Add support for Right filtered joins
+#[ignore]
+#[tokio::test]
+async fn test_right_join_1k_filtered() {
+    JoinFuzzTestCase::new(
         make_staggered_batches(1000),
         make_staggered_batches(1000),
         JoinType::Right,
+        Some(Box::new(less_than_10_join_filter)),
     )
+    .run_test()
     .await
 }
 
 #[tokio::test]
 async fn test_full_join_1k() {
-    run_join_test(
+    JoinFuzzTestCase::new(
         make_staggered_batches(1000),
         make_staggered_batches(1000),
         JoinType::Full,
+        None,
     )
+    .run_test()
+    .await
+}
+
+#[tokio::test]
+async fn test_full_join_1k_filtered() {
+    JoinFuzzTestCase::new(
+        make_staggered_batches(1000),
+        make_staggered_batches(1000),
+        JoinType::Full,
+        Some(Box::new(less_than_10_join_filter)),
+    )
+    .run_test()
     .await
 }
 
 #[tokio::test]
 async fn test_semi_join_1k() {
-    run_join_test(
+    JoinFuzzTestCase::new(
+        make_staggered_batches(1000),
+        make_staggered_batches(1000),
+        JoinType::LeftSemi,
+        None,
+    )
+    .run_test()
+    .await
+}
+
+#[tokio::test]
+async fn test_semi_join_1k_filtered() {
+    JoinFuzzTestCase::new(
         make_staggered_batches(1000),
         make_staggered_batches(1000),
         JoinType::LeftSemi,
+        Some(Box::new(less_than_10_join_filter)),
     )
+    .run_test()
     .await
 }
 
 #[tokio::test]
 async fn test_anti_join_1k() {
-    run_join_test(
+    JoinFuzzTestCase::new(
+        make_staggered_batches(1000),
+        make_staggered_batches(1000),
+        JoinType::LeftAnti,
+        None,
+    )
+    .run_test()
+    .await
+}
+
+// Test failed for now. https://github.com/apache/datafusion/issues/10872
+#[ignore]
+#[tokio::test]
+async fn test_anti_join_1k_filtered() {
+    JoinFuzzTestCase::new(
         make_staggered_batches(1000),
         make_staggered_batches(1000),
         JoinType::LeftAnti,
+        Some(Box::new(less_than_10_join_filter)),
     )
+    .run_test()
     .await
 }
 
-/// Perform sort-merge join and hash join on same input
-/// and verify two outputs are equal
-async fn run_join_test(
+type JoinFilterBuilder = Box<dyn Fn(Arc<Schema>, Arc<Schema>) -> JoinFilter>;
+
+struct JoinFuzzTestCase {
+    batch_sizes: &'static [usize],
     input1: Vec<RecordBatch>,
     input2: Vec<RecordBatch>,
     join_type: JoinType,
-) {
-    let batch_sizes = [1, 2, 7, 49, 50, 51, 100];
-    for batch_size in batch_sizes {
-        let session_config = SessionConfig::new().with_batch_size(batch_size);
-        let ctx = SessionContext::new_with_config(session_config);
-        let task_ctx = ctx.task_ctx();
-
-        let schema1 = input1[0].schema();
-        let schema2 = input2[0].schema();
-        let on_columns = vec![
-            (
-                Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _,
-                Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _,
-            ),
-            (
-                Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _,
-                Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _,
-            ),
-        ];
+    join_filter_builder: Option<JoinFilterBuilder>,
+}
 
-        // Nested loop join uses filter for joining records
-        let column_indices = vec![
+impl JoinFuzzTestCase {
+    fn new(
+        input1: Vec<RecordBatch>,
+        input2: Vec<RecordBatch>,
+        join_type: JoinType,
+        join_filter_builder: Option<JoinFilterBuilder>,
+    ) -> Self {
+        Self {
+            batch_sizes: &[1, 2, 7, 49, 50, 51, 100],
+            input1,
+            input2,
+            join_type,
+            join_filter_builder,
+        }
+    }
+
+    fn column_indices(&self) -> Vec<ColumnIndex> {
+        vec![
             ColumnIndex {
                 index: 0,
                 side: JoinSide::Left,
@@ -142,120 +262,185 @@ async fn run_join_test(
                 index: 1,
                 side: JoinSide::Right,
             },
-        ];
-        let intermediate_schema = Schema::new(vec![
-            schema1.field_with_name("a").unwrap().to_owned(),
-            schema1.field_with_name("b").unwrap().to_owned(),
-            schema2.field_with_name("a").unwrap().to_owned(),
-            schema2.field_with_name("b").unwrap().to_owned(),
-        ]);
+        ]
+    }
 
-        let equal_a = Arc::new(BinaryExpr::new(
-            Arc::new(Column::new("a", 0)),
-            Operator::Eq,
-            Arc::new(Column::new("a", 2)),
-        )) as _;
-        let equal_b = Arc::new(BinaryExpr::new(
-            Arc::new(Column::new("b", 1)),
-            Operator::Eq,
-            Arc::new(Column::new("b", 3)),
-        )) as _;
-        let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, 
equal_b)) as _;
+    fn on_columns(&self) -> Vec<(PhysicalExprRef, PhysicalExprRef)> {
+        let schema1 = self.input1[0].schema();
+        let schema2 = self.input2[0].schema();
+        vec![
+            (
+                Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _,
+                Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _,
+            ),
+            (
+                Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _,
+                Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _,
+            ),
+        ]
+    }
 
-        let on_filter = JoinFilter::new(expression, column_indices, 
intermediate_schema);
+    fn intermediate_schema(&self) -> Schema {
+        let schema1 = self.input1[0].schema();
+        let schema2 = self.input2[0].schema();
+        Schema::new(vec![
+            schema1
+                .field_with_name("a")
+                .unwrap()
+                .to_owned()
+                .with_nullable(true),
+            schema1
+                .field_with_name("b")
+                .unwrap()
+                .to_owned()
+                .with_nullable(true),
+            schema2.field_with_name("a").unwrap().to_owned(),
+            schema2.field_with_name("b").unwrap().to_owned(),
+        ])
+    }
 
-        // sort-merge join
+    fn left_right(&self) -> (Arc<MemoryExec>, Arc<MemoryExec>) {
+        let schema1 = self.input1[0].schema();
+        let schema2 = self.input2[0].schema();
         let left = Arc::new(
-            MemoryExec::try_new(&[input1.clone()], schema1.clone(), 
None).unwrap(),
+            MemoryExec::try_new(&[self.input1.clone()], schema1.clone(), 
None).unwrap(),
         );
         let right = Arc::new(
-            MemoryExec::try_new(&[input2.clone()], schema2.clone(), 
None).unwrap(),
+            MemoryExec::try_new(&[self.input2.clone()], schema2.clone(), 
None).unwrap(),
         );
-        let smj = Arc::new(
+        (left, right)
+    }
+
+    fn join_filter(&self) -> Option<JoinFilter> {
+        let schema1 = self.input1[0].schema();
+        let schema2 = self.input2[0].schema();
+        self.join_filter_builder
+            .as_ref()
+            .map(|builder| builder(schema1, schema2))
+    }
+
+    fn sort_merge_join(&self) -> Arc<SortMergeJoinExec> {
+        let (left, right) = self.left_right();
+        Arc::new(
             SortMergeJoinExec::try_new(
                 left,
                 right,
-                on_columns.clone(),
-                None,
-                join_type,
+                self.on_columns().clone(),
+                self.join_filter(),
+                self.join_type,
                 vec![SortOptions::default(), SortOptions::default()],
                 false,
             )
             .unwrap(),
-        );
-        let smj_collected = collect(smj, task_ctx.clone()).await.unwrap();
+        )
+    }
 
-        // hash join
-        let left = Arc::new(
-            MemoryExec::try_new(&[input1.clone()], schema1.clone(), 
None).unwrap(),
-        );
-        let right = Arc::new(
-            MemoryExec::try_new(&[input2.clone()], schema2.clone(), 
None).unwrap(),
-        );
-        let hj = Arc::new(
+    fn hash_join(&self) -> Arc<HashJoinExec> {
+        let (left, right) = self.left_right();
+        Arc::new(
             HashJoinExec::try_new(
                 left,
                 right,
-                on_columns.clone(),
-                None,
-                &join_type,
+                self.on_columns().clone(),
+                self.join_filter(),
+                &self.join_type,
                 None,
                 PartitionMode::Partitioned,
                 false,
             )
             .unwrap(),
-        );
-        let hj_collected = collect(hj, task_ctx.clone()).await.unwrap();
+        )
+    }
 
-        // nested loop join
-        let left = Arc::new(
-            MemoryExec::try_new(&[input1.clone()], schema1.clone(), 
None).unwrap(),
-        );
-        let right = Arc::new(
-            MemoryExec::try_new(&[input2.clone()], schema2.clone(), 
None).unwrap(),
-        );
-        let nlj = Arc::new(
-            NestedLoopJoinExec::try_new(left, right, Some(on_filter), 
&join_type)
+    fn nested_loop_join(&self) -> Arc<NestedLoopJoinExec> {
+        let (left, right) = self.left_right();
+        // Nested loop join uses filter for joining records
+        let column_indices = self.column_indices();
+        let intermediate_schema = self.intermediate_schema();
+
+        let equal_a = Arc::new(BinaryExpr::new(
+            Arc::new(Column::new("a", 0)),
+            Operator::Eq,
+            Arc::new(Column::new("a", 2)),
+        )) as _;
+        let equal_b = Arc::new(BinaryExpr::new(
+            Arc::new(Column::new("b", 1)),
+            Operator::Eq,
+            Arc::new(Column::new("b", 3)),
+        )) as _;
+        let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, 
equal_b)) as _;
+
+        let on_filter = JoinFilter::new(expression, column_indices, 
intermediate_schema);
+
+        Arc::new(
+            NestedLoopJoinExec::try_new(left, right, Some(on_filter), 
&self.join_type)
                 .unwrap(),
-        );
-        let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();
+        )
+    }
 
-        // compare
-        let smj_formatted = 
pretty_format_batches(&smj_collected).unwrap().to_string();
-        let hj_formatted = 
pretty_format_batches(&hj_collected).unwrap().to_string();
-        let nlj_formatted = 
pretty_format_batches(&nlj_collected).unwrap().to_string();
+    /// Perform sort-merge join and hash join on same input
+    /// and verify two outputs are equal
+    async fn run_test(&self) {
+        for batch_size in self.batch_sizes {
+            let session_config = 
SessionConfig::new().with_batch_size(*batch_size);
+            let ctx = SessionContext::new_with_config(session_config);
+            let task_ctx = ctx.task_ctx();
+            let smj = self.sort_merge_join();
+            let smj_collected = collect(smj, task_ctx.clone()).await.unwrap();
 
-        let mut smj_formatted_sorted: Vec<&str> = 
smj_formatted.trim().lines().collect();
-        smj_formatted_sorted.sort_unstable();
+            let hj = self.hash_join();
+            let hj_collected = collect(hj, task_ctx.clone()).await.unwrap();
 
-        let mut hj_formatted_sorted: Vec<&str> = 
hj_formatted.trim().lines().collect();
-        hj_formatted_sorted.sort_unstable();
+            let nlj = self.nested_loop_join();
+            let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();
 
-        let mut nlj_formatted_sorted: Vec<&str> = 
nlj_formatted.trim().lines().collect();
-        nlj_formatted_sorted.sort_unstable();
+            // compare
+            let smj_formatted =
+                pretty_format_batches(&smj_collected).unwrap().to_string();
+            let hj_formatted = 
pretty_format_batches(&hj_collected).unwrap().to_string();
+            let nlj_formatted =
+                pretty_format_batches(&nlj_collected).unwrap().to_string();
 
-        for (i, (smj_line, hj_line)) in smj_formatted_sorted
-            .iter()
-            .zip(&hj_formatted_sorted)
-            .enumerate()
-        {
-            assert_eq!(
-                (i, smj_line),
-                (i, hj_line),
-                "SortMergeJoinExec and HashJoinExec produced different results"
-            );
-        }
+            let mut smj_formatted_sorted: Vec<&str> =
+                smj_formatted.trim().lines().collect();
+            smj_formatted_sorted.sort_unstable();
+
+            let mut hj_formatted_sorted: Vec<&str> =
+                hj_formatted.trim().lines().collect();
+            hj_formatted_sorted.sort_unstable();
+
+            let mut nlj_formatted_sorted: Vec<&str> =
+                nlj_formatted.trim().lines().collect();
+            nlj_formatted_sorted.sort_unstable();
 
-        for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
-            .iter()
-            .zip(&hj_formatted_sorted)
-            .enumerate()
-        {
             assert_eq!(
-                (i, nlj_line),
-                (i, hj_line),
-                "NestedLoopJoinExec and HashJoinExec produced different 
results"
+                smj_formatted_sorted.len(),
+                hj_formatted_sorted.len(),
+                "SortMergeJoinExec and HashJoinExec produced different row 
counts"
             );
+            for (i, (smj_line, hj_line)) in smj_formatted_sorted
+                .iter()
+                .zip(&hj_formatted_sorted)
+                .enumerate()
+            {
+                assert_eq!(
+                    (i, smj_line),
+                    (i, hj_line),
+                    "SortMergeJoinExec and HashJoinExec produced different 
results"
+                );
+            }
+
+            for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
+                .iter()
+                .zip(&hj_formatted_sorted)
+                .enumerate()
+            {
+                assert_eq!(
+                    (i, nlj_line),
+                    (i, hj_line),
+                    "NestedLoopJoinExec and HashJoinExec produced different 
results"
+                );
+            }
         }
     }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to