This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 77311a5896 support Decimal256 type in datafusion-proto (#11606)
77311a5896 is described below

commit 77311a5896272c7ed252d8cd53d48ec6ea7c0ccf
Author: Leonardo Yvens <[email protected]>
AuthorDate: Tue Jul 23 11:18:00 2024 +0100

    support Decimal256 type in datafusion-proto (#11606)
---
 .../proto-common/proto/datafusion_common.proto     |   7 ++
 datafusion/proto-common/src/from_proto/mod.rs      |   4 +
 datafusion/proto-common/src/generated/pbjson.rs    | 125 +++++++++++++++++++++
 datafusion/proto-common/src/generated/prost.rs     |  12 +-
 datafusion/proto-common/src/to_proto/mod.rs        |   7 +-
 .../proto/src/generated/datafusion_proto_common.rs |  12 +-
 .../proto/tests/cases/roundtrip_logical_plan.rs    |   2 +
 7 files changed, 164 insertions(+), 5 deletions(-)

diff --git a/datafusion/proto-common/proto/datafusion_common.proto 
b/datafusion/proto-common/proto/datafusion_common.proto
index ca95136dad..8e8fd2352c 100644
--- a/datafusion/proto-common/proto/datafusion_common.proto
+++ b/datafusion/proto-common/proto/datafusion_common.proto
@@ -130,6 +130,12 @@ message Decimal{
   int32 scale = 4;
 }
 
+message Decimal256Type{
+  reserved 1, 2;
+  uint32 precision = 3;
+  int32 scale = 4;
+}
+
 message List{
   Field field_type = 1;
 }
@@ -335,6 +341,7 @@ message ArrowType{
     TimeUnit TIME64 = 22 ;
     IntervalUnit INTERVAL = 23 ;
     Decimal DECIMAL = 24 ;
+    Decimal256Type DECIMAL256 = 36;
     List LIST = 25;
     List LARGE_LIST = 26;
     FixedSizeList FIXED_SIZE_LIST = 27;
diff --git a/datafusion/proto-common/src/from_proto/mod.rs 
b/datafusion/proto-common/src/from_proto/mod.rs
index 9191ff185a..5fe9d937f7 100644
--- a/datafusion/proto-common/src/from_proto/mod.rs
+++ b/datafusion/proto-common/src/from_proto/mod.rs
@@ -260,6 +260,10 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for 
DataType {
                 precision,
                 scale,
             }) => DataType::Decimal128(*precision as u8, *scale as i8),
+            arrow_type::ArrowTypeEnum::Decimal256(protobuf::Decimal256Type {
+                precision,
+                scale,
+            }) => DataType::Decimal256(*precision as u8, *scale as i8),
             arrow_type::ArrowTypeEnum::List(list) => {
                 let list_type =
                     
list.as_ref().field_type.as_deref().required("field_type")?;
diff --git a/datafusion/proto-common/src/generated/pbjson.rs 
b/datafusion/proto-common/src/generated/pbjson.rs
index 4b34660ae2..511072f3cb 100644
--- a/datafusion/proto-common/src/generated/pbjson.rs
+++ b/datafusion/proto-common/src/generated/pbjson.rs
@@ -175,6 +175,9 @@ impl serde::Serialize for ArrowType {
                 arrow_type::ArrowTypeEnum::Decimal(v) => {
                     struct_ser.serialize_field("DECIMAL", v)?;
                 }
+                arrow_type::ArrowTypeEnum::Decimal256(v) => {
+                    struct_ser.serialize_field("DECIMAL256", v)?;
+                }
                 arrow_type::ArrowTypeEnum::List(v) => {
                     struct_ser.serialize_field("LIST", v)?;
                 }
@@ -241,6 +244,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
             "TIME64",
             "INTERVAL",
             "DECIMAL",
+            "DECIMAL256",
             "LIST",
             "LARGE_LIST",
             "LARGELIST",
@@ -282,6 +286,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
             Time64,
             Interval,
             Decimal,
+            Decimal256,
             List,
             LargeList,
             FixedSizeList,
@@ -338,6 +343,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
                             "TIME64" => Ok(GeneratedField::Time64),
                             "INTERVAL" => Ok(GeneratedField::Interval),
                             "DECIMAL" => Ok(GeneratedField::Decimal),
+                            "DECIMAL256" => Ok(GeneratedField::Decimal256),
                             "LIST" => Ok(GeneratedField::List),
                             "LARGELIST" | "LARGE_LIST" => 
Ok(GeneratedField::LargeList),
                             "FIXEDSIZELIST" | "FIXED_SIZE_LIST" => 
Ok(GeneratedField::FixedSizeList),
@@ -556,6 +562,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
                                 return 
Err(serde::de::Error::duplicate_field("DECIMAL"));
                             }
                             arrow_type_enum__ = 
map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal)
+;
+                        }
+                        GeneratedField::Decimal256 => {
+                            if arrow_type_enum__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("DECIMAL256"));
+                            }
+                            arrow_type_enum__ = 
map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal256)
 ;
                         }
                         GeneratedField::List => {
@@ -2849,6 +2862,118 @@ impl<'de> serde::Deserialize<'de> for Decimal256 {
         deserializer.deserialize_struct("datafusion_common.Decimal256", 
FIELDS, GeneratedVisitor)
     }
 }
+impl serde::Serialize for Decimal256Type {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        use serde::ser::SerializeStruct;
+        let mut len = 0;
+        if self.precision != 0 {
+            len += 1;
+        }
+        if self.scale != 0 {
+            len += 1;
+        }
+        let mut struct_ser = 
serializer.serialize_struct("datafusion_common.Decimal256Type", len)?;
+        if self.precision != 0 {
+            struct_ser.serialize_field("precision", &self.precision)?;
+        }
+        if self.scale != 0 {
+            struct_ser.serialize_field("scale", &self.scale)?;
+        }
+        struct_ser.end()
+    }
+}
+impl<'de> serde::Deserialize<'de> for Decimal256Type {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "precision",
+            "scale",
+        ];
+
+        #[allow(clippy::enum_variant_names)]
+        enum GeneratedField {
+            Precision,
+            Scale,
+        }
+        impl<'de> serde::Deserialize<'de> for GeneratedField {
+            fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct GeneratedVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+                    type Value = GeneratedField;
+
+                    fn expecting(&self, formatter: &mut 
std::fmt::Formatter<'_>) -> std::fmt::Result {
+                        write!(formatter, "expected one of: {:?}", &FIELDS)
+                    }
+
+                    #[allow(unused_variables)]
+                    fn visit_str<E>(self, value: &str) -> 
std::result::Result<GeneratedField, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                        match value {
+                            "precision" => Ok(GeneratedField::Precision),
+                            "scale" => Ok(GeneratedField::Scale),
+                            _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
+                        }
+                    }
+                }
+                deserializer.deserialize_identifier(GeneratedVisitor)
+            }
+        }
+        struct GeneratedVisitor;
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = Decimal256Type;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                formatter.write_str("struct datafusion_common.Decimal256Type")
+            }
+
+            fn visit_map<V>(self, mut map_: V) -> 
std::result::Result<Decimal256Type, V::Error>
+                where
+                    V: serde::de::MapAccess<'de>,
+            {
+                let mut precision__ = None;
+                let mut scale__ = None;
+                while let Some(k) = map_.next_key()? {
+                    match k {
+                        GeneratedField::Precision => {
+                            if precision__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("precision"));
+                            }
+                            precision__ = 
+                                
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+                            ;
+                        }
+                        GeneratedField::Scale => {
+                            if scale__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("scale"));
+                            }
+                            scale__ = 
+                                
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+                            ;
+                        }
+                    }
+                }
+                Ok(Decimal256Type {
+                    precision: precision__.unwrap_or_default(),
+                    scale: scale__.unwrap_or_default(),
+                })
+            }
+        }
+        deserializer.deserialize_struct("datafusion_common.Decimal256Type", 
FIELDS, GeneratedVisitor)
+    }
+}
 impl serde::Serialize for DfField {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
diff --git a/datafusion/proto-common/src/generated/prost.rs 
b/datafusion/proto-common/src/generated/prost.rs
index 9a2770997f..62919e218b 100644
--- a/datafusion/proto-common/src/generated/prost.rs
+++ b/datafusion/proto-common/src/generated/prost.rs
@@ -140,6 +140,14 @@ pub struct Decimal {
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
+pub struct Decimal256Type {
+    #[prost(uint32, tag = "3")]
+    pub precision: u32,
+    #[prost(int32, tag = "4")]
+    pub scale: i32,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
 pub struct List {
     #[prost(message, optional, boxed, tag = "1")]
     pub field_type: ::core::option::Option<::prost::alloc::boxed::Box<Field>>,
@@ -446,7 +454,7 @@ pub struct Decimal256 {
 pub struct ArrowType {
     #[prost(
         oneof = "arrow_type::ArrowTypeEnum",
-        tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 
16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33"
+        tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 
16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33"
     )]
     pub arrow_type_enum: ::core::option::Option<arrow_type::ArrowTypeEnum>,
 }
@@ -516,6 +524,8 @@ pub mod arrow_type {
         Interval(i32),
         #[prost(message, tag = "24")]
         Decimal(super::Decimal),
+        #[prost(message, tag = "36")]
+        Decimal256(super::Decimal256Type),
         #[prost(message, tag = "25")]
         List(::prost::alloc::boxed::Box<super::List>),
         #[prost(message, tag = "26")]
diff --git a/datafusion/proto-common/src/to_proto/mod.rs 
b/datafusion/proto-common/src/to_proto/mod.rs
index 9dcb65444a..c15da2895b 100644
--- a/datafusion/proto-common/src/to_proto/mod.rs
+++ b/datafusion/proto-common/src/to_proto/mod.rs
@@ -191,9 +191,10 @@ impl TryFrom<&DataType> for 
protobuf::arrow_type::ArrowTypeEnum {
                 precision: *precision as u32,
                 scale: *scale as i32,
             }),
-            DataType::Decimal256(_, _) => {
-                return Err(Error::General("Proto serialization error: The 
Decimal256 data type is not yet supported".to_owned()))
-            }
+            DataType::Decimal256(precision, scale) => 
Self::Decimal256(protobuf::Decimal256Type {
+                precision: *precision as u32,
+                scale: *scale as i32,
+            }),
             DataType::Map(field, sorted) => {
                 Self::Map(Box::new(
                     protobuf::Map {
diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs 
b/datafusion/proto/src/generated/datafusion_proto_common.rs
index 9a2770997f..62919e218b 100644
--- a/datafusion/proto/src/generated/datafusion_proto_common.rs
+++ b/datafusion/proto/src/generated/datafusion_proto_common.rs
@@ -140,6 +140,14 @@ pub struct Decimal {
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
+pub struct Decimal256Type {
+    #[prost(uint32, tag = "3")]
+    pub precision: u32,
+    #[prost(int32, tag = "4")]
+    pub scale: i32,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
 pub struct List {
     #[prost(message, optional, boxed, tag = "1")]
     pub field_type: ::core::option::Option<::prost::alloc::boxed::Box<Field>>,
@@ -446,7 +454,7 @@ pub struct Decimal256 {
 pub struct ArrowType {
     #[prost(
         oneof = "arrow_type::ArrowTypeEnum",
-        tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 
16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33"
+        tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 
16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33"
     )]
     pub arrow_type_enum: ::core::option::Option<arrow_type::ArrowTypeEnum>,
 }
@@ -516,6 +524,8 @@ pub mod arrow_type {
         Interval(i32),
         #[prost(message, tag = "24")]
         Decimal(super::Decimal),
+        #[prost(message, tag = "36")]
+        Decimal256(super::Decimal256Type),
         #[prost(message, tag = "25")]
         List(::prost::alloc::boxed::Box<super::List>),
         #[prost(message, tag = "26")]
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 3476d5d042..f6557c7b2d 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -27,6 +27,7 @@ use arrow::array::{
 use arrow::datatypes::{
     DataType, Field, Fields, Int32Type, IntervalDayTimeType, 
IntervalMonthDayNanoType,
     IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
+    DECIMAL256_MAX_PRECISION,
 };
 use prost::Message;
 
@@ -1379,6 +1380,7 @@ fn round_trip_datatype() {
         DataType::Utf8,
         DataType::LargeUtf8,
         DataType::Decimal128(7, 12),
+        DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0),
         // Recursive list tests
         DataType::List(new_arc_field("Level1", DataType::Binary, true)),
         DataType::List(new_arc_field(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to