This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 71ff27c05a feat: preserve metadata for Field and Schema in proto 
(#6865)
71ff27c05a is described below

commit 71ff27c05ac0b2492a2e4618dbe2288020ee8fea
Author: Jonah Gao <[email protected]>
AuthorDate: Fri Jul 7 04:34:17 2023 +0800

    feat: preserve metadata for Field and Schema in proto (#6865)
---
 datafusion/proto/proto/datafusion.proto         |  2 ++
 datafusion/proto/src/generated/pbjson.rs        | 38 +++++++++++++++++++++++++
 datafusion/proto/src/generated/prost.rs         | 10 +++++++
 datafusion/proto/src/logical_plan/from_proto.rs | 18 +++---------
 datafusion/proto/src/logical_plan/mod.rs        | 31 ++++++++++++++++++++
 datafusion/proto/src/logical_plan/to_proto.rs   |  3 ++
 6 files changed, 88 insertions(+), 14 deletions(-)

diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 00fa28906c..528c675570 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -740,6 +740,7 @@ message WindowFrameBound {
 
 message Schema {
   repeated Field columns = 1;
+  map<string, string> metadata = 2;
 }
 
 message Field {
@@ -749,6 +750,7 @@ message Field {
   bool nullable = 3;
   // for complex data types like structs, unions
   repeated Field children = 4;
+  map<string, string> metadata = 5;
 }
 
 message FixedSizeBinary{
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 63303fc32c..d6a770159b 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -6049,6 +6049,9 @@ impl serde::Serialize for Field {
         if !self.children.is_empty() {
             len += 1;
         }
+        if !self.metadata.is_empty() {
+            len += 1;
+        }
         let mut struct_ser = serializer.serialize_struct("datafusion.Field", 
len)?;
         if !self.name.is_empty() {
             struct_ser.serialize_field("name", &self.name)?;
@@ -6062,6 +6065,9 @@ impl serde::Serialize for Field {
         if !self.children.is_empty() {
             struct_ser.serialize_field("children", &self.children)?;
         }
+        if !self.metadata.is_empty() {
+            struct_ser.serialize_field("metadata", &self.metadata)?;
+        }
         struct_ser.end()
     }
 }
@@ -6077,6 +6083,7 @@ impl<'de> serde::Deserialize<'de> for Field {
             "arrowType",
             "nullable",
             "children",
+            "metadata",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -6085,6 +6092,7 @@ impl<'de> serde::Deserialize<'de> for Field {
             ArrowType,
             Nullable,
             Children,
+            Metadata,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -6110,6 +6118,7 @@ impl<'de> serde::Deserialize<'de> for Field {
                             "arrowType" | "arrow_type" => 
Ok(GeneratedField::ArrowType),
                             "nullable" => Ok(GeneratedField::Nullable),
                             "children" => Ok(GeneratedField::Children),
+                            "metadata" => Ok(GeneratedField::Metadata),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -6133,6 +6142,7 @@ impl<'de> serde::Deserialize<'de> for Field {
                 let mut arrow_type__ = None;
                 let mut nullable__ = None;
                 let mut children__ = None;
+                let mut metadata__ = None;
                 while let Some(k) = map.next_key()? {
                     match k {
                         GeneratedField::Name => {
@@ -6159,6 +6169,14 @@ impl<'de> serde::Deserialize<'de> for Field {
                             }
                             children__ = Some(map.next_value()?);
                         }
+                        GeneratedField::Metadata => {
+                            if metadata__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("metadata"));
+                            }
+                            metadata__ = Some(
+                                map.next_value::<std::collections::HashMap<_, 
_>>()?
+                            );
+                        }
                     }
                 }
                 Ok(Field {
@@ -6166,6 +6184,7 @@ impl<'de> serde::Deserialize<'de> for Field {
                     arrow_type: arrow_type__,
                     nullable: nullable__.unwrap_or_default(),
                     children: children__.unwrap_or_default(),
+                    metadata: metadata__.unwrap_or_default(),
                 })
             }
         }
@@ -19493,10 +19512,16 @@ impl serde::Serialize for Schema {
         if !self.columns.is_empty() {
             len += 1;
         }
+        if !self.metadata.is_empty() {
+            len += 1;
+        }
         let mut struct_ser = serializer.serialize_struct("datafusion.Schema", 
len)?;
         if !self.columns.is_empty() {
             struct_ser.serialize_field("columns", &self.columns)?;
         }
+        if !self.metadata.is_empty() {
+            struct_ser.serialize_field("metadata", &self.metadata)?;
+        }
         struct_ser.end()
     }
 }
@@ -19508,11 +19533,13 @@ impl<'de> serde::Deserialize<'de> for Schema {
     {
         const FIELDS: &[&str] = &[
             "columns",
+            "metadata",
         ];
 
         #[allow(clippy::enum_variant_names)]
         enum GeneratedField {
             Columns,
+            Metadata,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -19535,6 +19562,7 @@ impl<'de> serde::Deserialize<'de> for Schema {
                     {
                         match value {
                             "columns" => Ok(GeneratedField::Columns),
+                            "metadata" => Ok(GeneratedField::Metadata),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -19555,6 +19583,7 @@ impl<'de> serde::Deserialize<'de> for Schema {
                     V: serde::de::MapAccess<'de>,
             {
                 let mut columns__ = None;
+                let mut metadata__ = None;
                 while let Some(k) = map.next_key()? {
                     match k {
                         GeneratedField::Columns => {
@@ -19563,10 +19592,19 @@ impl<'de> serde::Deserialize<'de> for Schema {
                             }
                             columns__ = Some(map.next_value()?);
                         }
+                        GeneratedField::Metadata => {
+                            if metadata__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("metadata"));
+                            }
+                            metadata__ = Some(
+                                map.next_value::<std::collections::HashMap<_, 
_>>()?
+                            );
+                        }
                     }
                 }
                 Ok(Schema {
                     columns: columns__.unwrap_or_default(),
+                    metadata: metadata__.unwrap_or_default(),
                 })
             }
         }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 00eea4d6ed..4e91fbab19 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -893,6 +893,11 @@ pub struct WindowFrameBound {
 pub struct Schema {
     #[prost(message, repeated, tag = "1")]
     pub columns: ::prost::alloc::vec::Vec<Field>,
+    #[prost(map = "string, string", tag = "2")]
+    pub metadata: ::std::collections::HashMap<
+        ::prost::alloc::string::String,
+        ::prost::alloc::string::String,
+    >,
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
@@ -907,6 +912,11 @@ pub struct Field {
     /// for complex data types like structs, unions
     #[prost(message, repeated, tag = "4")]
     pub children: ::prost::alloc::vec::Vec<Field>,
+    #[prost(map = "string, string", tag = "5")]
+    pub metadata: ::std::collections::HashMap<
+        ::prost::alloc::string::String,
+        ::prost::alloc::string::String,
+    >,
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 4e2f59a118..1b48364ad4 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -365,8 +365,8 @@ impl TryFrom<&protobuf::Field> for Field {
     type Error = Error;
     fn try_from(field: &protobuf::Field) -> Result<Self, Self::Error> {
         let datatype = field.arrow_type.as_deref().required("arrow_type")?;
-
-        Ok(Self::new(field.name.as_str(), datatype, field.nullable))
+        Ok(Self::new(field.name.as_str(), datatype, field.nullable)
+            .with_metadata(field.metadata.clone()))
     }
 }
 
@@ -581,19 +581,9 @@ impl TryFrom<&protobuf::Schema> for Schema {
         let fields = schema
             .columns
             .iter()
-            .map(|c| {
-                let pb_arrow_type_res = c
-                    .arrow_type
-                    .as_ref()
-                    .ok_or_else(|| proto_error("Protobuf deserialization 
error: Field message was missing required field 'arrow_type'"));
-                let pb_arrow_type: &protobuf::ArrowType = match 
pb_arrow_type_res {
-                    Ok(res) => res,
-                    Err(e) => return Err(e),
-                };
-                Ok(Field::new(&c.name, pb_arrow_type.try_into()?, c.nullable))
-            })
+            .map(Field::try_from)
             .collect::<Result<Vec<_>, _>>()?;
-        Ok(Self::new(fields))
+        Ok(Self::new_with_metadata(fields, schema.metadata.clone()))
     }
 }
 
diff --git a/datafusion/proto/src/logical_plan/mod.rs 
b/datafusion/proto/src/logical_plan/mod.rs
index 7d0ddac484..ea293067b7 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -2341,6 +2341,37 @@ mod roundtrip_tests {
         }
     }
 
+    #[test]
+    fn roundtrip_field() {
+        let field =
+            Field::new("f", DataType::Int32, 
true).with_metadata(HashMap::from([
+                (String::from("k1"), String::from("v1")),
+                (String::from("k2"), String::from("v2")),
+            ]));
+        let proto_field: super::protobuf::Field = (&field).try_into().unwrap();
+        let returned_field: Field = (&proto_field).try_into().unwrap();
+        assert_eq!(field, returned_field);
+    }
+
+    #[test]
+    fn roundtrip_schema() {
+        let schema = Schema::new_with_metadata(
+            vec![
+                Field::new("a", DataType::Int64, false),
+                Field::new("b", DataType::Decimal128(15, 2), 
true).with_metadata(
+                    HashMap::from([(String::from("k1"), String::from("v1"))]),
+                ),
+            ],
+            HashMap::from([
+                (String::from("k2"), String::from("v2")),
+                (String::from("k3"), String::from("v3")),
+            ]),
+        );
+        let proto_schema: super::protobuf::Schema = 
(&schema).try_into().unwrap();
+        let returned_schema: Schema = (&proto_schema).try_into().unwrap();
+        assert_eq!(schema, returned_schema);
+    }
+
     #[test]
     fn roundtrip_not() {
         let test_expr = Expr::Not(Box::new(lit(1.0_f32)));
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index 4a4b16db80..8665ca00c3 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -117,6 +117,7 @@ impl TryFrom<&Field> for protobuf::Field {
             arrow_type: Some(Box::new(arrow_type)),
             nullable: field.is_nullable(),
             children: Vec::new(),
+            metadata: field.metadata().clone(),
         })
     }
 }
@@ -266,6 +267,7 @@ impl TryFrom<&Schema> for protobuf::Schema {
                 .iter()
                 .map(|f| f.as_ref().try_into())
                 .collect::<Result<Vec<_>, Error>>()?,
+            metadata: schema.metadata.clone(),
         })
     }
 }
@@ -280,6 +282,7 @@ impl TryFrom<SchemaRef> for protobuf::Schema {
                 .iter()
                 .map(|f| f.as_ref().try_into())
                 .collect::<Result<Vec<_>, Error>>()?,
+            metadata: schema.metadata.clone(),
         })
     }
 }

Reply via email to