This is an automated email from the ASF dual-hosted git repository.

areeve pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 857614c87e Fix reading encrypted Parquet pages when using the page 
index (#7633)
857614c87e is described below

commit 857614c87e389509fa165238aa4e7cf9e00dbc7f
Author: Adam Reeve <[email protected]>
AuthorDate: Wed Jun 11 16:15:06 2025 +1200

    Fix reading encrypted Parquet pages when using the page index (#7633)
    
    # Which issue does this PR close?
    
    Closes #7629.
    
    I also noticed that skipping pages in encrypted files was broken so have
    fixed that too.
    
    # What changes are included in this PR?
    
    * Refactors `SerializedPageReader` to reduce the use of `#[cfg(...)]`
    inline. To work with the borrow checker, I created a new
    `SerializedPageReaderContext` type to hold the `CryptoContext`.
    * Updates `SerializedPageReader::get_next_page` so that page headers and
    page data are decrypted when page indexes are used.
    * Updates `SerializedPageReader::skip_next_page` to update the page
    index so that encryption AADs are calculated correctly.
    * Adds new unit tests for reading with a page index and skipping pages
    in encrypted files.
    
    # Are there any user-facing changes?
    
    Only bug fixes.
    
    ---------
    
    Co-authored-by: Ed Seidl <[email protected]>
---
 parquet/src/file/serialized_reader.rs  | 363 +++++++++++++++++++--------------
 parquet/tests/encryption/encryption.rs | 175 ++++++++++++++--
 2 files changed, 364 insertions(+), 174 deletions(-)

diff --git a/parquet/src/file/serialized_reader.rs 
b/parquet/src/file/serialized_reader.rs
index ea99530d67..5d50a8c49d 100644
--- a/parquet/src/file/serialized_reader.rs
+++ b/parquet/src/file/serialized_reader.rs
@@ -36,7 +36,9 @@ use crate::format::{PageHeader, PageLocation, PageType};
 use crate::record::reader::RowIter;
 use crate::record::Row;
 use crate::schema::types::Type as SchemaType;
-use crate::thrift::{TCompactSliceInputProtocol, TSerializable};
+#[cfg(feature = "encryption")]
+use crate::thrift::TCompactSliceInputProtocol;
+use crate::thrift::TSerializable;
 use bytes::Bytes;
 use std::collections::VecDeque;
 use std::iter;
@@ -177,7 +179,7 @@ pub struct ReadOptions {
 
 impl<R: 'static + ChunkReader> SerializedFileReader<R> {
     /// Creates file reader from a Parquet file.
-    /// Returns error if Parquet file does not exist or is corrupt.
+    /// Returns an error if the Parquet file does not exist or is corrupt.
     pub fn new(chunk_reader: R) -> Result<Self> {
         let metadata = 
ParquetMetaDataReader::new().parse_and_finish(&chunk_reader)?;
         let props = Arc::new(ReaderProperties::builder().build());
@@ -189,7 +191,7 @@ impl<R: 'static + ChunkReader> SerializedFileReader<R> {
     }
 
     /// Creates file reader from a Parquet file with read options.
-    /// Returns error if Parquet file does not exist or is corrupt.
+    /// Returns an error if the Parquet file does not exist or is corrupt.
     pub fn new_with_options(chunk_reader: R, options: ReadOptions) -> 
Result<Self> {
         let mut metadata_builder = ParquetMetaDataReader::new()
             .parse_and_finish(&chunk_reader)?
@@ -338,84 +340,6 @@ impl<R: 'static + ChunkReader> RowGroupReader for 
SerializedRowGroupReader<'_, R
     }
 }
 
-/// Reads a [`PageHeader`] from the provided [`Read`]
-pub(crate) fn read_page_header<T: Read>(input: &mut T) -> Result<PageHeader> {
-    let mut prot = TCompactInputProtocol::new(input);
-    Ok(PageHeader::read_from_in_protocol(&mut prot)?)
-}
-
-#[cfg(feature = "encryption")]
-pub(crate) fn read_encrypted_page_header<T: Read>(
-    input: &mut T,
-    crypto_context: Arc<CryptoContext>,
-) -> Result<PageHeader> {
-    let data_decryptor = crypto_context.data_decryptor();
-    let aad = crypto_context.create_page_header_aad()?;
-
-    let buf = read_and_decrypt(data_decryptor, input, 
aad.as_ref()).map_err(|_| {
-        ParquetError::General(format!(
-            "Error decrypting column {}, decryptor may be wrong or missing",
-            crypto_context.column_ordinal
-        ))
-    })?;
-
-    let mut prot = TCompactSliceInputProtocol::new(buf.as_slice());
-    Ok(PageHeader::read_from_in_protocol(&mut prot)?)
-}
-
-/// Reads a [`PageHeader`] from the provided [`Read`] returning the number of 
bytes read.
-/// If the page header is encrypted [`CryptoContext`] must be provided.
-#[cfg(feature = "encryption")]
-fn read_encrypted_page_header_len<T: Read>(
-    input: &mut T,
-    crypto_context: Option<Arc<CryptoContext>>,
-) -> Result<(usize, PageHeader)> {
-    /// A wrapper around a [`std::io::Read`] that keeps track of the bytes read
-    struct TrackedRead<R> {
-        inner: R,
-        bytes_read: usize,
-    }
-
-    impl<R: Read> Read for TrackedRead<R> {
-        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
-            let v = self.inner.read(buf)?;
-            self.bytes_read += v;
-            Ok(v)
-        }
-    }
-
-    let mut tracked = TrackedRead {
-        inner: input,
-        bytes_read: 0,
-    };
-    let header = read_encrypted_page_header(&mut tracked, 
crypto_context.unwrap())?;
-    Ok((tracked.bytes_read, header))
-}
-
-/// Reads a [`PageHeader`] from the provided [`Read`] returning the number of 
bytes read.
-fn read_page_header_len<T: Read>(input: &mut T) -> Result<(usize, PageHeader)> 
{
-    /// A wrapper around a [`std::io::Read`] that keeps track of the bytes read
-    struct TrackedRead<R> {
-        inner: R,
-        bytes_read: usize,
-    }
-
-    impl<R: Read> Read for TrackedRead<R> {
-        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
-            let v = self.inner.read(buf)?;
-            self.bytes_read += v;
-            Ok(v)
-        }
-    }
-
-    let mut tracked = TrackedRead {
-        inner: input,
-        bytes_read: 0,
-    };
-    let header = read_page_header(&mut tracked)?;
-    Ok((tracked.bytes_read, header))
-}
-
 /// Decodes a [`Page`] from the provided `buffer`
 pub(crate) fn decode_page(
     page_header: PageHeader,
@@ -554,7 +478,7 @@ enum SerializedPageReaderState {
         next_page_header: Option<Box<PageHeader>>,
 
         /// The index of the data page within this column chunk
-        page_ordinal: usize,
+        page_index: usize,
 
         /// Whether the next page is expected to be a dictionary page
         require_dictionary: bool,
@@ -566,9 +490,18 @@ enum SerializedPageReaderState {
         dictionary_page: Option<PageLocation>,
         /// The total number of rows in this column chunk
         total_rows: usize,
+        /// The index of the data page within this column chunk
+        page_index: usize,
     },
 }
 
+#[derive(Default)]
+struct SerializedPageReaderContext {
+    /// Crypto context carrying objects required for decryption
+    #[cfg(feature = "encryption")]
+    crypto_context: Option<Arc<CryptoContext>>,
+}
+
 /// A serialized implementation for Parquet [`PageReader`].
 pub struct SerializedPageReader<R: ChunkReader> {
     /// The chunk reader
@@ -582,9 +515,7 @@ pub struct SerializedPageReader<R: ChunkReader> {
 
     state: SerializedPageReaderState,
 
-    /// Crypto context carrying objects required for decryption
-    #[cfg(feature = "encryption")]
-    crypto_context: Option<Arc<CryptoContext>>,
+    context: SerializedPageReaderContext,
 }
 
 impl<R: ChunkReader> SerializedPageReader<R> {
@@ -634,7 +565,7 @@ impl<R: ChunkReader> SerializedPageReader<R> {
         };
         let crypto_context =
             CryptoContext::for_column(file_decryptor, crypto_metadata, rg_idx, 
column_idx)?;
-        self.crypto_context = Some(Arc::new(crypto_context));
+        self.context.crypto_context = Some(Arc::new(crypto_context));
         Ok(self)
     }
 
@@ -651,6 +582,8 @@ impl<R: ChunkReader> SerializedPageReader<R> {
 
         let state = match page_locations {
             Some(locations) => {
+                // If the offset of the first page doesn't match the start of 
the column chunk
+                // then the preceding space must contain a dictionary page.
                 let dictionary_page = match locations.first() {
                     Some(dict_offset) if dict_offset.offset as u64 != start => 
Some(PageLocation {
                         offset: start as i64,
@@ -664,13 +597,14 @@ impl<R: ChunkReader> SerializedPageReader<R> {
                     page_locations: locations.into(),
                     dictionary_page,
                     total_rows,
+                    page_index: 0,
                 }
             }
             None => SerializedPageReaderState::Values {
                 offset: usize::try_from(start)?,
                 remaining_bytes: usize::try_from(len)?,
                 next_page_header: None,
-                page_ordinal: 0,
+                page_index: 0,
                 require_dictionary: meta.dictionary_page_offset().is_some(),
             },
         };
@@ -679,8 +613,7 @@ impl<R: ChunkReader> SerializedPageReader<R> {
             decompressor,
             state,
             physical_type: meta.column_type(),
-            #[cfg(feature = "encryption")]
-            crypto_context: None,
+            context: Default::default(),
         })
     }
 
@@ -696,7 +629,8 @@ impl<R: ChunkReader> SerializedPageReader<R> {
                 offset,
                 remaining_bytes,
                 next_page_header,
-                ..
+                page_index,
+                require_dictionary,
             } => {
                 loop {
                     if *remaining_bytes == 0 {
@@ -712,7 +646,12 @@ impl<R: ChunkReader> SerializedPageReader<R> {
                         }
                     } else {
                         let mut read = self.reader.get_read(*offset as u64)?;
-                        let (header_len, header) = read_page_header_len(&mut 
read)?;
+                        let (header_len, header) = Self::read_page_header_len(
+                            &self.context,
+                            &mut read,
+                            *page_index,
+                            *require_dictionary,
+                        )?;
                         *offset += header_len;
                         *remaining_bytes -= header_len;
                         let page_meta = if let Ok(_page_meta) = 
PageMetadata::try_from(&header) {
@@ -741,6 +680,129 @@ impl<R: ChunkReader> SerializedPageReader<R> {
             }
         }
     }
+
+    fn read_page_header_len<T: Read>(
+        context: &SerializedPageReaderContext,
+        input: &mut T,
+        page_index: usize,
+        dictionary_page: bool,
+    ) -> Result<(usize, PageHeader)> {
+        /// A wrapper around a [`std::io::Read`] that keeps track of the bytes 
read
+        struct TrackedRead<R> {
+            inner: R,
+            bytes_read: usize,
+        }
+
+        impl<R: Read> Read for TrackedRead<R> {
+            fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
+                let v = self.inner.read(buf)?;
+                self.bytes_read += v;
+                Ok(v)
+            }
+        }
+
+        let mut tracked = TrackedRead {
+            inner: input,
+            bytes_read: 0,
+        };
+        let header = context.read_page_header(&mut tracked, page_index, 
dictionary_page)?;
+        Ok((tracked.bytes_read, header))
+    }
+
+    fn read_page_header_len_from_bytes(
+        context: &SerializedPageReaderContext,
+        buffer: &[u8],
+        page_index: usize,
+        dictionary_page: bool,
+    ) -> Result<(usize, PageHeader)> {
+        let mut input = std::io::Cursor::new(buffer);
+        let header = context.read_page_header(&mut input, page_index, 
dictionary_page)?;
+        let header_len = input.position() as usize;
+        Ok((header_len, header))
+    }
+}
+
+#[cfg(not(feature = "encryption"))]
+impl SerializedPageReaderContext {
+    fn read_page_header<T: Read>(
+        &self,
+        input: &mut T,
+        _page_index: usize,
+        _dictionary_page: bool,
+    ) -> Result<PageHeader> {
+        let mut prot = TCompactInputProtocol::new(input);
+        Ok(PageHeader::read_from_in_protocol(&mut prot)?)
+    }
+
+    fn decrypt_page_data<T>(
+        &self,
+        buffer: T,
+        _page_index: usize,
+        _dictionary_page: bool,
+    ) -> Result<T> {
+        Ok(buffer)
+    }
+}
+
+#[cfg(feature = "encryption")]
+impl SerializedPageReaderContext {
+    fn read_page_header<T: Read>(
+        &self,
+        input: &mut T,
+        page_index: usize,
+        dictionary_page: bool,
+    ) -> Result<PageHeader> {
+        match self.page_crypto_context(page_index, dictionary_page) {
+            None => {
+                let mut prot = TCompactInputProtocol::new(input);
+                Ok(PageHeader::read_from_in_protocol(&mut prot)?)
+            }
+            Some(page_crypto_context) => {
+                let data_decryptor = page_crypto_context.data_decryptor();
+                let aad = page_crypto_context.create_page_header_aad()?;
+
+                let buf = read_and_decrypt(data_decryptor, input, 
aad.as_ref()).map_err(|_| {
+                    ParquetError::General(format!(
+                        "Error decrypting page header for column {}, 
decryption key may be wrong",
+                        page_crypto_context.column_ordinal
+                    ))
+                })?;
+
+                let mut prot = TCompactSliceInputProtocol::new(buf.as_slice());
+                Ok(PageHeader::read_from_in_protocol(&mut prot)?)
+            }
+        }
+    }
+
+    fn decrypt_page_data<T>(&self, buffer: T, page_index: usize, 
dictionary_page: bool) -> Result<T>
+    where
+        T: AsRef<[u8]>,
+        T: From<Vec<u8>>,
+    {
+        let page_crypto_context = self.page_crypto_context(page_index, 
dictionary_page);
+        if let Some(page_crypto_context) = page_crypto_context {
+            let decryptor = page_crypto_context.data_decryptor();
+            let aad = page_crypto_context.create_page_aad()?;
+            let decrypted = decryptor.decrypt(buffer.as_ref(), &aad)?;
+            Ok(T::from(decrypted))
+        } else {
+            Ok(buffer)
+        }
+    }
+
+    fn page_crypto_context(
+        &self,
+        page_index: usize,
+        dictionary_page: bool,
+    ) -> Option<Arc<CryptoContext>> {
+        self.crypto_context.as_ref().map(|c| {
+            Arc::new(if dictionary_page {
+                c.for_dictionary_page()
+            } else {
+                c.with_page_ordinal(page_index)
+            })
+        })
+    }
 }
 
 impl<R: ChunkReader> Iterator for SerializedPageReader<R> {
@@ -780,7 +842,7 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> 
{
                     offset,
                     remaining_bytes: remaining,
                     next_page_header,
-                    page_ordinal,
+                    page_index,
                     require_dictionary,
                 } => {
                     if *remaining == 0 {
@@ -791,21 +853,12 @@ impl<R: ChunkReader> PageReader for 
SerializedPageReader<R> {
                     let header = if let Some(header) = next_page_header.take() 
{
                         *header
                     } else {
-                        #[cfg(feature = "encryption")]
-                        let (header_len, header) = if 
self.crypto_context.is_some() {
-                            let crypto_context = page_crypto_context(
-                                &self.crypto_context,
-                                *page_ordinal,
-                                *require_dictionary,
-                            )?;
-                            read_encrypted_page_header_len(&mut read, 
crypto_context)?
-                        } else {
-                            read_page_header_len(&mut read)?
-                        };
-
-                        #[cfg(not(feature = "encryption"))]
-                        let (header_len, header) = read_page_header_len(&mut 
read)?;
-
+                        let (header_len, header) = Self::read_page_header_len(
+                            &self.context,
+                            &mut read,
+                            *page_index,
+                            *require_dictionary,
+                        )?;
                         verify_page_header_len(header_len, *remaining)?;
                         *offset += header_len;
                         *remaining -= header_len;
@@ -835,20 +888,9 @@ impl<R: ChunkReader> PageReader for 
SerializedPageReader<R> {
                         ));
                     }
 
-                    #[cfg(feature = "encryption")]
-                    let crypto_context = page_crypto_context(
-                        &self.crypto_context,
-                        *page_ordinal,
-                        *require_dictionary,
-                    )?;
-                    #[cfg(feature = "encryption")]
-                    let buffer: Vec<u8> = if let Some(crypto_context) = 
crypto_context {
-                        let decryptor = crypto_context.data_decryptor();
-                        let aad = crypto_context.create_page_aad()?;
-                        decryptor.decrypt(buffer.as_ref(), &aad)?
-                    } else {
-                        buffer
-                    };
+                    let buffer =
+                        self.context
+                            .decrypt_page_data(buffer, *page_index, 
*require_dictionary)?;
 
                     let page = decode_page(
                         header,
@@ -857,7 +899,7 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> 
{
                         self.decompressor.as_mut(),
                     )?;
                     if page.is_data_page() {
-                        *page_ordinal += 1;
+                        *page_index += 1;
                     } else if page.is_dictionary_page() {
                         *require_dictionary = false;
                     }
@@ -866,25 +908,34 @@ impl<R: ChunkReader> PageReader for 
SerializedPageReader<R> {
                 SerializedPageReaderState::Pages {
                     page_locations,
                     dictionary_page,
+                    page_index,
                     ..
                 } => {
-                    let front = match dictionary_page
-                        .take()
-                        .or_else(|| page_locations.pop_front())
-                    {
-                        Some(front) => front,
-                        None => return Ok(None),
+                    let (front, is_dictionary_page) = match 
dictionary_page.take() {
+                        Some(front) => (front, true),
+                        None => match page_locations.pop_front() {
+                            Some(front) => (front, false),
+                            None => return Ok(None),
+                        },
                     };
 
                     let page_len = 
usize::try_from(front.compressed_page_size)?;
-
                     let buffer = self.reader.get_bytes(front.offset as u64, 
page_len)?;
 
-                    let mut prot = 
TCompactSliceInputProtocol::new(buffer.as_ref());
-                    let header = PageHeader::read_from_in_protocol(&mut prot)?;
-                    let offset = buffer.len() - prot.as_slice().len();
-
+                    let (offset, header) = 
Self::read_page_header_len_from_bytes(
+                        &self.context,
+                        buffer.as_ref(),
+                        *page_index,
+                        is_dictionary_page,
+                    )?;
                     let bytes = buffer.slice(offset..);
+                    let bytes =
+                        self.context
+                            .decrypt_page_data(bytes, *page_index, 
is_dictionary_page)?;
+
+                    if !is_dictionary_page {
+                        *page_index += 1;
+                    }
                     decode_page(
                         header,
                         bytes,
@@ -904,7 +955,8 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> 
{
                 offset,
                 remaining_bytes,
                 next_page_header,
-                ..
+                page_index,
+                require_dictionary,
             } => {
                 loop {
                     if *remaining_bytes == 0 {
@@ -920,7 +972,12 @@ impl<R: ChunkReader> PageReader for 
SerializedPageReader<R> {
                         }
                     } else {
                         let mut read = self.reader.get_read(*offset as u64)?;
-                        let (header_len, header) = read_page_header_len(&mut 
read)?;
+                        let (header_len, header) = Self::read_page_header_len(
+                            &self.context,
+                            &mut read,
+                            *page_index,
+                            *require_dictionary,
+                        )?;
                         verify_page_header_len(header_len, *remaining_bytes)?;
                         *offset += header_len;
                         *remaining_bytes -= header_len;
@@ -939,6 +996,7 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> 
{
                 page_locations,
                 dictionary_page,
                 total_rows,
+                page_index: _,
             } => {
                 if dictionary_page.is_some() {
                     Ok(Some(PageMetadata {
@@ -970,7 +1028,8 @@ impl<R: ChunkReader> PageReader for 
SerializedPageReader<R> {
                 offset,
                 remaining_bytes,
                 next_page_header,
-                ..
+                page_index,
+                require_dictionary,
             } => {
                 if let Some(buffered_header) = next_page_header.take() {
                     verify_page_size(
@@ -983,7 +1042,12 @@ impl<R: ChunkReader> PageReader for 
SerializedPageReader<R> {
                     *remaining_bytes -= buffered_header.compressed_page_size 
as usize;
                 } else {
                     let mut read = self.reader.get_read(*offset as u64)?;
-                    let (header_len, header) = read_page_header_len(&mut 
read)?;
+                    let (header_len, header) = Self::read_page_header_len(
+                        &self.context,
+                        &mut read,
+                        *page_index,
+                        *require_dictionary,
+                    )?;
                     verify_page_header_len(header_len, *remaining_bytes)?;
                     verify_page_size(
                         header.compressed_page_size,
@@ -994,11 +1058,17 @@ impl<R: ChunkReader> PageReader for 
SerializedPageReader<R> {
                     *offset += header_len + data_page_size;
                     *remaining_bytes -= header_len + data_page_size;
                 }
+                if *require_dictionary {
+                    *require_dictionary = false;
+                } else {
+                    *page_index += 1;
+                }
                 Ok(())
             }
             SerializedPageReaderState::Pages {
                 page_locations,
                 dictionary_page,
+                page_index,
                 ..
             } => {
                 if dictionary_page.is_some() {
@@ -1006,7 +1076,9 @@ impl<R: ChunkReader> PageReader for 
SerializedPageReader<R> {
                     dictionary_page.take();
                 } else {
                     // If no dictionary page exists, simply pop the data page 
from page_locations
-                    page_locations.pop_front();
+                    if page_locations.pop_front().is_some() {
+                        *page_index += 1;
+                    }
                 }
 
                 Ok(())
@@ -1022,21 +1094,6 @@ impl<R: ChunkReader> PageReader for 
SerializedPageReader<R> {
     }
 }
 
-#[cfg(feature = "encryption")]
-fn page_crypto_context(
-    crypto_context: &Option<Arc<CryptoContext>>,
-    page_ordinal: usize,
-    dictionary_page: bool,
-) -> Result<Option<Arc<CryptoContext>>> {
-    Ok(crypto_context.as_ref().map(|c| {
-        Arc::new(if dictionary_page {
-            c.for_dictionary_page()
-        } else {
-            c.with_page_ordinal(page_ordinal)
-        })
-    }))
-}
-
 #[cfg(test)]
 mod tests {
     use std::collections::HashSet;
diff --git a/parquet/tests/encryption/encryption.rs 
b/parquet/tests/encryption/encryption.rs
index a46794a85f..7079e91d12 100644
--- a/parquet/tests/encryption/encryption.rs
+++ b/parquet/tests/encryption/encryption.rs
@@ -25,7 +25,8 @@ use arrow::error::Result as ArrowResult;
 use arrow_array::{Int32Array, RecordBatch};
 use arrow_schema::{DataType as ArrowDataType, DataType, Field, Schema};
 use parquet::arrow::arrow_reader::{
-    ArrowReaderMetadata, ArrowReaderOptions, ParquetRecordBatchReaderBuilder,
+    ArrowReaderMetadata, ArrowReaderOptions, ParquetRecordBatchReaderBuilder, 
RowSelection,
+    RowSelector,
 };
 use parquet::arrow::ArrowWriter;
 use parquet::data_type::{ByteArray, ByteArrayType};
@@ -397,6 +398,28 @@ fn row_group_sizes(metadata: &ParquetMetaData) -> Vec<i64> 
{
 
 #[test]
 fn test_uniform_encryption_roundtrip() {
+    uniform_encryption_roundtrip(false, false).unwrap();
+}
+
+#[test]
+fn test_uniform_encryption_roundtrip_with_dictionary() {
+    uniform_encryption_roundtrip(false, true).unwrap();
+}
+
+#[test]
+fn test_uniform_encryption_roundtrip_with_page_index() {
+    uniform_encryption_roundtrip(true, false).unwrap();
+}
+
+#[test]
+fn test_uniform_encryption_roundtrip_with_page_index_and_dictionary() {
+    uniform_encryption_roundtrip(true, true).unwrap();
+}
+
+fn uniform_encryption_roundtrip(
+    page_index: bool,
+    dictionary_encoding: bool,
+) -> parquet::errors::Result<()> {
     let x0_arrays = [
         Int32Array::from((0..100).collect::<Vec<_>>()),
         Int32Array::from((100..150).collect::<Vec<_>>()),
@@ -411,12 +434,11 @@ fn test_uniform_encryption_roundtrip() {
         Field::new("x1", ArrowDataType::Int32, false),
     ]));
 
-    let file = tempfile::tempfile().unwrap();
+    let file = tempfile::tempfile()?;
 
     let footer_key = b"0123456789012345";
-    let file_encryption_properties = 
FileEncryptionProperties::builder(footer_key.to_vec())
-        .build()
-        .unwrap();
+    let file_encryption_properties =
+        FileEncryptionProperties::builder(footer_key.to_vec()).build()?;
 
     let props = WriterProperties::builder()
         // Ensure multiple row groups
@@ -424,34 +446,32 @@ fn test_uniform_encryption_roundtrip() {
         // Ensure multiple pages per row group
         .set_write_batch_size(20)
         .set_data_page_row_count_limit(20)
+        .set_dictionary_enabled(dictionary_encoding)
         .with_file_encryption_properties(file_encryption_properties)
         .build();
 
-    let mut writer =
-        ArrowWriter::try_new(file.try_clone().unwrap(), schema.clone(), 
Some(props)).unwrap();
+    let mut writer = ArrowWriter::try_new(file.try_clone()?, schema.clone(), 
Some(props))?;
 
     for (x0, x1) in x0_arrays.into_iter().zip(x1_arrays.into_iter()) {
-        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(x0), 
Arc::new(x1)]).unwrap();
-        writer.write(&batch).unwrap();
+        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(x0), 
Arc::new(x1)])?;
+        writer.write(&batch)?;
     }
 
-    writer.close().unwrap();
+    writer.close()?;
 
-    let decryption_properties = 
FileDecryptionProperties::builder(footer_key.to_vec())
-        .build()
-        .unwrap();
+    let decryption_properties = 
FileDecryptionProperties::builder(footer_key.to_vec()).build()?;
 
-    let options = 
ArrowReaderOptions::new().with_file_decryption_properties(decryption_properties);
+    let options = ArrowReaderOptions::new()
+        .with_file_decryption_properties(decryption_properties)
+        .with_page_index(page_index);
 
-    let builder = ParquetRecordBatchReaderBuilder::try_new_with_options(file, 
options).unwrap();
+    let builder = ParquetRecordBatchReaderBuilder::try_new_with_options(file, 
options)?;
     assert_eq!(&row_group_sizes(builder.metadata()), &[50, 50, 50]);
 
     let batches = builder
         .with_batch_size(100)
-        .build()
-        .unwrap()
-        .collect::<ArrowResult<Vec<_>>>()
-        .unwrap();
+        .build()?
+        .collect::<ArrowResult<Vec<_>>>()?;
 
     assert_eq!(batches.len(), 2);
     assert!(batches.iter().all(|x| x.num_columns() == 2));
@@ -486,11 +506,124 @@ fn test_uniform_encryption_roundtrip() {
         })
         .collect();
 
-    let expected_x0_values: Vec<_> = [0..100, 
100..150].into_iter().flatten().collect();
+    let expected_x0_values: Vec<_> = (0..150).collect();
     assert_eq!(&x0_values, &expected_x0_values);
 
-    let expected_x1_values: Vec<_> = [100..200, 
200..250].into_iter().flatten().collect();
+    let expected_x1_values: Vec<_> = (100..250).collect();
     assert_eq!(&x1_values, &expected_x1_values);
+    Ok(())
+}
+
+#[test]
+fn test_uniform_encryption_page_skipping() {
+    uniform_encryption_page_skipping(false).unwrap();
+}
+
+#[test]
+fn test_uniform_encryption_page_skipping_with_page_index() {
+    uniform_encryption_page_skipping(true).unwrap();
+}
+
+fn uniform_encryption_page_skipping(page_index: bool) -> 
parquet::errors::Result<()> {
+    let x0_arrays = [
+        Int32Array::from((0..100).collect::<Vec<_>>()),
+        Int32Array::from((100..150).collect::<Vec<_>>()),
+    ];
+    let x1_arrays = [
+        Int32Array::from((100..200).collect::<Vec<_>>()),
+        Int32Array::from((200..250).collect::<Vec<_>>()),
+    ];
+
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("x0", ArrowDataType::Int32, false),
+        Field::new("x1", ArrowDataType::Int32, false),
+    ]));
+
+    let file = tempfile::tempfile()?;
+
+    let footer_key = b"0123456789012345";
+    let file_encryption_properties =
+        FileEncryptionProperties::builder(footer_key.to_vec()).build()?;
+
+    let props = WriterProperties::builder()
+        // Ensure multiple row groups
+        .set_max_row_group_size(50)
+        // Ensure multiple pages per row group
+        .set_write_batch_size(20)
+        .set_data_page_row_count_limit(20)
+        .with_file_encryption_properties(file_encryption_properties)
+        .build();
+
+    let mut writer = ArrowWriter::try_new(file.try_clone()?, schema.clone(), 
Some(props))?;
+
+    for (x0, x1) in x0_arrays.into_iter().zip(x1_arrays.into_iter()) {
+        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(x0), 
Arc::new(x1)])?;
+        writer.write(&batch)?;
+    }
+
+    writer.close()?;
+
+    let decryption_properties = 
FileDecryptionProperties::builder(footer_key.to_vec()).build()?;
+
+    let options = ArrowReaderOptions::new()
+        .with_file_decryption_properties(decryption_properties)
+        .with_page_index(page_index);
+
+    let builder = ParquetRecordBatchReaderBuilder::try_new_with_options(file, 
options)?;
+
+    let selection = RowSelection::from(vec![
+        RowSelector::skip(25),
+        RowSelector::select(50),
+        RowSelector::skip(25),
+        RowSelector::select(25),
+        RowSelector::skip(25),
+    ]);
+
+    let batches = builder
+        .with_row_selection(selection)
+        .with_batch_size(100)
+        .build()?
+        .collect::<ArrowResult<Vec<_>>>()?;
+
+    assert_eq!(batches.len(), 1);
+    assert!(batches.iter().all(|x| x.num_columns() == 2));
+
+    let batch_sizes: Vec<_> = batches.iter().map(|x| x.num_rows()).collect();
+
+    assert_eq!(&batch_sizes, &[75]);
+
+    let x0_values: Vec<_> = batches
+        .iter()
+        .flat_map(|x| {
+            x.column(0)
+                .as_any()
+                .downcast_ref::<Int32Array>()
+                .unwrap()
+                .values()
+                .iter()
+                .cloned()
+        })
+        .collect();
+
+    let x1_values: Vec<_> = batches
+        .iter()
+        .flat_map(|x| {
+            x.column(1)
+                .as_any()
+                .downcast_ref::<Int32Array>()
+                .unwrap()
+                .values()
+                .iter()
+                .cloned()
+        })
+        .collect();
+
+    let expected_x0_values: Vec<_> = [25..75, 
100..125].into_iter().flatten().collect();
+    assert_eq!(&x0_values, &expected_x0_values);
+
+    let expected_x1_values: Vec<_> = [125..175, 
200..225].into_iter().flatten().collect();
+    assert_eq!(&x1_values, &expected_x1_values);
+    Ok(())
 }
 
 #[test]

Reply via email to