This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ec10c0420a feat(substrait): set ProjectRel output_mapping in producer 
(#12495)
ec10c0420a is described below

commit ec10c0420aa3adb43c1f8793d66438946ae5b49d
Author: Victor Barua <[email protected]>
AuthorDate: Wed Sep 18 04:23:06 2024 -0700

    feat(substrait): set ProjectRel output_mapping in producer (#12495)
---
 datafusion/substrait/src/logical_plan/producer.rs |  79 +++++++++++------
 datafusion/substrait/tests/cases/serialize.rs     | 101 ++++++++++++++++++++++
 2 files changed, 154 insertions(+), 26 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index a923aaf31a..fada827875 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -61,7 +61,9 @@ use substrait::proto::expression::literal::{
 use substrait::proto::expression::subquery::InPredicate;
 use substrait::proto::expression::window_function::BoundsType;
 use substrait::proto::read_rel::VirtualTable;
-use substrait::proto::{CrossRel, ExchangeRel};
+use substrait::proto::rel_common::EmitKind;
+use substrait::proto::rel_common::EmitKind::Emit;
+use substrait::proto::{rel_common, CrossRel, ExchangeRel, RelCommon};
 use substrait::{
     proto::{
         aggregate_function::AggregationInvocation,
@@ -219,9 +221,20 @@ pub fn to_substrait_rel(
                 .iter()
                 .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, 
extensions))
                 .collect::<Result<Vec<_>>>()?;
+
+            let emit_kind = create_project_remapping(
+                expressions.len(),
+                p.input.as_ref().schema().fields().len(),
+            );
+            let common = RelCommon {
+                emit_kind: Some(emit_kind),
+                hint: None,
+                advanced_extension: None,
+            };
+
             Ok(Box::new(Rel {
                 rel_type: Some(RelType::Project(Box::new(ProjectRel {
-                    common: None,
+                    common: Some(common),
                     input: Some(to_substrait_rel(p.input.as_ref(), ctx, 
extensions)?),
                     expressions,
                     advanced_extension: None,
@@ -432,29 +445,15 @@ pub fn to_substrait_rel(
         }
         LogicalPlan::Window(window) => {
             let input = to_substrait_rel(window.input.as_ref(), ctx, 
extensions)?;
-            // If the input is a Project relation, we can just append the 
WindowFunction expressions
-            // before returning
-            // Otherwise, wrap the input in a Project relation before 
appending the WindowFunction
-            // expressions
-            let mut project_rel: Box<ProjectRel> = match 
&input.as_ref().rel_type {
-                Some(RelType::Project(p)) => Box::new(*p.clone()),
-                _ => {
-                    // Create Projection with field referencing all output 
fields in the input relation
-                    let expressions = (0..window.input.schema().fields().len())
-                        .map(substrait_field_ref)
-                        .collect::<Result<Vec<_>>>()?;
-                    Box::new(ProjectRel {
-                        common: None,
-                        input: Some(input),
-                        expressions,
-                        advanced_extension: None,
-                    })
-                }
-            };
-            // Parse WindowFunction expression
-            let mut window_exprs = vec![];
+
+            // create a field reference for each input field
+            let mut expressions = (0..window.input.schema().fields().len())
+                .map(substrait_field_ref)
+                .collect::<Result<Vec<_>>>()?;
+
+            // process and add each window function expression
             for expr in &window.window_expr {
-                window_exprs.push(to_substrait_rex(
+                expressions.push(to_substrait_rex(
                     ctx,
                     expr,
                     window.input.schema(),
@@ -462,8 +461,23 @@ pub fn to_substrait_rel(
                     extensions,
                 )?);
             }
-            // Append parsed WindowFunction expressions
-            project_rel.expressions.extend(window_exprs);
+
+            let emit_kind = create_project_remapping(
+                expressions.len(),
+                window.input.schema().fields().len(),
+            );
+            let common = RelCommon {
+                emit_kind: Some(emit_kind),
+                hint: None,
+                advanced_extension: None,
+            };
+            let project_rel = Box::new(ProjectRel {
+                common: Some(common),
+                input: Some(input),
+                expressions,
+                advanced_extension: None,
+            });
+
             Ok(Box::new(Rel {
                 rel_type: Some(RelType::Project(project_rel)),
             }))
@@ -553,6 +567,19 @@ pub fn to_substrait_rel(
     }
 }
 
+/// By default, a Substrait Project outputs all input fields followed by all 
expressions.
+/// A DataFusion Projection only outputs expressions. In order to keep the 
Substrait
+/// plan consistent with DataFusion, we must apply an output mapping that 
skips the input
+/// fields so that the Substrait Project will only output the expression 
fields.
+fn create_project_remapping(expr_count: usize, input_field_count: usize) -> 
EmitKind {
+    let expression_field_start = input_field_count;
+    let expression_field_end = expression_field_start + expr_count;
+    let output_mapping = (expression_field_start..expression_field_end)
+        .map(|i| i as i32)
+        .collect();
+    Emit(rel_common::Emit { output_mapping })
+}
+
 fn to_substrait_named_struct(
     schema: &DFSchemaRef,
     extensions: &mut Extensions,
diff --git a/datafusion/substrait/tests/cases/serialize.rs 
b/datafusion/substrait/tests/cases/serialize.rs
index d792ac33c3..da0898d222 100644
--- a/datafusion/substrait/tests/cases/serialize.rs
+++ b/datafusion/substrait/tests/cases/serialize.rs
@@ -26,7 +26,11 @@ mod tests {
     use datafusion::error::Result;
     use datafusion::prelude::*;
 
+    use datafusion_substrait::logical_plan::producer::to_substrait_plan;
     use std::fs;
+    use substrait::proto::plan_rel::RelType;
+    use substrait::proto::rel_common::{Emit, EmitKind};
+    use substrait::proto::{rel, RelCommon};
 
     #[tokio::test]
     async fn serialize_simple_select() -> Result<()> {
@@ -63,6 +67,103 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn include_remaps_for_projects() -> Result<()> {
+        let ctx = create_context().await?;
+        let df = ctx.sql("SELECT b, a + a, a FROM data").await?;
+        let datafusion_plan = df.into_optimized_plan()?;
+
+        assert_eq!(
+            format!("{}", datafusion_plan),
+            "Projection: data.b, data.a + data.a, data.a\
+            \n  TableScan: data projection=[a, b]",
+        );
+
+        let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone();
+
+        let relation = plan.relations.first().unwrap().rel_type.as_ref();
+        let root_rel = match relation {
+            Some(RelType::Root(root)) => root.input.as_ref().unwrap(),
+            _ => panic!("expected Root"),
+        };
+        if let Some(rel::RelType::Project(p)) = root_rel.rel_type.as_ref() {
+            // The input has 2 columns [a, b], the Projection has 3 
expressions [b, a + a, a]
+            // The required output mapping is [2,3,4], which skips the 2 input 
columns.
+            assert_emit(p.common.as_ref(), vec![2, 3, 4]);
+
+            if let Some(rel::RelType::Read(r)) =
+                p.input.as_ref().unwrap().rel_type.as_ref()
+            {
+                let mask_expression = r.projection.as_ref().unwrap();
+                let select = mask_expression.select.as_ref().unwrap();
+                assert_eq!(
+                    2,
+                    select.struct_items.len(),
+                    "Read outputs two columns: a, b"
+                );
+                return Ok(());
+            }
+        }
+        panic!("plan did not match expected structure")
+    }
+
+    #[tokio::test]
+    async fn include_remaps_for_windows() -> Result<()> {
+        let ctx = create_context().await?;
+        // let df = ctx.sql("SELECT a, b, lead(b) OVER (PARTITION BY a) FROM 
data").await?;
+        let df = ctx
+            .sql("SELECT b, RANK() OVER (PARTITION BY a), c FROM data;")
+            .await?;
+        let datafusion_plan = df.into_optimized_plan()?;
+        assert_eq!(
+            format!("{}", datafusion_plan),
+            "Projection: data.b, RANK() PARTITION BY [data.a] ROWS BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, data.c\
+            \n  WindowAggr: windowExpr=[[RANK() PARTITION BY [data.a] ROWS 
BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
+            \n    TableScan: data projection=[a, b, c]",
+        );
+
+        let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone();
+
+        let relation = plan.relations.first().unwrap().rel_type.as_ref();
+        let root_rel = match relation {
+            Some(RelType::Root(root)) => root.input.as_ref().unwrap(),
+            _ => panic!("expected Root"),
+        };
+
+        if let Some(rel::RelType::Project(p1)) = root_rel.rel_type.as_ref() {
+            // The WindowAggr outputs 4 columns, the Projection has 4 columns
+            assert_emit(p1.common.as_ref(), vec![4, 5, 6]);
+
+            if let Some(rel::RelType::Project(p2)) =
+                p1.input.as_ref().unwrap().rel_type.as_ref()
+            {
+                // The input has 3 columns, the WindowAggr has 4 expression
+                assert_emit(p2.common.as_ref(), vec![3, 4, 5, 6]);
+
+                if let Some(rel::RelType::Read(r)) =
+                    p2.input.as_ref().unwrap().rel_type.as_ref()
+                {
+                    let mask_expression = r.projection.as_ref().unwrap();
+                    let select = mask_expression.select.as_ref().unwrap();
+                    assert_eq!(
+                        3,
+                        select.struct_items.len(),
+                        "Read outputs three columns: a, b, c"
+                    );
+                    return Ok(());
+                }
+            }
+        }
+        panic!("plan did not match expected structure")
+    }
+
+    fn assert_emit(rel_common: Option<&RelCommon>, output_mapping: Vec<i32>) {
+        assert_eq!(
+            rel_common.unwrap().emit_kind.clone(),
+            Some(EmitKind::Emit(Emit { output_mapping }))
+        );
+    }
+
     async fn create_context() -> Result<SessionContext> {
         let ctx = SessionContext::new();
         ctx.register_csv("data", "tests/testdata/data.csv", 
CsvReadOptions::new())


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to