This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b2a31b96ce Add support for Substrait List/EmptyList literals (#10615)
b2a31b96ce is described below

commit b2a31b96ce28135cb0024aaa6c27ebf5a21cb95c
Author: Arttu <[email protected]>
AuthorDate: Wed May 22 19:49:21 2024 +0200

    Add support for Substrait List/EmptyList literals (#10615)
    
    * Add support for Substrait List/EmptyList literals
    
    Adds support for converting from DataFusion List/LargeList ScalarValues 
into Substrait List/EmptyList Literals and back
    
    * cleanup
    
    * fix test, add literal roundtrip tests for lists, and fix creating null 
large lists
    
    * add unit testing for type roundtrips
    
    * fix clippy
    
    * better error if a substrait literal list is empty
---
 datafusion/substrait/src/logical_plan/consumer.rs  |  63 ++++++++-
 datafusion/substrait/src/logical_plan/producer.rs  | 145 ++++++++++++++++++---
 .../tests/cases/roundtrip_logical_plan.rs          |  10 ++
 3 files changed, 197 insertions(+), 21 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index e164791106..5a71ab91db 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -61,6 +61,7 @@ use substrait::proto::{
 };
 use substrait::proto::{FunctionArgument, SortField};
 
+use datafusion::arrow::array::GenericListArray;
 use datafusion::common::plan_err;
 use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
 use std::collections::HashMap;
@@ -1058,7 +1059,7 @@ pub async fn from_substrait_rex(
     }
 }
 
-fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataType> {
+pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> 
Result<DataType> {
     match &dt.kind {
         Some(s_kind) => match s_kind {
             r#type::Kind::Bool(_) => Ok(DataType::Boolean),
@@ -1138,7 +1139,7 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> 
Result<DataType> {
                     from_substrait_type(list.r#type.as_ref().ok_or_else(|| {
                         substrait_datafusion_err!("List type must have inner 
type")
                     })?)?;
-                let field = Arc::new(Field::new("list_item", inner_type, 
true));
+                let field = Arc::new(Field::new_list_field(inner_type, true));
                 match list.type_variation_reference {
                     DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)),
                     LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)),
@@ -1278,6 +1279,45 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> 
Result<ScalarValue> {
                 s,
             )
         }
+        Some(LiteralType::List(l)) => {
+            let elements = l
+                .values
+                .iter()
+                .map(from_substrait_literal)
+                .collect::<Result<Vec<_>>>()?;
+            if elements.is_empty() {
+                return substrait_err!(
+                    "Empty list must be encoded as EmptyList literal type, not 
List"
+                );
+            }
+            let element_type = elements[0].data_type();
+            match lit.type_variation_reference {
+                DEFAULT_CONTAINER_TYPE_REF => 
ScalarValue::List(ScalarValue::new_list(
+                    elements.as_slice(),
+                    &element_type,
+                )),
+                LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeList(
+                    ScalarValue::new_large_list(elements.as_slice(), 
&element_type),
+                ),
+                others => {
+                    return substrait_err!("Unknown type variation reference 
{others}");
+                }
+            }
+        }
+        Some(LiteralType::EmptyList(l)) => {
+            let element_type = 
from_substrait_type(l.r#type.clone().unwrap().as_ref())?;
+            match lit.type_variation_reference {
+                DEFAULT_CONTAINER_TYPE_REF => {
+                    ScalarValue::List(ScalarValue::new_list(&[], 
&element_type))
+                }
+                LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeList(
+                    ScalarValue::new_large_list(&[], &element_type),
+                ),
+                others => {
+                    return substrait_err!("Unknown type variation reference 
{others}");
+                }
+            }
+        }
         Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
         _ => return not_impl_err!("Unsupported literal_type: {:?}", 
lit.literal_type),
     };
@@ -1361,7 +1401,24 @@ fn from_substrait_null(null_type: &Type) -> 
Result<ScalarValue> {
                 d.precision as u8,
                 d.scale as i8,
             )),
-            _ => not_impl_err!("Unsupported Substrait type: {kind:?}"),
+            r#type::Kind::List(l) => {
+                let field = Field::new_list_field(
+                    from_substrait_type(l.r#type.clone().unwrap().as_ref())?,
+                    true,
+                );
+                match l.type_variation_reference {
+                    DEFAULT_CONTAINER_TYPE_REF => 
Ok(ScalarValue::List(Arc::new(
+                        GenericListArray::new_null(field.into(), 1),
+                    ))),
+                    LARGE_CONTAINER_TYPE_REF => 
Ok(ScalarValue::LargeList(Arc::new(
+                        GenericListArray::new_null(field.into(), 1),
+                    ))),
+                    v => not_impl_err!(
+                        "Unsupported Substrait type variation {v} of type 
{kind:?}"
+                    ),
+                }
+            }
+            _ => not_impl_err!("Unsupported Substrait type for null: 
{kind:?}"),
         }
     } else {
         not_impl_err!("Null type without kind is not supported")
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index 6f0738c38d..bfdffdc3a2 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -30,6 +30,7 @@ use datafusion::{
     scalar::ScalarValue,
 };
 
+use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait};
 use datafusion::common::{exec_err, internal_err, not_impl_err};
 use datafusion::common::{substrait_err, DFSchemaRef};
 #[allow(unused_imports)]
@@ -42,6 +43,7 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, 
LogicalPlan, Opera
 use datafusion::prelude::Expr;
 use prost_types::Any as ProtoAny;
 use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
+use substrait::proto::expression::literal::List;
 use substrait::proto::expression::subquery::InPredicate;
 use substrait::proto::expression::window_function::BoundsType;
 use substrait::proto::{CrossRel, ExchangeRel};
@@ -1100,7 +1102,7 @@ pub fn to_substrait_rex(
                 ))),
             })
         }
-        Expr::Literal(value) => to_substrait_literal(value),
+        Expr::Literal(value) => to_substrait_literal_expr(value),
         Expr::Alias(Alias { expr, .. }) => {
             to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)
         }
@@ -1526,8 +1528,9 @@ fn make_substrait_like_expr(
     };
     let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, 
extension_info)?;
     let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, 
extension_info)?;
-    let escape_char =
-        to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| 
c.to_string())))?;
+    let escape_char = to_substrait_literal_expr(&ScalarValue::Utf8(
+        escape_char.map(|c| c.to_string()),
+    ))?;
     let arguments = vec![
         FunctionArgument {
             arg_type: Some(ArgType::Value(expr)),
@@ -1683,7 +1686,7 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> 
Result<(Bound, Bound)> {
     ))
 }
 
-fn to_substrait_literal(value: &ScalarValue) -> Result<Expression> {
+fn to_substrait_literal(value: &ScalarValue) -> Result<Literal> {
     let (literal_type, type_variation_reference) = match value {
         ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), 
DEFAULT_TYPE_REF),
         ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), 
DEFAULT_TYPE_REF),
@@ -1741,15 +1744,50 @@ fn to_substrait_literal(value: &ScalarValue) -> 
Result<Expression> {
             }),
             DECIMAL_128_TYPE_REF,
         ),
+        ScalarValue::List(l) if !value.is_null() => (
+            convert_array_to_literal_list(l)?,
+            DEFAULT_CONTAINER_TYPE_REF,
+        ),
+        ScalarValue::LargeList(l) if !value.is_null() => {
+            (convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF)
+        }
         _ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF),
     };
 
+    Ok(Literal {
+        nullable: true,
+        type_variation_reference,
+        literal_type: Some(literal_type),
+    })
+}
+
+fn convert_array_to_literal_list<T: OffsetSizeTrait>(
+    array: &GenericListArray<T>,
+) -> Result<LiteralType> {
+    assert_eq!(array.len(), 1);
+    let nested_array = array.value(0);
+
+    let values = (0..nested_array.len())
+        .map(|i| 
to_substrait_literal(&ScalarValue::try_from_array(&nested_array, i)?))
+        .collect::<Result<Vec<_>>>()?;
+
+    if values.is_empty() {
+        let et = match to_substrait_type(array.data_type())? {
+            substrait::proto::Type {
+                kind: Some(r#type::Kind::List(lt)),
+            } => lt.as_ref().to_owned(),
+            _ => unreachable!(),
+        };
+        Ok(LiteralType::EmptyList(et))
+    } else {
+        Ok(LiteralType::List(List { values }))
+    }
+}
+
+fn to_substrait_literal_expr(value: &ScalarValue) -> Result<Expression> {
+    let literal = to_substrait_literal(value)?;
     Ok(Expression {
-        rex_type: Some(RexType::Literal(Literal {
-            nullable: true,
-            type_variation_reference,
-            literal_type: Some(literal_type),
-        })),
+        rex_type: Some(RexType::Literal(literal)),
     })
 }
 
@@ -1937,6 +1975,10 @@ fn try_to_substrait_null(v: &ScalarValue) -> 
Result<LiteralType> {
                 })),
             }))
         }
+        ScalarValue::List(l) => 
Ok(LiteralType::Null(to_substrait_type(l.data_type())?)),
+        ScalarValue::LargeList(l) => {
+            Ok(LiteralType::Null(to_substrait_type(l.data_type())?))
+        }
         // TODO: Extend support for remaining data types
         _ => not_impl_err!("Unsupported literal: {v:?}"),
     }
@@ -2016,7 +2058,9 @@ fn substrait_field_ref(index: usize) -> 
Result<Expression> {
 
 #[cfg(test)]
 mod test {
-    use crate::logical_plan::consumer::from_substrait_literal;
+    use crate::logical_plan::consumer::{from_substrait_literal, 
from_substrait_type};
+    use datafusion::arrow::array::GenericListArray;
+    use datafusion::arrow::datatypes::Field;
 
     use super::*;
 
@@ -2054,22 +2098,87 @@ mod test {
         round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?;
         round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?;
 
+        round_trip_literal(ScalarValue::List(ScalarValue::new_list(
+            &[ScalarValue::Float32(Some(1.0))],
+            &DataType::Float32,
+        )))?;
+        round_trip_literal(ScalarValue::List(ScalarValue::new_list(
+            &[],
+            &DataType::Float32,
+        )))?;
+        
round_trip_literal(ScalarValue::List(Arc::new(GenericListArray::new_null(
+            Field::new_list_field(DataType::Float32, true).into(),
+            1,
+        ))))?;
+        round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list(
+            &[ScalarValue::Float32(Some(1.0))],
+            &DataType::Float32,
+        )))?;
+        round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list(
+            &[],
+            &DataType::Float32,
+        )))?;
+        round_trip_literal(ScalarValue::LargeList(Arc::new(
+            GenericListArray::new_null(
+                Field::new_list_field(DataType::Float32, true).into(),
+                1,
+            ),
+        )))?;
+
         Ok(())
     }
 
     fn round_trip_literal(scalar: ScalarValue) -> Result<()> {
         println!("Checking round trip of {scalar:?}");
 
-        let substrait = to_substrait_literal(&scalar)?;
-        let Expression {
-            rex_type: Some(RexType::Literal(substrait_literal)),
-        } = substrait
-        else {
-            panic!("Expected Literal expression, got {substrait:?}");
-        };
-
+        let substrait_literal = to_substrait_literal(&scalar)?;
         let roundtrip_scalar = from_substrait_literal(&substrait_literal)?;
         assert_eq!(scalar, roundtrip_scalar);
         Ok(())
     }
+
+    #[test]
+    fn round_trip_types() -> Result<()> {
+        round_trip_type(DataType::Boolean)?;
+        round_trip_type(DataType::Int8)?;
+        round_trip_type(DataType::UInt8)?;
+        round_trip_type(DataType::Int16)?;
+        round_trip_type(DataType::UInt16)?;
+        round_trip_type(DataType::Int32)?;
+        round_trip_type(DataType::UInt32)?;
+        round_trip_type(DataType::Int64)?;
+        round_trip_type(DataType::UInt64)?;
+        round_trip_type(DataType::Float32)?;
+        round_trip_type(DataType::Float64)?;
+        round_trip_type(DataType::Timestamp(TimeUnit::Second, None))?;
+        round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, None))?;
+        round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, None))?;
+        round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, None))?;
+        round_trip_type(DataType::Date32)?;
+        round_trip_type(DataType::Date64)?;
+        round_trip_type(DataType::Binary)?;
+        round_trip_type(DataType::FixedSizeBinary(10))?;
+        round_trip_type(DataType::LargeBinary)?;
+        round_trip_type(DataType::Utf8)?;
+        round_trip_type(DataType::LargeUtf8)?;
+        round_trip_type(DataType::Decimal128(10, 2))?;
+        round_trip_type(DataType::Decimal256(30, 2))?;
+        round_trip_type(DataType::List(
+            Field::new_list_field(DataType::Int32, true).into(),
+        ))?;
+        round_trip_type(DataType::LargeList(
+            Field::new_list_field(DataType::Int32, true).into(),
+        ))?;
+
+        Ok(())
+    }
+
+    fn round_trip_type(dt: DataType) -> Result<()> {
+        println!("Checking round trip of {dt:?}");
+
+        let substrait = to_substrait_type(&dt)?;
+        let roundtrip_dt = from_substrait_type(&substrait)?;
+        assert_eq!(dt, roundtrip_dt);
+        Ok(())
+    }
 }
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 4c7dc87145..02371063ef 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -665,6 +665,16 @@ async fn all_type_literal() -> Result<()> {
         .await
 }
 
+#[tokio::test]
+async fn roundtrip_literal_list() -> Result<()> {
+    assert_expected_plan(
+        "SELECT [[1,2,3], [], NULL, [NULL]] FROM data",
+        "Projection: List([[1, 2, 3], [], , []])\
+        \n  TableScan: data projection=[]",
+    )
+    .await
+}
+
 /// Construct a plan that cast columns. Only those SQL types are supported for 
now.
 #[tokio::test]
 async fn new_test_grammar() -> Result<()> {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to