Blizzara commented on code in PR #13931: URL: https://github.com/apache/datafusion/pull/13931#discussion_r1898902943
########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -185,257 +501,290 @@ pub fn to_substrait_extended_expr( })) } -/// Convert DataFusion LogicalPlan to Substrait Rel -#[allow(deprecated)] pub fn to_substrait_rel( + producer: &mut impl SubstraitProducer, plan: &LogicalPlan, - state: &dyn SubstraitPlanningState, - extensions: &mut Extensions, ) -> Result<Box<Rel>> { match plan { - LogicalPlan::TableScan(scan) => { - let projection = scan.projection.as_ref().map(|p| { - p.iter() - .map(|i| StructItem { - field: *i as i32, - child: None, - }) - .collect() - }); + LogicalPlan::Projection(plan) => producer.consume_projection(plan), + LogicalPlan::Filter(plan) => producer.consume_filter(plan), + LogicalPlan::Window(plan) => producer.consume_window(plan), + LogicalPlan::Aggregate(plan) => producer.consume_aggregate(plan), + LogicalPlan::Sort(plan) => producer.consume_sort(plan), + LogicalPlan::Join(plan) => producer.consume_join(plan), + LogicalPlan::Repartition(plan) => producer.consume_repartition(plan), + LogicalPlan::Union(plan) => producer.consume_union(plan), + LogicalPlan::TableScan(plan) => producer.consume_table_scan(plan), + LogicalPlan::EmptyRelation(plan) => producer.consume_empty_relation(plan), + LogicalPlan::SubqueryAlias(plan) => producer.consume_subquery_alias(plan), + LogicalPlan::Limit(plan) => producer.consume_limit(plan), + LogicalPlan::Values(plan) => producer.consume_values(plan), + LogicalPlan::Distinct(plan) => producer.consume_distinct(plan), + LogicalPlan::Extension(plan) => producer.consume_extension(plan), + _ => not_impl_err!("Unsupported plan type: {plan:?}")?, + } +} - let projection = projection.map(|struct_items| MaskExpression { - select: Some(StructSelect { struct_items }), - maintain_singular_struct: false, - }); +pub fn from_table_scan( + _producer: &mut impl SubstraitProducer, + scan: &TableScan, +) -> Result<Box<Rel>> { + let projection = scan.projection.as_ref().map(|p| { + p.iter() + .map(|i| StructItem { + field: *i as i32, + child: None, + }) + .collect() + }); + + let projection = projection.map(|struct_items| MaskExpression { + select: Some(StructSelect { struct_items }), + maintain_singular_struct: false, + }); + + let table_schema = scan.source.schema().to_dfschema_ref()?; + let base_schema = to_substrait_named_struct(&table_schema)?; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(base_schema), + filter: None, + best_effort_filter: None, + projection, + advanced_extension: None, + read_type: Some(ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + })), + }))), + })) +} - let table_schema = scan.source.schema().to_dfschema_ref()?; - let base_schema = to_substrait_named_struct(&table_schema)?; +pub fn from_empty_relation(e: &EmptyRelation) -> Result<Box<Rel>> { + if e.produce_one_row { + return not_impl_err!("Producing a row from empty relation is unsupported"); + } + #[allow(deprecated)] Review Comment: I think previously it was allowed on even higher level so this is fine, but ooc, what's deprecated in all these? ########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -998,450 +1304,418 @@ pub fn make_binary_op_scalar_func( /// Convert DataFusion Expr to Substrait Rex /// /// # Arguments -/// -/// * `expr` - DataFusion expression to be parse into a Substrait expression -/// * `schema` - DataFusion input schema for looking up field qualifiers -/// * `col_ref_offset` - Offset for calculating Substrait field reference indices. -/// This should only be set by caller with more than one input relations i.e. Join. -/// Substrait expects one set of indices when joining two relations. -/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` -/// relation will have column indices from `0` to `n-1`, however, Substrait will expect -/// the `right` indices to be offset by the `left`. This means Substrait will expect to -/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example: -/// ```SELECT * -/// FROM t1 -/// JOIN t2 -/// ON t1.c1 = t2.c0;``` -/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1] -/// the join condition should become -/// `col_ref(1) = col_ref(3 + 0)` -/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index -/// of the join key column from `right` -/// * `extensions` - Substrait extension info. Contains registered function information -#[allow(deprecated)] +/// * `producer` - SubstraitProducer implementation which the handles the actual conversion +/// * `expr` - DataFusion expression to convert into a Substrait expression +/// * `schema` - DataFusion input schema for looking up columns pub fn to_substrait_rex( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, expr: &Expr, schema: &DFSchemaRef, - col_ref_offset: usize, - extensions: &mut Extensions, ) -> Result<Expression> { match expr { - Expr::InList(InList { - expr, - list, - negated, - }) => { - let substrait_list = list - .iter() - .map(|x| to_substrait_rex(state, x, schema, col_ref_offset, extensions)) - .collect::<Result<Vec<Expression>>>()?; - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - - let substrait_or_list = Expression { - rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { - value: Some(Box::new(substrait_expr)), - options: substrait_list, - }))), - }; - - if *negated { - let function_anchor = extensions.register_function("not".to_string()); - - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_or_list)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_or_list) - } + Expr::Alias(expr) => producer.consume_alias(expr, schema), + Expr::Column(expr) => producer.consume_column(expr, schema), + Expr::Literal(expr) => producer.consume_literal(expr), + Expr::BinaryExpr(expr) => producer.consume_binary_expr(expr, schema), + Expr::Like(expr) => producer.consume_like(expr, schema), + Expr::SimilarTo(_) => not_impl_err!("SimilarTo is not supported"), + Expr::Not(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotNull(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNull(_) => producer.consume_unary_expr(expr, schema), + Expr::IsTrue(_) => producer.consume_unary_expr(expr, schema), + Expr::IsFalse(_) => producer.consume_unary_expr(expr, schema), + Expr::IsUnknown(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotTrue(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotFalse(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotUnknown(_) => producer.consume_unary_expr(expr, schema), + Expr::Negative(_) => producer.consume_unary_expr(expr, schema), + Expr::Between(expr) => producer.consume_between(expr, schema), + Expr::Case(expr) => producer.consume_case(expr, schema), + Expr::Cast(expr) => producer.consume_cast(expr, schema), + Expr::TryCast(expr) => producer.consume_try_cast(expr, schema), + Expr::ScalarFunction(expr) => producer.consume_scalar_function(expr, schema), + Expr::AggregateFunction(_) => { + internal_err!( + "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" + ) } - Expr::ScalarFunction(fun) => { - let mut arguments: Vec<FunctionArgument> = vec![]; - for arg in &fun.args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - state, - arg, - schema, - col_ref_offset, - extensions, - )?)), - }); - } + Expr::WindowFunction(expr) => producer.consume_window_function(expr, schema), + Expr::InList(expr) => producer.consume_in_list(expr, schema), + Expr::InSubquery(expr) => producer.consume_in_subquery(expr, schema), + _ => not_impl_err!("Cannot convert {expr:?} to Substrait"), Review Comment: same here - is there a lot of stuff behind this catch-all? or would it be easy enough to just list them all for clarity? ########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -101,14 +105,330 @@ use substrait::{ version, }; -use super::state::SubstraitPlanningState; +/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use std::sync::Arc; +/// # use substrait::proto::{Expression, Rel}; +/// # use substrait::proto::rel::RelType; +/// # use datafusion::common::DFSchemaRef; +/// # use datafusion::error::Result; +/// # use datafusion::execution::SessionState; +/// # use datafusion::logical_expr::{Between, Extension, Projection}; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; +/// +/// struct CustomSubstraitProducer { +/// extensions: Extensions, +/// state: Arc<SessionState>, +/// } +/// +/// impl SubstraitProducer for CustomSubstraitProducer { +/// +/// fn register_function(&mut self, signature: String) -> u32 { +/// self.extensions.register_function(signature) +/// } +/// +/// fn get_extensions(self) -> Extensions { +/// self.extensions +/// } +/// +/// // You can set additional metadata on the Rels you produce +/// fn consume_projection(&mut self, plan: &Projection) -> Result<Box<Rel>> { +/// let mut rel = from_projection(self, plan)?; +/// match rel.rel_type { +/// Some(RelType::Project(mut project)) => { +/// let mut project = project.clone(); +/// // set common metadata or advanced extension +/// project.common = None; +/// project.advanced_extension = None; +/// Ok(Box::new(Rel { +/// rel_type: Some(RelType::Project(project)), +/// })) +/// } +/// rel_type => Ok(Box::new(Rel { rel_type })), +/// } +/// } +/// +/// // You can tweak how you convert expressions for your target system +/// fn consume_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result<Expression> { +/// // add your own encoding for Between +/// todo!() +/// } +/// +/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait +/// fn consume_extension(&mut self, _plan: &Extension) -> Result<Box<Rel>> { +/// // implement your own serializer into Substrait +/// todo!() +/// } +/// } +/// ``` +pub trait SubstraitProducer: Send + Sync + Sized { + /// Within a Substrait plan, functions are referenced using function anchors that are stored at + /// the top level of the [Plan] within + /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) + /// messages. + /// + /// When given a function signature, this method should return the existing anchor for it if + /// there is one. Otherwise, it should generate a new anchor. + fn register_function(&mut self, signature: String) -> u32; + + /// Consume the producer to generate the [Extensions] for the Substrait plan based on the + /// functions that have been registered + fn get_extensions(self) -> Extensions; + + // Logical Plan Methods + // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn consume_plan(&mut self, plan: &LogicalPlan) -> Result<Box<Rel>> { + to_substrait_rel(self, plan) + } Review Comment: Yeah, I agree that "produce" doesn't make sense here, as it's more logical to think of the functions in terms of processing DF concepts rather than in producing Substrait things. However, the "consume" in producer can be a bit confusing w.r.t "consumer" - would it make sense to use some alternative, like "from" (which is already used for the functions) or "handle", "process", or something? ########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -730,32 +1035,23 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result<NamedStruct> { } fn to_substrait_join_expr( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, join_conditions: &Vec<(Expr, Expr)>, eq_op: Operator, - left_schema: &DFSchemaRef, - right_schema: &DFSchemaRef, - extensions: &mut Extensions, + join_schema: &DFSchemaRef, ) -> Result<Option<Expression>> { // Only support AND conjunction for each binary expression in join conditions let mut exprs: Vec<Expression> = vec![]; for (left, right) in join_conditions { - // Parse left - let l = to_substrait_rex(state, left, left_schema, 0, extensions)?; - // Parse right - let r = to_substrait_rex( - state, - right, - right_schema, - left_schema.fields().len(), // offset to return the correct index - extensions, - )?; + let l = producer.consume_expr(left, join_schema)?; + let r = producer.consume_expr(right, join_schema)?; Review Comment: I'm not sure I follow how the requalify_sides_if_needed (added by me in https://github.com/apache/datafusion/pull/11049, just for reference) affects the need for this handling, given it's on the consumer side and this is on the producer. https://github.com/apache/datafusion/pull/6135/files#r1215611954 seems to indicate the re-added test doesn't catch the issue. Does this change affect the produced substrait plan? ########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -101,14 +105,330 @@ use substrait::{ version, }; -use super::state::SubstraitPlanningState; +/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use std::sync::Arc; +/// # use substrait::proto::{Expression, Rel}; +/// # use substrait::proto::rel::RelType; +/// # use datafusion::common::DFSchemaRef; +/// # use datafusion::error::Result; +/// # use datafusion::execution::SessionState; +/// # use datafusion::logical_expr::{Between, Extension, Projection}; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; +/// +/// struct CustomSubstraitProducer { +/// extensions: Extensions, +/// state: Arc<SessionState>, +/// } +/// +/// impl SubstraitProducer for CustomSubstraitProducer { +/// +/// fn register_function(&mut self, signature: String) -> u32 { +/// self.extensions.register_function(signature) +/// } +/// +/// fn get_extensions(self) -> Extensions { +/// self.extensions +/// } +/// +/// // You can set additional metadata on the Rels you produce +/// fn consume_projection(&mut self, plan: &Projection) -> Result<Box<Rel>> { +/// let mut rel = from_projection(self, plan)?; +/// match rel.rel_type { +/// Some(RelType::Project(mut project)) => { +/// let mut project = project.clone(); +/// // set common metadata or advanced extension +/// project.common = None; +/// project.advanced_extension = None; +/// Ok(Box::new(Rel { +/// rel_type: Some(RelType::Project(project)), +/// })) +/// } +/// rel_type => Ok(Box::new(Rel { rel_type })), +/// } +/// } +/// +/// // You can tweak how you convert expressions for your target system +/// fn consume_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result<Expression> { +/// // add your own encoding for Between +/// todo!() +/// } +/// +/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait +/// fn consume_extension(&mut self, _plan: &Extension) -> Result<Box<Rel>> { +/// // implement your own serializer into Substrait +/// todo!() +/// } +/// } +/// ``` +pub trait SubstraitProducer: Send + Sync + Sized { + /// Within a Substrait plan, functions are referenced using function anchors that are stored at + /// the top level of the [Plan] within + /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) + /// messages. + /// + /// When given a function signature, this method should return the existing anchor for it if + /// there is one. Otherwise, it should generate a new anchor. + fn register_function(&mut self, signature: String) -> u32; + + /// Consume the producer to generate the [Extensions] for the Substrait plan based on the + /// functions that have been registered + fn get_extensions(self) -> Extensions; + + // Logical Plan Methods + // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn consume_plan(&mut self, plan: &LogicalPlan) -> Result<Box<Rel>> { + to_substrait_rel(self, plan) + } + + fn consume_projection(&mut self, plan: &Projection) -> Result<Box<Rel>> { + from_projection(self, plan) + } + + fn consume_filter(&mut self, plan: &Filter) -> Result<Box<Rel>> { + from_filter(self, plan) + } + + fn consume_window(&mut self, plan: &Window) -> Result<Box<Rel>> { + from_window(self, plan) + } + + fn consume_aggregate(&mut self, plan: &Aggregate) -> Result<Box<Rel>> { + from_aggregate(self, plan) + } + + fn consume_sort(&mut self, plan: &Sort) -> Result<Box<Rel>> { + from_sort(self, plan) + } + + fn consume_join(&mut self, plan: &Join) -> Result<Box<Rel>> { + from_join(self, plan) + } + + fn consume_repartition(&mut self, plan: &Repartition) -> Result<Box<Rel>> { + from_repartition(self, plan) + } + + fn consume_union(&mut self, plan: &Union) -> Result<Box<Rel>> { + from_union(self, plan) + } + + fn consume_table_scan(&mut self, plan: &TableScan) -> Result<Box<Rel>> { + from_table_scan(self, plan) + } + + fn consume_empty_relation(&mut self, plan: &EmptyRelation) -> Result<Box<Rel>> { + from_empty_relation(plan) + } + + fn consume_subquery_alias(&mut self, plan: &SubqueryAlias) -> Result<Box<Rel>> { + from_subquery_alias(self, plan) + } + + fn consume_limit(&mut self, plan: &Limit) -> Result<Box<Rel>> { + from_limit(self, plan) + } + + fn consume_values(&mut self, plan: &Values) -> Result<Box<Rel>> { + from_values(self, plan) + } + + fn consume_distinct(&mut self, plan: &Distinct) -> Result<Box<Rel>> { + from_distinct(self, plan) + } + + fn consume_extension(&mut self, _plan: &Extension) -> Result<Box<Rel>> { + substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") + } + + // Expression Methods + // There is one method per DataFusion Expr to allow for easy overriding of producer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn consume_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result<Expression> { + to_substrait_rex(self, expr, schema) + } + + fn consume_alias( + &mut self, + alias: &Alias, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_alias(self, alias, schema) + } + + fn consume_column( + &mut self, + column: &Column, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_column(column, schema) + } + + fn consume_literal(&mut self, value: &ScalarValue) -> Result<Expression> { + from_literal(self, value) + } + + fn consume_binary_expr( + &mut self, + expr: &BinaryExpr, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_binary_expr(self, expr, schema) + } + + fn consume_like(&mut self, like: &Like, schema: &DFSchemaRef) -> Result<Expression> { + from_like(self, like, schema) + } + + /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative + fn consume_unary_expr( + &mut self, + expr: &Expr, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_unary_expr(self, expr, schema) + } + + fn consume_between( + &mut self, + between: &Between, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_between(self, between, schema) + } + + fn consume_case(&mut self, case: &Case, schema: &DFSchemaRef) -> Result<Expression> { + from_case(self, case, schema) + } + + fn consume_cast(&mut self, cast: &Cast, schema: &DFSchemaRef) -> Result<Expression> { + from_cast(self, cast, schema) + } + + fn consume_try_cast( + &mut self, + cast: &TryCast, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_try_cast(self, cast, schema) + } + + fn consume_scalar_function( + &mut self, + scalar_fn: &expr::ScalarFunction, Review Comment: I guess this is to de-conflict with Substrait's ScalarFunction, which is imported? 👍 ########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -1727,29 +2001,26 @@ fn make_substrait_window_function( } } -#[allow(deprecated)] #[allow(clippy::too_many_arguments)] Review Comment: can this be removed now too, or do we still have too many args? ########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -185,257 +501,290 @@ pub fn to_substrait_extended_expr( })) } -/// Convert DataFusion LogicalPlan to Substrait Rel -#[allow(deprecated)] pub fn to_substrait_rel( + producer: &mut impl SubstraitProducer, plan: &LogicalPlan, - state: &dyn SubstraitPlanningState, - extensions: &mut Extensions, ) -> Result<Box<Rel>> { match plan { - LogicalPlan::TableScan(scan) => { - let projection = scan.projection.as_ref().map(|p| { - p.iter() - .map(|i| StructItem { - field: *i as i32, - child: None, - }) - .collect() - }); + LogicalPlan::Projection(plan) => producer.consume_projection(plan), + LogicalPlan::Filter(plan) => producer.consume_filter(plan), + LogicalPlan::Window(plan) => producer.consume_window(plan), + LogicalPlan::Aggregate(plan) => producer.consume_aggregate(plan), + LogicalPlan::Sort(plan) => producer.consume_sort(plan), + LogicalPlan::Join(plan) => producer.consume_join(plan), + LogicalPlan::Repartition(plan) => producer.consume_repartition(plan), + LogicalPlan::Union(plan) => producer.consume_union(plan), + LogicalPlan::TableScan(plan) => producer.consume_table_scan(plan), + LogicalPlan::EmptyRelation(plan) => producer.consume_empty_relation(plan), + LogicalPlan::SubqueryAlias(plan) => producer.consume_subquery_alias(plan), + LogicalPlan::Limit(plan) => producer.consume_limit(plan), + LogicalPlan::Values(plan) => producer.consume_values(plan), + LogicalPlan::Distinct(plan) => producer.consume_distinct(plan), + LogicalPlan::Extension(plan) => producer.consume_extension(plan), + _ => not_impl_err!("Unsupported plan type: {plan:?}")?, + } +} - let projection = projection.map(|struct_items| MaskExpression { - select: Some(StructSelect { struct_items }), - maintain_singular_struct: false, - }); +pub fn from_table_scan( + _producer: &mut impl SubstraitProducer, Review Comment: > likely Indicates to me we shouldn't add it there yet, since there's a risk it won't be used :) And I think it'll be fine to add it later - it'll be an API break, but only for those customizing the usage, and at least it'll be a clear break. ########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -185,257 +501,290 @@ pub fn to_substrait_extended_expr( })) } -/// Convert DataFusion LogicalPlan to Substrait Rel -#[allow(deprecated)] pub fn to_substrait_rel( + producer: &mut impl SubstraitProducer, plan: &LogicalPlan, - state: &dyn SubstraitPlanningState, - extensions: &mut Extensions, ) -> Result<Box<Rel>> { match plan { - LogicalPlan::TableScan(scan) => { - let projection = scan.projection.as_ref().map(|p| { - p.iter() - .map(|i| StructItem { - field: *i as i32, - child: None, - }) - .collect() - }); + LogicalPlan::Projection(plan) => producer.consume_projection(plan), + LogicalPlan::Filter(plan) => producer.consume_filter(plan), + LogicalPlan::Window(plan) => producer.consume_window(plan), + LogicalPlan::Aggregate(plan) => producer.consume_aggregate(plan), + LogicalPlan::Sort(plan) => producer.consume_sort(plan), + LogicalPlan::Join(plan) => producer.consume_join(plan), + LogicalPlan::Repartition(plan) => producer.consume_repartition(plan), + LogicalPlan::Union(plan) => producer.consume_union(plan), + LogicalPlan::TableScan(plan) => producer.consume_table_scan(plan), + LogicalPlan::EmptyRelation(plan) => producer.consume_empty_relation(plan), + LogicalPlan::SubqueryAlias(plan) => producer.consume_subquery_alias(plan), + LogicalPlan::Limit(plan) => producer.consume_limit(plan), + LogicalPlan::Values(plan) => producer.consume_values(plan), + LogicalPlan::Distinct(plan) => producer.consume_distinct(plan), + LogicalPlan::Extension(plan) => producer.consume_extension(plan), + _ => not_impl_err!("Unsupported plan type: {plan:?}")?, Review Comment: is there a lot of options behind this `_`? might be nice to explicitly list them out, to make it clear what isn't supported yet ########## datafusion/substrait/tests/cases/roundtrip_logical_plan.rs: ########## @@ -571,6 +571,21 @@ async fn roundtrip_self_implicit_cross_join() -> Result<()> { roundtrip("SELECT left.a left_a, left.b, right.a right_a, right.c FROM data AS left, data AS right").await } +#[tokio::test] +async fn self_join_introduces_aliases() -> Result<()> { Review Comment: This is adding back [this test](https://github.com/apache/datafusion/pull/11049/files#diff-26c40258720a61cb066cd14651b9b3617557cf7394838b3bad6b4e45718b3122L645), right? I seem to have argued back then that it is unnecessary given the roundtrip_self_join test. ########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -101,14 +105,330 @@ use substrait::{ version, }; -use super::state::SubstraitPlanningState; +/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use std::sync::Arc; +/// # use substrait::proto::{Expression, Rel}; +/// # use substrait::proto::rel::RelType; +/// # use datafusion::common::DFSchemaRef; +/// # use datafusion::error::Result; +/// # use datafusion::execution::SessionState; +/// # use datafusion::logical_expr::{Between, Extension, Projection}; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; +/// +/// struct CustomSubstraitProducer { +/// extensions: Extensions, +/// state: Arc<SessionState>, +/// } +/// +/// impl SubstraitProducer for CustomSubstraitProducer { +/// +/// fn register_function(&mut self, signature: String) -> u32 { +/// self.extensions.register_function(signature) +/// } +/// +/// fn get_extensions(self) -> Extensions { +/// self.extensions +/// } +/// +/// // You can set additional metadata on the Rels you produce +/// fn consume_projection(&mut self, plan: &Projection) -> Result<Box<Rel>> { +/// let mut rel = from_projection(self, plan)?; +/// match rel.rel_type { +/// Some(RelType::Project(mut project)) => { +/// let mut project = project.clone(); +/// // set common metadata or advanced extension +/// project.common = None; +/// project.advanced_extension = None; +/// Ok(Box::new(Rel { +/// rel_type: Some(RelType::Project(project)), +/// })) +/// } +/// rel_type => Ok(Box::new(Rel { rel_type })), +/// } +/// } +/// +/// // You can tweak how you convert expressions for your target system +/// fn consume_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result<Expression> { +/// // add your own encoding for Between +/// todo!() +/// } +/// +/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait +/// fn consume_extension(&mut self, _plan: &Extension) -> Result<Box<Rel>> { +/// // implement your own serializer into Substrait +/// todo!() +/// } +/// } +/// ``` +pub trait SubstraitProducer: Send + Sync + Sized { + /// Within a Substrait plan, functions are referenced using function anchors that are stored at + /// the top level of the [Plan] within + /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) + /// messages. + /// + /// When given a function signature, this method should return the existing anchor for it if + /// there is one. Otherwise, it should generate a new anchor. + fn register_function(&mut self, signature: String) -> u32; + + /// Consume the producer to generate the [Extensions] for the Substrait plan based on the + /// functions that have been registered + fn get_extensions(self) -> Extensions; + + // Logical Plan Methods + // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn consume_plan(&mut self, plan: &LogicalPlan) -> Result<Box<Rel>> { + to_substrait_rel(self, plan) + } + + fn consume_projection(&mut self, plan: &Projection) -> Result<Box<Rel>> { + from_projection(self, plan) + } + + fn consume_filter(&mut self, plan: &Filter) -> Result<Box<Rel>> { + from_filter(self, plan) + } + + fn consume_window(&mut self, plan: &Window) -> Result<Box<Rel>> { + from_window(self, plan) + } + + fn consume_aggregate(&mut self, plan: &Aggregate) -> Result<Box<Rel>> { + from_aggregate(self, plan) + } + + fn consume_sort(&mut self, plan: &Sort) -> Result<Box<Rel>> { + from_sort(self, plan) + } + + fn consume_join(&mut self, plan: &Join) -> Result<Box<Rel>> { + from_join(self, plan) + } + + fn consume_repartition(&mut self, plan: &Repartition) -> Result<Box<Rel>> { + from_repartition(self, plan) + } + + fn consume_union(&mut self, plan: &Union) -> Result<Box<Rel>> { + from_union(self, plan) + } + + fn consume_table_scan(&mut self, plan: &TableScan) -> Result<Box<Rel>> { + from_table_scan(self, plan) + } + + fn consume_empty_relation(&mut self, plan: &EmptyRelation) -> Result<Box<Rel>> { + from_empty_relation(plan) + } + + fn consume_subquery_alias(&mut self, plan: &SubqueryAlias) -> Result<Box<Rel>> { + from_subquery_alias(self, plan) + } + + fn consume_limit(&mut self, plan: &Limit) -> Result<Box<Rel>> { + from_limit(self, plan) + } + + fn consume_values(&mut self, plan: &Values) -> Result<Box<Rel>> { + from_values(self, plan) + } + + fn consume_distinct(&mut self, plan: &Distinct) -> Result<Box<Rel>> { + from_distinct(self, plan) + } + + fn consume_extension(&mut self, _plan: &Extension) -> Result<Box<Rel>> { + substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") + } + + // Expression Methods + // There is one method per DataFusion Expr to allow for easy overriding of producer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn consume_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result<Expression> { + to_substrait_rex(self, expr, schema) + } + + fn consume_alias( + &mut self, + alias: &Alias, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_alias(self, alias, schema) + } + + fn consume_column( + &mut self, + column: &Column, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_column(column, schema) + } + + fn consume_literal(&mut self, value: &ScalarValue) -> Result<Expression> { + from_literal(self, value) + } + + fn consume_binary_expr( + &mut self, + expr: &BinaryExpr, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_binary_expr(self, expr, schema) + } + + fn consume_like(&mut self, like: &Like, schema: &DFSchemaRef) -> Result<Expression> { + from_like(self, like, schema) + } + + /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative + fn consume_unary_expr( + &mut self, + expr: &Expr, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_unary_expr(self, expr, schema) + } + + fn consume_between( + &mut self, + between: &Between, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_between(self, between, schema) + } + + fn consume_case(&mut self, case: &Case, schema: &DFSchemaRef) -> Result<Expression> { + from_case(self, case, schema) + } + + fn consume_cast(&mut self, cast: &Cast, schema: &DFSchemaRef) -> Result<Expression> { + from_cast(self, cast, schema) + } + + fn consume_try_cast( + &mut self, + cast: &TryCast, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_try_cast(self, cast, schema) + } + + fn consume_scalar_function( + &mut self, + scalar_fn: &expr::ScalarFunction, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_scalar_function(self, scalar_fn, schema) + } + + fn consume_aggregate_function( + &mut self, + agg_fn: &expr::AggregateFunction, + schema: &DFSchemaRef, + ) -> Result<Measure> { + from_aggregate_function(self, agg_fn, schema) + } + + fn consume_window_function( + &mut self, + window_fn: &WindowFunction, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_window_function(self, window_fn, schema) + } + + fn consume_in_list( + &mut self, + in_list: &InList, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_in_list(self, in_list, schema) + } + + fn consume_in_subquery( + &mut self, + in_subquery: &InSubquery, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_in_subquery(self, in_subquery, schema) + } +} + +struct DefaultSubstraitProducer<'a> { + extensions: Extensions, + state: &'a SessionState, +} + +impl<'a> DefaultSubstraitProducer<'a> { + pub fn new(state: &'a SessionState) -> Self { + DefaultSubstraitProducer { + extensions: Extensions::default(), + state, + } + } +} + +impl SubstraitProducer for DefaultSubstraitProducer<'_> { + fn register_function(&mut self, fn_name: String) -> u32 { + self.extensions.register_function(fn_name) + } + + fn get_extensions(self) -> Extensions { + self.extensions + } + + fn consume_extension(&mut self, plan: &Extension) -> Result<Box<Rel>> { + let extension_bytes = self + .state Review Comment: I think this is the only use for SessionState in the DefaultSubstraitProducer, so presumably it wouldn't need the full state to operate with... But given users have the option of making their own producer if they care, maybe that's fine and better to just have the state here for future needs? ########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -998,450 +1304,418 @@ pub fn make_binary_op_scalar_func( /// Convert DataFusion Expr to Substrait Rex /// /// # Arguments -/// -/// * `expr` - DataFusion expression to be parse into a Substrait expression -/// * `schema` - DataFusion input schema for looking up field qualifiers -/// * `col_ref_offset` - Offset for calculating Substrait field reference indices. -/// This should only be set by caller with more than one input relations i.e. Join. -/// Substrait expects one set of indices when joining two relations. -/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` -/// relation will have column indices from `0` to `n-1`, however, Substrait will expect -/// the `right` indices to be offset by the `left`. This means Substrait will expect to -/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example: -/// ```SELECT * -/// FROM t1 -/// JOIN t2 -/// ON t1.c1 = t2.c0;``` -/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1] -/// the join condition should become -/// `col_ref(1) = col_ref(3 + 0)` -/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index -/// of the join key column from `right` -/// * `extensions` - Substrait extension info. Contains registered function information -#[allow(deprecated)] +/// * `producer` - SubstraitProducer implementation which the handles the actual conversion +/// * `expr` - DataFusion expression to convert into a Substrait expression +/// * `schema` - DataFusion input schema for looking up columns pub fn to_substrait_rex( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, expr: &Expr, schema: &DFSchemaRef, - col_ref_offset: usize, - extensions: &mut Extensions, ) -> Result<Expression> { match expr { - Expr::InList(InList { - expr, - list, - negated, - }) => { - let substrait_list = list - .iter() - .map(|x| to_substrait_rex(state, x, schema, col_ref_offset, extensions)) - .collect::<Result<Vec<Expression>>>()?; - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - - let substrait_or_list = Expression { - rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { - value: Some(Box::new(substrait_expr)), - options: substrait_list, - }))), - }; - - if *negated { - let function_anchor = extensions.register_function("not".to_string()); - - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_or_list)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_or_list) - } + Expr::Alias(expr) => producer.consume_alias(expr, schema), + Expr::Column(expr) => producer.consume_column(expr, schema), + Expr::Literal(expr) => producer.consume_literal(expr), + Expr::BinaryExpr(expr) => producer.consume_binary_expr(expr, schema), + Expr::Like(expr) => producer.consume_like(expr, schema), + Expr::SimilarTo(_) => not_impl_err!("SimilarTo is not supported"), + Expr::Not(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotNull(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNull(_) => producer.consume_unary_expr(expr, schema), + Expr::IsTrue(_) => producer.consume_unary_expr(expr, schema), + Expr::IsFalse(_) => producer.consume_unary_expr(expr, schema), + Expr::IsUnknown(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotTrue(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotFalse(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotUnknown(_) => producer.consume_unary_expr(expr, schema), + Expr::Negative(_) => producer.consume_unary_expr(expr, schema), + Expr::Between(expr) => producer.consume_between(expr, schema), + Expr::Case(expr) => producer.consume_case(expr, schema), + Expr::Cast(expr) => producer.consume_cast(expr, schema), + Expr::TryCast(expr) => producer.consume_try_cast(expr, schema), + Expr::ScalarFunction(expr) => producer.consume_scalar_function(expr, schema), + Expr::AggregateFunction(_) => { + internal_err!( + "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" + ) } - Expr::ScalarFunction(fun) => { - let mut arguments: Vec<FunctionArgument> = vec![]; - for arg in &fun.args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - state, - arg, - schema, - col_ref_offset, - extensions, - )?)), - }); - } + Expr::WindowFunction(expr) => producer.consume_window_function(expr, schema), + Expr::InList(expr) => producer.consume_in_list(expr, schema), + Expr::InSubquery(expr) => producer.consume_in_subquery(expr, schema), + _ => not_impl_err!("Cannot convert {expr:?} to Substrait"), + } +} - let function_anchor = extensions.register_function(fun.name().to_string()); - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }) - } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - if *negated { - // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - let substrait_low = - to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; - let substrait_high = - to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; - - let l_expr = make_binary_op_scalar_func( - &substrait_expr, - &substrait_low, - Operator::Lt, - extensions, - ); - let r_expr = make_binary_op_scalar_func( - &substrait_high, - &substrait_expr, - Operator::Lt, - extensions, - ); +pub fn from_in_list( + producer: &mut impl SubstraitProducer, + in_list: &InList, + schema: &DFSchemaRef, +) -> Result<Expression> { + let InList { + expr, + list, + negated, + } = in_list; + let substrait_list = list + .iter() + .map(|x| producer.consume_expr(x, schema)) + .collect::<Result<Vec<Expression>>>()?; + let substrait_expr = producer.consume_expr(expr, schema)?; + + let substrait_or_list = Expression { + rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { + value: Some(Box::new(substrait_expr)), + options: substrait_list, + }))), + }; - Ok(make_binary_op_scalar_func( - &l_expr, - &r_expr, - Operator::Or, - extensions, - )) - } else { - // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - let substrait_low = - to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; - let substrait_high = - to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; - - let l_expr = make_binary_op_scalar_func( - &substrait_low, - &substrait_expr, - Operator::LtEq, - extensions, - ); - let r_expr = make_binary_op_scalar_func( - &substrait_expr, - &substrait_high, - Operator::LtEq, - extensions, - ); + if *negated { + let function_anchor = producer.register_function("not".to_string()); - Ok(make_binary_op_scalar_func( - &l_expr, - &r_expr, - Operator::And, - extensions, - )) - } - } - Expr::Column(col) => { - let index = schema.index_of_column(col)?; - substrait_field_ref(index + col_ref_offset) - } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(state, left, schema, col_ref_offset, extensions)?; - let r = to_substrait_rex(state, right, schema, col_ref_offset, extensions)?; + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_or_list)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_or_list) + } +} - Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) - } - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => { - let mut ifs: Vec<IfClause> = vec![]; - // Parse base - if let Some(e) = expr { - // Base expression exists - ifs.push(IfClause { - r#if: Some(to_substrait_rex( - state, - e, - schema, - col_ref_offset, - extensions, - )?), - then: None, - }); - } - // Parse `when`s - for (r#if, then) in when_then_expr { - ifs.push(IfClause { - r#if: Some(to_substrait_rex( - state, - r#if, - schema, - col_ref_offset, - extensions, - )?), - then: Some(to_substrait_rex( - state, - then, - schema, - col_ref_offset, - extensions, - )?), - }); - } +pub fn from_scalar_function( + producer: &mut impl SubstraitProducer, + fun: &expr::ScalarFunction, + schema: &DFSchemaRef, +) -> Result<Expression> { + let mut arguments: Vec<FunctionArgument> = vec![]; + for arg in &fun.args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex(producer, arg, schema)?)), + }); + } - // Parse outer `else` - let r#else: Option<Box<Expression>> = match else_expr { - Some(e) => Some(Box::new(to_substrait_rex( - state, - e, - schema, - col_ref_offset, - extensions, - )?)), - None => None, - }; + let function_anchor = producer.register_function(fun.name().to_string()); + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + options: vec![], + args: vec![], + })), + }) +} - Ok(Expression { - rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), - }) - } - Expr::Cast(Cast { expr, data_type }) => Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex( - state, - expr, - schema, - col_ref_offset, - extensions, - )?)), - failure_behavior: FailureBehavior::ThrowException.into(), - }, - ))), - }), - Expr::TryCast(TryCast { expr, data_type }) => Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex( - state, - expr, - schema, - col_ref_offset, - extensions, - )?)), - failure_behavior: FailureBehavior::ReturnNull.into(), - }, - ))), - }), - Expr::Literal(value) => to_substrait_literal_expr(value, extensions), - Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(state, expr, schema, col_ref_offset, extensions) - } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }) => { - // function reference - let function_anchor = extensions.register_function(fun.to_string()); - // arguments - let mut arguments: Vec<FunctionArgument> = vec![]; - for arg in args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - state, - arg, - schema, - col_ref_offset, - extensions, - )?)), - }); - } - // partition by expressions - let partition_by = partition_by - .iter() - .map(|e| to_substrait_rex(state, e, schema, col_ref_offset, extensions)) - .collect::<Result<Vec<_>>>()?; - // order by expressions - let order_by = order_by - .iter() - .map(|e| substrait_sort_field(state, e, schema, extensions)) - .collect::<Result<Vec<_>>>()?; - // window frame - let bounds = to_substrait_bounds(window_frame)?; - let bound_type = to_substrait_bound_type(window_frame)?; - Ok(make_substrait_window_function( - function_anchor, - arguments, - partition_by, - order_by, - bounds, - bound_type, - )) - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => make_substrait_like_expr( - state, - *case_insensitive, - *negated, - expr, - pattern, - *escape_char, - schema, - col_ref_offset, - extensions, - ), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - - let subquery_plan = - to_substrait_rel(subquery.subquery.as_ref(), state, extensions)?; - - let substrait_subquery = Expression { - rex_type: Some(RexType::Subquery(Box::new(Subquery { - subquery_type: Some( - substrait::proto::expression::subquery::SubqueryType::InPredicate( - Box::new(InPredicate { - needles: (vec![substrait_expr]), - haystack: Some(subquery_plan), - }), - ), +pub fn from_between( + producer: &mut impl SubstraitProducer, + between: &Between, + schema: &DFSchemaRef, +) -> Result<Expression> { + let Between { + expr, + negated, + low, + high, + } = between; + if *negated { + // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) + let substrait_expr = producer.consume_expr(expr.as_ref(), schema)?; + let substrait_low = producer.consume_expr(low.as_ref(), schema)?; + let substrait_high = producer.consume_expr(high.as_ref(), schema)?; + + let l_expr = make_binary_op_scalar_func( + producer, + &substrait_expr, + &substrait_low, + Operator::Lt, + ); + let r_expr = make_binary_op_scalar_func( + producer, + &substrait_high, + &substrait_expr, + Operator::Lt, + ); + + Ok(make_binary_op_scalar_func( + producer, + &l_expr, + &r_expr, + Operator::Or, + )) + } else { + // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) + let substrait_expr = producer.consume_expr(expr.as_ref(), schema)?; + let substrait_low = producer.consume_expr(low.as_ref(), schema)?; + let substrait_high = producer.consume_expr(high.as_ref(), schema)?; + + let l_expr = make_binary_op_scalar_func( + producer, + &substrait_low, + &substrait_expr, + Operator::LtEq, + ); + let r_expr = make_binary_op_scalar_func( + producer, + &substrait_expr, + &substrait_high, + Operator::LtEq, + ); + + Ok(make_binary_op_scalar_func( + producer, + &l_expr, + &r_expr, + Operator::And, + )) + } +} +pub fn from_column(col: &Column, schema: &DFSchemaRef) -> Result<Expression> { + let index = schema.index_of_column(col)?; + substrait_field_ref(index) +} + +pub fn from_binary_expr( + producer: &mut impl SubstraitProducer, + expr: &BinaryExpr, + schema: &DFSchemaRef, +) -> Result<Expression> { + let BinaryExpr { left, op, right } = expr; + let l = producer.consume_expr(left, schema)?; + let r = producer.consume_expr(right, schema)?; + Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) +} +pub fn from_case( + producer: &mut impl SubstraitProducer, + case: &Case, + schema: &DFSchemaRef, +) -> Result<Expression> { + let Case { + expr, + when_then_expr, + else_expr, + } = case; + let mut ifs: Vec<IfClause> = vec![]; + // Parse base + if let Some(e) = expr { + // Base expression exists + ifs.push(IfClause { + r#if: Some(producer.consume_expr(e, schema)?), + then: None, + }); + } + // Parse `when`s + for (r#if, then) in when_then_expr { + ifs.push(IfClause { + r#if: Some(producer.consume_expr(r#if, schema)?), + then: Some(producer.consume_expr(then, schema)?), + }); + } + + // Parse outer `else` + let r#else: Option<Box<Expression>> = match else_expr { + Some(e) => Some(Box::new(to_substrait_rex(producer, e, schema)?)), + None => None, + }; + + Ok(Expression { + rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), + }) +} + +pub fn from_cast( + producer: &mut impl SubstraitProducer, + cast: &Cast, + schema: &DFSchemaRef, +) -> Result<Expression> { + let Cast { expr, data_type } = cast; + Ok(Expression { + rex_type: Some(RexType::Cast(Box::new( + substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(data_type, true)?), + input: Some(Box::new(to_substrait_rex(producer, expr, schema)?)), + failure_behavior: FailureBehavior::ThrowException.into(), + }, + ))), + }) +} + +pub fn from_try_cast( + producer: &mut impl SubstraitProducer, + cast: &TryCast, + schema: &DFSchemaRef, +) -> Result<Expression> { + let TryCast { expr, data_type } = cast; + Ok(Expression { + rex_type: Some(RexType::Cast(Box::new( + substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(data_type, true)?), + input: Some(Box::new(to_substrait_rex(producer, expr, schema)?)), + failure_behavior: FailureBehavior::ReturnNull.into(), + }, + ))), + }) +} + +pub fn from_literal( + producer: &mut impl SubstraitProducer, + value: &ScalarValue, +) -> Result<Expression> { + to_substrait_literal_expr(producer, value) +} + +pub fn from_alias( + producer: &mut impl SubstraitProducer, + alias: &Alias, + schema: &DFSchemaRef, +) -> Result<Expression> { + producer.consume_expr(alias.expr.as_ref(), schema) +} + +pub fn from_window_function( + producer: &mut impl SubstraitProducer, + window_fn: &WindowFunction, + schema: &DFSchemaRef, +) -> Result<Expression> { + let WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment: _, + } = window_fn; + // function reference + let function_anchor = producer.register_function(fun.to_string()); + // arguments + let mut arguments: Vec<FunctionArgument> = vec![]; + for arg in args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex(producer, arg, schema)?)), Review Comment: should this be produced.consume_expr(..) like on lines 1607/1612 below? there's couple other places as well where in the same function we call both. Is there a reason to do that, or should everything go through the `producer.consume_` calls? ########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -998,450 +1304,418 @@ pub fn make_binary_op_scalar_func( /// Convert DataFusion Expr to Substrait Rex /// /// # Arguments -/// -/// * `expr` - DataFusion expression to be parse into a Substrait expression -/// * `schema` - DataFusion input schema for looking up field qualifiers -/// * `col_ref_offset` - Offset for calculating Substrait field reference indices. -/// This should only be set by caller with more than one input relations i.e. Join. -/// Substrait expects one set of indices when joining two relations. -/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` -/// relation will have column indices from `0` to `n-1`, however, Substrait will expect -/// the `right` indices to be offset by the `left`. This means Substrait will expect to -/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example: -/// ```SELECT * -/// FROM t1 -/// JOIN t2 -/// ON t1.c1 = t2.c0;``` -/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1] -/// the join condition should become -/// `col_ref(1) = col_ref(3 + 0)` -/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index -/// of the join key column from `right` -/// * `extensions` - Substrait extension info. Contains registered function information -#[allow(deprecated)] +/// * `producer` - SubstraitProducer implementation which the handles the actual conversion +/// * `expr` - DataFusion expression to convert into a Substrait expression +/// * `schema` - DataFusion input schema for looking up columns pub fn to_substrait_rex( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, expr: &Expr, schema: &DFSchemaRef, - col_ref_offset: usize, - extensions: &mut Extensions, ) -> Result<Expression> { match expr { - Expr::InList(InList { - expr, - list, - negated, - }) => { - let substrait_list = list - .iter() - .map(|x| to_substrait_rex(state, x, schema, col_ref_offset, extensions)) - .collect::<Result<Vec<Expression>>>()?; - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - - let substrait_or_list = Expression { - rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { - value: Some(Box::new(substrait_expr)), - options: substrait_list, - }))), - }; - - if *negated { - let function_anchor = extensions.register_function("not".to_string()); - - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_or_list)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_or_list) - } + Expr::Alias(expr) => producer.consume_alias(expr, schema), + Expr::Column(expr) => producer.consume_column(expr, schema), + Expr::Literal(expr) => producer.consume_literal(expr), + Expr::BinaryExpr(expr) => producer.consume_binary_expr(expr, schema), + Expr::Like(expr) => producer.consume_like(expr, schema), + Expr::SimilarTo(_) => not_impl_err!("SimilarTo is not supported"), + Expr::Not(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotNull(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNull(_) => producer.consume_unary_expr(expr, schema), + Expr::IsTrue(_) => producer.consume_unary_expr(expr, schema), + Expr::IsFalse(_) => producer.consume_unary_expr(expr, schema), + Expr::IsUnknown(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotTrue(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotFalse(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotUnknown(_) => producer.consume_unary_expr(expr, schema), + Expr::Negative(_) => producer.consume_unary_expr(expr, schema), + Expr::Between(expr) => producer.consume_between(expr, schema), + Expr::Case(expr) => producer.consume_case(expr, schema), + Expr::Cast(expr) => producer.consume_cast(expr, schema), + Expr::TryCast(expr) => producer.consume_try_cast(expr, schema), + Expr::ScalarFunction(expr) => producer.consume_scalar_function(expr, schema), + Expr::AggregateFunction(_) => { + internal_err!( + "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" + ) } - Expr::ScalarFunction(fun) => { - let mut arguments: Vec<FunctionArgument> = vec![]; - for arg in &fun.args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - state, - arg, - schema, - col_ref_offset, - extensions, - )?)), - }); - } + Expr::WindowFunction(expr) => producer.consume_window_function(expr, schema), + Expr::InList(expr) => producer.consume_in_list(expr, schema), + Expr::InSubquery(expr) => producer.consume_in_subquery(expr, schema), + _ => not_impl_err!("Cannot convert {expr:?} to Substrait"), + } +} - let function_anchor = extensions.register_function(fun.name().to_string()); - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }) - } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - if *negated { - // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - let substrait_low = - to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; - let substrait_high = - to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; - - let l_expr = make_binary_op_scalar_func( - &substrait_expr, - &substrait_low, - Operator::Lt, - extensions, - ); - let r_expr = make_binary_op_scalar_func( - &substrait_high, - &substrait_expr, - Operator::Lt, - extensions, - ); +pub fn from_in_list( + producer: &mut impl SubstraitProducer, + in_list: &InList, + schema: &DFSchemaRef, +) -> Result<Expression> { + let InList { + expr, + list, + negated, + } = in_list; + let substrait_list = list + .iter() + .map(|x| producer.consume_expr(x, schema)) + .collect::<Result<Vec<Expression>>>()?; + let substrait_expr = producer.consume_expr(expr, schema)?; + + let substrait_or_list = Expression { + rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { + value: Some(Box::new(substrait_expr)), + options: substrait_list, + }))), + }; - Ok(make_binary_op_scalar_func( - &l_expr, - &r_expr, - Operator::Or, - extensions, - )) - } else { - // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - let substrait_low = - to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; - let substrait_high = - to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; - - let l_expr = make_binary_op_scalar_func( - &substrait_low, - &substrait_expr, - Operator::LtEq, - extensions, - ); - let r_expr = make_binary_op_scalar_func( - &substrait_expr, - &substrait_high, - Operator::LtEq, - extensions, - ); + if *negated { + let function_anchor = producer.register_function("not".to_string()); - Ok(make_binary_op_scalar_func( - &l_expr, - &r_expr, - Operator::And, - extensions, - )) - } - } - Expr::Column(col) => { - let index = schema.index_of_column(col)?; - substrait_field_ref(index + col_ref_offset) - } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(state, left, schema, col_ref_offset, extensions)?; - let r = to_substrait_rex(state, right, schema, col_ref_offset, extensions)?; + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_or_list)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_or_list) + } +} - Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) - } - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => { - let mut ifs: Vec<IfClause> = vec![]; - // Parse base - if let Some(e) = expr { - // Base expression exists - ifs.push(IfClause { - r#if: Some(to_substrait_rex( - state, - e, - schema, - col_ref_offset, - extensions, - )?), - then: None, - }); - } - // Parse `when`s - for (r#if, then) in when_then_expr { - ifs.push(IfClause { - r#if: Some(to_substrait_rex( - state, - r#if, - schema, - col_ref_offset, - extensions, - )?), - then: Some(to_substrait_rex( - state, - then, - schema, - col_ref_offset, - extensions, - )?), - }); - } +pub fn from_scalar_function( + producer: &mut impl SubstraitProducer, + fun: &expr::ScalarFunction, + schema: &DFSchemaRef, +) -> Result<Expression> { + let mut arguments: Vec<FunctionArgument> = vec![]; + for arg in &fun.args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex(producer, arg, schema)?)), + }); + } - // Parse outer `else` - let r#else: Option<Box<Expression>> = match else_expr { - Some(e) => Some(Box::new(to_substrait_rex( - state, - e, - schema, - col_ref_offset, - extensions, - )?)), - None => None, - }; + let function_anchor = producer.register_function(fun.name().to_string()); + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + options: vec![], + args: vec![], + })), + }) +} - Ok(Expression { - rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), - }) - } - Expr::Cast(Cast { expr, data_type }) => Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex( - state, - expr, - schema, - col_ref_offset, - extensions, - )?)), - failure_behavior: FailureBehavior::ThrowException.into(), - }, - ))), - }), - Expr::TryCast(TryCast { expr, data_type }) => Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex( - state, - expr, - schema, - col_ref_offset, - extensions, - )?)), - failure_behavior: FailureBehavior::ReturnNull.into(), - }, - ))), - }), - Expr::Literal(value) => to_substrait_literal_expr(value, extensions), - Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(state, expr, schema, col_ref_offset, extensions) - } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }) => { - // function reference - let function_anchor = extensions.register_function(fun.to_string()); - // arguments - let mut arguments: Vec<FunctionArgument> = vec![]; - for arg in args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - state, - arg, - schema, - col_ref_offset, - extensions, - )?)), - }); - } - // partition by expressions - let partition_by = partition_by - .iter() - .map(|e| to_substrait_rex(state, e, schema, col_ref_offset, extensions)) - .collect::<Result<Vec<_>>>()?; - // order by expressions - let order_by = order_by - .iter() - .map(|e| substrait_sort_field(state, e, schema, extensions)) - .collect::<Result<Vec<_>>>()?; - // window frame - let bounds = to_substrait_bounds(window_frame)?; - let bound_type = to_substrait_bound_type(window_frame)?; - Ok(make_substrait_window_function( - function_anchor, - arguments, - partition_by, - order_by, - bounds, - bound_type, - )) - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => make_substrait_like_expr( - state, - *case_insensitive, - *negated, - expr, - pattern, - *escape_char, - schema, - col_ref_offset, - extensions, - ), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - - let subquery_plan = - to_substrait_rel(subquery.subquery.as_ref(), state, extensions)?; - - let substrait_subquery = Expression { - rex_type: Some(RexType::Subquery(Box::new(Subquery { - subquery_type: Some( - substrait::proto::expression::subquery::SubqueryType::InPredicate( - Box::new(InPredicate { - needles: (vec![substrait_expr]), - haystack: Some(subquery_plan), - }), - ), +pub fn from_between( + producer: &mut impl SubstraitProducer, + between: &Between, + schema: &DFSchemaRef, +) -> Result<Expression> { + let Between { + expr, + negated, + low, + high, + } = between; + if *negated { + // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) + let substrait_expr = producer.consume_expr(expr.as_ref(), schema)?; + let substrait_low = producer.consume_expr(low.as_ref(), schema)?; + let substrait_high = producer.consume_expr(high.as_ref(), schema)?; Review Comment: unrelated to this PR and probs better not to change now to keep diff small(er), but I think there's no reason to duplicate these below, they could just happen above the `if` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org