This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new d49f017fe1 Introduce a ThriftProtocolError to avoid allocating and 
formattings strings for error messages (#8636)
d49f017fe1 is described below

commit d49f017fe1c6712ba32e2222c6f031278b588ca5
Author: Jörn Horstmann <[email protected]>
AuthorDate: Fri Oct 17 20:24:42 2025 +0200

    Introduce a ThriftProtocolError to avoid allocating and formattings strings 
for error messages (#8636)
    
    # Which issue does this PR close?
    
    This is a small performance improvement for the thrift remodeling
    
    - Part of https://github.com/apache/arrow-rs/issues/5853.
    
    # Rationale for this change
    
    Some of the often-called methods in the thrift protocol implementation
    created `ParquetError` instances with a string message that had to be
    allocated and formatted. This formatting code and probably also some
    drop glue bloats these otherwise small methods and prevented inlining.
    
    # What changes are included in this PR?
    
    Introduce a separate error type `ThriftProtocolError` that is smaller
    than `ParquetError` and does not contain any allocated data. The
    `ReadThrift` trait is not changed, since its custom implementations
    actually require the more expressive `ParquetError`.
    
    # Are these changes tested?
    
    The success path is covered by existing tests. Testing the error paths
    would require crafting some actually malformed files, or using a fuzzer.
    
    # Are there any user-facing changes?
    
    The `ThriftProtocolError` is crate-internal so there should be no api
    changes. Some error messages might differ slightly.
---
 parquet/src/parquet_thrift.rs | 191 ++++++++++++++++++++++++++++--------------
 1 file changed, 126 insertions(+), 65 deletions(-)

diff --git a/parquet/src/parquet_thrift.rs b/parquet/src/parquet_thrift.rs
index 221532ea83..8ee018ef95 100644
--- a/parquet/src/parquet_thrift.rs
+++ b/parquet/src/parquet_thrift.rs
@@ -35,6 +35,66 @@ use crate::{
     errors::{ParquetError, Result},
     write_thrift_field,
 };
+use std::io::Error;
+use std::str::Utf8Error;
+
+#[derive(Debug)]
+pub(crate) enum ThriftProtocolError {
+    Eof,
+    IO(Error),
+    InvalidFieldType(u8),
+    InvalidElementType(u8),
+    FieldDeltaOverflow { field_delta: u8, last_field_id: i16 },
+    InvalidBoolean(u8),
+    Utf8Error,
+    SkipDepth(FieldType),
+    SkipUnsupportedType(FieldType),
+}
+
+impl From<ThriftProtocolError> for ParquetError {
+    #[inline(never)]
+    fn from(e: ThriftProtocolError) -> Self {
+        match e {
+            ThriftProtocolError::Eof => eof_err!("Unexpected EOF"),
+            ThriftProtocolError::IO(e) => e.into(),
+            ThriftProtocolError::InvalidFieldType(value) => {
+                general_err!("Unexpected struct field type {}", value)
+            }
+            ThriftProtocolError::InvalidElementType(value) => {
+                general_err!("Unexpected list/set element type{}", value)
+            }
+            ThriftProtocolError::FieldDeltaOverflow {
+                field_delta,
+                last_field_id,
+            } => general_err!("cannot add {} to {}", field_delta, 
last_field_id),
+            ThriftProtocolError::InvalidBoolean(value) => {
+                general_err!("cannot convert {} into bool", value)
+            }
+            ThriftProtocolError::Utf8Error => general_err!("invalid utf8"),
+            ThriftProtocolError::SkipDepth(field_type) => {
+                general_err!("cannot parse past {:?}", field_type)
+            }
+            ThriftProtocolError::SkipUnsupportedType(field_type) => {
+                general_err!("cannot skip field type {:?}", field_type)
+            }
+        }
+    }
+}
+
+impl From<Utf8Error> for ThriftProtocolError {
+    fn from(_: Utf8Error) -> Self {
+        // ignore error payload to reduce the size of ThriftProtocolError
+        Self::Utf8Error
+    }
+}
+
+impl From<Error> for ThriftProtocolError {
+    fn from(e: Error) -> Self {
+        Self::IO(e)
+    }
+}
+
+pub type ThriftProtocolResult<T> = Result<T, ThriftProtocolError>;
 
 /// Wrapper for thrift `double` fields. This is used to provide
 /// an implementation of `Eq` for floats. This implementation
@@ -87,8 +147,8 @@ pub(crate) enum FieldType {
 }
 
 impl TryFrom<u8> for FieldType {
-    type Error = ParquetError;
-    fn try_from(value: u8) -> Result<Self> {
+    type Error = ThriftProtocolError;
+    fn try_from(value: u8) -> ThriftProtocolResult<Self> {
         match value {
             0 => Ok(Self::Stop),
             1 => Ok(Self::BooleanTrue),
@@ -103,13 +163,13 @@ impl TryFrom<u8> for FieldType {
             10 => Ok(Self::Set),
             11 => Ok(Self::Map),
             12 => Ok(Self::Struct),
-            _ => Err(general_err!("Unexpected struct field type{}", value)),
+            _ => Err(ThriftProtocolError::InvalidFieldType(value)),
         }
     }
 }
 
 impl TryFrom<ElementType> for FieldType {
-    type Error = ParquetError;
+    type Error = ThriftProtocolError;
     fn try_from(value: ElementType) -> std::result::Result<Self, Self::Error> {
         match value {
             ElementType::Bool => Ok(Self::BooleanTrue),
@@ -121,7 +181,7 @@ impl TryFrom<ElementType> for FieldType {
             ElementType::Binary => Ok(Self::Binary),
             ElementType::List => Ok(Self::List),
             ElementType::Struct => Ok(Self::Struct),
-            _ => Err(general_err!("Unexpected list element type{:?}", value)),
+            _ => Err(ThriftProtocolError::InvalidFieldType(value as u8)),
         }
     }
 }
@@ -143,8 +203,8 @@ pub(crate) enum ElementType {
 }
 
 impl TryFrom<u8> for ElementType {
-    type Error = ParquetError;
-    fn try_from(value: u8) -> Result<Self> {
+    type Error = ThriftProtocolError;
+    fn try_from(value: u8) -> ThriftProtocolResult<Self> {
         match value {
             // For historical and compatibility reasons, a reader should be 
capable to deal with both cases.
             // The only valid value in the original spec was 2, but due to an 
widespread implementation bug
@@ -162,7 +222,7 @@ impl TryFrom<u8> for ElementType {
             10 => Ok(Self::Set),
             11 => Ok(Self::Map),
             12 => Ok(Self::Struct),
-            _ => Err(general_err!("Unexpected list/set element type{}", 
value)),
+            _ => Err(ThriftProtocolError::InvalidElementType(value)),
         }
     }
 }
@@ -202,20 +262,20 @@ pub(crate) struct ListIdentifier {
 /// [compact]: 
https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md
 pub(crate) trait ThriftCompactInputProtocol<'a> {
     /// Read a single byte from the input.
-    fn read_byte(&mut self) -> Result<u8>;
+    fn read_byte(&mut self) -> ThriftProtocolResult<u8>;
 
     /// Read a Thrift encoded [binary] from the input.
     ///
     /// [binary]: 
https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
-    fn read_bytes(&mut self) -> Result<&'a [u8]>;
+    fn read_bytes(&mut self) -> ThriftProtocolResult<&'a [u8]>;
 
-    fn read_bytes_owned(&mut self) -> Result<Vec<u8>>;
+    fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>>;
 
     /// Skip the next `n` bytes of input.
-    fn skip_bytes(&mut self, n: usize) -> Result<()>;
+    fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()>;
 
     /// Read a ULEB128 encoded unsigned varint from the input.
-    fn read_vlq(&mut self) -> Result<u64> {
+    fn read_vlq(&mut self) -> ThriftProtocolResult<u64> {
         let mut in_progress = 0;
         let mut shift = 0;
         loop {
@@ -229,13 +289,13 @@ pub(crate) trait ThriftCompactInputProtocol<'a> {
     }
 
     /// Read a zig-zag encoded signed varint from the input.
-    fn read_zig_zag(&mut self) -> Result<i64> {
+    fn read_zig_zag(&mut self) -> ThriftProtocolResult<i64> {
         let val = self.read_vlq()?;
         Ok((val >> 1) as i64 ^ -((val & 1) as i64))
     }
 
     /// Read the [`ListIdentifier`] for a Thrift encoded list.
-    fn read_list_begin(&mut self) -> Result<ListIdentifier> {
+    fn read_list_begin(&mut self) -> ThriftProtocolResult<ListIdentifier> {
         let header = self.read_byte()?;
         let element_type = ElementType::try_from(header & 0x0f)?;
 
@@ -253,8 +313,16 @@ pub(crate) trait ThriftCompactInputProtocol<'a> {
         })
     }
 
+    // Full field ids are uncommon.
+    // Not inlining this method reduces the code size of `read_field_begin`, 
which then ideally gets
+    // inlined everywhere.
+    #[cold]
+    fn read_full_field_id(&mut self) -> ThriftProtocolResult<i16> {
+        self.read_i16()
+    }
+
     /// Read the [`FieldIdentifier`] for a field in a Thrift encoded struct.
-    fn read_field_begin(&mut self, last_field_id: i16) -> 
Result<FieldIdentifier> {
+    fn read_field_begin(&mut self, last_field_id: i16) -> 
ThriftProtocolResult<FieldIdentifier> {
         // we can read at least one byte, which is:
         // - the type
         // - the field delta and the type
@@ -277,17 +345,14 @@ pub(crate) trait ThriftCompactInputProtocol<'a> {
                     bool_val = Some(true);
                 }
                 let field_id = if field_delta != 0 {
-                    last_field_id.checked_add(field_delta as i16).map_or_else(
-                        || {
-                            Err(general_err!(format!(
-                                "cannot add {} to {}",
-                                field_delta, last_field_id
-                            )))
+                    last_field_id.checked_add(field_delta as i16).ok_or(
+                        ThriftProtocolError::FieldDeltaOverflow {
+                            field_delta,
+                            last_field_id,
                         },
-                        Ok,
                     )?
                 } else {
-                    self.read_i16()?
+                    self.read_full_field_id()?
                 };
 
                 Ok(FieldIdentifier {
@@ -305,7 +370,7 @@ pub(crate) trait ThriftCompactInputProtocol<'a> {
     /// This also skips validation of the field type.
     ///
     /// Returns a tuple of `(field_type, field_delta)`.
-    fn read_field_header(&mut self) -> Result<(u8, u8)> {
+    fn read_field_header(&mut self) -> ThriftProtocolResult<(u8, u8)> {
         let field_type = self.read_byte()?;
         let field_delta = (field_type & 0xf0) >> 4;
         let field_type = field_type & 0xf;
@@ -314,7 +379,7 @@ pub(crate) trait ThriftCompactInputProtocol<'a> {
 
     /// Read a boolean list element. This should not be used for struct 
fields. For the latter,
     /// use the [`FieldIdentifier::bool_val`] field.
-    fn read_bool(&mut self) -> Result<bool> {
+    fn read_bool(&mut self) -> ThriftProtocolResult<bool> {
         let b = self.read_byte()?;
         // Previous versions of the thrift specification said to use 0 and 1 
inside collections,
         // but that differed from existing implementations.
@@ -323,43 +388,43 @@ pub(crate) trait ThriftCompactInputProtocol<'a> {
         match b {
             0x01 => Ok(true),
             0x00 | 0x02 => Ok(false),
-            unkn => Err(general_err!(format!("cannot convert {unkn} into 
bool"))),
+            _ => Err(ThriftProtocolError::InvalidBoolean(b)),
         }
     }
 
     /// Read a Thrift [binary] as a UTF-8 encoded string.
     ///
     /// [binary]: 
https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
-    fn read_string(&mut self) -> Result<&'a str> {
+    fn read_string(&mut self) -> ThriftProtocolResult<&'a str> {
         let slice = self.read_bytes()?;
         Ok(std::str::from_utf8(slice)?)
     }
 
     /// Read an `i8`.
-    fn read_i8(&mut self) -> Result<i8> {
+    fn read_i8(&mut self) -> ThriftProtocolResult<i8> {
         Ok(self.read_byte()? as _)
     }
 
     /// Read an `i16`.
-    fn read_i16(&mut self) -> Result<i16> {
+    fn read_i16(&mut self) -> ThriftProtocolResult<i16> {
         Ok(self.read_zig_zag()? as _)
     }
 
     /// Read an `i32`.
-    fn read_i32(&mut self) -> Result<i32> {
+    fn read_i32(&mut self) -> ThriftProtocolResult<i32> {
         Ok(self.read_zig_zag()? as _)
     }
 
     /// Read an `i64`.
-    fn read_i64(&mut self) -> Result<i64> {
+    fn read_i64(&mut self) -> ThriftProtocolResult<i64> {
         self.read_zig_zag()
     }
 
     /// Read a Thrift `double` as `f64`.
-    fn read_double(&mut self) -> Result<f64>;
+    fn read_double(&mut self) -> ThriftProtocolResult<f64>;
 
     /// Skip a ULEB128 encoded varint.
-    fn skip_vlq(&mut self) -> Result<()> {
+    fn skip_vlq(&mut self) -> ThriftProtocolResult<()> {
         loop {
             let byte = self.read_byte()?;
             if byte & 0x80 == 0 {
@@ -371,14 +436,14 @@ pub(crate) trait ThriftCompactInputProtocol<'a> {
     /// Skip a thrift [binary].
     ///
     /// [binary]: 
https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#binary-encoding
-    fn skip_binary(&mut self) -> Result<()> {
+    fn skip_binary(&mut self) -> ThriftProtocolResult<()> {
         let len = self.read_vlq()? as usize;
         self.skip_bytes(len)
     }
 
     /// Skip a field with type `field_type` recursively until the default
     /// maximum skip depth (currently 64) is reached.
-    fn skip(&mut self, field_type: FieldType) -> Result<()> {
+    fn skip(&mut self, field_type: FieldType) -> ThriftProtocolResult<()> {
         const DEFAULT_SKIP_DEPTH: i8 = 64;
         self.skip_till_depth(field_type, DEFAULT_SKIP_DEPTH)
     }
@@ -396,9 +461,9 @@ pub(crate) trait ThriftCompactInputProtocol<'a> {
     }
 
     /// Skip a field with type `field_type` recursively up to `depth` levels.
-    fn skip_till_depth(&mut self, field_type: FieldType, depth: i8) -> 
Result<()> {
+    fn skip_till_depth(&mut self, field_type: FieldType, depth: i8) -> 
ThriftProtocolResult<()> {
         if depth == 0 {
-            return Err(general_err!(format!("cannot parse past {:?}", 
field_type)));
+            return Err(ThriftProtocolError::SkipDepth(field_type));
         }
 
         match field_type {
@@ -431,7 +496,7 @@ pub(crate) trait ThriftCompactInputProtocol<'a> {
                 Ok(())
             }
             // no list or map types in parquet format
-            u => Err(general_err!(format!("cannot skip field type {:?}", &u))),
+            _ => Err(ThriftProtocolError::SkipUnsupportedType(field_type)),
         }
     }
 }
@@ -455,44 +520,40 @@ impl<'a> ThriftSliceInputProtocol<'a> {
 
 impl<'b, 'a: 'b> ThriftCompactInputProtocol<'b> for 
ThriftSliceInputProtocol<'a> {
     #[inline]
-    fn read_byte(&mut self) -> Result<u8> {
-        let ret = *self.buf.first().ok_or_else(eof_error)?;
+    fn read_byte(&mut self) -> ThriftProtocolResult<u8> {
+        let ret = *self.buf.first().ok_or(ThriftProtocolError::Eof)?;
         self.buf = &self.buf[1..];
         Ok(ret)
     }
 
-    fn read_bytes(&mut self) -> Result<&'b [u8]> {
+    fn read_bytes(&mut self) -> ThriftProtocolResult<&'b [u8]> {
         let len = self.read_vlq()? as usize;
-        let ret = self.buf.get(..len).ok_or_else(eof_error)?;
+        let ret = self.buf.get(..len).ok_or(ThriftProtocolError::Eof)?;
         self.buf = &self.buf[len..];
         Ok(ret)
     }
 
-    fn read_bytes_owned(&mut self) -> Result<Vec<u8>> {
+    fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>> {
         Ok(self.read_bytes()?.to_vec())
     }
 
     #[inline]
-    fn skip_bytes(&mut self, n: usize) -> Result<()> {
-        self.buf.get(..n).ok_or_else(eof_error)?;
+    fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()> {
+        self.buf.get(..n).ok_or(ThriftProtocolError::Eof)?;
         self.buf = &self.buf[n..];
         Ok(())
     }
 
-    fn read_double(&mut self) -> Result<f64> {
-        let slice = self.buf.get(..8).ok_or_else(eof_error)?;
+    fn read_double(&mut self) -> ThriftProtocolResult<f64> {
+        let slice = self.buf.get(..8).ok_or(ThriftProtocolError::Eof)?;
         self.buf = &self.buf[8..];
         match slice.try_into() {
             Ok(slice) => Ok(f64::from_le_bytes(slice)),
-            Err(_) => Err(general_err!("Unexpected error converting slice")),
+            Err(_) => unreachable!(),
         }
     }
 }
 
-fn eof_error() -> ParquetError {
-    eof_err!("Unexpected EOF")
-}
-
 /// A Thrift input protocol that wraps a [`Read`] object.
 ///
 /// Note that this is only intended for use in reading Parquet page headers. 
This will panic
@@ -509,24 +570,24 @@ impl<R: Read> ThriftReadInputProtocol<R> {
 
 impl<'a, R: Read> ThriftCompactInputProtocol<'a> for 
ThriftReadInputProtocol<R> {
     #[inline]
-    fn read_byte(&mut self) -> Result<u8> {
+    fn read_byte(&mut self) -> ThriftProtocolResult<u8> {
         let mut buf = [0_u8; 1];
         self.reader.read_exact(&mut buf)?;
         Ok(buf[0])
     }
 
-    fn read_bytes(&mut self) -> Result<&'a [u8]> {
+    fn read_bytes(&mut self) -> ThriftProtocolResult<&'a [u8]> {
         unimplemented!()
     }
 
-    fn read_bytes_owned(&mut self) -> Result<Vec<u8>> {
+    fn read_bytes_owned(&mut self) -> ThriftProtocolResult<Vec<u8>> {
         let len = self.read_vlq()? as usize;
         let mut v = Vec::with_capacity(len);
         std::io::copy(&mut self.reader.by_ref().take(len as u64), &mut v)?;
         Ok(v)
     }
 
-    fn skip_bytes(&mut self, n: usize) -> Result<()> {
+    fn skip_bytes(&mut self, n: usize) -> ThriftProtocolResult<()> {
         std::io::copy(
             &mut self.reader.by_ref().take(n as u64),
             &mut std::io::sink(),
@@ -534,7 +595,7 @@ impl<'a, R: Read> ThriftCompactInputProtocol<'a> for 
ThriftReadInputProtocol<R>
         Ok(())
     }
 
-    fn read_double(&mut self) -> Result<f64> {
+    fn read_double(&mut self) -> ThriftProtocolResult<f64> {
         let mut buf = [0_u8; 8];
         self.reader.read_exact(&mut buf)?;
         Ok(f64::from_le_bytes(buf))
@@ -552,31 +613,31 @@ pub(crate) trait ReadThrift<'a, R: 
ThriftCompactInputProtocol<'a>> {
 
 impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for bool {
     fn read_thrift(prot: &mut R) -> Result<Self> {
-        prot.read_bool()
+        Ok(prot.read_bool()?)
     }
 }
 
 impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i8 {
     fn read_thrift(prot: &mut R) -> Result<Self> {
-        prot.read_i8()
+        Ok(prot.read_i8()?)
     }
 }
 
 impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i16 {
     fn read_thrift(prot: &mut R) -> Result<Self> {
-        prot.read_i16()
+        Ok(prot.read_i16()?)
     }
 }
 
 impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i32 {
     fn read_thrift(prot: &mut R) -> Result<Self> {
-        prot.read_i32()
+        Ok(prot.read_i32()?)
     }
 }
 
 impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for i64 {
     fn read_thrift(prot: &mut R) -> Result<Self> {
-        prot.read_i64()
+        Ok(prot.read_i64()?)
     }
 }
 
@@ -588,7 +649,7 @@ impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, 
R> for OrderedF64 {
 
 impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a str {
     fn read_thrift(prot: &mut R) -> Result<Self> {
-        prot.read_string()
+        Ok(prot.read_string()?)
     }
 }
 
@@ -600,7 +661,7 @@ impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, 
R> for String {
 
 impl<'a, R: ThriftCompactInputProtocol<'a>> ReadThrift<'a, R> for &'a [u8] {
     fn read_thrift(prot: &mut R) -> Result<Self> {
-        prot.read_bytes()
+        Ok(prot.read_bytes()?)
     }
 }
 

Reply via email to