This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 77311a5896 support Decimal256 type in datafusion-proto (#11606)
77311a5896 is described below
commit 77311a5896272c7ed252d8cd53d48ec6ea7c0ccf
Author: Leonardo Yvens <[email protected]>
AuthorDate: Tue Jul 23 11:18:00 2024 +0100
support Decimal256 type in datafusion-proto (#11606)
---
.../proto-common/proto/datafusion_common.proto | 7 ++
datafusion/proto-common/src/from_proto/mod.rs | 4 +
datafusion/proto-common/src/generated/pbjson.rs | 125 +++++++++++++++++++++
datafusion/proto-common/src/generated/prost.rs | 12 +-
datafusion/proto-common/src/to_proto/mod.rs | 7 +-
.../proto/src/generated/datafusion_proto_common.rs | 12 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 2 +
7 files changed, 164 insertions(+), 5 deletions(-)
diff --git a/datafusion/proto-common/proto/datafusion_common.proto
b/datafusion/proto-common/proto/datafusion_common.proto
index ca95136dad..8e8fd2352c 100644
--- a/datafusion/proto-common/proto/datafusion_common.proto
+++ b/datafusion/proto-common/proto/datafusion_common.proto
@@ -130,6 +130,12 @@ message Decimal{
int32 scale = 4;
}
+message Decimal256Type{
+ reserved 1, 2;
+ uint32 precision = 3;
+ int32 scale = 4;
+}
+
message List{
Field field_type = 1;
}
@@ -335,6 +341,7 @@ message ArrowType{
TimeUnit TIME64 = 22 ;
IntervalUnit INTERVAL = 23 ;
Decimal DECIMAL = 24 ;
+ Decimal256Type DECIMAL256 = 36;
List LIST = 25;
List LARGE_LIST = 26;
FixedSizeList FIXED_SIZE_LIST = 27;
diff --git a/datafusion/proto-common/src/from_proto/mod.rs
b/datafusion/proto-common/src/from_proto/mod.rs
index 9191ff185a..5fe9d937f7 100644
--- a/datafusion/proto-common/src/from_proto/mod.rs
+++ b/datafusion/proto-common/src/from_proto/mod.rs
@@ -260,6 +260,10 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for
DataType {
precision,
scale,
}) => DataType::Decimal128(*precision as u8, *scale as i8),
+ arrow_type::ArrowTypeEnum::Decimal256(protobuf::Decimal256Type {
+ precision,
+ scale,
+ }) => DataType::Decimal256(*precision as u8, *scale as i8),
arrow_type::ArrowTypeEnum::List(list) => {
let list_type =
list.as_ref().field_type.as_deref().required("field_type")?;
diff --git a/datafusion/proto-common/src/generated/pbjson.rs
b/datafusion/proto-common/src/generated/pbjson.rs
index 4b34660ae2..511072f3cb 100644
--- a/datafusion/proto-common/src/generated/pbjson.rs
+++ b/datafusion/proto-common/src/generated/pbjson.rs
@@ -175,6 +175,9 @@ impl serde::Serialize for ArrowType {
arrow_type::ArrowTypeEnum::Decimal(v) => {
struct_ser.serialize_field("DECIMAL", v)?;
}
+ arrow_type::ArrowTypeEnum::Decimal256(v) => {
+ struct_ser.serialize_field("DECIMAL256", v)?;
+ }
arrow_type::ArrowTypeEnum::List(v) => {
struct_ser.serialize_field("LIST", v)?;
}
@@ -241,6 +244,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
"TIME64",
"INTERVAL",
"DECIMAL",
+ "DECIMAL256",
"LIST",
"LARGE_LIST",
"LARGELIST",
@@ -282,6 +286,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
Time64,
Interval,
Decimal,
+ Decimal256,
List,
LargeList,
FixedSizeList,
@@ -338,6 +343,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
"TIME64" => Ok(GeneratedField::Time64),
"INTERVAL" => Ok(GeneratedField::Interval),
"DECIMAL" => Ok(GeneratedField::Decimal),
+ "DECIMAL256" => Ok(GeneratedField::Decimal256),
"LIST" => Ok(GeneratedField::List),
"LARGELIST" | "LARGE_LIST" =>
Ok(GeneratedField::LargeList),
"FIXEDSIZELIST" | "FIXED_SIZE_LIST" =>
Ok(GeneratedField::FixedSizeList),
@@ -556,6 +562,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
return
Err(serde::de::Error::duplicate_field("DECIMAL"));
}
arrow_type_enum__ =
map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal)
+;
+ }
+ GeneratedField::Decimal256 => {
+ if arrow_type_enum__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("DECIMAL256"));
+ }
+ arrow_type_enum__ =
map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal256)
;
}
GeneratedField::List => {
@@ -2849,6 +2862,118 @@ impl<'de> serde::Deserialize<'de> for Decimal256 {
deserializer.deserialize_struct("datafusion_common.Decimal256",
FIELDS, GeneratedVisitor)
}
}
+impl serde::Serialize for Decimal256Type {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if self.precision != 0 {
+ len += 1;
+ }
+ if self.scale != 0 {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion_common.Decimal256Type", len)?;
+ if self.precision != 0 {
+ struct_ser.serialize_field("precision", &self.precision)?;
+ }
+ if self.scale != 0 {
+ struct_ser.serialize_field("scale", &self.scale)?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for Decimal256Type {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "precision",
+ "scale",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ Precision,
+ Scale,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "precision" => Ok(GeneratedField::Precision),
+ "scale" => Ok(GeneratedField::Scale),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = Decimal256Type;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion_common.Decimal256Type")
+ }
+
+ fn visit_map<V>(self, mut map_: V) ->
std::result::Result<Decimal256Type, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut precision__ = None;
+ let mut scale__ = None;
+ while let Some(k) = map_.next_key()? {
+ match k {
+ GeneratedField::Precision => {
+ if precision__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("precision"));
+ }
+ precision__ =
+
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+ ;
+ }
+ GeneratedField::Scale => {
+ if scale__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("scale"));
+ }
+ scale__ =
+
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+ ;
+ }
+ }
+ }
+ Ok(Decimal256Type {
+ precision: precision__.unwrap_or_default(),
+ scale: scale__.unwrap_or_default(),
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion_common.Decimal256Type",
FIELDS, GeneratedVisitor)
+ }
+}
impl serde::Serialize for DfField {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
diff --git a/datafusion/proto-common/src/generated/prost.rs
b/datafusion/proto-common/src/generated/prost.rs
index 9a2770997f..62919e218b 100644
--- a/datafusion/proto-common/src/generated/prost.rs
+++ b/datafusion/proto-common/src/generated/prost.rs
@@ -140,6 +140,14 @@ pub struct Decimal {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct Decimal256Type {
+ #[prost(uint32, tag = "3")]
+ pub precision: u32,
+ #[prost(int32, tag = "4")]
+ pub scale: i32,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
pub struct List {
#[prost(message, optional, boxed, tag = "1")]
pub field_type: ::core::option::Option<::prost::alloc::boxed::Box<Field>>,
@@ -446,7 +454,7 @@ pub struct Decimal256 {
pub struct ArrowType {
#[prost(
oneof = "arrow_type::ArrowTypeEnum",
- tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34,
16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33"
+ tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34,
16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33"
)]
pub arrow_type_enum: ::core::option::Option<arrow_type::ArrowTypeEnum>,
}
@@ -516,6 +524,8 @@ pub mod arrow_type {
Interval(i32),
#[prost(message, tag = "24")]
Decimal(super::Decimal),
+ #[prost(message, tag = "36")]
+ Decimal256(super::Decimal256Type),
#[prost(message, tag = "25")]
List(::prost::alloc::boxed::Box<super::List>),
#[prost(message, tag = "26")]
diff --git a/datafusion/proto-common/src/to_proto/mod.rs
b/datafusion/proto-common/src/to_proto/mod.rs
index 9dcb65444a..c15da2895b 100644
--- a/datafusion/proto-common/src/to_proto/mod.rs
+++ b/datafusion/proto-common/src/to_proto/mod.rs
@@ -191,9 +191,10 @@ impl TryFrom<&DataType> for
protobuf::arrow_type::ArrowTypeEnum {
precision: *precision as u32,
scale: *scale as i32,
}),
- DataType::Decimal256(_, _) => {
- return Err(Error::General("Proto serialization error: The
Decimal256 data type is not yet supported".to_owned()))
- }
+ DataType::Decimal256(precision, scale) =>
Self::Decimal256(protobuf::Decimal256Type {
+ precision: *precision as u32,
+ scale: *scale as i32,
+ }),
DataType::Map(field, sorted) => {
Self::Map(Box::new(
protobuf::Map {
diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs
b/datafusion/proto/src/generated/datafusion_proto_common.rs
index 9a2770997f..62919e218b 100644
--- a/datafusion/proto/src/generated/datafusion_proto_common.rs
+++ b/datafusion/proto/src/generated/datafusion_proto_common.rs
@@ -140,6 +140,14 @@ pub struct Decimal {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct Decimal256Type {
+ #[prost(uint32, tag = "3")]
+ pub precision: u32,
+ #[prost(int32, tag = "4")]
+ pub scale: i32,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
pub struct List {
#[prost(message, optional, boxed, tag = "1")]
pub field_type: ::core::option::Option<::prost::alloc::boxed::Box<Field>>,
@@ -446,7 +454,7 @@ pub struct Decimal256 {
pub struct ArrowType {
#[prost(
oneof = "arrow_type::ArrowTypeEnum",
- tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34,
16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33"
+ tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34,
16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33"
)]
pub arrow_type_enum: ::core::option::Option<arrow_type::ArrowTypeEnum>,
}
@@ -516,6 +524,8 @@ pub mod arrow_type {
Interval(i32),
#[prost(message, tag = "24")]
Decimal(super::Decimal),
+ #[prost(message, tag = "36")]
+ Decimal256(super::Decimal256Type),
#[prost(message, tag = "25")]
List(::prost::alloc::boxed::Box<super::List>),
#[prost(message, tag = "26")]
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 3476d5d042..f6557c7b2d 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -27,6 +27,7 @@ use arrow::array::{
use arrow::datatypes::{
DataType, Field, Fields, Int32Type, IntervalDayTimeType,
IntervalMonthDayNanoType,
IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
+ DECIMAL256_MAX_PRECISION,
};
use prost::Message;
@@ -1379,6 +1380,7 @@ fn round_trip_datatype() {
DataType::Utf8,
DataType::LargeUtf8,
DataType::Decimal128(7, 12),
+ DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0),
// Recursive list tests
DataType::List(new_arc_field("Level1", DataType::Binary, true)),
DataType::List(new_arc_field(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]