XiangpengHao commented on code in PR #7850:
URL: https://github.com/apache/arrow-rs/pull/7850#discussion_r2180920730
##########
parquet/src/arrow/array_reader/cached_array_reader.rs:
##########
@@ -0,0 +1,621 @@
+use crate::arrow::array_reader::row_group_cache::BatchID;
+use crate::arrow::array_reader::{row_group_cache::RowGroupCache, ArrayReader};
+use crate::arrow::arrow_reader::RowSelector;
+use crate::errors::Result;
+use arrow_array::{new_empty_array, ArrayRef, BooleanArray};
+use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder};
+use arrow_schema::DataType as ArrowType;
+use std::any::Any;
+use std::collections::{HashMap, VecDeque};
+use std::sync::{Arc, Mutex};
+
+/// Role of the cached array reader
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum CacheRole {
+ /// Producer role: inserts data into the cache during filter phase
+ Producer,
+ /// Consumer role: removes consumed data from the cache during output
building phase
+ Consumer,
+}
+
+/// A cached wrapper around an ArrayReader that avoids duplicate decoding
+/// when the same column appears in both filter predicates and output
projection.
+///
+/// This reader acts as a transparent layer over the inner reader, using a
cache
+/// to avoid redundant work when the same data is needed multiple times.
+///
+/// The reader can operate in two roles:
+/// - Producer: During filter phase, inserts decoded data into the cache
+/// - Consumer: During output building, consumes and removes data from the
cache
+///
+/// This means the memory consumption of the cache has two stages:
+/// 1. During the filter phase, the memory increases as the cache is populated
+/// 2. It peaks when filters are built.
+/// 3. It decreases as the cached data is consumed.
+/// ▲
+/// │ ╭─╮
+/// │ ╱ ╲
+/// │ ╱ ╲
+/// │ ╱ ╲
+/// │ ╱ ╲
+/// │╱ ╲
+/// └─────────────╲──────► Time
+/// │ │ │
+/// Filter Peak Consume
+/// Phase (Built) (Decrease)
+pub struct CachedArrayReader {
+ /// The underlying array reader
+ inner: Box<dyn ArrayReader>,
+ /// Shared cache for this row group
+ cache: Arc<Mutex<RowGroupCache>>,
+ /// Column index for cache key generation
+ column_idx: usize,
+ /// Current logical position in the data stream (for cache key generation)
+ outer_position: usize,
+ /// Current position in the inner reader
+ inner_position: usize,
+ /// Batch size for the cache
+ batch_size: usize,
+ /// Selections to be applied to the next consume_batch()
+ selections: VecDeque<RowSelector>,
+ /// Role of this reader (Producer or Consumer)
+ role: CacheRole,
+ /// Local buffer to store batches between read_records and consume_batch
calls
+ /// This ensures data is available even if the shared cache evicts items
+ local_buffer: HashMap<BatchID, ArrayRef>,
+}
+
+impl CachedArrayReader {
+ /// Creates a new cached array reader with the specified role
+ pub fn new(
+ inner: Box<dyn ArrayReader>,
+ cache: Arc<Mutex<RowGroupCache>>,
+ column_idx: usize,
+ role: CacheRole,
+ ) -> Self {
+ let batch_size = cache.lock().unwrap().batch_size();
+
+ Self {
+ inner,
+ cache,
+ column_idx,
+ outer_position: 0,
+ inner_position: 0,
+ batch_size,
+ selections: VecDeque::new(),
+ role,
+ local_buffer: HashMap::new(),
+ }
+ }
+
+ fn get_batch_id_from_position(&self, row_id: usize) -> BatchID {
+ BatchID {
+ val: row_id / self.batch_size,
+ }
+ }
+
+ fn fetch_batch(&mut self, batch_id: BatchID) -> Result<usize> {
+ let row_id = batch_id.val * self.batch_size;
+ if self.inner_position < row_id {
+ let to_skip = row_id - self.inner_position;
+ let skipped = self.inner.skip_records(to_skip)?;
+ assert_eq!(skipped, to_skip);
+ self.inner_position += skipped;
+ }
+
+ let read = self.inner.read_records(self.batch_size)?;
+
+ // If there are no remaining records (EOF), return immediately without
+ // attempting to cache an empty batch. This prevents inserting
zero-length
+ // arrays into the cache which can later cause panics when slicing.
+ if read == 0 {
+ return Ok(0);
+ }
+
+ let array = self.inner.consume_batch()?;
+
+ // Store in both shared cache and local cache
+ // The shared cache is for coordination between readers
+ // The local cache ensures data is available for our consume_batch call
+ let _cached = self
+ .cache
+ .lock()
+ .unwrap()
+ .insert(self.column_idx, batch_id, array.clone());
+ // Note: if the shared cache is full (_cached == false), we continue
without caching
+ // The local cache will still store the data for this reader's use
+
+ self.local_buffer.insert(batch_id, array);
+
+ self.inner_position += read;
+ Ok(read)
+ }
+
+ /// Remove batches from cache that have been completely consumed
+ /// This is only called for Consumer role readers
+ fn cleanup_consumed_batches(&mut self) {
+ let current_batch_id =
self.get_batch_id_from_position(self.outer_position);
+
+ // Remove batches that are at least one batch behind the current
position
+ // This ensures we don't remove batches that might still be needed for
the current batch
+ // We can safely remove batch_id if current_batch_id > batch_id + 1
+ if current_batch_id.val > 1 {
+ let mut cache = self.cache.lock().unwrap();
+ for batch_id_to_remove in 0..(current_batch_id.val - 1) {
+ cache.remove(
+ self.column_idx,
+ BatchID {
+ val: batch_id_to_remove,
+ },
+ );
+ }
+ }
+ }
+}
+
+impl ArrayReader for CachedArrayReader {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn get_data_type(&self) -> &ArrowType {
+ self.inner.get_data_type()
+ }
+
+ fn read_records(&mut self, num_records: usize) -> Result<usize> {
+ let mut read = 0;
+ while read < num_records {
+ let batch_id = self.get_batch_id_from_position(self.outer_position
+ read);
+
+ // Check local cache first
+ let cached = if let Some(array) = self.local_buffer.get(&batch_id)
{
+ Some(array.clone())
+ } else {
+ // If not in local cache, check shared cache
+ let shared_cached =
self.cache.lock().unwrap().get(self.column_idx, batch_id);
+ if let Some(array) = shared_cached.as_ref() {
+ // Store in local cache for later use in consume_batch
+ self.local_buffer.insert(batch_id, array.clone());
+ }
+ shared_cached
+ };
+
+ match cached {
+ Some(array) => {
+ let array_len = array.len();
+ if array_len + batch_id.val * self.batch_size -
self.outer_position > 0 {
+ // the cache batch has some records that we can select
+ let v = array_len + batch_id.val * self.batch_size -
self.outer_position;
+ let select_cnt = std::cmp::min(num_records - read, v);
+ read += select_cnt;
+
self.selections.push_back(RowSelector::select(select_cnt));
+ } else {
+ // this is last batch and we have used all records
from it
+ break;
+ }
+ }
+ None => {
+ let read_from_inner = self.fetch_batch(batch_id)?;
+
+ // Reached end-of-file, no more records to read
+ if read_from_inner == 0 {
+ break;
+ }
+
+ let select_from_this_batch = std::cmp::min(num_records -
read, read_from_inner);
+ read += select_from_this_batch;
+ self.selections
+
.push_back(RowSelector::select(select_from_this_batch));
+ if read_from_inner < self.batch_size {
+ // this is last batch from inner reader
+ break;
+ }
+ }
+ }
+ }
+ self.outer_position += read;
+ Ok(read)
+ }
+
+ fn skip_records(&mut self, num_records: usize) -> Result<usize> {
+ let mut skipped = 0;
+ while skipped < num_records {
+ let size = std::cmp::min(num_records - skipped, self.batch_size);
+ skipped += size;
+ self.selections.push_back(RowSelector::skip(size));
+ self.outer_position += size;
+ }
+ Ok(num_records)
+ }
+
+ fn consume_batch(&mut self) -> Result<ArrayRef> {
+ let row_count = self.selections.iter().map(|s|
s.row_count).sum::<usize>();
+ if row_count == 0 {
+ return Ok(new_empty_array(self.inner.get_data_type()));
+ }
+
+ let start_position = self.outer_position - row_count;
+
+ let selection_buffer = row_selection_to_boolean_buffer(row_count,
self.selections.iter());
+
+ let start_batch = start_position / self.batch_size;
+ let end_batch = (start_position + row_count - 1) / self.batch_size;
+
+ let mut selected_arrays = Vec::new();
+ for batch_id in start_batch..=end_batch {
+ let batch_start = batch_id * self.batch_size;
+ let batch_end = batch_start + self.batch_size - 1;
+ let batch_id = self.get_batch_id_from_position(batch_start);
+
+ // Calculate the overlap between the start_position and the batch
+ let overlap_start = start_position.max(batch_start);
+ let overlap_end = (start_position + row_count - 1).min(batch_end);
+
+ if overlap_start > overlap_end {
+ continue;
+ }
+
+ let selection_start = overlap_start - start_position;
+ let selection_length = overlap_end - overlap_start + 1;
+ let mask = selection_buffer.slice(selection_start,
selection_length);
+
+ if mask.count_set_bits() == 0 {
+ continue;
+ }
+
+ let mask_array = BooleanArray::from(mask);
+ // Read from local cache instead of shared cache to avoid cache
eviction issues
+ let cached = self
+ .local_buffer
+ .get(&batch_id)
+ .expect("data must be already cached in the read_records call,
this is a bug");
+ let cached = cached.slice(overlap_start - batch_start,
selection_length);
+ let filtered = arrow_select::filter::filter(&cached, &mask_array)?;
+ selected_arrays.push(filtered);
+ }
+
+ self.selections.clear();
+ self.local_buffer.clear();
+
+ // For consumers, cleanup batches that have been completely consumed
+ // This reduces the memory usage of the shared cache
+ if self.role == CacheRole::Consumer {
+ self.cleanup_consumed_batches();
+ }
+
+ match selected_arrays.len() {
+ 0 => Ok(new_empty_array(self.inner.get_data_type())),
+ 1 => Ok(selected_arrays.into_iter().next().unwrap()),
+ _ => Ok(arrow_select::concat::concat(
+ &selected_arrays
+ .iter()
+ .map(|a| a.as_ref())
+ .collect::<Vec<_>>(),
+ )?),
+ }
+ }
+
+ fn get_def_levels(&self) -> Option<&[i16]> {
+ None // we don't allow nullable parent for now.
Review Comment:
nested columns not support yet
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]