This is an automated email from the ASF dual-hosted git repository.

jonah pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8bedecc00b More properly handle nullability of types/literals in 
Substrait (#10640)
8bedecc00b is described below

commit 8bedecc00b2f1f04d7b1a907152ce0d19b7046a5
Author: Arttu <[email protected]>
AuthorDate: Fri May 24 15:05:04 2024 +0200

    More properly handle nullability of types/literals in Substrait (#10640)
    
    * More properly handle nullability of types/literals in Substrait
    
    This isn't perfect; some things are still assumed to just always be 
nullable (e.g. Literal list elements).
    But it's still giving a closer match than just assuming everything is 
nullable.
    
    * Avoid cloning and creating DataFusionError
    
    Co-authored-by: Jonah Gao <[email protected]>
    
    * simplify Literal/ScalarValue null handling
    
    ---------
    
    Co-authored-by: Jonah Gao <[email protected]>
---
 datafusion/substrait/src/logical_plan/consumer.rs |  60 ++++-
 datafusion/substrait/src/logical_plan/producer.rs | 282 ++++++----------------
 2 files changed, 122 insertions(+), 220 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index a08485fd35..7e8a0cadb5 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -1136,11 +1136,13 @@ pub(crate) fn from_substrait_type(dt: 
&substrait::proto::Type) -> Result<DataTyp
                 ),
             },
             r#type::Kind::List(list) => {
-                let inner_type =
-                    from_substrait_type(list.r#type.as_ref().ok_or_else(|| {
-                        substrait_datafusion_err!("List type must have inner 
type")
-                    })?)?;
-                let field = Arc::new(Field::new_list_field(inner_type, true));
+                let inner_type = list.r#type.as_ref().ok_or_else(|| {
+                    substrait_datafusion_err!("List type must have inner type")
+                })?;
+                let field = Arc::new(Field::new_list_field(
+                    from_substrait_type(inner_type)?,
+                    is_substrait_type_nullable(inner_type)?,
+                ));
                 match list.type_variation_reference {
                     DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)),
                     LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)),
@@ -1163,8 +1165,11 @@ pub(crate) fn from_substrait_type(dt: 
&substrait::proto::Type) -> Result<DataTyp
             r#type::Kind::Struct(s) => {
                 let mut fields = vec![];
                 for (i, f) in s.types.iter().enumerate() {
-                    let field =
-                        Field::new(&format!("c{i}"), from_substrait_type(f)?, 
true);
+                    let field = Field::new(
+                        &format!("c{i}"),
+                        from_substrait_type(f)?,
+                        is_substrait_type_nullable(f)?,
+                    );
                     fields.push(field);
                 }
                 Ok(DataType::Struct(fields.into()))
@@ -1175,6 +1180,47 @@ pub(crate) fn from_substrait_type(dt: 
&substrait::proto::Type) -> Result<DataTyp
     }
 }
 
+fn is_substrait_type_nullable(dtype: &Type) -> Result<bool> {
+    fn is_nullable(nullability: i32) -> bool {
+        nullability != substrait::proto::r#type::Nullability::Required as i32
+    }
+
+    let nullable = match dtype
+        .kind
+        .as_ref()
+        .ok_or_else(|| substrait_datafusion_err!("Type must contain Kind"))?
+    {
+        r#type::Kind::Bool(val) => is_nullable(val.nullability),
+        r#type::Kind::I8(val) => is_nullable(val.nullability),
+        r#type::Kind::I16(val) => is_nullable(val.nullability),
+        r#type::Kind::I32(val) => is_nullable(val.nullability),
+        r#type::Kind::I64(val) => is_nullable(val.nullability),
+        r#type::Kind::Fp32(val) => is_nullable(val.nullability),
+        r#type::Kind::Fp64(val) => is_nullable(val.nullability),
+        r#type::Kind::String(val) => is_nullable(val.nullability),
+        r#type::Kind::Binary(val) => is_nullable(val.nullability),
+        r#type::Kind::Timestamp(val) => is_nullable(val.nullability),
+        r#type::Kind::Date(val) => is_nullable(val.nullability),
+        r#type::Kind::Time(val) => is_nullable(val.nullability),
+        r#type::Kind::IntervalYear(val) => is_nullable(val.nullability),
+        r#type::Kind::IntervalDay(val) => is_nullable(val.nullability),
+        r#type::Kind::TimestampTz(val) => is_nullable(val.nullability),
+        r#type::Kind::Uuid(val) => is_nullable(val.nullability),
+        r#type::Kind::FixedChar(val) => is_nullable(val.nullability),
+        r#type::Kind::Varchar(val) => is_nullable(val.nullability),
+        r#type::Kind::FixedBinary(val) => is_nullable(val.nullability),
+        r#type::Kind::Decimal(val) => is_nullable(val.nullability),
+        r#type::Kind::PrecisionTimestamp(val) => is_nullable(val.nullability),
+        r#type::Kind::PrecisionTimestampTz(val) => 
is_nullable(val.nullability),
+        r#type::Kind::Struct(val) => is_nullable(val.nullability),
+        r#type::Kind::List(val) => is_nullable(val.nullability),
+        r#type::Kind::Map(val) => is_nullable(val.nullability),
+        r#type::Kind::UserDefined(val) => is_nullable(val.nullability),
+        r#type::Kind::UserDefinedTypeReference(_) => true, // not implemented, 
assume nullable
+    };
+    Ok(nullable)
+}
+
 fn from_substrait_bound(
     bound: &Option<Bound>,
     is_lower: bool,
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index e216008c73..c0aac0c0a4 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -1089,7 +1089,7 @@ pub fn to_substrait_rex(
             Ok(Expression {
                 rex_type: Some(RexType::Cast(Box::new(
                     substrait::proto::expression::Cast {
-                        r#type: Some(to_substrait_type(data_type)?),
+                        r#type: Some(to_substrait_type(data_type, true)?),
                         input: Some(Box::new(to_substrait_rex(
                             ctx,
                             expr,
@@ -1296,75 +1296,79 @@ pub fn to_substrait_rex(
     }
 }
 
-fn to_substrait_type(dt: &DataType) -> Result<substrait::proto::Type> {
-    let default_nullability = r#type::Nullability::Required as i32;
+fn to_substrait_type(dt: &DataType, nullable: bool) -> 
Result<substrait::proto::Type> {
+    let nullability = if nullable {
+        r#type::Nullability::Nullable as i32
+    } else {
+        r#type::Nullability::Required as i32
+    };
     match dt {
         DataType::Null => internal_err!("Null cast is not valid"),
         DataType::Boolean => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::Bool(r#type::Boolean {
                 type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::Int8 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::I8(r#type::I8 {
                 type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::UInt8 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::I8(r#type::I8 {
                 type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::Int16 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::I16(r#type::I16 {
                 type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::UInt16 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::I16(r#type::I16 {
                 type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::Int32 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::I32(r#type::I32 {
                 type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::UInt32 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::I32(r#type::I32 {
                 type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::Int64 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::I64(r#type::I64 {
                 type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::UInt64 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::I64(r#type::I64 {
                 type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         // Float16 is not supported in Substrait
         DataType::Float32 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::Fp32(r#type::Fp32 {
                 type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::Float64 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::Fp64(r#type::Fp64 {
                 type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         // Timezone is ignored.
@@ -1378,90 +1382,90 @@ fn to_substrait_type(dt: &DataType) -> 
Result<substrait::proto::Type> {
             Ok(substrait::proto::Type {
                 kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
                     type_variation_reference,
-                    nullability: default_nullability,
+                    nullability,
                 })),
             })
         }
         DataType::Date32 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::Date(r#type::Date {
                 type_variation_reference: DATE_32_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::Date64 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::Date(r#type::Date {
                 type_variation_reference: DATE_64_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::Binary => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::Binary(r#type::Binary {
                 type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary {
                 length: *length,
                 type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::LargeBinary => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::Binary(r#type::Binary {
                 type_variation_reference: LARGE_CONTAINER_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::Utf8 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::String(r#type::String {
                 type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::LargeUtf8 => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::String(r#type::String {
                 type_variation_reference: LARGE_CONTAINER_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
             })),
         }),
         DataType::List(inner) => {
-            let inner_type = to_substrait_type(inner.data_type())?;
+            let inner_type = to_substrait_type(inner.data_type(), 
inner.is_nullable())?;
             Ok(substrait::proto::Type {
                 kind: Some(r#type::Kind::List(Box::new(r#type::List {
                     r#type: Some(Box::new(inner_type)),
                     type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
-                    nullability: default_nullability,
+                    nullability,
                 }))),
             })
         }
         DataType::LargeList(inner) => {
-            let inner_type = to_substrait_type(inner.data_type())?;
+            let inner_type = to_substrait_type(inner.data_type(), 
inner.is_nullable())?;
             Ok(substrait::proto::Type {
                 kind: Some(r#type::Kind::List(Box::new(r#type::List {
                     r#type: Some(Box::new(inner_type)),
                     type_variation_reference: LARGE_CONTAINER_TYPE_REF,
-                    nullability: default_nullability,
+                    nullability,
                 }))),
             })
         }
         DataType::Struct(fields) => {
             let field_types = fields
                 .iter()
-                .map(|field| to_substrait_type(field.data_type()))
+                .map(|field| to_substrait_type(field.data_type(), 
field.is_nullable()))
                 .collect::<Result<Vec<_>>>()?;
             Ok(substrait::proto::Type {
                 kind: Some(r#type::Kind::Struct(r#type::Struct {
                     types: field_types,
                     type_variation_reference: DEFAULT_TYPE_REF,
-                    nullability: default_nullability,
+                    nullability,
                 })),
             })
         }
         DataType::Decimal128(p, s) => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::Decimal(r#type::Decimal {
                 type_variation_reference: DECIMAL_128_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
                 scale: *s as i32,
                 precision: *p as i32,
             })),
@@ -1469,7 +1473,7 @@ fn to_substrait_type(dt: &DataType) -> 
Result<substrait::proto::Type> {
         DataType::Decimal256(p, s) => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::Decimal(r#type::Decimal {
                 type_variation_reference: DECIMAL_256_TYPE_REF,
-                nullability: default_nullability,
+                nullability,
                 scale: *s as i32,
                 precision: *p as i32,
             })),
@@ -1687,6 +1691,16 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> 
Result<(Bound, Bound)> {
 }
 
 fn to_substrait_literal(value: &ScalarValue) -> Result<Literal> {
+    if value.is_null() {
+        return Ok(Literal {
+            nullable: true,
+            type_variation_reference: DEFAULT_TYPE_REF,
+            literal_type: Some(LiteralType::Null(to_substrait_type(
+                &value.data_type(),
+                true,
+            )?)),
+        });
+    }
     let (literal_type, type_variation_reference) = match value {
         ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), 
DEFAULT_TYPE_REF),
         ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), 
DEFAULT_TYPE_REF),
@@ -1744,14 +1758,14 @@ fn to_substrait_literal(value: &ScalarValue) -> 
Result<Literal> {
             }),
             DECIMAL_128_TYPE_REF,
         ),
-        ScalarValue::List(l) if !value.is_null() => (
+        ScalarValue::List(l) => (
             convert_array_to_literal_list(l)?,
             DEFAULT_CONTAINER_TYPE_REF,
         ),
-        ScalarValue::LargeList(l) if !value.is_null() => {
+        ScalarValue::LargeList(l) => {
             (convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF)
         }
-        ScalarValue::Struct(s) if !value.is_null() => (
+        ScalarValue::Struct(s) => (
             LiteralType::Struct(Struct {
                 fields: s
                     .columns()
@@ -1763,11 +1777,14 @@ fn to_substrait_literal(value: &ScalarValue) -> 
Result<Literal> {
             }),
             DEFAULT_TYPE_REF,
         ),
-        _ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF),
+        _ => (
+            not_impl_err!("Unsupported literal: {value:?}")?,
+            DEFAULT_TYPE_REF,
+        ),
     };
 
     Ok(Literal {
-        nullable: true,
+        nullable: false,
         type_variation_reference,
         literal_type: Some(literal_type),
     })
@@ -1784,7 +1801,7 @@ fn convert_array_to_literal_list<T: OffsetSizeTrait>(
         .collect::<Result<Vec<_>>>()?;
 
     if values.is_empty() {
-        let et = match to_substrait_type(array.data_type())? {
+        let et = match to_substrait_type(array.data_type(), 
array.is_nullable())? {
             substrait::proto::Type {
                 kind: Some(r#type::Kind::List(lt)),
             } => lt.as_ref().to_owned(),
@@ -1832,173 +1849,6 @@ fn to_substrait_unary_scalar_fn(
     })
 }
 
-fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
-    let default_nullability = r#type::Nullability::Nullable as i32;
-    match v {
-        ScalarValue::Boolean(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::Bool(r#type::Boolean {
-                type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::Int8(None) => Ok(LiteralType::Null(substrait::proto::Type 
{
-            kind: Some(r#type::Kind::I8(r#type::I8 {
-                type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::UInt8(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::I8(r#type::I8 {
-                type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::Int16(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::I16(r#type::I16 {
-                type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::UInt16(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::I16(r#type::I16 {
-                type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::Int32(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::I32(r#type::I32 {
-                type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::UInt32(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::I32(r#type::I32 {
-                type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::Int64(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::I64(r#type::I64 {
-                type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::UInt64(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::I64(r#type::I64 {
-                type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::Float32(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::Fp32(r#type::Fp32 {
-                type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::Float64(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::Fp64(r#type::Fp64 {
-                type_variation_reference: DEFAULT_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::TimestampSecond(None, _) => {
-            Ok(LiteralType::Null(substrait::proto::Type {
-                kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
-                    type_variation_reference: TIMESTAMP_SECOND_TYPE_REF,
-                    nullability: default_nullability,
-                })),
-            }))
-        }
-        ScalarValue::TimestampMillisecond(None, _) => {
-            Ok(LiteralType::Null(substrait::proto::Type {
-                kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
-                    type_variation_reference: TIMESTAMP_MILLI_TYPE_REF,
-                    nullability: default_nullability,
-                })),
-            }))
-        }
-        ScalarValue::TimestampMicrosecond(None, _) => {
-            Ok(LiteralType::Null(substrait::proto::Type {
-                kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
-                    type_variation_reference: TIMESTAMP_MICRO_TYPE_REF,
-                    nullability: default_nullability,
-                })),
-            }))
-        }
-        ScalarValue::TimestampNanosecond(None, _) => {
-            Ok(LiteralType::Null(substrait::proto::Type {
-                kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
-                    type_variation_reference: TIMESTAMP_NANO_TYPE_REF,
-                    nullability: default_nullability,
-                })),
-            }))
-        }
-        ScalarValue::Date32(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::Date(r#type::Date {
-                type_variation_reference: DATE_32_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::Date64(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::Date(r#type::Date {
-                type_variation_reference: DATE_64_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::Binary(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::Binary(r#type::Binary {
-                type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::LargeBinary(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::Binary(r#type::Binary {
-                type_variation_reference: LARGE_CONTAINER_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::FixedSizeBinary(_, None) => {
-            Ok(LiteralType::Null(substrait::proto::Type {
-                kind: Some(r#type::Kind::Binary(r#type::Binary {
-                    type_variation_reference: DEFAULT_TYPE_REF,
-                    nullability: default_nullability,
-                })),
-            }))
-        }
-        ScalarValue::Utf8(None) => Ok(LiteralType::Null(substrait::proto::Type 
{
-            kind: Some(r#type::Kind::String(r#type::String {
-                type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::LargeUtf8(None) => 
Ok(LiteralType::Null(substrait::proto::Type {
-            kind: Some(r#type::Kind::String(r#type::String {
-                type_variation_reference: LARGE_CONTAINER_TYPE_REF,
-                nullability: default_nullability,
-            })),
-        })),
-        ScalarValue::Decimal128(None, p, s) => {
-            Ok(LiteralType::Null(substrait::proto::Type {
-                kind: Some(r#type::Kind::Decimal(r#type::Decimal {
-                    scale: *s as i32,
-                    precision: *p as i32,
-                    type_variation_reference: DEFAULT_TYPE_REF,
-                    nullability: default_nullability,
-                })),
-            }))
-        }
-        ScalarValue::List(l) => 
Ok(LiteralType::Null(to_substrait_type(l.data_type())?)),
-        ScalarValue::LargeList(l) => {
-            Ok(LiteralType::Null(to_substrait_type(l.data_type())?))
-        }
-        ScalarValue::Struct(s) => {
-            Ok(LiteralType::Null(to_substrait_type(s.data_type())?))
-        }
-        // TODO: Extend support for remaining data types
-        _ => not_impl_err!("Unsupported literal: {v:?}"),
-    }
-}
-
 /// Try to convert an [Expr] to a [FieldReference].
 /// Returns `Err` if the [Expr] is not a [Expr::Column].
 fn try_to_substrait_field_reference(
@@ -2141,8 +1991,8 @@ mod test {
             ),
         )))?;
 
-        let c0 = Field::new("c0", DataType::Boolean, true);
-        let c1 = Field::new("c1", DataType::Int32, true);
+        let c0 = Field::new("c0", DataType::Boolean, false);
+        let c1 = Field::new("c1", DataType::Int32, false);
         let c2 = Field::new("c2", DataType::Utf8, true);
         round_trip_literal(
             ScalarStructBuilder::new()
@@ -2190,16 +2040,20 @@ mod test {
         round_trip_type(DataType::LargeUtf8)?;
         round_trip_type(DataType::Decimal128(10, 2))?;
         round_trip_type(DataType::Decimal256(30, 2))?;
-        round_trip_type(DataType::List(
-            Field::new_list_field(DataType::Int32, true).into(),
-        ))?;
-        round_trip_type(DataType::LargeList(
-            Field::new_list_field(DataType::Int32, true).into(),
-        ))?;
+
+        for nullable in [true, false] {
+            round_trip_type(DataType::List(
+                Field::new_list_field(DataType::Int32, nullable).into(),
+            ))?;
+            round_trip_type(DataType::LargeList(
+                Field::new_list_field(DataType::Int32, nullable).into(),
+            ))?;
+        }
+
         round_trip_type(DataType::Struct(
             vec![
                 Field::new("c0", DataType::Int32, true),
-                Field::new("c1", DataType::Utf8, true),
+                Field::new("c1", DataType::Utf8, false),
             ]
             .into(),
         ))?;
@@ -2210,7 +2064,9 @@ mod test {
     fn round_trip_type(dt: DataType) -> Result<()> {
         println!("Checking round trip of {dt:?}");
 
-        let substrait = to_substrait_type(&dt)?;
+        // As DataFusion doesn't consider nullability as a property of the 
type, but field,
+        // it doesn't matter if we set nullability to true or false here.
+        let substrait = to_substrait_type(&dt, true)?;
         let roundtrip_dt = from_substrait_type(&substrait)?;
         assert_eq!(dt, roundtrip_dt);
         Ok(())


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to