This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new b2a31b96ce Add support for Substrait List/EmptyList literals (#10615)
b2a31b96ce is described below
commit b2a31b96ce28135cb0024aaa6c27ebf5a21cb95c
Author: Arttu <[email protected]>
AuthorDate: Wed May 22 19:49:21 2024 +0200
Add support for Substrait List/EmptyList literals (#10615)
* Add support for Substrait List/EmptyList literals
Adds support for converting from DataFusion List/LargeList ScalarValues
into Substrait List/EmptyList Literals and back
* cleanup
* fix test, add literal roundtrip tests for lists, and fix creating null
large lists
* add unit testing for type roundtrips
* fix clippy
* better error if a substrait literal list is empty
---
datafusion/substrait/src/logical_plan/consumer.rs | 63 ++++++++-
datafusion/substrait/src/logical_plan/producer.rs | 145 ++++++++++++++++++---
.../tests/cases/roundtrip_logical_plan.rs | 10 ++
3 files changed, 197 insertions(+), 21 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index e164791106..5a71ab91db 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -61,6 +61,7 @@ use substrait::proto::{
};
use substrait::proto::{FunctionArgument, SortField};
+use datafusion::arrow::array::GenericListArray;
use datafusion::common::plan_err;
use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
use std::collections::HashMap;
@@ -1058,7 +1059,7 @@ pub async fn from_substrait_rex(
}
}
-fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataType> {
+pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) ->
Result<DataType> {
match &dt.kind {
Some(s_kind) => match s_kind {
r#type::Kind::Bool(_) => Ok(DataType::Boolean),
@@ -1138,7 +1139,7 @@ fn from_substrait_type(dt: &substrait::proto::Type) ->
Result<DataType> {
from_substrait_type(list.r#type.as_ref().ok_or_else(|| {
substrait_datafusion_err!("List type must have inner
type")
})?)?;
- let field = Arc::new(Field::new("list_item", inner_type,
true));
+ let field = Arc::new(Field::new_list_field(inner_type, true));
match list.type_variation_reference {
DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)),
LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)),
@@ -1278,6 +1279,45 @@ pub(crate) fn from_substrait_literal(lit: &Literal) ->
Result<ScalarValue> {
s,
)
}
+ Some(LiteralType::List(l)) => {
+ let elements = l
+ .values
+ .iter()
+ .map(from_substrait_literal)
+ .collect::<Result<Vec<_>>>()?;
+ if elements.is_empty() {
+ return substrait_err!(
+ "Empty list must be encoded as EmptyList literal type, not
List"
+ );
+ }
+ let element_type = elements[0].data_type();
+ match lit.type_variation_reference {
+ DEFAULT_CONTAINER_TYPE_REF =>
ScalarValue::List(ScalarValue::new_list(
+ elements.as_slice(),
+ &element_type,
+ )),
+ LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeList(
+ ScalarValue::new_large_list(elements.as_slice(),
&element_type),
+ ),
+ others => {
+ return substrait_err!("Unknown type variation reference
{others}");
+ }
+ }
+ }
+ Some(LiteralType::EmptyList(l)) => {
+ let element_type =
from_substrait_type(l.r#type.clone().unwrap().as_ref())?;
+ match lit.type_variation_reference {
+ DEFAULT_CONTAINER_TYPE_REF => {
+ ScalarValue::List(ScalarValue::new_list(&[],
&element_type))
+ }
+ LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeList(
+ ScalarValue::new_large_list(&[], &element_type),
+ ),
+ others => {
+ return substrait_err!("Unknown type variation reference
{others}");
+ }
+ }
+ }
Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
_ => return not_impl_err!("Unsupported literal_type: {:?}",
lit.literal_type),
};
@@ -1361,7 +1401,24 @@ fn from_substrait_null(null_type: &Type) ->
Result<ScalarValue> {
d.precision as u8,
d.scale as i8,
)),
- _ => not_impl_err!("Unsupported Substrait type: {kind:?}"),
+ r#type::Kind::List(l) => {
+ let field = Field::new_list_field(
+ from_substrait_type(l.r#type.clone().unwrap().as_ref())?,
+ true,
+ );
+ match l.type_variation_reference {
+ DEFAULT_CONTAINER_TYPE_REF =>
Ok(ScalarValue::List(Arc::new(
+ GenericListArray::new_null(field.into(), 1),
+ ))),
+ LARGE_CONTAINER_TYPE_REF =>
Ok(ScalarValue::LargeList(Arc::new(
+ GenericListArray::new_null(field.into(), 1),
+ ))),
+ v => not_impl_err!(
+ "Unsupported Substrait type variation {v} of type
{kind:?}"
+ ),
+ }
+ }
+ _ => not_impl_err!("Unsupported Substrait type for null:
{kind:?}"),
}
} else {
not_impl_err!("Null type without kind is not supported")
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 6f0738c38d..bfdffdc3a2 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -30,6 +30,7 @@ use datafusion::{
scalar::ScalarValue,
};
+use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait};
use datafusion::common::{exec_err, internal_err, not_impl_err};
use datafusion::common::{substrait_err, DFSchemaRef};
#[allow(unused_imports)]
@@ -42,6 +43,7 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint,
LogicalPlan, Opera
use datafusion::prelude::Expr;
use prost_types::Any as ProtoAny;
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
+use substrait::proto::expression::literal::List;
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::{CrossRel, ExchangeRel};
@@ -1100,7 +1102,7 @@ pub fn to_substrait_rex(
))),
})
}
- Expr::Literal(value) => to_substrait_literal(value),
+ Expr::Literal(value) => to_substrait_literal_expr(value),
Expr::Alias(Alias { expr, .. }) => {
to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)
}
@@ -1526,8 +1528,9 @@ fn make_substrait_like_expr(
};
let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset,
extension_info)?;
let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset,
extension_info)?;
- let escape_char =
- to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c|
c.to_string())))?;
+ let escape_char = to_substrait_literal_expr(&ScalarValue::Utf8(
+ escape_char.map(|c| c.to_string()),
+ ))?;
let arguments = vec![
FunctionArgument {
arg_type: Some(ArgType::Value(expr)),
@@ -1683,7 +1686,7 @@ fn to_substrait_bounds(window_frame: &WindowFrame) ->
Result<(Bound, Bound)> {
))
}
-fn to_substrait_literal(value: &ScalarValue) -> Result<Expression> {
+fn to_substrait_literal(value: &ScalarValue) -> Result<Literal> {
let (literal_type, type_variation_reference) = match value {
ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b),
DEFAULT_TYPE_REF),
ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32),
DEFAULT_TYPE_REF),
@@ -1741,15 +1744,50 @@ fn to_substrait_literal(value: &ScalarValue) ->
Result<Expression> {
}),
DECIMAL_128_TYPE_REF,
),
+ ScalarValue::List(l) if !value.is_null() => (
+ convert_array_to_literal_list(l)?,
+ DEFAULT_CONTAINER_TYPE_REF,
+ ),
+ ScalarValue::LargeList(l) if !value.is_null() => {
+ (convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF)
+ }
_ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF),
};
+ Ok(Literal {
+ nullable: true,
+ type_variation_reference,
+ literal_type: Some(literal_type),
+ })
+}
+
+fn convert_array_to_literal_list<T: OffsetSizeTrait>(
+ array: &GenericListArray<T>,
+) -> Result<LiteralType> {
+ assert_eq!(array.len(), 1);
+ let nested_array = array.value(0);
+
+ let values = (0..nested_array.len())
+ .map(|i|
to_substrait_literal(&ScalarValue::try_from_array(&nested_array, i)?))
+ .collect::<Result<Vec<_>>>()?;
+
+ if values.is_empty() {
+ let et = match to_substrait_type(array.data_type())? {
+ substrait::proto::Type {
+ kind: Some(r#type::Kind::List(lt)),
+ } => lt.as_ref().to_owned(),
+ _ => unreachable!(),
+ };
+ Ok(LiteralType::EmptyList(et))
+ } else {
+ Ok(LiteralType::List(List { values }))
+ }
+}
+
+fn to_substrait_literal_expr(value: &ScalarValue) -> Result<Expression> {
+ let literal = to_substrait_literal(value)?;
Ok(Expression {
- rex_type: Some(RexType::Literal(Literal {
- nullable: true,
- type_variation_reference,
- literal_type: Some(literal_type),
- })),
+ rex_type: Some(RexType::Literal(literal)),
})
}
@@ -1937,6 +1975,10 @@ fn try_to_substrait_null(v: &ScalarValue) ->
Result<LiteralType> {
})),
}))
}
+ ScalarValue::List(l) =>
Ok(LiteralType::Null(to_substrait_type(l.data_type())?)),
+ ScalarValue::LargeList(l) => {
+ Ok(LiteralType::Null(to_substrait_type(l.data_type())?))
+ }
// TODO: Extend support for remaining data types
_ => not_impl_err!("Unsupported literal: {v:?}"),
}
@@ -2016,7 +2058,9 @@ fn substrait_field_ref(index: usize) ->
Result<Expression> {
#[cfg(test)]
mod test {
- use crate::logical_plan::consumer::from_substrait_literal;
+ use crate::logical_plan::consumer::{from_substrait_literal,
from_substrait_type};
+ use datafusion::arrow::array::GenericListArray;
+ use datafusion::arrow::datatypes::Field;
use super::*;
@@ -2054,22 +2098,87 @@ mod test {
round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?;
round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?;
+ round_trip_literal(ScalarValue::List(ScalarValue::new_list(
+ &[ScalarValue::Float32(Some(1.0))],
+ &DataType::Float32,
+ )))?;
+ round_trip_literal(ScalarValue::List(ScalarValue::new_list(
+ &[],
+ &DataType::Float32,
+ )))?;
+
round_trip_literal(ScalarValue::List(Arc::new(GenericListArray::new_null(
+ Field::new_list_field(DataType::Float32, true).into(),
+ 1,
+ ))))?;
+ round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list(
+ &[ScalarValue::Float32(Some(1.0))],
+ &DataType::Float32,
+ )))?;
+ round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list(
+ &[],
+ &DataType::Float32,
+ )))?;
+ round_trip_literal(ScalarValue::LargeList(Arc::new(
+ GenericListArray::new_null(
+ Field::new_list_field(DataType::Float32, true).into(),
+ 1,
+ ),
+ )))?;
+
Ok(())
}
fn round_trip_literal(scalar: ScalarValue) -> Result<()> {
println!("Checking round trip of {scalar:?}");
- let substrait = to_substrait_literal(&scalar)?;
- let Expression {
- rex_type: Some(RexType::Literal(substrait_literal)),
- } = substrait
- else {
- panic!("Expected Literal expression, got {substrait:?}");
- };
-
+ let substrait_literal = to_substrait_literal(&scalar)?;
let roundtrip_scalar = from_substrait_literal(&substrait_literal)?;
assert_eq!(scalar, roundtrip_scalar);
Ok(())
}
+
+ #[test]
+ fn round_trip_types() -> Result<()> {
+ round_trip_type(DataType::Boolean)?;
+ round_trip_type(DataType::Int8)?;
+ round_trip_type(DataType::UInt8)?;
+ round_trip_type(DataType::Int16)?;
+ round_trip_type(DataType::UInt16)?;
+ round_trip_type(DataType::Int32)?;
+ round_trip_type(DataType::UInt32)?;
+ round_trip_type(DataType::Int64)?;
+ round_trip_type(DataType::UInt64)?;
+ round_trip_type(DataType::Float32)?;
+ round_trip_type(DataType::Float64)?;
+ round_trip_type(DataType::Timestamp(TimeUnit::Second, None))?;
+ round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, None))?;
+ round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, None))?;
+ round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, None))?;
+ round_trip_type(DataType::Date32)?;
+ round_trip_type(DataType::Date64)?;
+ round_trip_type(DataType::Binary)?;
+ round_trip_type(DataType::FixedSizeBinary(10))?;
+ round_trip_type(DataType::LargeBinary)?;
+ round_trip_type(DataType::Utf8)?;
+ round_trip_type(DataType::LargeUtf8)?;
+ round_trip_type(DataType::Decimal128(10, 2))?;
+ round_trip_type(DataType::Decimal256(30, 2))?;
+ round_trip_type(DataType::List(
+ Field::new_list_field(DataType::Int32, true).into(),
+ ))?;
+ round_trip_type(DataType::LargeList(
+ Field::new_list_field(DataType::Int32, true).into(),
+ ))?;
+
+ Ok(())
+ }
+
+ fn round_trip_type(dt: DataType) -> Result<()> {
+ println!("Checking round trip of {dt:?}");
+
+ let substrait = to_substrait_type(&dt)?;
+ let roundtrip_dt = from_substrait_type(&substrait)?;
+ assert_eq!(dt, roundtrip_dt);
+ Ok(())
+ }
}
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 4c7dc87145..02371063ef 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -665,6 +665,16 @@ async fn all_type_literal() -> Result<()> {
.await
}
+#[tokio::test]
+async fn roundtrip_literal_list() -> Result<()> {
+ assert_expected_plan(
+ "SELECT [[1,2,3], [], NULL, [NULL]] FROM data",
+ "Projection: List([[1, 2, 3], [], , []])\
+ \n TableScan: data projection=[]",
+ )
+ .await
+}
+
/// Construct a plan that cast columns. Only those SQL types are supported for
now.
#[tokio::test]
async fn new_test_grammar() -> Result<()> {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]