This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ec10c0420a feat(substrait): set ProjectRel output_mapping in producer
(#12495)
ec10c0420a is described below
commit ec10c0420aa3adb43c1f8793d66438946ae5b49d
Author: Victor Barua <[email protected]>
AuthorDate: Wed Sep 18 04:23:06 2024 -0700
feat(substrait): set ProjectRel output_mapping in producer (#12495)
---
datafusion/substrait/src/logical_plan/producer.rs | 79 +++++++++++------
datafusion/substrait/tests/cases/serialize.rs | 101 ++++++++++++++++++++++
2 files changed, 154 insertions(+), 26 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index a923aaf31a..fada827875 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -61,7 +61,9 @@ use substrait::proto::expression::literal::{
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::read_rel::VirtualTable;
-use substrait::proto::{CrossRel, ExchangeRel};
+use substrait::proto::rel_common::EmitKind;
+use substrait::proto::rel_common::EmitKind::Emit;
+use substrait::proto::{rel_common, CrossRel, ExchangeRel, RelCommon};
use substrait::{
proto::{
aggregate_function::AggregationInvocation,
@@ -219,9 +221,20 @@ pub fn to_substrait_rel(
.iter()
.map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0,
extensions))
.collect::<Result<Vec<_>>>()?;
+
+ let emit_kind = create_project_remapping(
+ expressions.len(),
+ p.input.as_ref().schema().fields().len(),
+ );
+ let common = RelCommon {
+ emit_kind: Some(emit_kind),
+ hint: None,
+ advanced_extension: None,
+ };
+
Ok(Box::new(Rel {
rel_type: Some(RelType::Project(Box::new(ProjectRel {
- common: None,
+ common: Some(common),
input: Some(to_substrait_rel(p.input.as_ref(), ctx,
extensions)?),
expressions,
advanced_extension: None,
@@ -432,29 +445,15 @@ pub fn to_substrait_rel(
}
LogicalPlan::Window(window) => {
let input = to_substrait_rel(window.input.as_ref(), ctx,
extensions)?;
- // If the input is a Project relation, we can just append the
WindowFunction expressions
- // before returning
- // Otherwise, wrap the input in a Project relation before
appending the WindowFunction
- // expressions
- let mut project_rel: Box<ProjectRel> = match
&input.as_ref().rel_type {
- Some(RelType::Project(p)) => Box::new(*p.clone()),
- _ => {
- // Create Projection with field referencing all output
fields in the input relation
- let expressions = (0..window.input.schema().fields().len())
- .map(substrait_field_ref)
- .collect::<Result<Vec<_>>>()?;
- Box::new(ProjectRel {
- common: None,
- input: Some(input),
- expressions,
- advanced_extension: None,
- })
- }
- };
- // Parse WindowFunction expression
- let mut window_exprs = vec![];
+
+ // create a field reference for each input field
+ let mut expressions = (0..window.input.schema().fields().len())
+ .map(substrait_field_ref)
+ .collect::<Result<Vec<_>>>()?;
+
+ // process and add each window function expression
for expr in &window.window_expr {
- window_exprs.push(to_substrait_rex(
+ expressions.push(to_substrait_rex(
ctx,
expr,
window.input.schema(),
@@ -462,8 +461,23 @@ pub fn to_substrait_rel(
extensions,
)?);
}
- // Append parsed WindowFunction expressions
- project_rel.expressions.extend(window_exprs);
+
+ let emit_kind = create_project_remapping(
+ expressions.len(),
+ window.input.schema().fields().len(),
+ );
+ let common = RelCommon {
+ emit_kind: Some(emit_kind),
+ hint: None,
+ advanced_extension: None,
+ };
+ let project_rel = Box::new(ProjectRel {
+ common: Some(common),
+ input: Some(input),
+ expressions,
+ advanced_extension: None,
+ });
+
Ok(Box::new(Rel {
rel_type: Some(RelType::Project(project_rel)),
}))
@@ -553,6 +567,19 @@ pub fn to_substrait_rel(
}
}
+/// By default, a Substrait Project outputs all input fields followed by all
expressions.
+/// A DataFusion Projection only outputs expressions. In order to keep the
Substrait
+/// plan consistent with DataFusion, we must apply an output mapping that
skips the input
+/// fields so that the Substrait Project will only output the expression
fields.
+fn create_project_remapping(expr_count: usize, input_field_count: usize) ->
EmitKind {
+ let expression_field_start = input_field_count;
+ let expression_field_end = expression_field_start + expr_count;
+ let output_mapping = (expression_field_start..expression_field_end)
+ .map(|i| i as i32)
+ .collect();
+ Emit(rel_common::Emit { output_mapping })
+}
+
fn to_substrait_named_struct(
schema: &DFSchemaRef,
extensions: &mut Extensions,
diff --git a/datafusion/substrait/tests/cases/serialize.rs
b/datafusion/substrait/tests/cases/serialize.rs
index d792ac33c3..da0898d222 100644
--- a/datafusion/substrait/tests/cases/serialize.rs
+++ b/datafusion/substrait/tests/cases/serialize.rs
@@ -26,7 +26,11 @@ mod tests {
use datafusion::error::Result;
use datafusion::prelude::*;
+ use datafusion_substrait::logical_plan::producer::to_substrait_plan;
use std::fs;
+ use substrait::proto::plan_rel::RelType;
+ use substrait::proto::rel_common::{Emit, EmitKind};
+ use substrait::proto::{rel, RelCommon};
#[tokio::test]
async fn serialize_simple_select() -> Result<()> {
@@ -63,6 +67,103 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn include_remaps_for_projects() -> Result<()> {
+ let ctx = create_context().await?;
+ let df = ctx.sql("SELECT b, a + a, a FROM data").await?;
+ let datafusion_plan = df.into_optimized_plan()?;
+
+ assert_eq!(
+ format!("{}", datafusion_plan),
+ "Projection: data.b, data.a + data.a, data.a\
+ \n TableScan: data projection=[a, b]",
+ );
+
+ let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone();
+
+ let relation = plan.relations.first().unwrap().rel_type.as_ref();
+ let root_rel = match relation {
+ Some(RelType::Root(root)) => root.input.as_ref().unwrap(),
+ _ => panic!("expected Root"),
+ };
+ if let Some(rel::RelType::Project(p)) = root_rel.rel_type.as_ref() {
+ // The input has 2 columns [a, b], the Projection has 3
expressions [b, a + a, a]
+ // The required output mapping is [2,3,4], which skips the 2 input
columns.
+ assert_emit(p.common.as_ref(), vec![2, 3, 4]);
+
+ if let Some(rel::RelType::Read(r)) =
+ p.input.as_ref().unwrap().rel_type.as_ref()
+ {
+ let mask_expression = r.projection.as_ref().unwrap();
+ let select = mask_expression.select.as_ref().unwrap();
+ assert_eq!(
+ 2,
+ select.struct_items.len(),
+ "Read outputs two columns: a, b"
+ );
+ return Ok(());
+ }
+ }
+ panic!("plan did not match expected structure")
+ }
+
+ #[tokio::test]
+ async fn include_remaps_for_windows() -> Result<()> {
+ let ctx = create_context().await?;
+ // let df = ctx.sql("SELECT a, b, lead(b) OVER (PARTITION BY a) FROM
data").await?;
+ let df = ctx
+ .sql("SELECT b, RANK() OVER (PARTITION BY a), c FROM data;")
+ .await?;
+ let datafusion_plan = df.into_optimized_plan()?;
+ assert_eq!(
+ format!("{}", datafusion_plan),
+ "Projection: data.b, RANK() PARTITION BY [data.a] ROWS BETWEEN
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, data.c\
+ \n WindowAggr: windowExpr=[[RANK() PARTITION BY [data.a] ROWS
BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
+ \n TableScan: data projection=[a, b, c]",
+ );
+
+ let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone();
+
+ let relation = plan.relations.first().unwrap().rel_type.as_ref();
+ let root_rel = match relation {
+ Some(RelType::Root(root)) => root.input.as_ref().unwrap(),
+ _ => panic!("expected Root"),
+ };
+
+ if let Some(rel::RelType::Project(p1)) = root_rel.rel_type.as_ref() {
+ // The WindowAggr outputs 4 columns, the Projection has 4 columns
+ assert_emit(p1.common.as_ref(), vec![4, 5, 6]);
+
+ if let Some(rel::RelType::Project(p2)) =
+ p1.input.as_ref().unwrap().rel_type.as_ref()
+ {
+ // The input has 3 columns, the WindowAggr has 4 expression
+ assert_emit(p2.common.as_ref(), vec![3, 4, 5, 6]);
+
+ if let Some(rel::RelType::Read(r)) =
+ p2.input.as_ref().unwrap().rel_type.as_ref()
+ {
+ let mask_expression = r.projection.as_ref().unwrap();
+ let select = mask_expression.select.as_ref().unwrap();
+ assert_eq!(
+ 3,
+ select.struct_items.len(),
+ "Read outputs three columns: a, b, c"
+ );
+ return Ok(());
+ }
+ }
+ }
+ panic!("plan did not match expected structure")
+ }
+
+ fn assert_emit(rel_common: Option<&RelCommon>, output_mapping: Vec<i32>) {
+ assert_eq!(
+ rel_common.unwrap().emit_kind.clone(),
+ Some(EmitKind::Emit(Emit { output_mapping }))
+ );
+ }
+
async fn create_context() -> Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("data", "tests/testdata/data.csv",
CsvReadOptions::new())
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]