This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 1f9973cd6 Separate ArrayReader::next_batch with read_records and 
consume_batch (#2237)
1f9973cd6 is described below

commit 1f9973cd6cdde73883da50d9417126390437ef76
Author: Yang Jiang <[email protected]>
AuthorDate: Wed Aug 3 15:50:52 2022 +0800

    Separate ArrayReader::next_batch with read_records and consume_batch (#2237)
    
    * replace ArrayReader::next_batch with ArrayReader::read_records and 
ArrayReader::consume_batch.
    
    * fix ut
    
    * fix comment
    
    * avoid clone.
    
    * fix new ut
    
    * fix comment
    
    Co-authored-by: Raphael Taylor-Davies <[email protected]>
---
 parquet/src/arrow/array_reader/byte_array.rs       |   7 +-
 .../arrow/array_reader/byte_array_dictionary.rs    |   9 +-
 .../src/arrow/array_reader/complex_object_array.rs | 133 +++++++++++++++------
 parquet/src/arrow/array_reader/empty_array.rs      |  12 +-
 parquet/src/arrow/array_reader/list_array.rs       |   8 +-
 parquet/src/arrow/array_reader/map_array.rs        |  18 ++-
 parquet/src/arrow/array_reader/mod.rs              |  15 ++-
 parquet/src/arrow/array_reader/null_array.rs       |   9 +-
 parquet/src/arrow/array_reader/primitive_array.rs  |   7 +-
 parquet/src/arrow/array_reader/struct_array.rs     |  27 ++++-
 parquet/src/arrow/array_reader/test_util.rs        |  15 ++-
 parquet/src/arrow/arrow_reader.rs                  |  38 ++++++
 parquet/src/arrow/arrow_writer/mod.rs              |   2 +-
 13 files changed, 233 insertions(+), 67 deletions(-)

diff --git a/parquet/src/arrow/array_reader/byte_array.rs 
b/parquet/src/arrow/array_reader/byte_array.rs
index ec4188890..172aeb96d 100644
--- a/parquet/src/arrow/array_reader/byte_array.rs
+++ b/parquet/src/arrow/array_reader/byte_array.rs
@@ -108,8 +108,11 @@ impl<I: OffsetSizeTrait + ScalarValue> ArrayReader for 
ByteArrayReader<I> {
         &self.data_type
     }
 
-    fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
-        read_records(&mut self.record_reader, self.pages.as_mut(), 
batch_size)?;
+    fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+        read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)
+    }
+
+    fn consume_batch(&mut self) -> Result<ArrayRef> {
         let buffer = self.record_reader.consume_record_data();
         let null_buffer = self.record_reader.consume_bitmap_buffer();
         self.def_levels_buffer = self.record_reader.consume_def_levels();
diff --git a/parquet/src/arrow/array_reader/byte_array_dictionary.rs 
b/parquet/src/arrow/array_reader/byte_array_dictionary.rs
index 51ef38d0d..0a5d94fa6 100644
--- a/parquet/src/arrow/array_reader/byte_array_dictionary.rs
+++ b/parquet/src/arrow/array_reader/byte_array_dictionary.rs
@@ -25,7 +25,7 @@ use arrow::buffer::Buffer;
 use arrow::datatypes::{ArrowNativeType, DataType as ArrowType};
 
 use crate::arrow::array_reader::byte_array::{ByteArrayDecoder, 
ByteArrayDecoderPlain};
-use crate::arrow::array_reader::{read_records, ArrayReader, skip_records};
+use crate::arrow::array_reader::{read_records, skip_records, ArrayReader};
 use crate::arrow::buffer::{
     dictionary_buffer::DictionaryBuffer, offset_buffer::OffsetBuffer,
 };
@@ -167,8 +167,11 @@ where
         &self.data_type
     }
 
-    fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
-        read_records(&mut self.record_reader, self.pages.as_mut(), 
batch_size)?;
+    fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+        read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)
+    }
+
+    fn consume_batch(&mut self) -> Result<ArrayRef> {
         let buffer = self.record_reader.consume_record_data();
         let null_buffer = self.record_reader.consume_bitmap_buffer();
         let array = buffer.into_array(null_buffer, &self.data_type)?;
diff --git a/parquet/src/arrow/array_reader/complex_object_array.rs 
b/parquet/src/arrow/array_reader/complex_object_array.rs
index 1390866cf..79b537331 100644
--- a/parquet/src/arrow/array_reader/complex_object_array.rs
+++ b/parquet/src/arrow/array_reader/complex_object_array.rs
@@ -39,9 +39,13 @@ where
     pages: Box<dyn PageIterator>,
     def_levels_buffer: Option<Vec<i16>>,
     rep_levels_buffer: Option<Vec<i16>>,
+    data_buffer: Vec<T::T>,
     column_desc: ColumnDescPtr,
     column_reader: Option<ColumnReaderImpl<T>>,
     converter: C,
+    in_progress_def_levels_buffer: Option<Vec<i16>>,
+    in_progress_rep_levels_buffer: Option<Vec<i16>>,
+    before_consume: bool,
     _parquet_type_marker: PhantomData<T>,
     _converter_marker: PhantomData<C>,
 }
@@ -59,7 +63,10 @@ where
         &self.data_type
     }
 
-    fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
+    fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+        if !self.before_consume {
+            self.before_consume = true;
+        }
         // Try to initialize column reader
         if self.column_reader.is_none() {
             self.next_column_reader()?;
@@ -126,7 +133,6 @@ where
                 break;
             }
         }
-
         data_buffer.truncate(num_read);
         def_levels_buffer
             .iter_mut()
@@ -135,13 +141,35 @@ where
             .iter_mut()
             .for_each(|buf| buf.truncate(num_read));
 
-        self.def_levels_buffer = def_levels_buffer;
-        self.rep_levels_buffer = rep_levels_buffer;
+        if let Some(mut def_levels_buffer) = def_levels_buffer {
+            match &mut self.in_progress_def_levels_buffer {
+                None => {
+                    self.in_progress_def_levels_buffer = 
Some(def_levels_buffer);
+                }
+                Some(buf) => buf.append(&mut def_levels_buffer),
+            }
+        }
+
+        if let Some(mut rep_levels_buffer) = rep_levels_buffer {
+            match &mut self.in_progress_rep_levels_buffer {
+                None => {
+                    self.in_progress_rep_levels_buffer = 
Some(rep_levels_buffer);
+                }
+                Some(buf) => buf.append(&mut rep_levels_buffer),
+            }
+        }
+
+        self.data_buffer.append(&mut data_buffer);
+
+        Ok(num_read)
+    }
 
-        let data: Vec<Option<T::T>> = if self.def_levels_buffer.is_some() {
+    fn consume_batch(&mut self) -> Result<ArrayRef> {
+        let data: Vec<Option<T::T>> = if 
self.in_progress_def_levels_buffer.is_some() {
+            let data_buffer = std::mem::take(&mut self.data_buffer);
             data_buffer
                 .into_iter()
-                .zip(self.def_levels_buffer.as_ref().unwrap().iter())
+                
.zip(self.in_progress_def_levels_buffer.as_ref().unwrap().iter())
                 .map(|(t, def_level)| {
                     if *def_level == self.column_desc.max_def_level() {
                         Some(t)
@@ -151,7 +179,7 @@ where
                 })
                 .collect()
         } else {
-            data_buffer.into_iter().map(Some).collect()
+            self.data_buffer.iter().map(|x| Some(x.clone())).collect()
         };
 
         let mut array = self.converter.convert(data)?;
@@ -160,6 +188,11 @@ where
             array = arrow::compute::cast(&array, &self.data_type)?;
         }
 
+        self.data_buffer = vec![];
+        self.def_levels_buffer = std::mem::take(&mut 
self.in_progress_def_levels_buffer);
+        self.rep_levels_buffer = std::mem::take(&mut 
self.in_progress_rep_levels_buffer);
+        self.before_consume = false;
+
         Ok(array)
     }
 
@@ -168,8 +201,11 @@ where
             Some(reader) => reader.skip_records(num_records),
             None => {
                 if self.next_column_reader()? {
-                    
self.column_reader.as_mut().unwrap().skip_records(num_records)
-                }else {
+                    self.column_reader
+                        .as_mut()
+                        .unwrap()
+                        .skip_records(num_records)
+                } else {
                     Ok(0)
                 }
             }
@@ -177,11 +213,19 @@ where
     }
 
     fn get_def_levels(&self) -> Option<&[i16]> {
-        self.def_levels_buffer.as_deref()
+        if self.before_consume {
+            self.in_progress_def_levels_buffer.as_deref()
+        } else {
+            self.def_levels_buffer.as_deref()
+        }
     }
 
     fn get_rep_levels(&self) -> Option<&[i16]> {
-        self.rep_levels_buffer.as_deref()
+        if self.before_consume {
+            self.in_progress_rep_levels_buffer.as_deref()
+        } else {
+            self.rep_levels_buffer.as_deref()
+        }
     }
 }
 
@@ -208,9 +252,13 @@ where
             pages,
             def_levels_buffer: None,
             rep_levels_buffer: None,
+            data_buffer: vec![],
             column_desc,
             column_reader: None,
             converter,
+            in_progress_def_levels_buffer: None,
+            in_progress_rep_levels_buffer: None,
+            before_consume: true,
             _parquet_type_marker: PhantomData,
             _converter_marker: PhantomData,
         })
@@ -349,30 +397,32 @@ mod tests {
 
         let mut accu_len: usize = 0;
 
-        let array = array_reader.next_batch(values_per_page / 2).unwrap();
-        assert_eq!(array.len(), values_per_page / 2);
+        let len = array_reader.read_records(values_per_page / 2).unwrap();
+        assert_eq!(len, values_per_page / 2);
         assert_eq!(
-            Some(&def_levels[accu_len..(accu_len + array.len())]),
+            Some(&def_levels[accu_len..(accu_len + len)]),
             array_reader.get_def_levels()
         );
         assert_eq!(
-            Some(&rep_levels[accu_len..(accu_len + array.len())]),
+            Some(&rep_levels[accu_len..(accu_len + len)]),
             array_reader.get_rep_levels()
         );
-        accu_len += array.len();
+        accu_len += len;
+        array_reader.consume_batch().unwrap();
 
         // Read next values_per_page values, the first values_per_page/2 ones 
are from the first column chunk,
         // and the last values_per_page/2 ones are from the second column chunk
-        let array = array_reader.next_batch(values_per_page).unwrap();
-        assert_eq!(array.len(), values_per_page);
+        let len = array_reader.read_records(values_per_page).unwrap();
+        assert_eq!(len, values_per_page);
         assert_eq!(
-            Some(&def_levels[accu_len..(accu_len + array.len())]),
+            Some(&def_levels[accu_len..(accu_len + len)]),
             array_reader.get_def_levels()
         );
         assert_eq!(
-            Some(&rep_levels[accu_len..(accu_len + array.len())]),
+            Some(&rep_levels[accu_len..(accu_len + len)]),
             array_reader.get_rep_levels()
         );
+        let array = array_reader.consume_batch().unwrap();
         let strings = array.as_any().downcast_ref::<StringArray>().unwrap();
         for i in 0..array.len() {
             if array.is_valid(i) {
@@ -384,19 +434,20 @@ mod tests {
                 assert_eq!(all_values[i + accu_len], None)
             }
         }
-        accu_len += array.len();
+        accu_len += len;
 
         // Try to read values_per_page values, however there are only 
values_per_page/2 values
-        let array = array_reader.next_batch(values_per_page).unwrap();
-        assert_eq!(array.len(), values_per_page / 2);
+        let len = array_reader.read_records(values_per_page).unwrap();
+        assert_eq!(len, values_per_page / 2);
         assert_eq!(
-            Some(&def_levels[accu_len..(accu_len + array.len())]),
+            Some(&def_levels[accu_len..(accu_len + len)]),
             array_reader.get_def_levels()
         );
         assert_eq!(
-            Some(&rep_levels[accu_len..(accu_len + array.len())]),
+            Some(&rep_levels[accu_len..(accu_len + len)]),
             array_reader.get_rep_levels()
         );
+        array_reader.consume_batch().unwrap();
     }
 
     #[test]
@@ -491,31 +542,34 @@ mod tests {
         let mut accu_len: usize = 0;
 
         // println!("---------- reading a batch of {} values ----------", 
values_per_page / 2);
-        let array = array_reader.next_batch(values_per_page / 2).unwrap();
-        assert_eq!(array.len(), values_per_page / 2);
+        let len = array_reader.read_records(values_per_page / 2).unwrap();
+        assert_eq!(len, values_per_page / 2);
         assert_eq!(
-            Some(&def_levels[accu_len..(accu_len + array.len())]),
+            Some(&def_levels[accu_len..(accu_len + len)]),
             array_reader.get_def_levels()
         );
         assert_eq!(
-            Some(&rep_levels[accu_len..(accu_len + array.len())]),
+            Some(&rep_levels[accu_len..(accu_len + len)]),
             array_reader.get_rep_levels()
         );
-        accu_len += array.len();
+        accu_len += len;
+        array_reader.consume_batch().unwrap();
 
         // Read next values_per_page values, the first values_per_page/2 ones 
are from the first column chunk,
         // and the last values_per_page/2 ones are from the second column chunk
         // println!("---------- reading a batch of {} values ----------", 
values_per_page);
-        let array = array_reader.next_batch(values_per_page).unwrap();
-        assert_eq!(array.len(), values_per_page);
+        //let array = array_reader.next_batch(values_per_page).unwrap();
+        let len = array_reader.read_records(values_per_page).unwrap();
+        assert_eq!(len, values_per_page);
         assert_eq!(
-            Some(&def_levels[accu_len..(accu_len + array.len())]),
+            Some(&def_levels[accu_len..(accu_len + len)]),
             array_reader.get_def_levels()
         );
         assert_eq!(
-            Some(&rep_levels[accu_len..(accu_len + array.len())]),
+            Some(&rep_levels[accu_len..(accu_len + len)]),
             array_reader.get_rep_levels()
         );
+        let array = array_reader.consume_batch().unwrap();
         let strings = array.as_any().downcast_ref::<StringArray>().unwrap();
         for i in 0..array.len() {
             if array.is_valid(i) {
@@ -527,19 +581,20 @@ mod tests {
                 assert_eq!(all_values[i + accu_len], None)
             }
         }
-        accu_len += array.len();
+        accu_len += len;
 
         // Try to read values_per_page values, however there are only 
values_per_page/2 values
         // println!("---------- reading a batch of {} values ----------", 
values_per_page);
-        let array = array_reader.next_batch(values_per_page).unwrap();
-        assert_eq!(array.len(), values_per_page / 2);
+        let len = array_reader.read_records(values_per_page).unwrap();
+        assert_eq!(len, values_per_page / 2);
         assert_eq!(
-            Some(&def_levels[accu_len..(accu_len + array.len())]),
+            Some(&def_levels[accu_len..(accu_len + len)]),
             array_reader.get_def_levels()
         );
         assert_eq!(
-            Some(&rep_levels[accu_len..(accu_len + array.len())]),
+            Some(&rep_levels[accu_len..(accu_len + len)]),
             array_reader.get_rep_levels()
         );
+        array_reader.consume_batch().unwrap();
     }
 }
diff --git a/parquet/src/arrow/array_reader/empty_array.rs 
b/parquet/src/arrow/array_reader/empty_array.rs
index b06646cc1..abe839b9d 100644
--- a/parquet/src/arrow/array_reader/empty_array.rs
+++ b/parquet/src/arrow/array_reader/empty_array.rs
@@ -33,6 +33,7 @@ pub fn make_empty_array_reader(row_count: usize) -> Box<dyn 
ArrayReader> {
 struct EmptyArrayReader {
     data_type: ArrowType,
     remaining_rows: usize,
+    need_consume_records: usize,
 }
 
 impl EmptyArrayReader {
@@ -40,6 +41,7 @@ impl EmptyArrayReader {
         Self {
             data_type: ArrowType::Struct(vec![]),
             remaining_rows: row_count,
+            need_consume_records: 0,
         }
     }
 }
@@ -53,15 +55,19 @@ impl ArrayReader for EmptyArrayReader {
         &self.data_type
     }
 
-    fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
+    fn read_records(&mut self, batch_size: usize) -> Result<usize> {
         let len = self.remaining_rows.min(batch_size);
         self.remaining_rows -= len;
+        self.need_consume_records += len;
+        Ok(len)
+    }
 
+    fn consume_batch(&mut self) -> Result<ArrayRef> {
         let data = ArrayDataBuilder::new(self.data_type.clone())
-            .len(len)
+            .len(self.need_consume_records)
             .build()
             .unwrap();
-
+        self.need_consume_records = 0;
         Ok(Arc::new(StructArray::from(data)))
     }
 
diff --git a/parquet/src/arrow/array_reader/list_array.rs 
b/parquet/src/arrow/array_reader/list_array.rs
index 33bd9772a..c245c6131 100644
--- a/parquet/src/arrow/array_reader/list_array.rs
+++ b/parquet/src/arrow/array_reader/list_array.rs
@@ -78,9 +78,13 @@ impl<OffsetSize: OffsetSizeTrait> ArrayReader for 
ListArrayReader<OffsetSize> {
         &self.data_type
     }
 
-    fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
-        let next_batch_array = self.item_reader.next_batch(batch_size)?;
+    fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+        let size = self.item_reader.read_records(batch_size)?;
+        Ok(size)
+    }
 
+    fn consume_batch(&mut self) -> Result<ArrayRef> {
+        let next_batch_array = self.item_reader.consume_batch()?;
         if next_batch_array.len() == 0 {
             return Ok(new_empty_array(&self.data_type));
         }
diff --git a/parquet/src/arrow/array_reader/map_array.rs 
b/parquet/src/arrow/array_reader/map_array.rs
index 00c3db41a..83ba63ca1 100644
--- a/parquet/src/arrow/array_reader/map_array.rs
+++ b/parquet/src/arrow/array_reader/map_array.rs
@@ -62,9 +62,21 @@ impl ArrayReader for MapArrayReader {
         &self.data_type
     }
 
-    fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
-        let key_array = self.key_reader.next_batch(batch_size)?;
-        let value_array = self.value_reader.next_batch(batch_size)?;
+    fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+        let key_len = self.key_reader.read_records(batch_size)?;
+        let value_len = self.value_reader.read_records(batch_size)?;
+        // Check that key and value have the same lengths
+        if key_len != value_len {
+            return Err(general_err!(
+                "Map key and value should have the same lengths."
+            ));
+        }
+        Ok(key_len)
+    }
+
+    fn consume_batch(&mut self) -> Result<ArrayRef> {
+        let key_array = self.key_reader.consume_batch()?;
+        let value_array = self.value_reader.consume_batch()?;
 
         // Check that key and value have the same lengths
         let key_length = key_array.len();
diff --git a/parquet/src/arrow/array_reader/mod.rs 
b/parquet/src/arrow/array_reader/mod.rs
index 8bdd6c071..d7665ef0f 100644
--- a/parquet/src/arrow/array_reader/mod.rs
+++ b/parquet/src/arrow/array_reader/mod.rs
@@ -62,7 +62,20 @@ pub trait ArrayReader: Send {
     fn get_data_type(&self) -> &ArrowType;
 
     /// Reads at most `batch_size` records into an arrow array and return it.
-    fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef>;
+    fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
+        self.read_records(batch_size)?;
+        self.consume_batch()
+    }
+
+    /// Reads at most `batch_size` records' bytes into buffer
+    ///
+    /// Returns the number of records read, which can be less than 
`batch_size` if
+    /// pages is exhausted.
+    fn read_records(&mut self, batch_size: usize) -> Result<usize>;
+
+    /// Consume all currently stored buffer data
+    /// into an arrow array and return it.
+    fn consume_batch(&mut self) -> Result<ArrayRef>;
 
     /// Skips over `num_records` records, returning the number of rows skipped
     fn skip_records(&mut self, num_records: usize) -> Result<usize>;
diff --git a/parquet/src/arrow/array_reader/null_array.rs 
b/parquet/src/arrow/array_reader/null_array.rs
index 63f73d41e..682d15f8a 100644
--- a/parquet/src/arrow/array_reader/null_array.rs
+++ b/parquet/src/arrow/array_reader/null_array.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::arrow::array_reader::{read_records, ArrayReader, skip_records};
+use crate::arrow::array_reader::{read_records, skip_records, ArrayReader};
 use crate::arrow::record_reader::buffer::ScalarValue;
 use crate::arrow::record_reader::RecordReader;
 use crate::column::page::PageIterator;
@@ -78,10 +78,11 @@ where
         &self.data_type
     }
 
-    /// Reads at most `batch_size` records into array.
-    fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
-        read_records(&mut self.record_reader, self.pages.as_mut(), 
batch_size)?;
+    fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+        read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)
+    }
 
+    fn consume_batch(&mut self) -> Result<ArrayRef> {
         // convert to arrays
         let array = 
arrow::array::NullArray::new(self.record_reader.num_values());
 
diff --git a/parquet/src/arrow/array_reader/primitive_array.rs 
b/parquet/src/arrow/array_reader/primitive_array.rs
index 89f2ce51b..59526f093 100644
--- a/parquet/src/arrow/array_reader/primitive_array.rs
+++ b/parquet/src/arrow/array_reader/primitive_array.rs
@@ -95,10 +95,11 @@ where
         &self.data_type
     }
 
-    /// Reads at most `batch_size` records into array.
-    fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
-        read_records(&mut self.record_reader, self.pages.as_mut(), 
batch_size)?;
+    fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+        read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)
+    }
 
+    fn consume_batch(&mut self) -> Result<ArrayRef> {
         let target_type = self.get_data_type().clone();
         let arrow_data_type = match T::get_physical_type() {
             PhysicalType::BOOLEAN => ArrowType::Boolean,
diff --git a/parquet/src/arrow/array_reader/struct_array.rs 
b/parquet/src/arrow/array_reader/struct_array.rs
index 602c598f8..b333c66cb 100644
--- a/parquet/src/arrow/array_reader/struct_array.rs
+++ b/parquet/src/arrow/array_reader/struct_array.rs
@@ -63,7 +63,27 @@ impl ArrayReader for StructArrayReader {
         &self.data_type
     }
 
-    /// Read `batch_size` struct records.
+    fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+        let mut read = None;
+        for child in self.children.iter_mut() {
+            let child_read = child.read_records(batch_size)?;
+            match read {
+                Some(expected) => {
+                    if expected != child_read {
+                        return Err(general_err!(
+                            "StructArrayReader out of sync in read_records, 
expected {} skipped, got {}",
+                            expected,
+                            child_read
+                        ));
+                    }
+                }
+                None => read = Some(child_read),
+            }
+        }
+        Ok(read.unwrap_or(0))
+    }
+
+    /// Consume struct records.
     ///
     /// Definition levels of struct array is calculated as following:
     /// ```ignore
@@ -80,7 +100,8 @@ impl ArrayReader for StructArrayReader {
     /// ```ignore
     /// null_bitmap[i] = (def_levels[i] >= self.def_level);
     /// ```
-    fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
+    ///
+    fn consume_batch(&mut self) -> Result<ArrayRef> {
         if self.children.is_empty() {
             return Ok(Arc::new(StructArray::from(Vec::new())));
         }
@@ -88,7 +109,7 @@ impl ArrayReader for StructArrayReader {
         let children_array = self
             .children
             .iter_mut()
-            .map(|reader| reader.next_batch(batch_size))
+            .map(|reader| reader.consume_batch())
             .collect::<Result<Vec<_>>>()?;
 
         // check that array child data has same size
diff --git a/parquet/src/arrow/array_reader/test_util.rs 
b/parquet/src/arrow/array_reader/test_util.rs
index 04c0f6c68..da9b8d3bf 100644
--- a/parquet/src/arrow/array_reader/test_util.rs
+++ b/parquet/src/arrow/array_reader/test_util.rs
@@ -101,6 +101,7 @@ pub struct InMemoryArrayReader {
     rep_levels: Option<Vec<i16>>,
     last_idx: usize,
     cur_idx: usize,
+    need_consume_records: usize,
 }
 
 impl InMemoryArrayReader {
@@ -127,6 +128,7 @@ impl InMemoryArrayReader {
             rep_levels,
             cur_idx: 0,
             last_idx: 0,
+            need_consume_records: 0,
         }
     }
 }
@@ -140,7 +142,7 @@ impl ArrayReader for InMemoryArrayReader {
         &self.data_type
     }
 
-    fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
+    fn read_records(&mut self, batch_size: usize) -> Result<usize> {
         assert_ne!(batch_size, 0);
         // This replicates the logical normally performed by
         // RecordReader to delimit semantic records
@@ -164,10 +166,17 @@ impl ArrayReader for InMemoryArrayReader {
             }
             None => batch_size.min(self.array.len() - self.cur_idx),
         };
+        self.need_consume_records += read;
+        Ok(read)
+    }
 
+    fn consume_batch(&mut self) -> Result<ArrayRef> {
+        let batch_size = self.need_consume_records;
+        assert_ne!(batch_size, 0);
         self.last_idx = self.cur_idx;
-        self.cur_idx += read;
-        Ok(self.array.slice(self.last_idx, read))
+        self.cur_idx += batch_size;
+        self.need_consume_records = 0;
+        Ok(self.array.slice(self.last_idx, batch_size))
     }
 
     fn skip_records(&mut self, num_records: usize) -> Result<usize> {
diff --git a/parquet/src/arrow/arrow_reader.rs 
b/parquet/src/arrow/arrow_reader.rs
index 26305cd41..3cd5cb9d4 100644
--- a/parquet/src/arrow/arrow_reader.rs
+++ b/parquet/src/arrow/arrow_reader.rs
@@ -769,6 +769,44 @@ mod tests {
         assert_eq!(&written.slice(6, 2), &read[2]);
     }
 
+    #[test]
+    fn test_int32_nullable_struct() {
+        let int32 = Int32Array::from_iter_values([1, 2, 3, 4, 5, 6, 7, 8]);
+        let data = ArrayDataBuilder::new(ArrowDataType::Struct(vec![Field::new(
+            "int32",
+            int32.data_type().clone(),
+            false,
+        )]))
+        .len(8)
+        .null_bit_buffer(Some(Buffer::from(&[0b11101111])))
+        .child_data(vec![int32.into_data()])
+        .build()
+        .unwrap();
+
+        let written = RecordBatch::try_from_iter([(
+            "struct",
+            Arc::new(StructArray::from(data)) as ArrayRef,
+        )])
+        .unwrap();
+
+        let mut buffer = Vec::with_capacity(1024);
+        let mut writer =
+            ArrowWriter::try_new(&mut buffer, written.schema(), None).unwrap();
+        writer.write(&written).unwrap();
+        writer.close().unwrap();
+
+        let read = ParquetFileArrowReader::try_new(Bytes::from(buffer))
+            .unwrap()
+            .get_record_reader(3)
+            .unwrap()
+            .collect::<ArrowResult<Vec<_>>>()
+            .unwrap();
+
+        assert_eq!(&written.slice(0, 3), &read[0]);
+        assert_eq!(&written.slice(3, 3), &read[1]);
+        assert_eq!(&written.slice(6, 2), &read[2]);
+    }
+
     #[test]
     #[ignore] // https://github.com/apache/arrow-rs/issues/2253
     fn test_decimal_list() {
diff --git a/parquet/src/arrow/arrow_writer/mod.rs 
b/parquet/src/arrow/arrow_writer/mod.rs
index 1c95fcc27..49531d972 100644
--- a/parquet/src/arrow/arrow_writer/mod.rs
+++ b/parquet/src/arrow/arrow_writer/mod.rs
@@ -1161,7 +1161,7 @@ mod tests {
             Some(props),
         )
         .expect("Unable to write file");
-        writer.write(&expected_batch).unwrap();
+        writer.write(expected_batch).unwrap();
         writer.close().unwrap();
 
         let mut arrow_reader =

Reply via email to