This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ce3a0bb063 feat: implement substrait join filter support (#6868)
ce3a0bb063 is described below

commit ce3a0bb0630eab47a3b66448d4fa6f1fa8314b8d
Author: Nuttiiya Seekhao <[email protected]>
AuthorDate: Sun Jul 9 10:00:05 2023 -0400

    feat: implement substrait join filter support (#6868)
    
    * Added join filter support
    
    * clippy fix
    
    * Added Arc import
    
    * cargo fmt
---
 datafusion/substrait/src/logical_plan/consumer.rs  | 117 ++++++++++++---------
 datafusion/substrait/src/logical_plan/producer.rs  |  48 +++++----
 .../tests/cases/roundtrip_logical_plan.rs          |  28 ++++-
 .../substrait/tests/roundtrip_logical_plan.rs      |  28 ++++-
 4 files changed, 146 insertions(+), 75 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index ffa935be1f..dc06b64a9e 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -22,8 +22,8 @@ use datafusion::logical_expr::{
     aggregate_function, window_function::find_df_window_func, BinaryExpr,
     BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
 };
-use datafusion::logical_expr::{build_join_schema, Extension, 
LogicalPlanBuilder};
 use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits};
+use datafusion::logical_expr::{Extension, LogicalPlanBuilder};
 use datafusion::prelude::JoinType;
 use datafusion::sql::TableReference;
 use datafusion::{
@@ -344,63 +344,76 @@ pub async fn from_substrait_rel(
             let join_type = from_substrait_jointype(join.r#type)?;
             // The join condition expression needs full input schema and not 
the output schema from join since we lose columns from
             // certain join types such as semi and anti joins
-            // - if left and right schemas are different, we combine (join) 
the schema to include all fields
-            // - if left and right schemas are the same, we handle the 
duplicate fields by using `build_join_schema()`, which discard the unused schema
-            // TODO: Handle duplicate fields error for other join types 
(non-semi/anti). The current approach does not work due to Substrait's inability
-            //       to encode aliases
-            let join_schema = match left.schema().join(right.schema()) {
-                Ok(schema) => Ok(schema),
-                Err(DataFusionError::SchemaError(
-                    datafusion::common::SchemaError::DuplicateQualifiedField {
-                        qualifier: _,
-                        name: _,
-                    },
-                )) => build_join_schema(left.schema(), right.schema(), 
&join_type),
-                Err(e) => Err(e),
+            let in_join_schema = left.schema().join(right.schema())?;
+            // Parse post join filter if exists
+            let join_filter = match &join.post_join_filter {
+                Some(filter) => {
+                    let parsed_filter =
+                        from_substrait_rex(filter, &in_join_schema, 
extensions).await?;
+                    Some(parsed_filter.as_ref().clone())
+                }
+                None => None,
             };
-            let on = from_substrait_rex(
-                join.expression.as_ref().unwrap(),
-                &join_schema?,
-                extensions,
-            )
-            .await?;
-            let predicates = split_conjunction(&on);
-            // TODO: collect only one null_eq_null
-            let join_exprs: Vec<(Column, Column, bool)> = predicates
-                .iter()
-                .map(|p| match p {
-                    Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
-                        match (left.as_ref(), right.as_ref()) {
-                            (Expr::Column(l), Expr::Column(r)) => match op {
-                                Operator::Eq => Ok((l.clone(), r.clone(), 
false)),
-                                Operator::IsNotDistinctFrom => {
-                                    Ok((l.clone(), r.clone(), true))
+            // If join expression exists, parse the `on` condition expression, 
build join and return
+            // Otherwise, build join with koin filter, without join keys
+            match &join.expression.as_ref() {
+                Some(expr) => {
+                    let on =
+                        from_substrait_rex(expr, &in_join_schema, 
extensions).await?;
+                    let predicates = split_conjunction(&on);
+                    // TODO: collect only one null_eq_null
+                    let join_exprs: Vec<(Column, Column, bool)> = predicates
+                        .iter()
+                        .map(|p| {
+                            match p {
+                            Expr::BinaryExpr(BinaryExpr { left, op, right }) 
=> {
+                                match (left.as_ref(), right.as_ref()) {
+                                    (Expr::Column(l), Expr::Column(r)) => 
match op {
+                                        Operator::Eq => Ok((l.clone(), 
r.clone(), false)),
+                                        Operator::IsNotDistinctFrom => {
+                                            Ok((l.clone(), r.clone(), true))
+                                        }
+                                        _ => Err(DataFusionError::Plan(
+                                            "invalid join condition 
op".to_string(),
+                                        )),
+                                    },
+                                    _ => Err(DataFusionError::Plan(
+                                        "invalid join condition 
expression".to_string(),
+                                    )),
                                 }
-                                _ => Err(DataFusionError::Plan(
-                                    "invalid join condition op".to_string(),
-                                )),
-                            },
+                            }
                             _ => Err(DataFusionError::Plan(
-                                "invalid join condition 
expression".to_string(),
+                                "Non-binary expression is not supported in 
join condition"
+                                    .to_string(),
                             )),
                         }
-                    }
-                    _ => Err(DataFusionError::Plan(
-                        "Non-binary expression is not supported in join 
condition"
-                            .to_string(),
+                        })
+                        .collect::<Result<Vec<_>>>()?;
+                    let (left_cols, right_cols, null_eq_nulls): (Vec<_>, 
Vec<_>, Vec<_>) =
+                        itertools::multiunzip(join_exprs);
+                    left.join_detailed(
+                        right.build()?,
+                        join_type,
+                        (left_cols, right_cols),
+                        join_filter,
+                        null_eq_nulls[0],
+                    )?
+                    .build()
+                }
+                None => match &join_filter {
+                    Some(_) => left
+                        .join(
+                            right.build()?,
+                            join_type,
+                            (Vec::<Column>::new(), Vec::<Column>::new()),
+                            join_filter,
+                        )?
+                        .build(),
+                    None => Err(DataFusionError::Plan(
+                        "Join without join keys require a valid 
filter".to_string(),
                     )),
-                })
-                .collect::<Result<Vec<_>>>()?;
-            let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>, 
Vec<_>) =
-                itertools::multiunzip(join_exprs);
-            left.join_detailed(
-                right.build()?,
-                join_type,
-                (left_cols, right_cols),
-                None,
-                null_eq_nulls[0],
-            )?
-            .build()
+                },
+            }
         }
         Some(RelType::Read(read)) => match &read.as_ref().read_type {
             Some(ReadType::NamedTable(nt)) => {
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index 6cc9eec75b..5e7ee267c4 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -17,6 +17,7 @@
 
 use std::collections::HashMap;
 use std::ops::Deref;
+use std::sync::Arc;
 
 use datafusion::{
     arrow::datatypes::{DataType, TimeUnit},
@@ -266,17 +267,25 @@ pub fn to_substrait_rel(
             let right = to_substrait_rel(join.right.as_ref(), ctx, 
extension_info)?;
             let join_type = to_substrait_jointype(join.join_type);
             // we only support basic joins so return an error for anything not 
yet supported
-            if join.filter.is_some() {
-                return Err(DataFusionError::NotImplemented("join 
filter".to_string()));
-            }
             match join.join_constraint {
                 JoinConstraint::On => {}
-                _ => {
+                JoinConstraint::Using => {
                     return Err(DataFusionError::NotImplemented(
-                        "join constraint".to_string(),
+                        "join constraint: `using`".to_string(),
                     ))
                 }
             }
+            // parse filter if exists
+            let in_join_schema = join.left.schema().join(join.right.schema())?;
+            let join_filter = match &join.filter {
+                Some(filter) => Some(Box::new(to_substrait_rex(
+                    filter,
+                    &Arc::new(in_join_schema),
+                    0,
+                    extension_info,
+                )?)),
+                None => None,
+            };
             // map the left and right columns to binary expressions in the 
form `l = r`
             // build a single expression for the ON condition, such as `l.a = 
r.a AND l.b = r.b`
             let eq_op = if join.null_equals_null {
@@ -285,20 +294,23 @@ pub fn to_substrait_rel(
                 Operator::Eq
             };
 
+            let join_expr = to_substrait_join_expr(
+                &join.on,
+                eq_op,
+                join.left.schema(),
+                join.right.schema(),
+                extension_info,
+            )?
+            .map(Box::new);
+
             Ok(Box::new(Rel {
                 rel_type: Some(RelType::Join(Box::new(JoinRel {
                     common: None,
                     left: Some(left),
                     right: Some(right),
                     r#type: join_type as i32,
-                    expression: Some(Box::new(to_substrait_join_expr(
-                        &join.on,
-                        eq_op,
-                        join.left.schema(),
-                        join.right.schema(),
-                        extension_info,
-                    )?)),
-                    post_join_filter: None,
+                    expression: join_expr,
+                    post_join_filter: join_filter,
                     advanced_extension: None,
                 }))),
             }))
@@ -395,7 +407,7 @@ fn to_substrait_join_expr(
         Vec<extensions::SimpleExtensionDeclaration>,
         HashMap<String, u32>,
     ),
-) -> Result<Expression> {
+) -> Result<Option<Expression>> {
     // Only support AND conjunction for each binary expression in join 
conditions
     let mut exprs: Vec<Expression> = vec![];
     for (left, right) in join_conditions {
@@ -411,12 +423,10 @@ fn to_substrait_join_expr(
         // AND with existing expression
         exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info));
     }
-    let join_expr: Expression = exprs
-        .into_iter()
-        .reduce(|acc: Expression, e: Expression| {
+    let join_expr: Option<Expression> =
+        exprs.into_iter().reduce(|acc: Expression, e: Expression| {
             make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info)
-        })
-        .unwrap();
+        });
     Ok(join_expr)
 }
 
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 26cf845232..1d1efb2e8d 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -342,11 +342,11 @@ async fn roundtrip_inlist_2() -> Result<()> {
 // Test with length > 
datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST
 async fn roundtrip_inlist_3() -> Result<()> {
     let inlist = (0..THRESHOLD_INLINE_INLIST + 1)
-        .map(|i| format!("'{}'", i))
+        .map(|i| format!("'{i}'"))
         .collect::<Vec<_>>()
         .join(", ");
 
-    roundtrip(&format!("SELECT * FROM data WHERE f IN ({})", inlist)).await
+    roundtrip(&format!("SELECT * FROM data WHERE f IN ({inlist})")).await
 }
 
 #[tokio::test]
@@ -359,6 +359,30 @@ async fn roundtrip_inner_join() -> Result<()> {
     roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await
 }
 
+#[tokio::test]
+async fn roundtrip_non_equi_inner_join() -> Result<()> {
+    roundtrip("SELECT data.a FROM data JOIN data2 ON data.a <> data2.a").await
+}
+
+#[tokio::test]
+async fn roundtrip_non_equi_join() -> Result<()> {
+    roundtrip(
+        "SELECT data.a FROM data, data2 WHERE data.a = data2.a AND data.e > 
data2.a",
+    )
+    .await
+}
+
+#[tokio::test]
+async fn roundtrip_exists_filter() -> Result<()> {
+    assert_expected_plan(
+        "SELECT b FROM data d1 WHERE EXISTS (SELECT * FROM data2 d2 WHERE d2.a 
= d1.a AND d2.e != d1.e)",
+        "Projection: data.b\
+        \n  LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e AS 
Int64)\
+        \n    TableScan: data projection=[a, b, e]\
+        \n    TableScan: data2 projection=[a, e]"
+    ).await
+}
+
 #[tokio::test]
 async fn inner_join() -> Result<()> {
     assert_expected_plan(
diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/roundtrip_logical_plan.rs
index 5054042501..e49be18a52 100644
--- a/datafusion/substrait/tests/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs
@@ -349,11 +349,11 @@ mod tests {
     // Test with length > 
datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST
     async fn roundtrip_inlist_3() -> Result<()> {
         let inlist = (0..THRESHOLD_INLINE_INLIST + 1)
-            .map(|i| format!("'{}'", i))
+            .map(|i| format!("'{i}'"))
             .collect::<Vec<_>>()
             .join(", ");
 
-        roundtrip(&format!("SELECT * FROM data WHERE f IN ({})", inlist)).await
+        roundtrip(&format!("SELECT * FROM data WHERE f IN ({inlist})")).await
     }
 
     #[tokio::test]
@@ -366,6 +366,30 @@ mod tests {
         roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = 
data2.a").await
     }
 
+    #[tokio::test]
+    async fn roundtrip_non_equi_inner_join() -> Result<()> {
+        roundtrip("SELECT data.a FROM data JOIN data2 ON data.a <> 
data2.a").await
+    }
+
+    #[tokio::test]
+    async fn roundtrip_non_equi_join() -> Result<()> {
+        roundtrip(
+            "SELECT data.a FROM data, data2 WHERE data.a = data2.a AND data.e 
> data2.a",
+        )
+        .await
+    }
+
+    #[tokio::test]
+    async fn roundtrip_exists_filter() -> Result<()> {
+        assert_expected_plan(
+            "SELECT b FROM data d1 WHERE EXISTS (SELECT * FROM data2 d2 WHERE 
d2.a = d1.a AND d2.e != d1.e)",
+            "Projection: data.b\
+            \n  LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e 
AS Int64)\
+            \n    TableScan: data projection=[a, b, e]\
+            \n    TableScan: data2 projection=[a, e]"
+        ).await
+    }
+
     #[tokio::test]
     async fn inner_join() -> Result<()> {
         assert_expected_plan(

Reply via email to