This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ce3a0bb063 feat: implement substrait join filter support (#6868)
ce3a0bb063 is described below
commit ce3a0bb0630eab47a3b66448d4fa6f1fa8314b8d
Author: Nuttiiya Seekhao <[email protected]>
AuthorDate: Sun Jul 9 10:00:05 2023 -0400
feat: implement substrait join filter support (#6868)
* Added join filter support
* clippy fix
* Added Arc import
* cargo fmt
---
datafusion/substrait/src/logical_plan/consumer.rs | 117 ++++++++++++---------
datafusion/substrait/src/logical_plan/producer.rs | 48 +++++----
.../tests/cases/roundtrip_logical_plan.rs | 28 ++++-
.../substrait/tests/roundtrip_logical_plan.rs | 28 ++++-
4 files changed, 146 insertions(+), 75 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index ffa935be1f..dc06b64a9e 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -22,8 +22,8 @@ use datafusion::logical_expr::{
aggregate_function, window_function::find_df_window_func, BinaryExpr,
BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
};
-use datafusion::logical_expr::{build_join_schema, Extension,
LogicalPlanBuilder};
use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits};
+use datafusion::logical_expr::{Extension, LogicalPlanBuilder};
use datafusion::prelude::JoinType;
use datafusion::sql::TableReference;
use datafusion::{
@@ -344,63 +344,76 @@ pub async fn from_substrait_rel(
let join_type = from_substrait_jointype(join.r#type)?;
// The join condition expression needs full input schema and not
the output schema from join since we lose columns from
// certain join types such as semi and anti joins
- // - if left and right schemas are different, we combine (join)
the schema to include all fields
- // - if left and right schemas are the same, we handle the
duplicate fields by using `build_join_schema()`, which discard the unused schema
- // TODO: Handle duplicate fields error for other join types
(non-semi/anti). The current approach does not work due to Substrait's inability
- // to encode aliases
- let join_schema = match left.schema().join(right.schema()) {
- Ok(schema) => Ok(schema),
- Err(DataFusionError::SchemaError(
- datafusion::common::SchemaError::DuplicateQualifiedField {
- qualifier: _,
- name: _,
- },
- )) => build_join_schema(left.schema(), right.schema(),
&join_type),
- Err(e) => Err(e),
+ let in_join_schema = left.schema().join(right.schema())?;
+ // Parse post join filter if exists
+ let join_filter = match &join.post_join_filter {
+ Some(filter) => {
+ let parsed_filter =
+ from_substrait_rex(filter, &in_join_schema,
extensions).await?;
+ Some(parsed_filter.as_ref().clone())
+ }
+ None => None,
};
- let on = from_substrait_rex(
- join.expression.as_ref().unwrap(),
- &join_schema?,
- extensions,
- )
- .await?;
- let predicates = split_conjunction(&on);
- // TODO: collect only one null_eq_null
- let join_exprs: Vec<(Column, Column, bool)> = predicates
- .iter()
- .map(|p| match p {
- Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
- match (left.as_ref(), right.as_ref()) {
- (Expr::Column(l), Expr::Column(r)) => match op {
- Operator::Eq => Ok((l.clone(), r.clone(),
false)),
- Operator::IsNotDistinctFrom => {
- Ok((l.clone(), r.clone(), true))
+ // If join expression exists, parse the `on` condition expression,
build join and return
+ // Otherwise, build join with koin filter, without join keys
+ match &join.expression.as_ref() {
+ Some(expr) => {
+ let on =
+ from_substrait_rex(expr, &in_join_schema,
extensions).await?;
+ let predicates = split_conjunction(&on);
+ // TODO: collect only one null_eq_null
+ let join_exprs: Vec<(Column, Column, bool)> = predicates
+ .iter()
+ .map(|p| {
+ match p {
+ Expr::BinaryExpr(BinaryExpr { left, op, right })
=> {
+ match (left.as_ref(), right.as_ref()) {
+ (Expr::Column(l), Expr::Column(r)) =>
match op {
+ Operator::Eq => Ok((l.clone(),
r.clone(), false)),
+ Operator::IsNotDistinctFrom => {
+ Ok((l.clone(), r.clone(), true))
+ }
+ _ => Err(DataFusionError::Plan(
+ "invalid join condition
op".to_string(),
+ )),
+ },
+ _ => Err(DataFusionError::Plan(
+ "invalid join condition
expression".to_string(),
+ )),
}
- _ => Err(DataFusionError::Plan(
- "invalid join condition op".to_string(),
- )),
- },
+ }
_ => Err(DataFusionError::Plan(
- "invalid join condition
expression".to_string(),
+ "Non-binary expression is not supported in
join condition"
+ .to_string(),
)),
}
- }
- _ => Err(DataFusionError::Plan(
- "Non-binary expression is not supported in join
condition"
- .to_string(),
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let (left_cols, right_cols, null_eq_nulls): (Vec<_>,
Vec<_>, Vec<_>) =
+ itertools::multiunzip(join_exprs);
+ left.join_detailed(
+ right.build()?,
+ join_type,
+ (left_cols, right_cols),
+ join_filter,
+ null_eq_nulls[0],
+ )?
+ .build()
+ }
+ None => match &join_filter {
+ Some(_) => left
+ .join(
+ right.build()?,
+ join_type,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ join_filter,
+ )?
+ .build(),
+ None => Err(DataFusionError::Plan(
+ "Join without join keys require a valid
filter".to_string(),
)),
- })
- .collect::<Result<Vec<_>>>()?;
- let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>,
Vec<_>) =
- itertools::multiunzip(join_exprs);
- left.join_detailed(
- right.build()?,
- join_type,
- (left_cols, right_cols),
- None,
- null_eq_nulls[0],
- )?
- .build()
+ },
+ }
}
Some(RelType::Read(read)) => match &read.as_ref().read_type {
Some(ReadType::NamedTable(nt)) => {
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 6cc9eec75b..5e7ee267c4 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -17,6 +17,7 @@
use std::collections::HashMap;
use std::ops::Deref;
+use std::sync::Arc;
use datafusion::{
arrow::datatypes::{DataType, TimeUnit},
@@ -266,17 +267,25 @@ pub fn to_substrait_rel(
let right = to_substrait_rel(join.right.as_ref(), ctx,
extension_info)?;
let join_type = to_substrait_jointype(join.join_type);
// we only support basic joins so return an error for anything not
yet supported
- if join.filter.is_some() {
- return Err(DataFusionError::NotImplemented("join
filter".to_string()));
- }
match join.join_constraint {
JoinConstraint::On => {}
- _ => {
+ JoinConstraint::Using => {
return Err(DataFusionError::NotImplemented(
- "join constraint".to_string(),
+ "join constraint: `using`".to_string(),
))
}
}
+ // parse filter if exists
+ let in_join_schema = join.left.schema().join(join.right.schema())?;
+ let join_filter = match &join.filter {
+ Some(filter) => Some(Box::new(to_substrait_rex(
+ filter,
+ &Arc::new(in_join_schema),
+ 0,
+ extension_info,
+ )?)),
+ None => None,
+ };
// map the left and right columns to binary expressions in the
form `l = r`
// build a single expression for the ON condition, such as `l.a =
r.a AND l.b = r.b`
let eq_op = if join.null_equals_null {
@@ -285,20 +294,23 @@ pub fn to_substrait_rel(
Operator::Eq
};
+ let join_expr = to_substrait_join_expr(
+ &join.on,
+ eq_op,
+ join.left.schema(),
+ join.right.schema(),
+ extension_info,
+ )?
+ .map(Box::new);
+
Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
common: None,
left: Some(left),
right: Some(right),
r#type: join_type as i32,
- expression: Some(Box::new(to_substrait_join_expr(
- &join.on,
- eq_op,
- join.left.schema(),
- join.right.schema(),
- extension_info,
- )?)),
- post_join_filter: None,
+ expression: join_expr,
+ post_join_filter: join_filter,
advanced_extension: None,
}))),
}))
@@ -395,7 +407,7 @@ fn to_substrait_join_expr(
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
-) -> Result<Expression> {
+) -> Result<Option<Expression>> {
// Only support AND conjunction for each binary expression in join
conditions
let mut exprs: Vec<Expression> = vec![];
for (left, right) in join_conditions {
@@ -411,12 +423,10 @@ fn to_substrait_join_expr(
// AND with existing expression
exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info));
}
- let join_expr: Expression = exprs
- .into_iter()
- .reduce(|acc: Expression, e: Expression| {
+ let join_expr: Option<Expression> =
+ exprs.into_iter().reduce(|acc: Expression, e: Expression| {
make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info)
- })
- .unwrap();
+ });
Ok(join_expr)
}
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 26cf845232..1d1efb2e8d 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -342,11 +342,11 @@ async fn roundtrip_inlist_2() -> Result<()> {
// Test with length >
datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST
async fn roundtrip_inlist_3() -> Result<()> {
let inlist = (0..THRESHOLD_INLINE_INLIST + 1)
- .map(|i| format!("'{}'", i))
+ .map(|i| format!("'{i}'"))
.collect::<Vec<_>>()
.join(", ");
- roundtrip(&format!("SELECT * FROM data WHERE f IN ({})", inlist)).await
+ roundtrip(&format!("SELECT * FROM data WHERE f IN ({inlist})")).await
}
#[tokio::test]
@@ -359,6 +359,30 @@ async fn roundtrip_inner_join() -> Result<()> {
roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = data2.a").await
}
+#[tokio::test]
+async fn roundtrip_non_equi_inner_join() -> Result<()> {
+ roundtrip("SELECT data.a FROM data JOIN data2 ON data.a <> data2.a").await
+}
+
+#[tokio::test]
+async fn roundtrip_non_equi_join() -> Result<()> {
+ roundtrip(
+ "SELECT data.a FROM data, data2 WHERE data.a = data2.a AND data.e >
data2.a",
+ )
+ .await
+}
+
+#[tokio::test]
+async fn roundtrip_exists_filter() -> Result<()> {
+ assert_expected_plan(
+ "SELECT b FROM data d1 WHERE EXISTS (SELECT * FROM data2 d2 WHERE d2.a
= d1.a AND d2.e != d1.e)",
+ "Projection: data.b\
+ \n LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e AS
Int64)\
+ \n TableScan: data projection=[a, b, e]\
+ \n TableScan: data2 projection=[a, e]"
+ ).await
+}
+
#[tokio::test]
async fn inner_join() -> Result<()> {
assert_expected_plan(
diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/roundtrip_logical_plan.rs
index 5054042501..e49be18a52 100644
--- a/datafusion/substrait/tests/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs
@@ -349,11 +349,11 @@ mod tests {
// Test with length >
datafusion_optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST
async fn roundtrip_inlist_3() -> Result<()> {
let inlist = (0..THRESHOLD_INLINE_INLIST + 1)
- .map(|i| format!("'{}'", i))
+ .map(|i| format!("'{i}'"))
.collect::<Vec<_>>()
.join(", ");
- roundtrip(&format!("SELECT * FROM data WHERE f IN ({})", inlist)).await
+ roundtrip(&format!("SELECT * FROM data WHERE f IN ({inlist})")).await
}
#[tokio::test]
@@ -366,6 +366,30 @@ mod tests {
roundtrip("SELECT data.a FROM data JOIN data2 ON data.a =
data2.a").await
}
+ #[tokio::test]
+ async fn roundtrip_non_equi_inner_join() -> Result<()> {
+ roundtrip("SELECT data.a FROM data JOIN data2 ON data.a <>
data2.a").await
+ }
+
+ #[tokio::test]
+ async fn roundtrip_non_equi_join() -> Result<()> {
+ roundtrip(
+ "SELECT data.a FROM data, data2 WHERE data.a = data2.a AND data.e
> data2.a",
+ )
+ .await
+ }
+
+ #[tokio::test]
+ async fn roundtrip_exists_filter() -> Result<()> {
+ assert_expected_plan(
+ "SELECT b FROM data d1 WHERE EXISTS (SELECT * FROM data2 d2 WHERE
d2.a = d1.a AND d2.e != d1.e)",
+ "Projection: data.b\
+ \n LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e
AS Int64)\
+ \n TableScan: data projection=[a, b, e]\
+ \n TableScan: data2 projection=[a, e]"
+ ).await
+ }
+
#[tokio::test]
async fn inner_join() -> Result<()> {
assert_expected_plan(