This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 1f9973cd6 Separate ArrayReader::next_batch with read_records and
consume_batch (#2237)
1f9973cd6 is described below
commit 1f9973cd6cdde73883da50d9417126390437ef76
Author: Yang Jiang <[email protected]>
AuthorDate: Wed Aug 3 15:50:52 2022 +0800
Separate ArrayReader::next_batch with read_records and consume_batch (#2237)
* replace ArrayReader::next_batch with ArrayReader::read_records and
ArrayReader::consume_batch.
* fix ut
* fix comment
* avoid clone.
* fix new ut
* fix comment
Co-authored-by: Raphael Taylor-Davies <[email protected]>
---
parquet/src/arrow/array_reader/byte_array.rs | 7 +-
.../arrow/array_reader/byte_array_dictionary.rs | 9 +-
.../src/arrow/array_reader/complex_object_array.rs | 133 +++++++++++++++------
parquet/src/arrow/array_reader/empty_array.rs | 12 +-
parquet/src/arrow/array_reader/list_array.rs | 8 +-
parquet/src/arrow/array_reader/map_array.rs | 18 ++-
parquet/src/arrow/array_reader/mod.rs | 15 ++-
parquet/src/arrow/array_reader/null_array.rs | 9 +-
parquet/src/arrow/array_reader/primitive_array.rs | 7 +-
parquet/src/arrow/array_reader/struct_array.rs | 27 ++++-
parquet/src/arrow/array_reader/test_util.rs | 15 ++-
parquet/src/arrow/arrow_reader.rs | 38 ++++++
parquet/src/arrow/arrow_writer/mod.rs | 2 +-
13 files changed, 233 insertions(+), 67 deletions(-)
diff --git a/parquet/src/arrow/array_reader/byte_array.rs
b/parquet/src/arrow/array_reader/byte_array.rs
index ec4188890..172aeb96d 100644
--- a/parquet/src/arrow/array_reader/byte_array.rs
+++ b/parquet/src/arrow/array_reader/byte_array.rs
@@ -108,8 +108,11 @@ impl<I: OffsetSizeTrait + ScalarValue> ArrayReader for
ByteArrayReader<I> {
&self.data_type
}
- fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
- read_records(&mut self.record_reader, self.pages.as_mut(),
batch_size)?;
+ fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+ read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)
+ }
+
+ fn consume_batch(&mut self) -> Result<ArrayRef> {
let buffer = self.record_reader.consume_record_data();
let null_buffer = self.record_reader.consume_bitmap_buffer();
self.def_levels_buffer = self.record_reader.consume_def_levels();
diff --git a/parquet/src/arrow/array_reader/byte_array_dictionary.rs
b/parquet/src/arrow/array_reader/byte_array_dictionary.rs
index 51ef38d0d..0a5d94fa6 100644
--- a/parquet/src/arrow/array_reader/byte_array_dictionary.rs
+++ b/parquet/src/arrow/array_reader/byte_array_dictionary.rs
@@ -25,7 +25,7 @@ use arrow::buffer::Buffer;
use arrow::datatypes::{ArrowNativeType, DataType as ArrowType};
use crate::arrow::array_reader::byte_array::{ByteArrayDecoder,
ByteArrayDecoderPlain};
-use crate::arrow::array_reader::{read_records, ArrayReader, skip_records};
+use crate::arrow::array_reader::{read_records, skip_records, ArrayReader};
use crate::arrow::buffer::{
dictionary_buffer::DictionaryBuffer, offset_buffer::OffsetBuffer,
};
@@ -167,8 +167,11 @@ where
&self.data_type
}
- fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
- read_records(&mut self.record_reader, self.pages.as_mut(),
batch_size)?;
+ fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+ read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)
+ }
+
+ fn consume_batch(&mut self) -> Result<ArrayRef> {
let buffer = self.record_reader.consume_record_data();
let null_buffer = self.record_reader.consume_bitmap_buffer();
let array = buffer.into_array(null_buffer, &self.data_type)?;
diff --git a/parquet/src/arrow/array_reader/complex_object_array.rs
b/parquet/src/arrow/array_reader/complex_object_array.rs
index 1390866cf..79b537331 100644
--- a/parquet/src/arrow/array_reader/complex_object_array.rs
+++ b/parquet/src/arrow/array_reader/complex_object_array.rs
@@ -39,9 +39,13 @@ where
pages: Box<dyn PageIterator>,
def_levels_buffer: Option<Vec<i16>>,
rep_levels_buffer: Option<Vec<i16>>,
+ data_buffer: Vec<T::T>,
column_desc: ColumnDescPtr,
column_reader: Option<ColumnReaderImpl<T>>,
converter: C,
+ in_progress_def_levels_buffer: Option<Vec<i16>>,
+ in_progress_rep_levels_buffer: Option<Vec<i16>>,
+ before_consume: bool,
_parquet_type_marker: PhantomData<T>,
_converter_marker: PhantomData<C>,
}
@@ -59,7 +63,10 @@ where
&self.data_type
}
- fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
+ fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+ if !self.before_consume {
+ self.before_consume = true;
+ }
// Try to initialize column reader
if self.column_reader.is_none() {
self.next_column_reader()?;
@@ -126,7 +133,6 @@ where
break;
}
}
-
data_buffer.truncate(num_read);
def_levels_buffer
.iter_mut()
@@ -135,13 +141,35 @@ where
.iter_mut()
.for_each(|buf| buf.truncate(num_read));
- self.def_levels_buffer = def_levels_buffer;
- self.rep_levels_buffer = rep_levels_buffer;
+ if let Some(mut def_levels_buffer) = def_levels_buffer {
+ match &mut self.in_progress_def_levels_buffer {
+ None => {
+ self.in_progress_def_levels_buffer =
Some(def_levels_buffer);
+ }
+ Some(buf) => buf.append(&mut def_levels_buffer),
+ }
+ }
+
+ if let Some(mut rep_levels_buffer) = rep_levels_buffer {
+ match &mut self.in_progress_rep_levels_buffer {
+ None => {
+ self.in_progress_rep_levels_buffer =
Some(rep_levels_buffer);
+ }
+ Some(buf) => buf.append(&mut rep_levels_buffer),
+ }
+ }
+
+ self.data_buffer.append(&mut data_buffer);
+
+ Ok(num_read)
+ }
- let data: Vec<Option<T::T>> = if self.def_levels_buffer.is_some() {
+ fn consume_batch(&mut self) -> Result<ArrayRef> {
+ let data: Vec<Option<T::T>> = if
self.in_progress_def_levels_buffer.is_some() {
+ let data_buffer = std::mem::take(&mut self.data_buffer);
data_buffer
.into_iter()
- .zip(self.def_levels_buffer.as_ref().unwrap().iter())
+
.zip(self.in_progress_def_levels_buffer.as_ref().unwrap().iter())
.map(|(t, def_level)| {
if *def_level == self.column_desc.max_def_level() {
Some(t)
@@ -151,7 +179,7 @@ where
})
.collect()
} else {
- data_buffer.into_iter().map(Some).collect()
+ self.data_buffer.iter().map(|x| Some(x.clone())).collect()
};
let mut array = self.converter.convert(data)?;
@@ -160,6 +188,11 @@ where
array = arrow::compute::cast(&array, &self.data_type)?;
}
+ self.data_buffer = vec![];
+ self.def_levels_buffer = std::mem::take(&mut
self.in_progress_def_levels_buffer);
+ self.rep_levels_buffer = std::mem::take(&mut
self.in_progress_rep_levels_buffer);
+ self.before_consume = false;
+
Ok(array)
}
@@ -168,8 +201,11 @@ where
Some(reader) => reader.skip_records(num_records),
None => {
if self.next_column_reader()? {
-
self.column_reader.as_mut().unwrap().skip_records(num_records)
- }else {
+ self.column_reader
+ .as_mut()
+ .unwrap()
+ .skip_records(num_records)
+ } else {
Ok(0)
}
}
@@ -177,11 +213,19 @@ where
}
fn get_def_levels(&self) -> Option<&[i16]> {
- self.def_levels_buffer.as_deref()
+ if self.before_consume {
+ self.in_progress_def_levels_buffer.as_deref()
+ } else {
+ self.def_levels_buffer.as_deref()
+ }
}
fn get_rep_levels(&self) -> Option<&[i16]> {
- self.rep_levels_buffer.as_deref()
+ if self.before_consume {
+ self.in_progress_rep_levels_buffer.as_deref()
+ } else {
+ self.rep_levels_buffer.as_deref()
+ }
}
}
@@ -208,9 +252,13 @@ where
pages,
def_levels_buffer: None,
rep_levels_buffer: None,
+ data_buffer: vec![],
column_desc,
column_reader: None,
converter,
+ in_progress_def_levels_buffer: None,
+ in_progress_rep_levels_buffer: None,
+ before_consume: true,
_parquet_type_marker: PhantomData,
_converter_marker: PhantomData,
})
@@ -349,30 +397,32 @@ mod tests {
let mut accu_len: usize = 0;
- let array = array_reader.next_batch(values_per_page / 2).unwrap();
- assert_eq!(array.len(), values_per_page / 2);
+ let len = array_reader.read_records(values_per_page / 2).unwrap();
+ assert_eq!(len, values_per_page / 2);
assert_eq!(
- Some(&def_levels[accu_len..(accu_len + array.len())]),
+ Some(&def_levels[accu_len..(accu_len + len)]),
array_reader.get_def_levels()
);
assert_eq!(
- Some(&rep_levels[accu_len..(accu_len + array.len())]),
+ Some(&rep_levels[accu_len..(accu_len + len)]),
array_reader.get_rep_levels()
);
- accu_len += array.len();
+ accu_len += len;
+ array_reader.consume_batch().unwrap();
// Read next values_per_page values, the first values_per_page/2 ones
are from the first column chunk,
// and the last values_per_page/2 ones are from the second column chunk
- let array = array_reader.next_batch(values_per_page).unwrap();
- assert_eq!(array.len(), values_per_page);
+ let len = array_reader.read_records(values_per_page).unwrap();
+ assert_eq!(len, values_per_page);
assert_eq!(
- Some(&def_levels[accu_len..(accu_len + array.len())]),
+ Some(&def_levels[accu_len..(accu_len + len)]),
array_reader.get_def_levels()
);
assert_eq!(
- Some(&rep_levels[accu_len..(accu_len + array.len())]),
+ Some(&rep_levels[accu_len..(accu_len + len)]),
array_reader.get_rep_levels()
);
+ let array = array_reader.consume_batch().unwrap();
let strings = array.as_any().downcast_ref::<StringArray>().unwrap();
for i in 0..array.len() {
if array.is_valid(i) {
@@ -384,19 +434,20 @@ mod tests {
assert_eq!(all_values[i + accu_len], None)
}
}
- accu_len += array.len();
+ accu_len += len;
// Try to read values_per_page values, however there are only
values_per_page/2 values
- let array = array_reader.next_batch(values_per_page).unwrap();
- assert_eq!(array.len(), values_per_page / 2);
+ let len = array_reader.read_records(values_per_page).unwrap();
+ assert_eq!(len, values_per_page / 2);
assert_eq!(
- Some(&def_levels[accu_len..(accu_len + array.len())]),
+ Some(&def_levels[accu_len..(accu_len + len)]),
array_reader.get_def_levels()
);
assert_eq!(
- Some(&rep_levels[accu_len..(accu_len + array.len())]),
+ Some(&rep_levels[accu_len..(accu_len + len)]),
array_reader.get_rep_levels()
);
+ array_reader.consume_batch().unwrap();
}
#[test]
@@ -491,31 +542,34 @@ mod tests {
let mut accu_len: usize = 0;
// println!("---------- reading a batch of {} values ----------",
values_per_page / 2);
- let array = array_reader.next_batch(values_per_page / 2).unwrap();
- assert_eq!(array.len(), values_per_page / 2);
+ let len = array_reader.read_records(values_per_page / 2).unwrap();
+ assert_eq!(len, values_per_page / 2);
assert_eq!(
- Some(&def_levels[accu_len..(accu_len + array.len())]),
+ Some(&def_levels[accu_len..(accu_len + len)]),
array_reader.get_def_levels()
);
assert_eq!(
- Some(&rep_levels[accu_len..(accu_len + array.len())]),
+ Some(&rep_levels[accu_len..(accu_len + len)]),
array_reader.get_rep_levels()
);
- accu_len += array.len();
+ accu_len += len;
+ array_reader.consume_batch().unwrap();
// Read next values_per_page values, the first values_per_page/2 ones
are from the first column chunk,
// and the last values_per_page/2 ones are from the second column chunk
// println!("---------- reading a batch of {} values ----------",
values_per_page);
- let array = array_reader.next_batch(values_per_page).unwrap();
- assert_eq!(array.len(), values_per_page);
+ //let array = array_reader.next_batch(values_per_page).unwrap();
+ let len = array_reader.read_records(values_per_page).unwrap();
+ assert_eq!(len, values_per_page);
assert_eq!(
- Some(&def_levels[accu_len..(accu_len + array.len())]),
+ Some(&def_levels[accu_len..(accu_len + len)]),
array_reader.get_def_levels()
);
assert_eq!(
- Some(&rep_levels[accu_len..(accu_len + array.len())]),
+ Some(&rep_levels[accu_len..(accu_len + len)]),
array_reader.get_rep_levels()
);
+ let array = array_reader.consume_batch().unwrap();
let strings = array.as_any().downcast_ref::<StringArray>().unwrap();
for i in 0..array.len() {
if array.is_valid(i) {
@@ -527,19 +581,20 @@ mod tests {
assert_eq!(all_values[i + accu_len], None)
}
}
- accu_len += array.len();
+ accu_len += len;
// Try to read values_per_page values, however there are only
values_per_page/2 values
// println!("---------- reading a batch of {} values ----------",
values_per_page);
- let array = array_reader.next_batch(values_per_page).unwrap();
- assert_eq!(array.len(), values_per_page / 2);
+ let len = array_reader.read_records(values_per_page).unwrap();
+ assert_eq!(len, values_per_page / 2);
assert_eq!(
- Some(&def_levels[accu_len..(accu_len + array.len())]),
+ Some(&def_levels[accu_len..(accu_len + len)]),
array_reader.get_def_levels()
);
assert_eq!(
- Some(&rep_levels[accu_len..(accu_len + array.len())]),
+ Some(&rep_levels[accu_len..(accu_len + len)]),
array_reader.get_rep_levels()
);
+ array_reader.consume_batch().unwrap();
}
}
diff --git a/parquet/src/arrow/array_reader/empty_array.rs
b/parquet/src/arrow/array_reader/empty_array.rs
index b06646cc1..abe839b9d 100644
--- a/parquet/src/arrow/array_reader/empty_array.rs
+++ b/parquet/src/arrow/array_reader/empty_array.rs
@@ -33,6 +33,7 @@ pub fn make_empty_array_reader(row_count: usize) -> Box<dyn
ArrayReader> {
struct EmptyArrayReader {
data_type: ArrowType,
remaining_rows: usize,
+ need_consume_records: usize,
}
impl EmptyArrayReader {
@@ -40,6 +41,7 @@ impl EmptyArrayReader {
Self {
data_type: ArrowType::Struct(vec![]),
remaining_rows: row_count,
+ need_consume_records: 0,
}
}
}
@@ -53,15 +55,19 @@ impl ArrayReader for EmptyArrayReader {
&self.data_type
}
- fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
+ fn read_records(&mut self, batch_size: usize) -> Result<usize> {
let len = self.remaining_rows.min(batch_size);
self.remaining_rows -= len;
+ self.need_consume_records += len;
+ Ok(len)
+ }
+ fn consume_batch(&mut self) -> Result<ArrayRef> {
let data = ArrayDataBuilder::new(self.data_type.clone())
- .len(len)
+ .len(self.need_consume_records)
.build()
.unwrap();
-
+ self.need_consume_records = 0;
Ok(Arc::new(StructArray::from(data)))
}
diff --git a/parquet/src/arrow/array_reader/list_array.rs
b/parquet/src/arrow/array_reader/list_array.rs
index 33bd9772a..c245c6131 100644
--- a/parquet/src/arrow/array_reader/list_array.rs
+++ b/parquet/src/arrow/array_reader/list_array.rs
@@ -78,9 +78,13 @@ impl<OffsetSize: OffsetSizeTrait> ArrayReader for
ListArrayReader<OffsetSize> {
&self.data_type
}
- fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
- let next_batch_array = self.item_reader.next_batch(batch_size)?;
+ fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+ let size = self.item_reader.read_records(batch_size)?;
+ Ok(size)
+ }
+ fn consume_batch(&mut self) -> Result<ArrayRef> {
+ let next_batch_array = self.item_reader.consume_batch()?;
if next_batch_array.len() == 0 {
return Ok(new_empty_array(&self.data_type));
}
diff --git a/parquet/src/arrow/array_reader/map_array.rs
b/parquet/src/arrow/array_reader/map_array.rs
index 00c3db41a..83ba63ca1 100644
--- a/parquet/src/arrow/array_reader/map_array.rs
+++ b/parquet/src/arrow/array_reader/map_array.rs
@@ -62,9 +62,21 @@ impl ArrayReader for MapArrayReader {
&self.data_type
}
- fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
- let key_array = self.key_reader.next_batch(batch_size)?;
- let value_array = self.value_reader.next_batch(batch_size)?;
+ fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+ let key_len = self.key_reader.read_records(batch_size)?;
+ let value_len = self.value_reader.read_records(batch_size)?;
+ // Check that key and value have the same lengths
+ if key_len != value_len {
+ return Err(general_err!(
+ "Map key and value should have the same lengths."
+ ));
+ }
+ Ok(key_len)
+ }
+
+ fn consume_batch(&mut self) -> Result<ArrayRef> {
+ let key_array = self.key_reader.consume_batch()?;
+ let value_array = self.value_reader.consume_batch()?;
// Check that key and value have the same lengths
let key_length = key_array.len();
diff --git a/parquet/src/arrow/array_reader/mod.rs
b/parquet/src/arrow/array_reader/mod.rs
index 8bdd6c071..d7665ef0f 100644
--- a/parquet/src/arrow/array_reader/mod.rs
+++ b/parquet/src/arrow/array_reader/mod.rs
@@ -62,7 +62,20 @@ pub trait ArrayReader: Send {
fn get_data_type(&self) -> &ArrowType;
/// Reads at most `batch_size` records into an arrow array and return it.
- fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef>;
+ fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
+ self.read_records(batch_size)?;
+ self.consume_batch()
+ }
+
+ /// Reads at most `batch_size` records' bytes into buffer
+ ///
+ /// Returns the number of records read, which can be less than
`batch_size` if
+ /// pages is exhausted.
+ fn read_records(&mut self, batch_size: usize) -> Result<usize>;
+
+ /// Consume all currently stored buffer data
+ /// into an arrow array and return it.
+ fn consume_batch(&mut self) -> Result<ArrayRef>;
/// Skips over `num_records` records, returning the number of rows skipped
fn skip_records(&mut self, num_records: usize) -> Result<usize>;
diff --git a/parquet/src/arrow/array_reader/null_array.rs
b/parquet/src/arrow/array_reader/null_array.rs
index 63f73d41e..682d15f8a 100644
--- a/parquet/src/arrow/array_reader/null_array.rs
+++ b/parquet/src/arrow/array_reader/null_array.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use crate::arrow::array_reader::{read_records, ArrayReader, skip_records};
+use crate::arrow::array_reader::{read_records, skip_records, ArrayReader};
use crate::arrow::record_reader::buffer::ScalarValue;
use crate::arrow::record_reader::RecordReader;
use crate::column::page::PageIterator;
@@ -78,10 +78,11 @@ where
&self.data_type
}
- /// Reads at most `batch_size` records into array.
- fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
- read_records(&mut self.record_reader, self.pages.as_mut(),
batch_size)?;
+ fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+ read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)
+ }
+ fn consume_batch(&mut self) -> Result<ArrayRef> {
// convert to arrays
let array =
arrow::array::NullArray::new(self.record_reader.num_values());
diff --git a/parquet/src/arrow/array_reader/primitive_array.rs
b/parquet/src/arrow/array_reader/primitive_array.rs
index 89f2ce51b..59526f093 100644
--- a/parquet/src/arrow/array_reader/primitive_array.rs
+++ b/parquet/src/arrow/array_reader/primitive_array.rs
@@ -95,10 +95,11 @@ where
&self.data_type
}
- /// Reads at most `batch_size` records into array.
- fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
- read_records(&mut self.record_reader, self.pages.as_mut(),
batch_size)?;
+ fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+ read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)
+ }
+ fn consume_batch(&mut self) -> Result<ArrayRef> {
let target_type = self.get_data_type().clone();
let arrow_data_type = match T::get_physical_type() {
PhysicalType::BOOLEAN => ArrowType::Boolean,
diff --git a/parquet/src/arrow/array_reader/struct_array.rs
b/parquet/src/arrow/array_reader/struct_array.rs
index 602c598f8..b333c66cb 100644
--- a/parquet/src/arrow/array_reader/struct_array.rs
+++ b/parquet/src/arrow/array_reader/struct_array.rs
@@ -63,7 +63,27 @@ impl ArrayReader for StructArrayReader {
&self.data_type
}
- /// Read `batch_size` struct records.
+ fn read_records(&mut self, batch_size: usize) -> Result<usize> {
+ let mut read = None;
+ for child in self.children.iter_mut() {
+ let child_read = child.read_records(batch_size)?;
+ match read {
+ Some(expected) => {
+ if expected != child_read {
+ return Err(general_err!(
+ "StructArrayReader out of sync in read_records,
expected {} skipped, got {}",
+ expected,
+ child_read
+ ));
+ }
+ }
+ None => read = Some(child_read),
+ }
+ }
+ Ok(read.unwrap_or(0))
+ }
+
+ /// Consume struct records.
///
/// Definition levels of struct array is calculated as following:
/// ```ignore
@@ -80,7 +100,8 @@ impl ArrayReader for StructArrayReader {
/// ```ignore
/// null_bitmap[i] = (def_levels[i] >= self.def_level);
/// ```
- fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
+ ///
+ fn consume_batch(&mut self) -> Result<ArrayRef> {
if self.children.is_empty() {
return Ok(Arc::new(StructArray::from(Vec::new())));
}
@@ -88,7 +109,7 @@ impl ArrayReader for StructArrayReader {
let children_array = self
.children
.iter_mut()
- .map(|reader| reader.next_batch(batch_size))
+ .map(|reader| reader.consume_batch())
.collect::<Result<Vec<_>>>()?;
// check that array child data has same size
diff --git a/parquet/src/arrow/array_reader/test_util.rs
b/parquet/src/arrow/array_reader/test_util.rs
index 04c0f6c68..da9b8d3bf 100644
--- a/parquet/src/arrow/array_reader/test_util.rs
+++ b/parquet/src/arrow/array_reader/test_util.rs
@@ -101,6 +101,7 @@ pub struct InMemoryArrayReader {
rep_levels: Option<Vec<i16>>,
last_idx: usize,
cur_idx: usize,
+ need_consume_records: usize,
}
impl InMemoryArrayReader {
@@ -127,6 +128,7 @@ impl InMemoryArrayReader {
rep_levels,
cur_idx: 0,
last_idx: 0,
+ need_consume_records: 0,
}
}
}
@@ -140,7 +142,7 @@ impl ArrayReader for InMemoryArrayReader {
&self.data_type
}
- fn next_batch(&mut self, batch_size: usize) -> Result<ArrayRef> {
+ fn read_records(&mut self, batch_size: usize) -> Result<usize> {
assert_ne!(batch_size, 0);
// This replicates the logical normally performed by
// RecordReader to delimit semantic records
@@ -164,10 +166,17 @@ impl ArrayReader for InMemoryArrayReader {
}
None => batch_size.min(self.array.len() - self.cur_idx),
};
+ self.need_consume_records += read;
+ Ok(read)
+ }
+ fn consume_batch(&mut self) -> Result<ArrayRef> {
+ let batch_size = self.need_consume_records;
+ assert_ne!(batch_size, 0);
self.last_idx = self.cur_idx;
- self.cur_idx += read;
- Ok(self.array.slice(self.last_idx, read))
+ self.cur_idx += batch_size;
+ self.need_consume_records = 0;
+ Ok(self.array.slice(self.last_idx, batch_size))
}
fn skip_records(&mut self, num_records: usize) -> Result<usize> {
diff --git a/parquet/src/arrow/arrow_reader.rs
b/parquet/src/arrow/arrow_reader.rs
index 26305cd41..3cd5cb9d4 100644
--- a/parquet/src/arrow/arrow_reader.rs
+++ b/parquet/src/arrow/arrow_reader.rs
@@ -769,6 +769,44 @@ mod tests {
assert_eq!(&written.slice(6, 2), &read[2]);
}
+ #[test]
+ fn test_int32_nullable_struct() {
+ let int32 = Int32Array::from_iter_values([1, 2, 3, 4, 5, 6, 7, 8]);
+ let data = ArrayDataBuilder::new(ArrowDataType::Struct(vec![Field::new(
+ "int32",
+ int32.data_type().clone(),
+ false,
+ )]))
+ .len(8)
+ .null_bit_buffer(Some(Buffer::from(&[0b11101111])))
+ .child_data(vec![int32.into_data()])
+ .build()
+ .unwrap();
+
+ let written = RecordBatch::try_from_iter([(
+ "struct",
+ Arc::new(StructArray::from(data)) as ArrayRef,
+ )])
+ .unwrap();
+
+ let mut buffer = Vec::with_capacity(1024);
+ let mut writer =
+ ArrowWriter::try_new(&mut buffer, written.schema(), None).unwrap();
+ writer.write(&written).unwrap();
+ writer.close().unwrap();
+
+ let read = ParquetFileArrowReader::try_new(Bytes::from(buffer))
+ .unwrap()
+ .get_record_reader(3)
+ .unwrap()
+ .collect::<ArrowResult<Vec<_>>>()
+ .unwrap();
+
+ assert_eq!(&written.slice(0, 3), &read[0]);
+ assert_eq!(&written.slice(3, 3), &read[1]);
+ assert_eq!(&written.slice(6, 2), &read[2]);
+ }
+
#[test]
#[ignore] // https://github.com/apache/arrow-rs/issues/2253
fn test_decimal_list() {
diff --git a/parquet/src/arrow/arrow_writer/mod.rs
b/parquet/src/arrow/arrow_writer/mod.rs
index 1c95fcc27..49531d972 100644
--- a/parquet/src/arrow/arrow_writer/mod.rs
+++ b/parquet/src/arrow/arrow_writer/mod.rs
@@ -1161,7 +1161,7 @@ mod tests {
Some(props),
)
.expect("Unable to write file");
- writer.write(&expected_batch).unwrap();
+ writer.write(expected_batch).unwrap();
writer.close().unwrap();
let mut arrow_reader =