This is an automated email from the ASF dual-hosted git repository.

etseidl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new edfb9aba45 Use Thrift macro to generate Parquet `LogicalType` 
serialization code (#9997)
edfb9aba45 is described below

commit edfb9aba45e3fba6fdc52d525c8d4b4132a0d857
Author: Ed Seidl <[email protected]>
AuthorDate: Fri May 22 08:33:47 2026 -0700

    Use Thrift macro to generate Parquet `LogicalType` serialization code 
(#9997)
    
    # Which issue does this PR close?
    
    - Closes #9995.
    
    # Rationale for this change
    See issue. Improve code maintainability by using thrift macro to
    generate `LogicalType` serialization code.
    
    # What changes are included in this PR?
    
    Adds a new macro to generate code for a Thrift `union` that needs to be
    forward compatible. Does this by adding a catchall `_Unknown` variant
    for unknown field ids.
    
    # Are there any user-facing changes?
    Yes this is a breaking API change because the `LogicalType` enum will
    now use tuple variants rather than struct. This also makes public some
    structs that were previously private.
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 parquet/src/arrow/schema/extension.rs |  15 +-
 parquet/src/arrow/schema/primitive.rs |  58 ++---
 parquet/src/basic.rs                  | 415 ++++++++--------------------------
 parquet/src/column/writer/mod.rs      |   9 +-
 parquet/src/parquet_macros.rs         |  80 +++++++
 parquet/src/schema/printer.rs         |  39 ++--
 parquet/src/schema/types.rs           |  39 ++--
 parquet/tests/geospatial.rs           |  23 +-
 8 files changed, 259 insertions(+), 419 deletions(-)

diff --git a/parquet/src/arrow/schema/extension.rs 
b/parquet/src/arrow/schema/extension.rs
index 0244c1b6bb..353770ddbd 100644
--- a/parquet/src/arrow/schema/extension.rs
+++ b/parquet/src/arrow/schema/extension.rs
@@ -48,7 +48,7 @@ pub(crate) fn try_add_extension_type(
     };
     Ok(match parquet_logical_type {
         #[cfg(feature = "variant_experimental")]
-        LogicalType::Variant { .. } => {
+        LogicalType::Variant(_) => {
             let mut arrow_field = arrow_field;
             
arrow_field.try_with_extension_type(parquet_variant_compute::VariantType)?;
             arrow_field
@@ -66,16 +66,19 @@ pub(crate) fn try_add_extension_type(
             arrow_field
         }
         #[cfg(feature = "geospatial")]
-        LogicalType::Geometry { crs } => {
-            let md = parquet_geospatial::WkbMetadata::new(crs.as_deref(), 
None);
+        LogicalType::Geometry(geometry) => {
+            let md = 
parquet_geospatial::WkbMetadata::new(geometry.crs.as_deref(), None);
             let mut arrow_field = arrow_field;
             
arrow_field.try_with_extension_type(parquet_geospatial::WkbType::new(Some(md)))?;
             arrow_field
         }
         #[cfg(feature = "geospatial")]
-        LogicalType::Geography { crs, algorithm } => {
-            let algorithm = algorithm.map(|a| a.try_as_edges()).transpose()?;
-            let md = parquet_geospatial::WkbMetadata::new(crs.as_deref(), 
algorithm);
+        LogicalType::Geography(geography) => {
+            let algorithm = geography
+                .algorithm()
+                .map(|a| a.try_as_edges())
+                .transpose()?;
+            let md = 
parquet_geospatial::WkbMetadata::new(geography.crs.as_deref(), algorithm);
             let mut arrow_field = arrow_field;
             
arrow_field.try_with_extension_type(parquet_geospatial::WkbType::new(Some(md)))?;
             arrow_field
diff --git a/parquet/src/arrow/schema/primitive.rs 
b/parquet/src/arrow/schema/primitive.rs
index b440753cc8..2272014a93 100644
--- a/parquet/src/arrow/schema/primitive.rs
+++ b/parquet/src/arrow/schema/primitive.rs
@@ -15,7 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::basic::{ConvertedType, LogicalType, TimeUnit as ParquetTimeUnit, 
Type as PhysicalType};
+use crate::basic::{
+    ConvertedType, IntType, LogicalType, TimeUnit as ParquetTimeUnit, Type as 
PhysicalType,
+};
 use crate::errors::{ParquetError, Result};
 use crate::schema::types::{BasicTypeInfo, Type};
 use arrow_schema::{DECIMAL128_MAX_PRECISION, DataType, IntervalUnit, TimeUnit};
@@ -171,15 +173,7 @@ fn decimal_256_type(scale: i32, precision: i32) -> 
Result<DataType> {
 fn from_int32(info: &BasicTypeInfo, scale: i32, precision: i32) -> 
Result<DataType> {
     match (info.logical_type_ref(), info.converted_type()) {
         (None, ConvertedType::NONE) => Ok(DataType::Int32),
-        (
-            Some(
-                ref t @ LogicalType::Integer {
-                    bit_width,
-                    is_signed,
-                },
-            ),
-            _,
-        ) => match (bit_width, is_signed) {
+        (Some(ref t @ LogicalType::Integer(int)), _) => match (int.bit_width, 
int.is_signed) {
             (8, true) => Ok(DataType::Int8),
             (16, true) => Ok(DataType::Int16),
             (32, true) => Ok(DataType::Int32),
@@ -188,15 +182,15 @@ fn from_int32(info: &BasicTypeInfo, scale: i32, 
precision: i32) -> Result<DataTy
             (32, false) => Ok(DataType::UInt32),
             _ => Err(arrow_err!("Cannot create INT32 physical type from {:?}", 
t)),
         },
-        (Some(LogicalType::Decimal { scale, precision }), _) => {
-            decimal_128_type(*scale, *precision)
+        (Some(LogicalType::Decimal(decimal)), _) => {
+            decimal_128_type(decimal.scale, decimal.precision)
         }
         (Some(LogicalType::Date), _) => Ok(DataType::Date32),
-        (Some(LogicalType::Time { unit, .. }), _) => match unit {
+        (Some(LogicalType::Time(time)), _) => match time.unit {
             ParquetTimeUnit::MILLIS => 
Ok(DataType::Time32(TimeUnit::Millisecond)),
             _ => Err(arrow_err!(
                 "Cannot create INT32 physical type from {:?}",
-                unit
+                time.unit
             )),
         },
         (None, ConvertedType::UINT_8) => Ok(DataType::UInt8),
@@ -220,35 +214,29 @@ fn from_int64(info: &BasicTypeInfo, scale: i32, 
precision: i32) -> Result<DataTy
     match (info.logical_type_ref(), info.converted_type()) {
         (None, ConvertedType::NONE) => Ok(DataType::Int64),
         (
-            Some(LogicalType::Integer {
+            Some(LogicalType::Integer(IntType {
                 bit_width: 64,
                 is_signed,
-            }),
+            })),
             _,
         ) => match is_signed {
             true => Ok(DataType::Int64),
             false => Ok(DataType::UInt64),
         },
-        (Some(LogicalType::Time { unit, .. }), _) => match unit {
+        (Some(LogicalType::Time(time)), _) => match time.unit {
             ParquetTimeUnit::MILLIS => {
                 Err(arrow_err!("Cannot create INT64 from MILLIS time unit",))
             }
             ParquetTimeUnit::MICROS => 
Ok(DataType::Time64(TimeUnit::Microsecond)),
             ParquetTimeUnit::NANOS => 
Ok(DataType::Time64(TimeUnit::Nanosecond)),
         },
-        (
-            Some(LogicalType::Timestamp {
-                is_adjusted_to_u_t_c,
-                unit,
-            }),
-            _,
-        ) => Ok(DataType::Timestamp(
-            match unit {
+        (Some(LogicalType::Timestamp(timestamp)), _) => Ok(DataType::Timestamp(
+            match timestamp.unit {
                 ParquetTimeUnit::MILLIS => TimeUnit::Millisecond,
                 ParquetTimeUnit::MICROS => TimeUnit::Microsecond,
                 ParquetTimeUnit::NANOS => TimeUnit::Nanosecond,
             },
-            if *is_adjusted_to_u_t_c {
+            if timestamp.is_adjusted_to_u_t_c {
                 Some("UTC".into())
             } else {
                 None
@@ -265,9 +253,7 @@ fn from_int64(info: &BasicTypeInfo, scale: i32, precision: 
i32) -> Result<DataTy
             TimeUnit::Microsecond,
             Some("UTC".into()),
         )),
-        (Some(LogicalType::Decimal { scale, precision }), _) => {
-            decimal_128_type(*scale, *precision)
-        }
+        (Some(LogicalType::Decimal(dec)), _) => decimal_128_type(dec.scale, 
dec.precision),
         (None, ConvertedType::DECIMAL) => decimal_128_type(scale, precision),
         (logical, converted) => Err(arrow_err!(
             "Unable to convert parquet INT64 logical type {:?} or converted 
type {}",
@@ -291,13 +277,7 @@ fn from_byte_array(info: &BasicTypeInfo, precision: i32, 
scale: i32) -> Result<D
         (None, ConvertedType::BSON) => Ok(DataType::Binary),
         (None, ConvertedType::ENUM) => Ok(DataType::Binary),
         (None, ConvertedType::UTF8) => Ok(DataType::Utf8),
-        (
-            Some(LogicalType::Decimal {
-                scale: s,
-                precision: p,
-            }),
-            _,
-        ) => decimal_type(*s, *p),
+        (Some(LogicalType::Decimal(decimal)), _) => 
decimal_type(decimal.scale, decimal.precision),
         (None, ConvertedType::DECIMAL) => decimal_type(scale, precision),
         (logical, converted) => Err(arrow_err!(
             "Unable to convert parquet BYTE_ARRAY logical type {:?} or 
converted type {}",
@@ -315,11 +295,11 @@ fn from_fixed_len_byte_array(
 ) -> Result<DataType> {
     // TODO: This should check the type length for the decimal and interval 
types
     match (info.logical_type_ref(), info.converted_type()) {
-        (Some(LogicalType::Decimal { scale, precision }), _) => {
+        (Some(LogicalType::Decimal(decimal)), _) => {
             if type_length <= 16 {
-                decimal_128_type(*scale, *precision)
+                decimal_128_type(decimal.scale, decimal.precision)
             } else {
-                decimal_256_type(*scale, *precision)
+                decimal_256_type(decimal.scale, decimal.precision)
             }
         }
         (None, ConvertedType::DECIMAL) => {
diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs
index 796779358c..b4f18f3117 100644
--- a/parquet/src/basic.rs
+++ b/parquet/src/basic.rs
@@ -30,7 +30,10 @@ use crate::parquet_thrift::{
     ElementType, FieldType, ReadThrift, ThriftCompactInputProtocol, 
ThriftCompactOutputProtocol,
     WriteThrift, WriteThriftField, validate_list_type,
 };
-use crate::{thrift_enum, thrift_struct, thrift_union_all_empty, 
write_thrift_field};
+use crate::{
+    thrift_enum, thrift_struct, thrift_union_all_empty, 
thrift_union_with_unknown,
+    write_thrift_field,
+};
 
 use crate::errors::{ParquetError, Result};
 
@@ -183,383 +186,166 @@ union TimeUnit {
 // ----------------------------------------------------------------------
 // Mirrors thrift union `LogicalType`
 
-// private structs for decoding logical type
-
 thrift_struct!(
-struct DecimalType {
+pub struct DecimalType {
+  /// The number of digits in the decimal.
   1: required i32 scale
+  /// The location of the decimal point.
   2: required i32 precision
 }
 );
 
 thrift_struct!(
-struct TimestampType {
+pub struct TimestampType {
+  /// Whether the timestamp is adjusted to UTC.
   1: required bool is_adjusted_to_u_t_c
+  /// The unit of time.
   2: required TimeUnit unit
 }
 );
 
-// they are identical
-use TimestampType as TimeType;
+/// Identical to [`TimestampType`]
+pub use TimestampType as TimeType;
 
 thrift_struct!(
-struct IntType {
+pub struct IntType {
+  /// The number of bits in the integer.
   1: required i8 bit_width
+  /// Whether the integer is signed.
   2: required bool is_signed
 }
 );
 
 thrift_struct!(
-struct VariantType {
-  // The version of the variant specification that the variant was
-  // written with.
+pub struct VariantType {
+  /// The version of the variant specification that the variant was
+  /// written with.
   1: optional i8 specification_version
 }
 );
 
 thrift_struct!(
-struct GeometryType<'a> {
-  1: optional string<'a> crs;
+pub struct GeometryType {
+  /// A custom CRS. If unset the CRS `OGC:CRS84` should be used, which means 
that the geometries
+  /// must be stored in longitude, latitude based on the WGS84 datum.
+  1: optional string crs;
 }
 );
 
 thrift_struct!(
-struct GeographyType<'a> {
-  1: optional string<'a> crs;
+pub struct GeographyType {
+  /// A custom CRS. If unset the CRS `OGC:CRS84` should be used.
+  1: optional string crs;
+  /// An optional algorithm can be set to correctly interpret edges 
interpolation
+  /// of the geometries. If unset, the `SPHERICAL` algorithm should be used.
   2: optional EdgeInterpolationAlgorithm algorithm;
 }
 );
 
-// TODO(ets): should we switch to tuple variants so we can use
-// the thrift macros?
+impl GeographyType {
+    /// Accessor for the `GeographyType::algorithm` field. If this field is 
not set, this
+    /// function returns the default value (currently 
[`EdgeInterpolationAlgorithm::SPHERICAL`]
+    /// per the Parquet [specification]).
+    ///
+    /// [specification]: 
https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#geography
+    pub fn algorithm(&self) -> Option<EdgeInterpolationAlgorithm> {
+        self.algorithm.or(Some(Default::default()))
+    }
+}
 
+thrift_union_with_unknown!(
 /// Logical types used by version 2.4.0+ of the Parquet format.
 ///
 /// This is an *entirely new* struct as of version
 /// 4.0.0. The struct previously named `LogicalType` was renamed to
 /// [`ConvertedType`]. Please see the README.md for more details.
-#[derive(Debug, Clone, PartialEq, Eq)]
-pub enum LogicalType {
-    /// A UTF8 encoded string.
-    String,
-    /// A map of key-value pairs.
-    Map,
-    /// A list of elements.
-    List,
-    /// A set of predefined values.
-    Enum,
-    /// A decimal value with a specified scale and precision.
-    Decimal {
-        /// The number of digits in the decimal.
-        scale: i32,
-        /// The location of the decimal point.
-        precision: i32,
-    },
-    /// A date stored as days since Unix epoch.
-    Date,
-    /// A time stored as [`TimeUnit`] since midnight.
-    Time {
-        /// Whether the time is adjusted to UTC.
-        is_adjusted_to_u_t_c: bool,
-        /// The unit of time.
-        unit: TimeUnit,
-    },
-    /// A timestamp stored as [`TimeUnit`] since Unix epoch.
-    Timestamp {
-        /// Whether the timestamp is adjusted to UTC.
-        is_adjusted_to_u_t_c: bool,
-        /// The unit of time.
-        unit: TimeUnit,
-    },
-    /// An integer with a specified bit width and signedness.
-    Integer {
-        /// The number of bits in the integer.
-        bit_width: i8,
-        /// Whether the integer is signed.
-        is_signed: bool,
-    },
-    /// An unknown logical type.
-    Unknown,
-    /// A JSON document.
-    Json,
-    /// A BSON document.
-    Bson,
-    /// A UUID.
-    Uuid,
-    /// A 16-bit floating point number.
-    Float16,
-    /// A Variant value.
-    Variant {
-        /// The version of the variant specification that the variant was 
written with.
-        specification_version: Option<i8>,
-    },
-    /// A geospatial feature in the Well-Known Binary (WKB) format with 
linear/planar edges interpolation.
-    Geometry {
-        /// A custom CRS. If unset the defaults to `OGC:CRS84`, which means 
that the geometries
-        /// must be stored in longitude, latitude based on the WGS84 datum.
-        crs: Option<String>,
-    },
-    /// A geospatial feature in the WKB format with an explicit 
(non-linear/non-planar) edges interpolation.
-    Geography {
-        /// A custom CRS. If unset the defaults to `OGC:CRS84`.
-        crs: Option<String>,
-        /// An optional algorithm can be set to correctly interpret edges 
interpolation
-        /// of the geometries. If unset, the algorithm defaults to `SPHERICAL`.
-        algorithm: Option<EdgeInterpolationAlgorithm>,
-    },
-    /// For forward compatibility; used when an unknown union value is 
encountered.
-    _Unknown {
-        /// The field id encountered when parsing the unknown logical type.
-        field_id: i16,
-    },
+union LogicalType {
+   /// A UTF8 encoded string.
+   1:  String
+   /// A map of key-value pairs.
+   2:  Map
+   /// A list of elements.
+   3:  List
+   /// A set of predefined values.
+   4:  Enum
+   /// A decimal value with a specified scale and precision.
+   5:  (DecimalType) Decimal
+   /// A date stored as days since Unix epoch.
+   6:  Date
+   /// A time stored as [`TimeUnit`] since midnight.
+   7:  (TimeType) Time
+   /// A timestamp stored as [`TimeUnit`] since Unix epoch.
+   8:  (TimestampType) Timestamp
+   // 9: reserved for INTERVAL
+   /// An integer with a specified bit width and signedness.
+   10: (IntType) Integer
+   /// An unknown logical type.
+   11: Unknown
+   /// A JSON document.
+   12: Json
+   /// A BSON document.
+   13: Bson
+   /// A UUID.
+   14: Uuid
+   /// A 16-bit floating point number.
+   15: Float16
+   /// A Variant value.
+   16: (VariantType) Variant
+   /// A geospatial feature in the Well-Known Binary (WKB) format with 
linear/planar edges interpolation.
+   17: (GeometryType) Geometry
+   /// A geospatial feature in the WKB format with an explicit 
(non-linear/non-planar) edges interpolation.
+   18: (GeographyType) Geography
 }
+);
 
 impl LogicalType {
     /// Create a [`LogicalType::Integer`] variant with the given `bit_width` 
and `is_signed`
     pub fn integer(bit_width: i8, is_signed: bool) -> Self {
-        Self::Integer {
+        Self::Integer(IntType {
             bit_width,
             is_signed,
-        }
+        })
     }
 
     /// Create a [`LogicalType::Decimal`] variant with the given `scale` and 
`precision`
     pub fn decimal(scale: i32, precision: i32) -> Self {
-        Self::Decimal { scale, precision }
+        Self::Decimal(DecimalType { scale, precision })
     }
 
     /// Create a [`LogicalType::Time`] variant with the given 
`is_adjusted_to_u_t_c` and `unit`
     pub fn time(is_adjusted_to_u_t_c: bool, unit: TimeUnit) -> Self {
-        Self::Time {
+        Self::Time(TimeType {
             is_adjusted_to_u_t_c,
             unit,
-        }
+        })
     }
 
     /// Create a [`LogicalType::Timestamp`] variant with the given 
`is_adjusted_to_u_t_c` and `unit`
     pub fn timestamp(is_adjusted_to_u_t_c: bool, unit: TimeUnit) -> Self {
-        Self::Timestamp {
+        Self::Timestamp(TimestampType {
             is_adjusted_to_u_t_c,
             unit,
-        }
+        })
     }
 
     /// Create a [`LogicalType::Variant`] variant with the given 
`specification_version`
     pub fn variant(specification_version: Option<i8>) -> Self {
-        Self::Variant {
+        Self::Variant(VariantType {
             specification_version,
-        }
+        })
     }
 
     /// Create a [`LogicalType::Geometry`] variant with the given `crs`
     pub fn geometry(crs: Option<String>) -> Self {
-        Self::Geometry { crs }
+        Self::Geometry(GeometryType { crs })
     }
 
     /// Create a [`LogicalType::Geography`] variant with the given `crs` and 
`algorithm`
     pub fn geography(crs: Option<String>, algorithm: 
Option<EdgeInterpolationAlgorithm>) -> Self {
-        Self::Geography { crs, algorithm }
-    }
-}
-
-impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for LogicalType {
-    fn read_thrift(prot: &mut R) -> Result<Self> {
-        let field_ident = prot.read_field_begin(0)?;
-        if field_ident.field_type == FieldType::Stop {
-            return Err(general_err!("received empty union from remote 
LogicalType"));
-        }
-        let ret = match field_ident.id {
-            1 => {
-                prot.skip_empty_struct()?;
-                Self::String
-            }
-            2 => {
-                prot.skip_empty_struct()?;
-                Self::Map
-            }
-            3 => {
-                prot.skip_empty_struct()?;
-                Self::List
-            }
-            4 => {
-                prot.skip_empty_struct()?;
-                Self::Enum
-            }
-            5 => {
-                let val = DecimalType::read_thrift(&mut *prot)?;
-                Self::decimal(val.scale, val.precision)
-            }
-            6 => {
-                prot.skip_empty_struct()?;
-                Self::Date
-            }
-            7 => {
-                let val = TimeType::read_thrift(&mut *prot)?;
-                Self::time(val.is_adjusted_to_u_t_c, val.unit)
-            }
-            8 => {
-                let val = TimestampType::read_thrift(&mut *prot)?;
-                Self::timestamp(val.is_adjusted_to_u_t_c, val.unit)
-            }
-            10 => {
-                let val = IntType::read_thrift(&mut *prot)?;
-                Self::integer(val.bit_width, val.is_signed)
-            }
-            11 => {
-                prot.skip_empty_struct()?;
-                Self::Unknown
-            }
-            12 => {
-                prot.skip_empty_struct()?;
-                Self::Json
-            }
-            13 => {
-                prot.skip_empty_struct()?;
-                Self::Bson
-            }
-            14 => {
-                prot.skip_empty_struct()?;
-                Self::Uuid
-            }
-            15 => {
-                prot.skip_empty_struct()?;
-                Self::Float16
-            }
-            16 => {
-                let val = VariantType::read_thrift(&mut *prot)?;
-                Self::variant(val.specification_version)
-            }
-            17 => {
-                let val = GeometryType::read_thrift(&mut *prot)?;
-                Self::geometry(val.crs.map(|s| s.to_owned()))
-            }
-            18 => {
-                let val = GeographyType::read_thrift(&mut *prot)?;
-                // unset algorithm means SPHERICAL, per the spec:
-                // 
https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#geography
-                let algorithm = val
-                    .algorithm
-                    .unwrap_or(EdgeInterpolationAlgorithm::SPHERICAL);
-                Self::geography(val.crs.map(|s| s.to_owned()), Some(algorithm))
-            }
-            _ => {
-                prot.skip(field_ident.field_type)?;
-                Self::_Unknown {
-                    field_id: field_ident.id,
-                }
-            }
-        };
-        let field_ident = prot.read_field_begin(field_ident.id)?;
-        if field_ident.field_type != FieldType::Stop {
-            return Err(general_err!(
-                "Received multiple fields for union from remote LogicalType"
-            ));
-        }
-        Ok(ret)
+        Self::Geography(GeographyType { crs, algorithm })
     }
 }
 
-impl WriteThrift for LogicalType {
-    const ELEMENT_TYPE: ElementType = ElementType::Struct;
-
-    fn write_thrift<W: Write>(&self, writer: &mut 
ThriftCompactOutputProtocol<W>) -> Result<()> {
-        match self {
-            Self::String => {
-                writer.write_empty_struct(1, 0)?;
-            }
-            Self::Map => {
-                writer.write_empty_struct(2, 0)?;
-            }
-            Self::List => {
-                writer.write_empty_struct(3, 0)?;
-            }
-            Self::Enum => {
-                writer.write_empty_struct(4, 0)?;
-            }
-            Self::Decimal { scale, precision } => {
-                DecimalType {
-                    scale: *scale,
-                    precision: *precision,
-                }
-                .write_thrift_field(writer, 5, 0)?;
-            }
-            Self::Date => {
-                writer.write_empty_struct(6, 0)?;
-            }
-            Self::Time {
-                is_adjusted_to_u_t_c,
-                unit,
-            } => {
-                TimeType {
-                    is_adjusted_to_u_t_c: *is_adjusted_to_u_t_c,
-                    unit: *unit,
-                }
-                .write_thrift_field(writer, 7, 0)?;
-            }
-            Self::Timestamp {
-                is_adjusted_to_u_t_c,
-                unit,
-            } => {
-                TimestampType {
-                    is_adjusted_to_u_t_c: *is_adjusted_to_u_t_c,
-                    unit: *unit,
-                }
-                .write_thrift_field(writer, 8, 0)?;
-            }
-            Self::Integer {
-                bit_width,
-                is_signed,
-            } => {
-                IntType {
-                    bit_width: *bit_width,
-                    is_signed: *is_signed,
-                }
-                .write_thrift_field(writer, 10, 0)?;
-            }
-            Self::Unknown => {
-                writer.write_empty_struct(11, 0)?;
-            }
-            Self::Json => {
-                writer.write_empty_struct(12, 0)?;
-            }
-            Self::Bson => {
-                writer.write_empty_struct(13, 0)?;
-            }
-            Self::Uuid => {
-                writer.write_empty_struct(14, 0)?;
-            }
-            Self::Float16 => {
-                writer.write_empty_struct(15, 0)?;
-            }
-            Self::Variant {
-                specification_version,
-            } => {
-                VariantType {
-                    specification_version: *specification_version,
-                }
-                .write_thrift_field(writer, 16, 0)?;
-            }
-            Self::Geometry { crs } => {
-                GeometryType {
-                    crs: crs.as_ref().map(|s| s.as_str()),
-                }
-                .write_thrift_field(writer, 17, 0)?;
-            }
-            Self::Geography { crs, algorithm } => {
-                GeographyType {
-                    crs: crs.as_ref().map(|s| s.as_str()),
-                    algorithm: *algorithm,
-                }
-                .write_thrift_field(writer, 18, 0)?;
-            }
-            _ => return Err(nyi_err!("logical type")),
-        }
-        writer.write_struct_end()
-    }
-}
-
-write_thrift_field!(LogicalType, FieldType::Struct);
-
 // ----------------------------------------------------------------------
 // Mirrors thrift enum `FieldRepetitionType`
 //
@@ -1253,21 +1039,21 @@ impl ColumnOrder {
                 LogicalType::String | LogicalType::Enum | LogicalType::Json | 
LogicalType::Bson => {
                     SortOrder::UNSIGNED
                 }
-                LogicalType::Integer { is_signed, .. } => match is_signed {
+                LogicalType::Integer(int) => match int.is_signed {
                     true => SortOrder::SIGNED,
                     false => SortOrder::UNSIGNED,
                 },
                 LogicalType::Map | LogicalType::List => SortOrder::UNDEFINED,
-                LogicalType::Decimal { .. } => SortOrder::SIGNED,
+                LogicalType::Decimal(_) => SortOrder::SIGNED,
                 LogicalType::Date => SortOrder::SIGNED,
-                LogicalType::Time { .. } => SortOrder::SIGNED,
-                LogicalType::Timestamp { .. } => SortOrder::SIGNED,
+                LogicalType::Time(_) => SortOrder::SIGNED,
+                LogicalType::Timestamp(_) => SortOrder::SIGNED,
                 LogicalType::Unknown => SortOrder::UNDEFINED,
                 LogicalType::Uuid => SortOrder::UNSIGNED,
                 LogicalType::Float16 => SortOrder::SIGNED,
-                LogicalType::Variant { .. }
-                | LogicalType::Geometry { .. }
-                | LogicalType::Geography { .. }
+                LogicalType::Variant(_)
+                | LogicalType::Geometry(_)
+                | LogicalType::Geography(_)
                 | LogicalType::_Unknown { .. } => SortOrder::UNDEFINED,
             },
             // Fall back to converted type
@@ -1426,20 +1212,17 @@ impl From<Option<LogicalType>> for ConvertedType {
                 LogicalType::Enum => ConvertedType::ENUM,
                 LogicalType::Decimal { .. } => ConvertedType::DECIMAL,
                 LogicalType::Date => ConvertedType::DATE,
-                LogicalType::Time { unit, .. } => match unit {
+                LogicalType::Time(time) => match time.unit {
                     TimeUnit::MILLIS => ConvertedType::TIME_MILLIS,
                     TimeUnit::MICROS => ConvertedType::TIME_MICROS,
                     TimeUnit::NANOS => ConvertedType::NONE,
                 },
-                LogicalType::Timestamp { unit, .. } => match unit {
+                LogicalType::Timestamp(time) => match time.unit {
                     TimeUnit::MILLIS => ConvertedType::TIMESTAMP_MILLIS,
                     TimeUnit::MICROS => ConvertedType::TIMESTAMP_MICROS,
                     TimeUnit::NANOS => ConvertedType::NONE,
                 },
-                LogicalType::Integer {
-                    bit_width,
-                    is_signed,
-                } => match (bit_width, is_signed) {
+                LogicalType::Integer(int_type) => match (int_type.bit_width, 
int_type.is_signed) {
                     (8, true) => ConvertedType::INT_8,
                     (16, true) => ConvertedType::INT_16,
                     (32, true) => ConvertedType::INT_32,
@@ -1456,9 +1239,9 @@ impl From<Option<LogicalType>> for ConvertedType {
                 LogicalType::Bson => ConvertedType::BSON,
                 LogicalType::Uuid
                 | LogicalType::Float16
-                | LogicalType::Variant { .. }
-                | LogicalType::Geometry { .. }
-                | LogicalType::Geography { .. }
+                | LogicalType::Variant(_)
+                | LogicalType::Geometry(_)
+                | LogicalType::Geography(_)
                 | LogicalType::_Unknown { .. }
                 | LogicalType::Unknown => ConvertedType::NONE,
             },
diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs
index 595eadbc90..4e53230bbf 100644
--- a/parquet/src/column/writer/mod.rs
+++ b/parquet/src/column/writer/mod.rs
@@ -27,7 +27,8 @@ use std::collections::{BTreeSet, VecDeque};
 use std::str;
 
 use crate::basic::{
-    BoundaryOrder, Compression, ConvertedType, Encoding, EncodingMask, 
LogicalType, PageType, Type,
+    BoundaryOrder, Compression, ConvertedType, Encoding, EncodingMask, 
IntType, LogicalType,
+    PageType, Type,
 };
 use crate::column::page::{CompressedPage, Page, PageWriteSpec, PageWriter};
 use crate::column::writer::encoder::{ColumnValueEncoder, 
ColumnValueEncoderImpl, ColumnValues};
@@ -1522,9 +1523,9 @@ fn update_stat<T: ParquetValueType, F>(
 fn compare_greater<T: ParquetValueType>(descr: &ColumnDescriptor, a: &T, b: 
&T) -> bool {
     match T::PHYSICAL_TYPE {
         Type::INT32 | Type::INT64 => {
-            if let Some(LogicalType::Integer {
+            if let Some(LogicalType::Integer(IntType {
                 is_signed: false, ..
-            }) = descr.logical_type_ref()
+            })) = descr.logical_type_ref()
             {
                 // need to compare unsigned
                 return compare_greater_unsigned_int(a, b);
@@ -1541,7 +1542,7 @@ fn compare_greater<T: ParquetValueType>(descr: 
&ColumnDescriptor, a: &T, b: &T)
             };
         }
         Type::FIXED_LEN_BYTE_ARRAY | Type::BYTE_ARRAY => {
-            if let Some(LogicalType::Decimal { .. }) = 
descr.logical_type_ref() {
+            if let Some(LogicalType::Decimal(_)) = descr.logical_type_ref() {
                 return compare_greater_byte_array_decimals(a.as_bytes(), 
b.as_bytes());
             }
             if let ConvertedType::DECIMAL = descr.converted_type() {
diff --git a/parquet/src/parquet_macros.rs b/parquet/src/parquet_macros.rs
index 8bb2ad23b0..f7ddf57b14 100644
--- a/parquet/src/parquet_macros.rs
+++ b/parquet/src/parquet_macros.rs
@@ -260,6 +260,86 @@ macro_rules! thrift_union {
     }
 }
 
+/// Macro used to generate Rust enums for Thrift unions where variants are a 
mix of unit and
+/// tuple types. This version allows for unknown variants for forwards 
compatibility.
+///
+/// Use of this macro requires modifying the thrift IDL. For variants with 
empty structs as their
+/// type, delete the typename (i.e. `1: EmptyStruct Var1;` becomes `1: Var1`). 
For variants with a
+/// non-empty type, the typename must be contained within parens (e.g. `1: 
MyType Var1;` becomes
+/// `1: (MyType) Var1;`).
+///
+/// This macro allows for specifying lifetime annotations for the resulting 
`enum` and its fields.
+///
+/// When utilizing this macro the Thrift serialization traits and structs need 
to be in scope.
+#[doc(hidden)]
+#[macro_export]
+#[allow(clippy::crate_in_macro_def)]
+macro_rules! thrift_union_with_unknown {
+    ($(#[$($def_attrs:tt)*])* union $identifier:ident $(< $lt:lifetime >)? { 
$($(#[$($field_attrs:tt)*])* $field_id:literal : $( ( $field_type:ident $(< 
$element_type:ident >)? $(< $field_lt:lifetime >)?) )? $field_name:ident 
$(;)?)* }) => {
+        $(#[cfg_attr(not(doctest), $($def_attrs)*)])*
+        #[derive(Clone, Debug, Eq, PartialEq)]
+        #[allow(non_camel_case_types)]
+        #[allow(non_snake_case)]
+        #[allow(missing_docs)]
+        pub enum $identifier $(<$lt>)? {
+            $($(#[cfg_attr(not(doctest), $($field_attrs)*)])* $field_name $( ( 
$crate::__thrift_union_type!{$field_type $($field_lt)? $($element_type)?} ) 
)?),*,
+            _Unknown {
+                /// The field id encountered when parsing the unknown variant.
+                field_id: i16,
+            },
+        }
+
+        impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for 
$identifier $(<$lt>)? {
+            fn read_thrift(prot: &mut R) -> Result<Self> {
+                let field_ident = prot.read_field_begin(0)?;
+                if field_ident.field_type == FieldType::Stop {
+                    return Err(general_err!("Received empty union from remote 
{}", stringify!($identifier)));
+                }
+                let ret = match field_ident.id {
+                    $($field_id => {
+                        let val = $crate::__thrift_read_variant!(prot, 
$field_name $($field_type $($element_type)?)?);
+                        val
+                    })*
+                    _ => {
+                        prot.skip(field_ident.field_type)?;
+                        Self::_Unknown {
+                            field_id: field_ident.id,
+                        }
+                    }
+                };
+                let field_ident = prot.read_field_begin(field_ident.id)?;
+                if field_ident.field_type != FieldType::Stop {
+                    return Err(general_err!(
+                        concat!("Received multiple fields for union from 
remote {}", stringify!($identifier))
+                    ));
+                }
+                Ok(ret)
+            }
+        }
+
+        impl $(<$lt>)? WriteThrift for $identifier $(<$lt>)? {
+            const ELEMENT_TYPE: ElementType = ElementType::Struct;
+
+            fn write_thrift<W: Write>(&self, writer: &mut 
ThriftCompactOutputProtocol<W>) -> Result<()> {
+                match self {
+                    $($crate::__thrift_write_variant_lhs!($field_name 
$($field_type)?, variant_val) =>
+                      $crate::__thrift_write_variant_rhs!($field_id 
$($field_type)?, writer, variant_val),)*
+                    Self::_Unknown{..} => return Err(general_err!("Trying to 
write unknown variant")),
+                };
+                writer.write_struct_end()
+            }
+        }
+
+        impl $(<$lt>)? WriteThriftField for $identifier $(<$lt>)? {
+            fn write_thrift_field<W: Write>(&self, writer: &mut 
ThriftCompactOutputProtocol<W>, field_id: i16, last_field_id: i16) -> 
Result<i16> {
+                writer.write_field_begin(FieldType::Struct, field_id, 
last_field_id)?;
+                self.write_thrift(writer)?;
+                Ok(field_id)
+            }
+        }
+    }
+}
+
 /// Macro used to generate Rust structs from a Thrift `struct` definition.
 ///
 /// Note:
diff --git a/parquet/src/schema/printer.rs b/parquet/src/schema/printer.rs
index dbeddcfc12..31d00fc23b 100644
--- a/parquet/src/schema/printer.rs
+++ b/parquet/src/schema/printer.rs
@@ -45,7 +45,10 @@
 
 use std::{fmt, io};
 
-use crate::basic::{ConvertedType, LogicalType, TimeUnit, Type as PhysicalType};
+use crate::basic::{
+    ConvertedType, DecimalType, GeographyType, GeometryType, IntType, 
LogicalType, TimeUnit,
+    Type as PhysicalType, VariantType,
+};
 use crate::file::metadata::{ColumnChunkMetaData, FileMetaData, 
ParquetMetaData, RowGroupMetaData};
 use crate::schema::types::Type;
 
@@ -284,30 +287,28 @@ fn print_logical_and_converted(
 ) -> String {
     match logical_type {
         Some(logical_type) => match logical_type {
-            LogicalType::Integer {
+            LogicalType::Integer(IntType {
                 bit_width,
                 is_signed,
-            } => {
+            }) => {
                 format!("INTEGER({bit_width},{is_signed})")
             }
-            LogicalType::Decimal { scale, precision } => {
+            LogicalType::Decimal(DecimalType { scale, precision }) => {
                 format!("DECIMAL({precision},{scale})")
             }
-            LogicalType::Timestamp {
-                is_adjusted_to_u_t_c,
-                unit,
-            } => {
+            LogicalType::Timestamp(timestamp) => {
                 format!(
                     "TIMESTAMP({},{})",
-                    print_timeunit(unit),
-                    is_adjusted_to_u_t_c
+                    print_timeunit(&timestamp.unit),
+                    timestamp.is_adjusted_to_u_t_c
                 )
             }
-            LogicalType::Time {
-                is_adjusted_to_u_t_c,
-                unit,
-            } => {
-                format!("TIME({},{})", print_timeunit(unit), 
is_adjusted_to_u_t_c)
+            LogicalType::Time(time) => {
+                format!(
+                    "TIME({},{})",
+                    print_timeunit(&time.unit),
+                    time.is_adjusted_to_u_t_c
+                )
             }
             LogicalType::Date => "DATE".to_string(),
             LogicalType::Bson => "BSON".to_string(),
@@ -318,17 +319,17 @@ fn print_logical_and_converted(
             LogicalType::List => "LIST".to_string(),
             LogicalType::Map => "MAP".to_string(),
             LogicalType::Float16 => "FLOAT16".to_string(),
-            LogicalType::Variant {
+            LogicalType::Variant(VariantType {
                 specification_version,
-            } => format!("VARIANT({specification_version:?})"),
-            LogicalType::Geometry { crs } => {
+            }) => format!("VARIANT({specification_version:?})"),
+            LogicalType::Geometry(GeometryType { crs }) => {
                 if let Some(crs) = crs {
                     format!("GEOMETRY({crs})")
                 } else {
                     "GEOMETRY".to_string()
                 }
             }
-            LogicalType::Geography { crs, algorithm } => {
+            LogicalType::Geography(GeographyType { crs, algorithm }) => {
                 let algorithm = algorithm.unwrap_or_default();
                 if let Some(crs) = crs {
                     format!("GEOGRAPHY({algorithm}, {crs})")
diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs
index d8b3456d47..1f9b8590fc 100644
--- a/parquet/src/schema/types.rs
+++ b/parquet/src/schema/types.rs
@@ -24,7 +24,8 @@ use crate::file::metadata::HeapSize;
 use crate::file::metadata::thrift::SchemaElement;
 
 use crate::basic::{
-    ColumnOrder, ConvertedType, LogicalType, Repetition, SortOrder, TimeUnit, 
Type as PhysicalType,
+    ColumnOrder, ConvertedType, IntType, LogicalType, Repetition, SortOrder, 
TimeType, TimeUnit,
+    Type as PhysicalType,
 };
 use crate::errors::{ParquetError, Result};
 
@@ -356,20 +357,20 @@ impl<'a> PrimitiveTypeBuilder<'a> {
                     ));
                 }
                 (LogicalType::Enum, PhysicalType::BYTE_ARRAY) => {}
-                (LogicalType::Decimal { scale, precision }, _) => {
+                (LogicalType::Decimal(decimal), _) => {
                     // Check that scale and precision are consistent with 
legacy values
-                    if *scale != self.scale {
+                    if decimal.scale != self.scale {
                         return Err(general_err!(
                             "DECIMAL logical type scale {} must match 
self.scale {} for field '{}'",
-                            scale,
+                            decimal.scale,
                             self.scale,
                             self.name
                         ));
                     }
-                    if *precision != self.precision {
+                    if decimal.precision != self.precision {
                         return Err(general_err!(
                             "DECIMAL logical type precision {} must match 
self.precision {} for field '{}'",
-                            precision,
+                            decimal.precision,
                             self.precision,
                             self.name
                         ));
@@ -378,32 +379,30 @@ impl<'a> PrimitiveTypeBuilder<'a> {
                 }
                 (LogicalType::Date, PhysicalType::INT32) => {}
                 (
-                    LogicalType::Time {
+                    LogicalType::Time(TimeType {
                         unit: TimeUnit::MILLIS,
                         ..
-                    },
+                    }),
                     PhysicalType::INT32,
                 ) => {}
-                (LogicalType::Time { unit, .. }, PhysicalType::INT64) => {
-                    if *unit == TimeUnit::MILLIS {
+                (LogicalType::Time(time), PhysicalType::INT64) => {
+                    if time.unit == TimeUnit::MILLIS {
                         return Err(general_err!(
                             "Cannot use millisecond unit on INT64 type for 
field '{}'",
                             self.name
                         ));
                     }
                 }
-                (LogicalType::Timestamp { .. }, PhysicalType::INT64) => {}
-                (LogicalType::Integer { bit_width, .. }, PhysicalType::INT32)
-                    if *bit_width <= 32 => {}
-                (LogicalType::Integer { bit_width, .. }, PhysicalType::INT64)
-                    if *bit_width == 64 => {}
+                (LogicalType::Timestamp(_), PhysicalType::INT64) => {}
+                (LogicalType::Integer(int), PhysicalType::INT32) if 
int.bit_width <= 32 => {}
+                (LogicalType::Integer(int), PhysicalType::INT64) if 
int.bit_width == 64 => {}
                 // Null type
                 (LogicalType::Unknown, _) => {}
                 (LogicalType::String, PhysicalType::BYTE_ARRAY) => {}
                 (LogicalType::Json, PhysicalType::BYTE_ARRAY) => {}
                 (LogicalType::Bson, PhysicalType::BYTE_ARRAY) => {}
-                (LogicalType::Geometry { .. }, PhysicalType::BYTE_ARRAY) => {}
-                (LogicalType::Geography { .. }, PhysicalType::BYTE_ARRAY) => {}
+                (LogicalType::Geometry(_), PhysicalType::BYTE_ARRAY) => {}
+                (LogicalType::Geography(_), PhysicalType::BYTE_ARRAY) => {}
                 (LogicalType::Uuid, PhysicalType::FIXED_LEN_BYTE_ARRAY) if 
self.length == 16 => {}
                 (LogicalType::Uuid, PhysicalType::FIXED_LEN_BYTE_ARRAY) => {
                     return Err(general_err!(
@@ -1280,8 +1279,8 @@ fn build_tree<'a>(
 
 /// Checks if the logical type is valid.
 fn check_logical_type(logical_type: &Option<LogicalType>) -> Result<()> {
-    if let Some(LogicalType::Integer { bit_width, .. }) = *logical_type {
-        if bit_width != 8 && bit_width != 16 && bit_width != 32 && bit_width 
!= 64 {
+    if let Some(LogicalType::Integer(IntType { bit_width, .. })) = 
logical_type {
+        if *bit_width != 8 && *bit_width != 16 && *bit_width != 32 && 
*bit_width != 64 {
             return Err(general_err!(
                 "Bit width must be 8, 16, 32, or 64 for Integer logical type"
             ));
@@ -1482,7 +1481,7 @@ mod tests {
         if let Err(e) = result {
             assert_eq!(
                 format!("{e}"),
-                "Parquet error: Cannot annotate Integer { bit_width: 8, 
is_signed: true } from INT64 for field 'foo'"
+                "Parquet error: Cannot annotate Integer(IntType { bit_width: 
8, is_signed: true }) from INT64 for field 'foo'"
             );
         }
 
diff --git a/parquet/tests/geospatial.rs b/parquet/tests/geospatial.rs
index fcc93661ed..bf34528d03 100644
--- a/parquet/tests/geospatial.rs
+++ b/parquet/tests/geospatial.rs
@@ -30,7 +30,7 @@ mod test {
             ArrowSchemaConverter, ArrowWriter, 
arrow_reader::ParquetRecordBatchReaderBuilder,
             arrow_writer::ArrowWriterOptions,
         },
-        basic::{EdgeInterpolationAlgorithm, LogicalType},
+        basic::LogicalType,
         column::reader::ColumnReader,
         data_type::{ByteArray, ByteArrayType},
         file::{
@@ -62,29 +62,22 @@ mod test {
         let expected_metadata = [
             (
                 "crs-default.parquet",
-                LogicalType::Geometry { crs: None },
+                LogicalType::geometry(None),
                 WkbMetadata::new(None, None),
             ),
             (
                 "crs-srid.parquet",
-                LogicalType::Geometry {
-                    crs: Some("srid:5070".to_string()),
-                },
+                LogicalType::geometry(Some("srid:5070".to_string())),
                 WkbMetadata::new(Some("srid:5070"), None),
             ),
             (
                 "crs-projjson.parquet",
-                LogicalType::Geometry {
-                    crs: Some("projjson:projjson_epsg_5070".to_string()),
-                },
+                
LogicalType::geometry(Some("projjson:projjson_epsg_5070".to_string())),
                 WkbMetadata::new(Some("projjson:projjson_epsg_5070"), None),
             ),
             (
                 "crs-geography.parquet",
-                LogicalType::Geography {
-                    crs: None,
-                    algorithm: Some(EdgeInterpolationAlgorithm::SPHERICAL),
-                },
+                LogicalType::geography(None, None),
                 WkbMetadata::new(None, Some(WkbEdges::Spherical)),
             ),
         ];
@@ -109,8 +102,8 @@ mod test {
         let column_descr = metadata.file_metadata().schema_descr().column(1);
         let logical_type = column_descr.logical_type_ref().unwrap();
 
-        if let LogicalType::Geometry { crs } = logical_type {
-            let crs = crs.as_ref();
+        if let LogicalType::Geometry(geometry) = logical_type {
+            let crs = geometry.crs.as_ref();
             let crs_parsed: Value = 
serde_json::from_str(crs.unwrap()).unwrap();
             assert_eq!(crs_parsed.get("id").unwrap().get("code").unwrap(), 
5070);
         } else {
@@ -136,7 +129,7 @@ mod test {
         //    optional binary field_id=-1 geometry (Geometry(crs=));
         let fields = metadata.file_metadata().schema().get_fields();
         let logical_type = 
fields[2].get_basic_info().logical_type_ref().unwrap();
-        assert_eq!(logical_type, &LogicalType::Geometry { crs: None });
+        assert_eq!(logical_type, &LogicalType::geometry(None));
 
         let geo_statistics = metadata.row_group(0).column(2).geo_statistics();
         assert!(geo_statistics.is_some());


Reply via email to