This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new df691d5be Extract IPC ArrayReader struct (#4259)
df691d5be is described below
commit df691d5be14ea334e1d541697457291ba0796c52
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Tue May 23 11:38:36 2023 +0100
Extract IPC ArrayReader struct (#4259)
* Extract IPC ArrayReader struct
* Review feedback
---
arrow-ipc/src/reader.rs | 512 +++++++++++++++++-------------------------------
1 file changed, 175 insertions(+), 337 deletions(-)
diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs
index 162e92914..cabf81fc2 100644
--- a/arrow-ipc/src/reader.rs
+++ b/arrow-ipc/src/reader.rs
@@ -21,6 +21,7 @@
//! however the `FileReader` expects a reader that supports `Seek`ing
use arrow_buffer::i256;
+use flatbuffers::VectorIter;
use std::collections::HashMap;
use std::fmt;
use std::io::{BufReader, Read, Seek, SeekFrom};
@@ -33,7 +34,7 @@ use arrow_data::ArrayData;
use arrow_schema::*;
use crate::compression::CompressionCodec;
-use crate::CONTINUATION_MARKER;
+use crate::{FieldNode, MetadataVersion, CONTINUATION_MARKER};
use DataType::*;
/// Read a buffer based on offset and length
@@ -48,7 +49,7 @@ use DataType::*;
fn read_buffer(
buf: &crate::Buffer,
a_data: &Buffer,
- compression_codec: &Option<CompressionCodec>,
+ compression_codec: Option<CompressionCodec>,
) -> Result<Buffer, ArrowError> {
let start_offset = buf.offset() as usize;
let buf_data = a_data.slice_with_length(start_offset, buf.length() as
usize);
@@ -68,122 +69,46 @@ fn read_buffer(
/// - check if the bit width of non-64-bit numbers is 64, and
/// - read the buffer as 64-bit (signed integer or float), and
/// - cast the 64-bit array to the appropriate data type
-#[allow(clippy::too_many_arguments)]
-fn create_array(
- nodes: flatbuffers::Vector<'_, crate::FieldNode>,
- field: &Field,
- data: &Buffer,
- buffers: flatbuffers::Vector<'_, crate::Buffer>,
- dictionaries_by_id: &HashMap<i64, ArrayRef>,
- mut node_index: usize,
- mut buffer_index: usize,
- compression_codec: &Option<CompressionCodec>,
- metadata: &crate::MetadataVersion,
-) -> Result<(ArrayRef, usize, usize), ArrowError> {
+fn create_array(reader: &mut ArrayReader, field: &Field) -> Result<ArrayRef,
ArrowError> {
let data_type = field.data_type();
- let array = match data_type {
- Utf8 | Binary | LargeBinary | LargeUtf8 => {
- let array = create_primitive_array(
- nodes.get(node_index),
- data_type,
- &[
- read_buffer(buffers.get(buffer_index), data,
compression_codec)?,
- read_buffer(buffers.get(buffer_index + 1), data,
compression_codec)?,
- read_buffer(buffers.get(buffer_index + 2), data,
compression_codec)?,
- ],
- )?;
- node_index += 1;
- buffer_index += 3;
- array
- }
- FixedSizeBinary(_) => {
- let array = create_primitive_array(
- nodes.get(node_index),
- data_type,
- &[
- read_buffer(buffers.get(buffer_index), data,
compression_codec)?,
- read_buffer(buffers.get(buffer_index + 1), data,
compression_codec)?,
- ],
- )?;
- node_index += 1;
- buffer_index += 2;
- array
- }
+ match data_type {
+ Utf8 | Binary | LargeBinary | LargeUtf8 => create_primitive_array(
+ reader.next_node(field)?,
+ data_type,
+ &[
+ reader.next_buffer()?,
+ reader.next_buffer()?,
+ reader.next_buffer()?,
+ ],
+ ),
+ FixedSizeBinary(_) => create_primitive_array(
+ reader.next_node(field)?,
+ data_type,
+ &[reader.next_buffer()?, reader.next_buffer()?],
+ ),
List(ref list_field) | LargeList(ref list_field) | Map(ref list_field,
_) => {
- let list_node = nodes.get(node_index);
- let list_buffers = [
- read_buffer(buffers.get(buffer_index), data,
compression_codec)?,
- read_buffer(buffers.get(buffer_index + 1), data,
compression_codec)?,
- ];
- node_index += 1;
- buffer_index += 2;
- let triple = create_array(
- nodes,
- list_field,
- data,
- buffers,
- dictionaries_by_id,
- node_index,
- buffer_index,
- compression_codec,
- metadata,
- )?;
- node_index = triple.1;
- buffer_index = triple.2;
-
- create_list_array(list_node, data_type, &list_buffers, triple.0)?
+ let list_node = reader.next_node(field)?;
+ let list_buffers = [reader.next_buffer()?, reader.next_buffer()?];
+ let values = create_array(reader, list_field)?;
+ create_list_array(list_node, data_type, &list_buffers, values)
}
FixedSizeList(ref list_field, _) => {
- let list_node = nodes.get(node_index);
- let list_buffers = [read_buffer(
- buffers.get(buffer_index),
- data,
- compression_codec,
- )?];
- node_index += 1;
- buffer_index += 1;
- let triple = create_array(
- nodes,
- list_field,
- data,
- buffers,
- dictionaries_by_id,
- node_index,
- buffer_index,
- compression_codec,
- metadata,
- )?;
- node_index = triple.1;
- buffer_index = triple.2;
-
- create_list_array(list_node, data_type, &list_buffers, triple.0)?
+ let list_node = reader.next_node(field)?;
+ let list_buffers = [reader.next_buffer()?];
+ let values = create_array(reader, list_field)?;
+ create_list_array(list_node, data_type, &list_buffers, values)
}
Struct(struct_fields) => {
- let struct_node = nodes.get(node_index);
- let null_buffer =
- read_buffer(buffers.get(buffer_index), data,
compression_codec)?;
- node_index += 1;
- buffer_index += 1;
+ let struct_node = reader.next_node(field)?;
+ let null_buffer = reader.next_buffer()?;
// read the arrays for each field
let mut struct_arrays = vec![];
// TODO investigate whether just knowing the number of buffers
could
// still work
for struct_field in struct_fields {
- let triple = create_array(
- nodes,
- struct_field,
- data,
- buffers,
- dictionaries_by_id,
- node_index,
- buffer_index,
- compression_codec,
- metadata,
- )?;
- node_index = triple.1;
- buffer_index = triple.2;
- struct_arrays.push((struct_field.clone(), triple.0));
+ let child = create_array(reader, struct_field)?;
+ struct_arrays.push((struct_field.clone(), child));
}
let null_count = struct_node.null_count() as usize;
let struct_array = if null_count > 0 {
@@ -192,101 +117,61 @@ fn create_array(
} else {
StructArray::from(struct_arrays)
};
- Arc::new(struct_array)
+ Ok(Arc::new(struct_array))
}
RunEndEncoded(run_ends_field, values_field) => {
- let run_node = nodes.get(node_index);
- node_index += 1;
-
- let run_ends_triple = create_array(
- nodes,
- run_ends_field,
- data,
- buffers,
- dictionaries_by_id,
- node_index,
- buffer_index,
- compression_codec,
- metadata,
- )?;
- node_index = run_ends_triple.1;
- buffer_index = run_ends_triple.2;
-
- let values_triple = create_array(
- nodes,
- values_field,
- data,
- buffers,
- dictionaries_by_id,
- node_index,
- buffer_index,
- compression_codec,
- metadata,
- )?;
- node_index = values_triple.1;
- buffer_index = values_triple.2;
+ let run_node = reader.next_node(field)?;
+ let run_ends = create_array(reader, run_ends_field)?;
+ let values = create_array(reader, values_field)?;
let run_array_length = run_node.length() as usize;
let data = ArrayData::builder(data_type.clone())
.len(run_array_length)
.offset(0)
- .add_child_data(run_ends_triple.0.into_data())
- .add_child_data(values_triple.0.into_data())
+ .add_child_data(run_ends.into_data())
+ .add_child_data(values.into_data())
.build()?;
- make_array(data)
+ Ok(make_array(data))
}
// Create dictionary array from RecordBatch
Dictionary(_, _) => {
- let index_node = nodes.get(node_index);
- let index_buffers = [
- read_buffer(buffers.get(buffer_index), data,
compression_codec)?,
- read_buffer(buffers.get(buffer_index + 1), data,
compression_codec)?,
- ];
+ let index_node = reader.next_node(field)?;
+ let index_buffers = [reader.next_buffer()?, reader.next_buffer()?];
let dict_id = field.dict_id().ok_or_else(|| {
ArrowError::IoError(format!("Field {field} does not have dict
id"))
})?;
- let value_array = dictionaries_by_id.get(&dict_id).ok_or_else(|| {
- ArrowError::IoError(format!(
- "Cannot find a dictionary batch with dict id: {dict_id}"
- ))
- })?;
- node_index += 1;
- buffer_index += 2;
+ let value_array =
+ reader.dictionaries_by_id.get(&dict_id).ok_or_else(|| {
+ ArrowError::IoError(format!(
+ "Cannot find a dictionary batch with dict id:
{dict_id}"
+ ))
+ })?;
create_dictionary_array(
index_node,
data_type,
&index_buffers,
value_array.clone(),
- )?
+ )
}
Union(fields, mode) => {
- let union_node = nodes.get(node_index);
- node_index += 1;
-
+ let union_node = reader.next_node(field)?;
let len = union_node.length() as usize;
// In V4, union types has validity bitmap
// In V5 and later, union types have no validity bitmap
- if metadata < &crate::MetadataVersion::V5 {
- read_buffer(buffers.get(buffer_index), data,
compression_codec)?;
- buffer_index += 1;
+ if reader.version < MetadataVersion::V5 {
+ reader.next_buffer()?;
}
- let type_ids: Buffer =
- read_buffer(buffers.get(buffer_index), data,
compression_codec)?[..len]
- .into();
-
- buffer_index += 1;
+ let type_ids: Buffer = reader.next_buffer()?[..len].into();
let value_offsets = match mode {
UnionMode::Dense => {
- let buffer =
- read_buffer(buffers.get(buffer_index), data,
compression_codec)?;
- buffer_index += 1;
+ let buffer = reader.next_buffer()?;
Some(buffer[..len * 4].into())
}
UnionMode::Sparse => None,
@@ -296,30 +181,16 @@ fn create_array(
let mut ids = Vec::with_capacity(fields.len());
for (id, field) in fields.iter() {
- let triple = create_array(
- nodes,
- field,
- data,
- buffers,
- dictionaries_by_id,
- node_index,
- buffer_index,
- compression_codec,
- metadata,
- )?;
-
- node_index = triple.1;
- buffer_index = triple.2;
-
- children.push((field.as_ref().clone(), triple.0));
+ let child = create_array(reader, field)?;
+ children.push((field.as_ref().clone(), child));
ids.push(id);
}
let array = UnionArray::try_new(&ids, type_ids, value_offsets,
children)?;
- Arc::new(array)
+ Ok(Arc::new(array))
}
Null => {
- let node = nodes.get(node_index);
+ let node = reader.next_node(field)?;
let length = node.length();
let null_count = node.null_count();
@@ -334,125 +205,21 @@ fn create_array(
.offset(0)
.build()
.unwrap();
- node_index += 1;
// no buffer increases
- make_array(data)
- }
- _ => {
- if nodes.len() <= node_index {
- return Err(ArrowError::IoError(format!(
- "Invalid data for schema. {} refers to node index {} but
only {} in schema",
- field, node_index, nodes.len()
- )));
- }
- let array = create_primitive_array(
- nodes.get(node_index),
- data_type,
- &[
- read_buffer(buffers.get(buffer_index), data,
compression_codec)?,
- read_buffer(buffers.get(buffer_index + 1), data,
compression_codec)?,
- ],
- )?;
- node_index += 1;
- buffer_index += 2;
- array
- }
- };
- Ok((array, node_index, buffer_index))
-}
-
-/// Skip fields based on data types to advance `node_index` and `buffer_index`.
-/// This function should be called when doing projection in fn
`read_record_batch`.
-/// The advancement logic references fn `create_array`.
-fn skip_field(
- data_type: &DataType,
- mut node_index: usize,
- mut buffer_index: usize,
-) -> Result<(usize, usize), ArrowError> {
- match data_type {
- Utf8 | Binary | LargeBinary | LargeUtf8 => {
- node_index += 1;
- buffer_index += 3;
- }
- FixedSizeBinary(_) => {
- node_index += 1;
- buffer_index += 2;
- }
- List(ref list_field) | LargeList(ref list_field) | Map(ref list_field,
_) => {
- node_index += 1;
- buffer_index += 2;
- let tuple = skip_field(list_field.data_type(), node_index,
buffer_index)?;
- node_index = tuple.0;
- buffer_index = tuple.1;
+ Ok(Arc::new(NullArray::from(data)))
}
- FixedSizeList(ref list_field, _) => {
- node_index += 1;
- buffer_index += 1;
- let tuple = skip_field(list_field.data_type(), node_index,
buffer_index)?;
- node_index = tuple.0;
- buffer_index = tuple.1;
- }
- Struct(struct_fields) => {
- node_index += 1;
- buffer_index += 1;
-
- // skip for each field
- for struct_field in struct_fields {
- let tuple =
- skip_field(struct_field.data_type(), node_index,
buffer_index)?;
- node_index = tuple.0;
- buffer_index = tuple.1;
- }
- }
- RunEndEncoded(run_ends_field, values_field) => {
- node_index += 1;
-
- let tuple = skip_field(run_ends_field.data_type(), node_index,
buffer_index)?;
- node_index = tuple.0;
- buffer_index = tuple.1;
-
- let tuple = skip_field(values_field.data_type(), node_index,
buffer_index)?;
- node_index = tuple.0;
- buffer_index = tuple.1;
- }
- Dictionary(_, _) => {
- node_index += 1;
- buffer_index += 2;
- }
- Union(fields, mode) => {
- node_index += 1;
- buffer_index += 1;
-
- match mode {
- UnionMode::Dense => {
- buffer_index += 1;
- }
- UnionMode::Sparse => {}
- };
-
- for (_, field) in fields.iter() {
- let tuple = skip_field(field.data_type(), node_index,
buffer_index)?;
-
- node_index = tuple.0;
- buffer_index = tuple.1;
- }
- }
- Null => {
- node_index += 1;
- // no buffer increases
- }
- _ => {
- node_index += 1;
- buffer_index += 2;
- }
- };
- Ok((node_index, buffer_index))
+ _ => create_primitive_array(
+ reader.next_node(field)?,
+ data_type,
+ &[reader.next_buffer()?, reader.next_buffer()?],
+ ),
+ }
}
/// Reads the correct number of buffers based on data type and null_count, and
creates a
/// primitive array ref
fn create_primitive_array(
- field_node: &crate::FieldNode,
+ field_node: &FieldNode,
data_type: &DataType,
buffers: &[Buffer],
) -> Result<ArrayRef, ArrowError> {
@@ -628,6 +395,100 @@ fn create_dictionary_array(
}
}
+/// State for decoding arrays from an encoded [`RecordBatch`]
+struct ArrayReader<'a> {
+ /// Decoded dictionaries indexed by dictionary id
+ dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
+ /// Optional compression codec
+ compression: Option<CompressionCodec>,
+ /// The format version
+ version: MetadataVersion,
+ /// The raw data buffer
+ data: &'a Buffer,
+ /// The fields comprising this array
+ nodes: VectorIter<'a, FieldNode>,
+ /// The buffers comprising this array
+ buffers: VectorIter<'a, crate::Buffer>,
+}
+
+impl<'a> ArrayReader<'a> {
+ fn next_buffer(&mut self) -> Result<Buffer, ArrowError> {
+ read_buffer(self.buffers.next().unwrap(), self.data, self.compression)
+ }
+
+ fn skip_buffer(&mut self) {
+ self.buffers.next().unwrap();
+ }
+
+ fn next_node(&mut self, field: &Field) -> Result<&'a FieldNode,
ArrowError> {
+ self.nodes.next().ok_or_else(|| {
+ ArrowError::IoError(format!(
+ "Invalid data for schema. {} refers to node not found in
schema",
+ field
+ ))
+ })
+ }
+
+ fn skip_field(&mut self, field: &Field) -> Result<(), ArrowError> {
+ self.next_node(field)?;
+
+ match field.data_type() {
+ Utf8 | Binary | LargeBinary | LargeUtf8 => {
+ for _ in 0..3 {
+ self.skip_buffer()
+ }
+ }
+ FixedSizeBinary(_) => {
+ self.skip_buffer();
+ self.skip_buffer();
+ }
+ List(list_field) | LargeList(list_field) | Map(list_field, _) => {
+ self.skip_buffer();
+ self.skip_buffer();
+ self.skip_field(list_field)?;
+ }
+ FixedSizeList(list_field, _) => {
+ self.skip_buffer();
+ self.skip_field(list_field)?;
+ }
+ Struct(struct_fields) => {
+ self.skip_buffer();
+
+ // skip for each field
+ for struct_field in struct_fields {
+ self.skip_field(struct_field)?
+ }
+ }
+ RunEndEncoded(run_ends_field, values_field) => {
+ self.skip_field(run_ends_field)?;
+ self.skip_field(values_field)?;
+ }
+ Dictionary(_, _) => {
+ self.skip_buffer(); // Nulls
+ self.skip_buffer(); // Indices
+ }
+ Union(fields, mode) => {
+ self.skip_buffer(); // Nulls
+
+ match mode {
+ UnionMode::Dense => self.skip_buffer(),
+ UnionMode::Sparse => {}
+ };
+
+ for (_, field) in fields.iter() {
+ self.skip_field(field)?
+ }
+ }
+ Null => {} // No buffer increases
+ _ => {
+ self.skip_buffer();
+ self.skip_buffer();
+ }
+ };
+ Ok(())
+ }
+}
+
/// Creates a record batch from binary data using the `crate::RecordBatch`
indexes and the `Schema`
pub fn read_record_batch(
buf: &Buffer,
@@ -635,7 +496,7 @@ pub fn read_record_batch(
schema: SchemaRef,
dictionaries_by_id: &HashMap<i64, ArrayRef>,
projection: Option<&[usize]>,
- metadata: &crate::MetadataVersion,
+ metadata: &MetadataVersion,
) -> Result<RecordBatch, ArrowError> {
let buffers = batch.buffers().ok_or_else(|| {
ArrowError::IoError("Unable to get buffers from IPC
RecordBatch".to_string())
@@ -644,13 +505,18 @@ pub fn read_record_batch(
ArrowError::IoError("Unable to get field nodes from IPC
RecordBatch".to_string())
})?;
let batch_compression = batch.compression();
- let compression_codec: Option<CompressionCodec> = batch_compression
+ let compression = batch_compression
.map(|batch_compression| batch_compression.codec().try_into())
.transpose()?;
- // keep track of buffer and node index, the functions that create arrays
mutate these
- let mut buffer_index = 0;
- let mut node_index = 0;
+ let mut reader = ArrayReader {
+ dictionaries_by_id,
+ compression,
+ version: *metadata,
+ data: buf,
+ nodes: field_nodes.iter(),
+ buffers: buffers.iter(),
+ };
let options = RecordBatchOptions::new().with_row_count(Some(batch.length()
as usize));
@@ -660,26 +526,10 @@ pub fn read_record_batch(
for (idx, field) in schema.fields().iter().enumerate() {
// Create array for projected field
if let Some(proj_idx) = projection.iter().position(|p| p == &idx) {
- let triple = create_array(
- field_nodes,
- field,
- buf,
- buffers,
- dictionaries_by_id,
- node_index,
- buffer_index,
- &compression_codec,
- metadata,
- )?;
- node_index = triple.1;
- buffer_index = triple.2;
- arrays.push((proj_idx, triple.0));
+ let child = create_array(&mut reader, field)?;
+ arrays.push((proj_idx, child));
} else {
- // Skip field.
- // This must be called to advance `node_index` and
`buffer_index`.
- let tuple = skip_field(field.data_type(), node_index,
buffer_index)?;
- node_index = tuple.0;
- buffer_index = tuple.1;
+ reader.skip_field(field)?;
}
}
arrays.sort_by_key(|t| t.0);
@@ -689,25 +539,13 @@ pub fn read_record_batch(
&options,
)
} else {
- let mut arrays = vec![];
+ let mut children = vec![];
// keep track of index as lists require more than one node
for field in schema.fields() {
- let triple = create_array(
- field_nodes,
- field,
- buf,
- buffers,
- dictionaries_by_id,
- node_index,
- buffer_index,
- &compression_codec,
- metadata,
- )?;
- node_index = triple.1;
- buffer_index = triple.2;
- arrays.push(triple.0);
+ let child = create_array(&mut reader, field)?;
+ children.push(child);
}
- RecordBatch::try_new_with_options(schema, arrays, &options)
+ RecordBatch::try_new_with_options(schema, children, &options)
}
}