This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 21fe0b7762 Implement semi/anti join output statistics estimation 
(#9800)
21fe0b7762 is described below

commit 21fe0b7762d088731689750e2cef1762d4f9db5e
Author: Eduard Karacharov <[email protected]>
AuthorDate: Sat Mar 30 14:44:35 2024 +0200

    Implement semi/anti join output statistics estimation (#9800)
    
    * semi/anti join output statistics
    
    * fix antijoin cardinality estimation
---
 datafusion/physical-plan/src/joins/utils.rs | 373 ++++++++++++++++++++++++----
 1 file changed, 323 insertions(+), 50 deletions(-)

diff --git a/datafusion/physical-plan/src/joins/utils.rs 
b/datafusion/physical-plan/src/joins/utils.rs
index 1cb2b100e2..a3d20b97d1 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -825,27 +825,27 @@ fn estimate_join_cardinality(
     right_stats: Statistics,
     on: &JoinOn,
 ) -> Option<PartialJoinStatistics> {
+    let (left_col_stats, right_col_stats) = on
+        .iter()
+        .map(|(left, right)| {
+            match (
+                left.as_any().downcast_ref::<Column>(),
+                right.as_any().downcast_ref::<Column>(),
+            ) {
+                (Some(left), Some(right)) => (
+                    left_stats.column_statistics[left.index()].clone(),
+                    right_stats.column_statistics[right.index()].clone(),
+                ),
+                _ => (
+                    ColumnStatistics::new_unknown(),
+                    ColumnStatistics::new_unknown(),
+                ),
+            }
+        })
+        .unzip::<_, _, Vec<_>, Vec<_>>();
+
     match join_type {
         JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => 
{
-            let (left_col_stats, right_col_stats) = on
-                .iter()
-                .map(|(left, right)| {
-                    match (
-                        left.as_any().downcast_ref::<Column>(),
-                        right.as_any().downcast_ref::<Column>(),
-                    ) {
-                        (Some(left), Some(right)) => (
-                            left_stats.column_statistics[left.index()].clone(),
-                            
right_stats.column_statistics[right.index()].clone(),
-                        ),
-                        _ => (
-                            ColumnStatistics::new_unknown(),
-                            ColumnStatistics::new_unknown(),
-                        ),
-                    }
-                })
-                .unzip::<_, _, Vec<_>, Vec<_>>();
-
             let ij_cardinality = estimate_inner_join_cardinality(
                 Statistics {
                     num_rows: left_stats.num_rows.clone(),
@@ -888,10 +888,38 @@ fn estimate_join_cardinality(
             })
         }
 
-        JoinType::LeftSemi
-        | JoinType::RightSemi
-        | JoinType::LeftAnti
-        | JoinType::RightAnti => None,
+        // For SemiJoins estimation result is either zero, in cases when inputs
+        // are non-overlapping according to statistics, or equal to number of 
rows
+        // for outer input
+        JoinType::LeftSemi | JoinType::RightSemi => {
+            let (outer_stats, inner_stats) = match join_type {
+                JoinType::LeftSemi => (left_stats, right_stats),
+                _ => (right_stats, left_stats),
+            };
+            let cardinality = match estimate_disjoint_inputs(&outer_stats, 
&inner_stats) {
+                Some(estimation) => *estimation.get_value()?,
+                None => *outer_stats.num_rows.get_value()?,
+            };
+
+            Some(PartialJoinStatistics {
+                num_rows: cardinality,
+                column_statistics: outer_stats.column_statistics,
+            })
+        }
+
+        // For AntiJoins estimation always equals to outer statistics, as
+        // non-overlapping inputs won't affect estimation
+        JoinType::LeftAnti | JoinType::RightAnti => {
+            let outer_stats = match join_type {
+                JoinType::LeftAnti => left_stats,
+                _ => right_stats,
+            };
+
+            Some(PartialJoinStatistics {
+                num_rows: *outer_stats.num_rows.get_value()?,
+                column_statistics: outer_stats.column_statistics,
+            })
+        }
     }
 }
 
@@ -903,6 +931,11 @@ fn estimate_inner_join_cardinality(
     left_stats: Statistics,
     right_stats: Statistics,
 ) -> Option<Precision<usize>> {
+    // Immediatedly return if inputs considered as non-overlapping
+    if let Some(estimation) = estimate_disjoint_inputs(&left_stats, 
&right_stats) {
+        return Some(estimation);
+    };
+
     // The algorithm here is partly based on the non-histogram selectivity 
estimation
     // from Spark's Catalyst optimizer.
     let mut join_selectivity = Precision::Absent;
@@ -911,30 +944,13 @@ fn estimate_inner_join_cardinality(
         .iter()
         .zip(right_stats.column_statistics.iter())
     {
-        // If there is no overlap in any of the join columns, this means the 
join
-        // itself is disjoint and the cardinality is 0. Though we can only 
assume
-        // this when the statistics are exact (since it is a very strong 
assumption).
-        if left_stat.min_value.get_value()? > 
right_stat.max_value.get_value()? {
-            return Some(
-                if left_stat.min_value.is_exact().unwrap_or(false)
-                    && right_stat.max_value.is_exact().unwrap_or(false)
-                {
-                    Precision::Exact(0)
-                } else {
-                    Precision::Inexact(0)
-                },
-            );
-        }
-        if left_stat.max_value.get_value()? < 
right_stat.min_value.get_value()? {
-            return Some(
-                if left_stat.max_value.is_exact().unwrap_or(false)
-                    && right_stat.min_value.is_exact().unwrap_or(false)
-                {
-                    Precision::Exact(0)
-                } else {
-                    Precision::Inexact(0)
-                },
-            );
+        // Break if any of statistics bounds are undefined
+        if left_stat.min_value.get_value().is_none()
+            || left_stat.max_value.get_value().is_none()
+            || right_stat.min_value.get_value().is_none()
+            || right_stat.max_value.get_value().is_none()
+        {
+            return None;
         }
 
         let left_max_distinct = max_distinct_count(&left_stats.num_rows, 
left_stat);
@@ -968,6 +984,58 @@ fn estimate_inner_join_cardinality(
     }
 }
 
+/// Estimates if inputs are non-overlapping, using input statistics.
+/// If inputs are disjoint, returns zero estimation, otherwise returns None
+fn estimate_disjoint_inputs(
+    left_stats: &Statistics,
+    right_stats: &Statistics,
+) -> Option<Precision<usize>> {
+    for (left_stat, right_stat) in left_stats
+        .column_statistics
+        .iter()
+        .zip(right_stats.column_statistics.iter())
+    {
+        // If there is no overlap in any of the join columns, this means the 
join
+        // itself is disjoint and the cardinality is 0. Though we can only 
assume
+        // this when the statistics are exact (since it is a very strong 
assumption).
+        let left_min_val = left_stat.min_value.get_value();
+        let right_max_val = right_stat.max_value.get_value();
+        if left_min_val.is_some()
+            && right_max_val.is_some()
+            && left_min_val > right_max_val
+        {
+            return Some(
+                if left_stat.min_value.is_exact().unwrap_or(false)
+                    && right_stat.max_value.is_exact().unwrap_or(false)
+                {
+                    Precision::Exact(0)
+                } else {
+                    Precision::Inexact(0)
+                },
+            );
+        }
+
+        let left_max_val = left_stat.max_value.get_value();
+        let right_min_val = right_stat.min_value.get_value();
+        if left_max_val.is_some()
+            && right_min_val.is_some()
+            && left_max_val < right_min_val
+        {
+            return Some(
+                if left_stat.max_value.is_exact().unwrap_or(false)
+                    && right_stat.min_value.is_exact().unwrap_or(false)
+                {
+                    Precision::Exact(0)
+                } else {
+                    Precision::Inexact(0)
+                },
+            );
+        }
+    }
+
+    None
+}
+
 /// Estimate the number of maximum distinct values that can be present in the
 /// given column from its statistics. If distinct_count is available, uses it
 /// directly. Otherwise, if the column is numeric and has min/max values, it
@@ -1716,9 +1784,11 @@ mod tests {
     #[test]
     fn test_inner_join_cardinality_single_column() -> Result<()> {
         let cases: Vec<(PartialStats, PartialStats, Option<Precision<usize>>)> 
= vec![
-            // 
-----------------------------------------------------------------------------
-            // | left(rows, min, max, distinct), right(rows, min, max, 
distinct), expected |
-            // 
-----------------------------------------------------------------------------
+            // ------------------------------------------------
+            // | left(rows, min, max, distinct, null_count),  |
+            // | right(rows, min, max, distinct, null_count), |
+            // | expected,                                    |
+            // ------------------------------------------------
 
             // Cardinality computation
             // =======================
@@ -1824,6 +1894,11 @@ mod tests {
                 None,
             ),
             // Non overlapping min/max (when exact=False).
+            (
+                (10, Absent, Inexact(4), Absent, Absent),
+                (10, Inexact(5), Absent, Absent, Absent),
+                Some(Inexact(0)),
+            ),
             (
                 (10, Inexact(0), Inexact(10), Absent, Absent),
                 (10, Inexact(11), Inexact(20), Absent, Absent),
@@ -2106,6 +2181,204 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn test_anti_semi_join_cardinality() -> Result<()> {
+        let cases: Vec<(JoinType, PartialStats, PartialStats, Option<usize>)> 
= vec![
+            // ------------------------------------------------
+            // | join_type ,                                   |
+            // | left(rows, min, max, distinct, null_count), |
+            // | right(rows, min, max, distinct, null_count), |
+            // | expected,                                    |
+            // ------------------------------------------------
+
+            // Cardinality computation
+            // =======================
+            (
+                JoinType::LeftSemi,
+                (50, Inexact(10), Inexact(20), Absent, Absent),
+                (10, Inexact(15), Inexact(25), Absent, Absent),
+                Some(50),
+            ),
+            (
+                JoinType::RightSemi,
+                (50, Inexact(10), Inexact(20), Absent, Absent),
+                (10, Inexact(15), Inexact(25), Absent, Absent),
+                Some(10),
+            ),
+            (
+                JoinType::LeftSemi,
+                (10, Absent, Absent, Absent, Absent),
+                (50, Absent, Absent, Absent, Absent),
+                Some(10),
+            ),
+            (
+                JoinType::LeftSemi,
+                (50, Inexact(10), Inexact(20), Absent, Absent),
+                (10, Inexact(30), Inexact(40), Absent, Absent),
+                Some(0),
+            ),
+            (
+                JoinType::LeftSemi,
+                (50, Inexact(10), Absent, Absent, Absent),
+                (10, Absent, Inexact(5), Absent, Absent),
+                Some(0),
+            ),
+            (
+                JoinType::LeftSemi,
+                (50, Absent, Inexact(20), Absent, Absent),
+                (10, Inexact(30), Absent, Absent, Absent),
+                Some(0),
+            ),
+            (
+                JoinType::LeftAnti,
+                (50, Inexact(10), Inexact(20), Absent, Absent),
+                (10, Inexact(15), Inexact(25), Absent, Absent),
+                Some(50),
+            ),
+            (
+                JoinType::RightAnti,
+                (50, Inexact(10), Inexact(20), Absent, Absent),
+                (10, Inexact(15), Inexact(25), Absent, Absent),
+                Some(10),
+            ),
+            (
+                JoinType::LeftAnti,
+                (10, Absent, Absent, Absent, Absent),
+                (50, Absent, Absent, Absent, Absent),
+                Some(10),
+            ),
+            (
+                JoinType::LeftAnti,
+                (50, Inexact(10), Inexact(20), Absent, Absent),
+                (10, Inexact(30), Inexact(40), Absent, Absent),
+                Some(50),
+            ),
+            (
+                JoinType::LeftAnti,
+                (50, Inexact(10), Absent, Absent, Absent),
+                (10, Absent, Inexact(5), Absent, Absent),
+                Some(50),
+            ),
+            (
+                JoinType::LeftAnti,
+                (50, Absent, Inexact(20), Absent, Absent),
+                (10, Inexact(30), Absent, Absent, Absent),
+                Some(50),
+            ),
+        ];
+
+        let join_on = vec![(
+            Arc::new(Column::new("l_col", 0)) as _,
+            Arc::new(Column::new("r_col", 0)) as _,
+        )];
+
+        for (join_type, outer_info, inner_info, expected) in cases {
+            let outer_num_rows = outer_info.0;
+            let outer_col_stats = vec![create_column_stats(
+                outer_info.1,
+                outer_info.2,
+                outer_info.3,
+                outer_info.4,
+            )];
+
+            let inner_num_rows = inner_info.0;
+            let inner_col_stats = vec![create_column_stats(
+                inner_info.1,
+                inner_info.2,
+                inner_info.3,
+                inner_info.4,
+            )];
+
+            let output_cardinality = estimate_join_cardinality(
+                &join_type,
+                Statistics {
+                    num_rows: Inexact(outer_num_rows),
+                    total_byte_size: Absent,
+                    column_statistics: outer_col_stats,
+                },
+                Statistics {
+                    num_rows: Inexact(inner_num_rows),
+                    total_byte_size: Absent,
+                    column_statistics: inner_col_stats,
+                },
+                &join_on,
+            )
+            .map(|cardinality| cardinality.num_rows);
+
+            assert_eq!(
+                output_cardinality, expected,
+                "failure for join_type: {}",
+                join_type
+            );
+        }
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_semi_join_cardinality_absent_rows() -> Result<()> {
+        let dummy_column_stats =
+            vec![create_column_stats(Absent, Absent, Absent, Absent)];
+        let join_on = vec![(
+            Arc::new(Column::new("l_col", 0)) as _,
+            Arc::new(Column::new("r_col", 0)) as _,
+        )];
+
+        let absent_outer_estimation = estimate_join_cardinality(
+            &JoinType::LeftSemi,
+            Statistics {
+                num_rows: Absent,
+                total_byte_size: Absent,
+                column_statistics: dummy_column_stats.clone(),
+            },
+            Statistics {
+                num_rows: Exact(10),
+                total_byte_size: Absent,
+                column_statistics: dummy_column_stats.clone(),
+            },
+            &join_on,
+        );
+        assert!(
+            absent_outer_estimation.is_none(),
+            "Expected \"None\" esimated SemiJoin cardinality for absent outer 
num_rows"
+        );
+
+        let absent_inner_estimation = estimate_join_cardinality(
+            &JoinType::LeftSemi,
+            Statistics {
+                num_rows: Inexact(500),
+                total_byte_size: Absent,
+                column_statistics: dummy_column_stats.clone(),
+            },
+            Statistics {
+                num_rows: Absent,
+                total_byte_size: Absent,
+                column_statistics: dummy_column_stats.clone(),
+            },
+            &join_on,
+        ).expect("Expected non-empty PartialJoinStatistics for SemiJoin with 
absent inner num_rows");
+
+        assert_eq!(absent_inner_estimation.num_rows, 500, "Expected 
outer.num_rows esimated SemiJoin cardinality for absent inner num_rows");
+
+        let absent_inner_estimation = estimate_join_cardinality(
+            &JoinType::LeftSemi,
+            Statistics {
+                num_rows: Absent,
+                total_byte_size: Absent,
+                column_statistics: dummy_column_stats.clone(),
+            },
+            Statistics {
+                num_rows: Absent,
+                total_byte_size: Absent,
+                column_statistics: dummy_column_stats.clone(),
+            },
+            &join_on,
+        );
+        assert!(absent_inner_estimation.is_none(), "Expected \"None\" esimated 
SemiJoin cardinality for absent outer and inner num_rows");
+
+        Ok(())
+    }
+
     #[test]
     fn test_calculate_join_output_ordering() -> Result<()> {
         let options = SortOptions::default();

Reply via email to