This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 19d9174150 Add support for Substrait Struct literals and type (#10622)
19d9174150 is described below
commit 19d91741509ea975f32f99b17d3e0595a85c0e09
Author: Arttu <[email protected]>
AuthorDate: Thu May 23 19:33:08 2024 +0200
Add support for Substrait Struct literals and type (#10622)
* Add support for (un-named) Substrait Struct literal
Adds support for converting from DataFusion Struct ScalarValues into
Substrait Struct Literals and back.
All structs are assumed to be unnamed - ie fields are renamed
into "c0", "c1", etc
* add converting from Substrait Struct type
* cargo fmt --all
* Unit test for NULL inside Struct
* retry ci
---
datafusion/substrait/src/logical_plan/consumer.rs | 22 +++++++++++++
datafusion/substrait/src/logical_plan/producer.rs | 36 +++++++++++++++++++++-
.../tests/cases/roundtrip_logical_plan.rs | 10 ++++++
3 files changed, 67 insertions(+), 1 deletion(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 5a71ab91db..a08485fd35 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -63,6 +63,7 @@ use substrait::proto::{FunctionArgument, SortField};
use datafusion::arrow::array::GenericListArray;
use datafusion::common::plan_err;
+use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
use std::collections::HashMap;
use std::str::FromStr;
@@ -1159,6 +1160,15 @@ pub(crate) fn from_substrait_type(dt:
&substrait::proto::Type) -> Result<DataTyp
"Unsupported Substrait type variation {v} of type
{s_kind:?}"
),
},
+ r#type::Kind::Struct(s) => {
+ let mut fields = vec![];
+ for (i, f) in s.types.iter().enumerate() {
+ let field =
+ Field::new(&format!("c{i}"), from_substrait_type(f)?,
true);
+ fields.push(field);
+ }
+ Ok(DataType::Struct(fields.into()))
+ }
_ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"),
},
_ => not_impl_err!("`None` Substrait kind is not supported"),
@@ -1318,6 +1328,18 @@ pub(crate) fn from_substrait_literal(lit: &Literal) ->
Result<ScalarValue> {
}
}
}
+ Some(LiteralType::Struct(s)) => {
+ let mut builder = ScalarStructBuilder::new();
+ for (i, field) in s.fields.iter().enumerate() {
+ let sv = from_substrait_literal(field)?;
+ // c0, c1, ... align with e.g. SqlToRel::create_named_struct
+ builder = builder.with_scalar(
+ Field::new(&format!("c{i}"), sv.data_type(),
field.nullable),
+ sv,
+ );
+ }
+ builder.build()?
+ }
Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
_ => return not_impl_err!("Unsupported literal_type: {:?}",
lit.literal_type),
};
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index bfdffdc3a2..e216008c73 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -43,7 +43,7 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint,
LogicalPlan, Opera
use datafusion::prelude::Expr;
use prost_types::Any as ProtoAny;
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
-use substrait::proto::expression::literal::List;
+use substrait::proto::expression::literal::{List, Struct};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::{CrossRel, ExchangeRel};
@@ -1751,6 +1751,18 @@ fn to_substrait_literal(value: &ScalarValue) ->
Result<Literal> {
ScalarValue::LargeList(l) if !value.is_null() => {
(convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF)
}
+ ScalarValue::Struct(s) if !value.is_null() => (
+ LiteralType::Struct(Struct {
+ fields: s
+ .columns()
+ .iter()
+ .map(|col| {
+ to_substrait_literal(&ScalarValue::try_from_array(col,
0)?)
+ })
+ .collect::<Result<Vec<_>>>()?,
+ }),
+ DEFAULT_TYPE_REF,
+ ),
_ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF),
};
@@ -1979,6 +1991,9 @@ fn try_to_substrait_null(v: &ScalarValue) ->
Result<LiteralType> {
ScalarValue::LargeList(l) => {
Ok(LiteralType::Null(to_substrait_type(l.data_type())?))
}
+ ScalarValue::Struct(s) => {
+ Ok(LiteralType::Null(to_substrait_type(s.data_type())?))
+ }
// TODO: Extend support for remaining data types
_ => not_impl_err!("Unsupported literal: {v:?}"),
}
@@ -2061,6 +2076,7 @@ mod test {
use crate::logical_plan::consumer::{from_substrait_literal,
from_substrait_type};
use datafusion::arrow::array::GenericListArray;
use datafusion::arrow::datatypes::Field;
+ use datafusion::common::scalar::ScalarStructBuilder;
use super::*;
@@ -2125,6 +2141,17 @@ mod test {
),
)))?;
+ let c0 = Field::new("c0", DataType::Boolean, true);
+ let c1 = Field::new("c1", DataType::Int32, true);
+ let c2 = Field::new("c2", DataType::Utf8, true);
+ round_trip_literal(
+ ScalarStructBuilder::new()
+ .with_scalar(c0, ScalarValue::Boolean(Some(true)))
+ .with_scalar(c1, ScalarValue::Int32(Some(1)))
+ .with_scalar(c2, ScalarValue::Utf8(None))
+ .build()?,
+ )?;
+
Ok(())
}
@@ -2169,6 +2196,13 @@ mod test {
round_trip_type(DataType::LargeList(
Field::new_list_field(DataType::Int32, true).into(),
))?;
+ round_trip_type(DataType::Struct(
+ vec![
+ Field::new("c0", DataType::Int32, true),
+ Field::new("c1", DataType::Utf8, true),
+ ]
+ .into(),
+ ))?;
Ok(())
}
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 02371063ef..8d0e96cedd 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -675,6 +675,16 @@ async fn roundtrip_literal_list() -> Result<()> {
.await
}
+#[tokio::test]
+async fn roundtrip_literal_struct() -> Result<()> {
+ assert_expected_plan(
+ "SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
+ "Projection: Struct({c0:1,c1:true,c2:})\
+ \n TableScan: data projection=[]",
+ )
+ .await
+}
+
/// Construct a plan that cast columns. Only those SQL types are supported for
now.
#[tokio::test]
async fn new_test_grammar() -> Result<()> {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]