This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 1b89da4455 Support correct output column names and struct field names
when consuming/producing Substrait (#10829)
1b89da4455 is described below
commit 1b89da4455f68c3199c56a7c4a4298ce3120a714
Author: Arttu <[email protected]>
AuthorDate: Tue Jun 11 20:35:34 2024 +0200
Support correct output column names and struct field names when
consuming/producing Substrait (#10829)
* produce flattened list of names including inner struct fields
* add a (failing) test
* rename output columns (incl. inner struct fields) according to the given
list of names
* fix a test
* add column names project to the new TPC-H test and fix case
(assert_eq gives nicer error messages than assert)
---
datafusion/substrait/src/logical_plan/consumer.rs | 133 ++++++++++++++++++++-
datafusion/substrait/src/logical_plan/producer.rs | 2 +-
.../substrait/tests/cases/consumer_integration.rs | 17 +--
datafusion/substrait/tests/cases/logical_plans.rs | 2 +-
.../tests/cases/roundtrip_logical_plan.rs | 43 +++----
5 files changed, 156 insertions(+), 41 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 8a483db8c4..648a281832 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -17,7 +17,7 @@
use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{
- DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
+ DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema,
DFSchemaRef,
@@ -29,12 +29,13 @@ use url::Url;
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
- aggregate_function, expr::find_df_window_func, BinaryExpr, Case,
EmptyRelation, Expr,
- LogicalPlan, Operator, ScalarUDF, Values,
+ aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case,
+ EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection,
ScalarUDF,
+ Values,
};
use datafusion::logical_expr::{
- expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
+ col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder,
Partitioning,
Repartition, Subquery, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
use datafusion::prelude::JoinType;
@@ -225,6 +226,7 @@ pub async fn from_substrait_plan(
None => not_impl_err!("Cannot parse empty extension"),
})
.collect::<Result<HashMap<_, _>>>()?;
+
// Parse relations
match plan.relations.len() {
1 => {
@@ -234,7 +236,29 @@ pub async fn from_substrait_plan(
Ok(from_substrait_rel(ctx, rel,
&function_extension).await?)
},
plan_rel::RelType::Root(root) => {
- Ok(from_substrait_rel(ctx,
root.input.as_ref().unwrap(), &function_extension).await?)
+ let plan = from_substrait_rel(ctx,
root.input.as_ref().unwrap(), &function_extension).await?;
+ if root.names.is_empty() {
+ // Backwards compatibility for plans missing names
+ return Ok(plan);
+ }
+ let renamed_schema =
make_renamed_schema(plan.schema(), &root.names)?;
+ if
renamed_schema.equivalent_names_and_types(plan.schema()) {
+ // Nothing to do if the schema is already
equivalent
+ return Ok(plan);
+ }
+
+ match plan {
+ // If the last node of the plan produces
expressions, bake the renames into those expressions.
+ // This isn't necessary for correctness, but helps
with roundtrip tests.
+ LogicalPlan::Projection(p) =>
Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr,
p.input.schema(), renamed_schema)?, p.input)?)),
+ LogicalPlan::Aggregate(a) => {
+ let new_aggr_exprs =
rename_expressions(a.aggr_expr, a.input.schema(), renamed_schema)?;
+
Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr,
new_aggr_exprs)?))
+ },
+ // There are probably more plans where we could
bake things in, can add them later as needed.
+ // Otherwise, add a new Project to handle the
renaming.
+ _ =>
Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c|
col(c.to_owned())), plan.schema(), renamed_schema)?, Arc::new(plan))?))
+ }
}
},
None => plan_err!("Cannot parse plan relation: None")
@@ -284,6 +308,105 @@ pub fn extract_projection(
}
}
+fn rename_expressions(
+ exprs: impl IntoIterator<Item = Expr>,
+ input_schema: &DFSchema,
+ new_schema: DFSchemaRef,
+) -> Result<Vec<Expr>> {
+ exprs
+ .into_iter()
+ .zip(new_schema.fields())
+ .map(|(old_expr, new_field)| {
+ if &old_expr.get_type(input_schema)? == new_field.data_type() {
+ // Alias column if needed
+ old_expr.alias_if_changed(new_field.name().into())
+ } else {
+ // Use Cast to rename inner struct fields + alias column if
needed
+ Expr::Cast(Cast::new(
+ Box::new(old_expr),
+ new_field.data_type().to_owned(),
+ ))
+ .alias_if_changed(new_field.name().into())
+ }
+ })
+ .collect()
+}
+
+fn make_renamed_schema(
+ schema: &DFSchemaRef,
+ dfs_names: &Vec<String>,
+) -> Result<DFSchemaRef> {
+ fn rename_inner_fields(
+ dtype: &DataType,
+ dfs_names: &Vec<String>,
+ name_idx: &mut usize,
+ ) -> Result<DataType> {
+ match dtype {
+ DataType::Struct(fields) => {
+ let fields = fields
+ .iter()
+ .map(|f| {
+ let name = next_struct_field_name(0, dfs_names,
name_idx)?;
+ Ok((**f).to_owned().with_name(name).with_data_type(
+ rename_inner_fields(f.data_type(), dfs_names,
name_idx)?,
+ ))
+ })
+ .collect::<Result<_>>()?;
+ Ok(DataType::Struct(fields))
+ }
+ DataType::List(inner) => Ok(DataType::List(FieldRef::new(
+ (**inner).to_owned().with_data_type(rename_inner_fields(
+ inner.data_type(),
+ dfs_names,
+ name_idx,
+ )?),
+ ))),
+ DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new(
+ (**inner).to_owned().with_data_type(rename_inner_fields(
+ inner.data_type(),
+ dfs_names,
+ name_idx,
+ )?),
+ ))),
+ _ => Ok(dtype.to_owned()),
+ }
+ }
+
+ let mut name_idx = 0;
+
+ let (qualifiers, fields): (_, Vec<Field>) = schema
+ .iter()
+ .map(|(q, f)| {
+ let name = next_struct_field_name(0, dfs_names, &mut name_idx)?;
+ Ok((
+ q.cloned(),
+ (**f)
+ .to_owned()
+ .with_name(name)
+ .with_data_type(rename_inner_fields(
+ f.data_type(),
+ dfs_names,
+ &mut name_idx,
+ )?),
+ ))
+ })
+ .collect::<Result<Vec<_>>>()?
+ .into_iter()
+ .unzip();
+
+ if name_idx != dfs_names.len() {
+ return substrait_err!(
+ "Names list must match exactly to nested schema, but found {} uses
for {} names",
+ name_idx,
+ dfs_names.len());
+ }
+
+ Ok(Arc::new(DFSchema::from_field_specific_qualified_schema(
+ qualifiers,
+ &Arc::new(Schema::new(fields)),
+ )?))
+}
+
/// Convert Substrait Rel to DataFusion DataFrame
#[async_recursion]
pub async fn from_substrait_rel(
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 6c8be4aa9b..88dc894ecc 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -115,7 +115,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx:
&SessionContext) -> Result<Box
let plan_rels = vec![PlanRel {
rel_type: Some(plan_rel::RelType::Root(RelRoot {
input: Some(*to_substrait_rel(plan, ctx, &mut extension_info)?),
- names: plan.schema().field_names(),
+ names: to_substrait_named_struct(plan.schema())?.names,
})),
}];
diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs
b/datafusion/substrait/tests/cases/consumer_integration.rs
index c2ae569113..e0151ecc3a 100644
--- a/datafusion/substrait/tests/cases/consumer_integration.rs
+++ b/datafusion/substrait/tests/cases/consumer_integration.rs
@@ -43,14 +43,15 @@ mod tests {
let plan = from_substrait_plan(&ctx, &proto).await?;
- assert!(
- format!("{:?}", plan).eq_ignore_ascii_case(
- "Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST,
FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\n \
- Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag,
FILENAME_PLACEHOLDER_0.l_linestatus]],
aggr=[[SUM(FILENAME_PLACEHOLDER_0.l_quantity),
SUM(FILENAME_PLACEHOLDER_0.l_extendedprice),
SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) -
FILENAME_PLACEHOLDER_0.l_discount), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice
* Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) +
FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity),
AVG(FILENAME_PLACEHO [...]
- Projection: FILENAME_PLACEHOLDER_0.l_returnflag,
FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity,
FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice
* (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount),
FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) -
FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) +
FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PL [...]
- Filter: FILENAME_PLACEHOLDER_0.l_shipdate <=
Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120,
milliseconds: 0 }\")\n \
- TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey,
l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount,
l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate,
l_shipinstruct, l_shipmode, l_comment]"
- )
+ let plan_str = format!("{:?}", plan);
+ assert_eq!(
+ plan_str,
+ "Projection: FILENAME_PLACEHOLDER_0.l_returnflag AS L_RETURNFLAG,
FILENAME_PLACEHOLDER_0.l_linestatus AS L_LINESTATUS,
sum(FILENAME_PLACEHOLDER_0.l_quantity) AS SUM_QTY,
sum(FILENAME_PLACEHOLDER_0.l_extendedprice) AS SUM_BASE_PRICE,
sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) -
FILENAME_PLACEHOLDER_0.l_discount) AS SUM_DISC_PRICE,
sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) -
FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax) AS
S [...]
+ \n Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST,
FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\
+ \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag,
FILENAME_PLACEHOLDER_0.l_linestatus]],
aggr=[[sum(FILENAME_PLACEHOLDER_0.l_quantity),
sum(FILENAME_PLACEHOLDER_0.l_extendedprice),
sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) -
FILENAME_PLACEHOLDER_0.l_discount), sum(FILENAME_PLACEHOLDER_0.l_extendedprice
* Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) +
FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity),
AVG(FILENAME_PLACE [...]
+ \n Projection: FILENAME_PLACEHOLDER_0.l_returnflag,
FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity,
FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice
* (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount),
FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) -
FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) +
FILENAME_PLACEHOLDER_0.l_tax), FILENAM [...]
+ \n Filter: FILENAME_PLACEHOLDER_0.l_shipdate <=
Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120,
milliseconds: 0 }\")\
+ \n TableScan: FILENAME_PLACEHOLDER_0
projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity,
l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate,
l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"
);
Ok(())
}
diff --git a/datafusion/substrait/tests/cases/logical_plans.rs
b/datafusion/substrait/tests/cases/logical_plans.rs
index 4d485b7f12..994a932c30 100644
--- a/datafusion/substrait/tests/cases/logical_plans.rs
+++ b/datafusion/substrait/tests/cases/logical_plans.rs
@@ -48,7 +48,7 @@ mod tests {
assert_eq!(
format!("{:?}", plan),
- "Projection: NOT DATA.a\
+ "Projection: NOT DATA.a AS EXPR$0\
\n TableScan: DATA projection=[a, b, c, d, e, f]"
);
Ok(())
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index ae148119ad..4e4fa45a15 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -162,6 +162,11 @@ async fn wildcard_select() -> Result<()> {
roundtrip("SELECT * FROM data").await
}
+#[tokio::test]
+async fn select_with_alias() -> Result<()> {
+ roundtrip("SELECT a AS aliased_a FROM data").await
+}
+
#[tokio::test]
async fn select_with_filter() -> Result<()> {
roundtrip("SELECT * FROM data WHERE a > 1").await
@@ -367,9 +372,9 @@ async fn implicit_cast() -> Result<()> {
async fn aggregate_case() -> Result<()> {
assert_expected_plan(
"SELECT sum(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data",
- "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN
Int64(1) ELSE Int64(NULL) END)]]\
+ "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN
Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1)
ELSE NULL END)]]\
\n TableScan: data projection=[a]",
- false // NULL vs Int64(NULL)
+ true
)
.await
}
@@ -589,32 +594,23 @@ async fn roundtrip_union_all() -> Result<()> {
#[tokio::test]
async fn simple_intersect() -> Result<()> {
+ // Substrait treats both COUNT(*) and COUNT(1) the same
assert_expected_plan(
"SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT
data2.a FROM data2);",
- "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
+ "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\
\n Projection: \
\n LeftSemi Join: data.a = data2.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data2 projection=[a]",
- false // COUNT(*) vs COUNT(Int64(1))
+ true
)
.await
}
#[tokio::test]
async fn simple_intersect_table_reuse() -> Result<()> {
- assert_expected_plan(
- "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a
FROM data);",
- "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
- \n Projection: \
- \n LeftSemi Join: data.a = data.a\
- \n Aggregate: groupBy=[[data.a]], aggr=[[]]\
- \n TableScan: data projection=[a]\
- \n TableScan: data projection=[a]",
- false // COUNT(*) vs COUNT(Int64(1))
- )
- .await
+ roundtrip("SELECT COUNT(1) FROM (SELECT data.a FROM data INTERSECT SELECT
data.a FROM data);").await
}
#[tokio::test]
@@ -694,20 +690,14 @@ async fn all_type_literal() -> Result<()> {
#[tokio::test]
async fn roundtrip_literal_list() -> Result<()> {
- assert_expected_plan(
- "SELECT [[1,2,3], [], NULL, [NULL]] FROM data",
- "Projection: List([[1, 2, 3], [], , []])\
- \n TableScan: data projection=[]",
- false, // "List(..)" vs "make_array(..)"
- )
- .await
+ roundtrip("SELECT [[1,2,3], [], NULL, [NULL]] FROM data").await
}
#[tokio::test]
async fn roundtrip_literal_struct() -> Result<()> {
assert_expected_plan(
"SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
- "Projection: Struct({c0:1,c1:true,c2:})\
+ "Projection: Struct({c0:1,c1:true,c2:}) AS
struct(Int64(1),Boolean(true),NULL)\
\n TableScan: data projection=[]",
false, // "Struct(..)" vs "struct(..)"
)
@@ -980,12 +970,13 @@ async fn assert_expected_plan(
println!("{proto:?}");
- let plan2str = format!("{plan2:?}");
- assert_eq!(expected_plan_str, &plan2str);
-
if assert_schema {
assert_eq!(plan.schema(), plan2.schema());
}
+
+ let plan2str = format!("{plan2:?}");
+ assert_eq!(expected_plan_str, &plan2str);
+
Ok(())
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]