This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new df691d5be Extract IPC ArrayReader struct (#4259)
df691d5be is described below

commit df691d5be14ea334e1d541697457291ba0796c52
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Tue May 23 11:38:36 2023 +0100

    Extract IPC ArrayReader struct (#4259)
    
    * Extract IPC ArrayReader struct
    
    * Review feedback
---
 arrow-ipc/src/reader.rs | 512 +++++++++++++++++-------------------------------
 1 file changed, 175 insertions(+), 337 deletions(-)

diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs
index 162e92914..cabf81fc2 100644
--- a/arrow-ipc/src/reader.rs
+++ b/arrow-ipc/src/reader.rs
@@ -21,6 +21,7 @@
 //! however the `FileReader` expects a reader that supports `Seek`ing
 
 use arrow_buffer::i256;
+use flatbuffers::VectorIter;
 use std::collections::HashMap;
 use std::fmt;
 use std::io::{BufReader, Read, Seek, SeekFrom};
@@ -33,7 +34,7 @@ use arrow_data::ArrayData;
 use arrow_schema::*;
 
 use crate::compression::CompressionCodec;
-use crate::CONTINUATION_MARKER;
+use crate::{FieldNode, MetadataVersion, CONTINUATION_MARKER};
 use DataType::*;
 
 /// Read a buffer based on offset and length
@@ -48,7 +49,7 @@ use DataType::*;
 fn read_buffer(
     buf: &crate::Buffer,
     a_data: &Buffer,
-    compression_codec: &Option<CompressionCodec>,
+    compression_codec: Option<CompressionCodec>,
 ) -> Result<Buffer, ArrowError> {
     let start_offset = buf.offset() as usize;
     let buf_data = a_data.slice_with_length(start_offset, buf.length() as 
usize);
@@ -68,122 +69,46 @@ fn read_buffer(
 ///     - check if the bit width of non-64-bit numbers is 64, and
 ///     - read the buffer as 64-bit (signed integer or float), and
 ///     - cast the 64-bit array to the appropriate data type
-#[allow(clippy::too_many_arguments)]
-fn create_array(
-    nodes: flatbuffers::Vector<'_, crate::FieldNode>,
-    field: &Field,
-    data: &Buffer,
-    buffers: flatbuffers::Vector<'_, crate::Buffer>,
-    dictionaries_by_id: &HashMap<i64, ArrayRef>,
-    mut node_index: usize,
-    mut buffer_index: usize,
-    compression_codec: &Option<CompressionCodec>,
-    metadata: &crate::MetadataVersion,
-) -> Result<(ArrayRef, usize, usize), ArrowError> {
+fn create_array(reader: &mut ArrayReader, field: &Field) -> Result<ArrayRef, 
ArrowError> {
     let data_type = field.data_type();
-    let array = match data_type {
-        Utf8 | Binary | LargeBinary | LargeUtf8 => {
-            let array = create_primitive_array(
-                nodes.get(node_index),
-                data_type,
-                &[
-                    read_buffer(buffers.get(buffer_index), data, 
compression_codec)?,
-                    read_buffer(buffers.get(buffer_index + 1), data, 
compression_codec)?,
-                    read_buffer(buffers.get(buffer_index + 2), data, 
compression_codec)?,
-                ],
-            )?;
-            node_index += 1;
-            buffer_index += 3;
-            array
-        }
-        FixedSizeBinary(_) => {
-            let array = create_primitive_array(
-                nodes.get(node_index),
-                data_type,
-                &[
-                    read_buffer(buffers.get(buffer_index), data, 
compression_codec)?,
-                    read_buffer(buffers.get(buffer_index + 1), data, 
compression_codec)?,
-                ],
-            )?;
-            node_index += 1;
-            buffer_index += 2;
-            array
-        }
+    match data_type {
+        Utf8 | Binary | LargeBinary | LargeUtf8 => create_primitive_array(
+            reader.next_node(field)?,
+            data_type,
+            &[
+                reader.next_buffer()?,
+                reader.next_buffer()?,
+                reader.next_buffer()?,
+            ],
+        ),
+        FixedSizeBinary(_) => create_primitive_array(
+            reader.next_node(field)?,
+            data_type,
+            &[reader.next_buffer()?, reader.next_buffer()?],
+        ),
         List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, 
_) => {
-            let list_node = nodes.get(node_index);
-            let list_buffers = [
-                read_buffer(buffers.get(buffer_index), data, 
compression_codec)?,
-                read_buffer(buffers.get(buffer_index + 1), data, 
compression_codec)?,
-            ];
-            node_index += 1;
-            buffer_index += 2;
-            let triple = create_array(
-                nodes,
-                list_field,
-                data,
-                buffers,
-                dictionaries_by_id,
-                node_index,
-                buffer_index,
-                compression_codec,
-                metadata,
-            )?;
-            node_index = triple.1;
-            buffer_index = triple.2;
-
-            create_list_array(list_node, data_type, &list_buffers, triple.0)?
+            let list_node = reader.next_node(field)?;
+            let list_buffers = [reader.next_buffer()?, reader.next_buffer()?];
+            let values = create_array(reader, list_field)?;
+            create_list_array(list_node, data_type, &list_buffers, values)
         }
         FixedSizeList(ref list_field, _) => {
-            let list_node = nodes.get(node_index);
-            let list_buffers = [read_buffer(
-                buffers.get(buffer_index),
-                data,
-                compression_codec,
-            )?];
-            node_index += 1;
-            buffer_index += 1;
-            let triple = create_array(
-                nodes,
-                list_field,
-                data,
-                buffers,
-                dictionaries_by_id,
-                node_index,
-                buffer_index,
-                compression_codec,
-                metadata,
-            )?;
-            node_index = triple.1;
-            buffer_index = triple.2;
-
-            create_list_array(list_node, data_type, &list_buffers, triple.0)?
+            let list_node = reader.next_node(field)?;
+            let list_buffers = [reader.next_buffer()?];
+            let values = create_array(reader, list_field)?;
+            create_list_array(list_node, data_type, &list_buffers, values)
         }
         Struct(struct_fields) => {
-            let struct_node = nodes.get(node_index);
-            let null_buffer =
-                read_buffer(buffers.get(buffer_index), data, 
compression_codec)?;
-            node_index += 1;
-            buffer_index += 1;
+            let struct_node = reader.next_node(field)?;
+            let null_buffer = reader.next_buffer()?;
 
             // read the arrays for each field
             let mut struct_arrays = vec![];
             // TODO investigate whether just knowing the number of buffers 
could
             // still work
             for struct_field in struct_fields {
-                let triple = create_array(
-                    nodes,
-                    struct_field,
-                    data,
-                    buffers,
-                    dictionaries_by_id,
-                    node_index,
-                    buffer_index,
-                    compression_codec,
-                    metadata,
-                )?;
-                node_index = triple.1;
-                buffer_index = triple.2;
-                struct_arrays.push((struct_field.clone(), triple.0));
+                let child = create_array(reader, struct_field)?;
+                struct_arrays.push((struct_field.clone(), child));
             }
             let null_count = struct_node.null_count() as usize;
             let struct_array = if null_count > 0 {
@@ -192,101 +117,61 @@ fn create_array(
             } else {
                 StructArray::from(struct_arrays)
             };
-            Arc::new(struct_array)
+            Ok(Arc::new(struct_array))
         }
         RunEndEncoded(run_ends_field, values_field) => {
-            let run_node = nodes.get(node_index);
-            node_index += 1;
-
-            let run_ends_triple = create_array(
-                nodes,
-                run_ends_field,
-                data,
-                buffers,
-                dictionaries_by_id,
-                node_index,
-                buffer_index,
-                compression_codec,
-                metadata,
-            )?;
-            node_index = run_ends_triple.1;
-            buffer_index = run_ends_triple.2;
-
-            let values_triple = create_array(
-                nodes,
-                values_field,
-                data,
-                buffers,
-                dictionaries_by_id,
-                node_index,
-                buffer_index,
-                compression_codec,
-                metadata,
-            )?;
-            node_index = values_triple.1;
-            buffer_index = values_triple.2;
+            let run_node = reader.next_node(field)?;
+            let run_ends = create_array(reader, run_ends_field)?;
+            let values = create_array(reader, values_field)?;
 
             let run_array_length = run_node.length() as usize;
             let data = ArrayData::builder(data_type.clone())
                 .len(run_array_length)
                 .offset(0)
-                .add_child_data(run_ends_triple.0.into_data())
-                .add_child_data(values_triple.0.into_data())
+                .add_child_data(run_ends.into_data())
+                .add_child_data(values.into_data())
                 .build()?;
 
-            make_array(data)
+            Ok(make_array(data))
         }
         // Create dictionary array from RecordBatch
         Dictionary(_, _) => {
-            let index_node = nodes.get(node_index);
-            let index_buffers = [
-                read_buffer(buffers.get(buffer_index), data, 
compression_codec)?,
-                read_buffer(buffers.get(buffer_index + 1), data, 
compression_codec)?,
-            ];
+            let index_node = reader.next_node(field)?;
+            let index_buffers = [reader.next_buffer()?, reader.next_buffer()?];
 
             let dict_id = field.dict_id().ok_or_else(|| {
                 ArrowError::IoError(format!("Field {field} does not have dict 
id"))
             })?;
 
-            let value_array = dictionaries_by_id.get(&dict_id).ok_or_else(|| {
-                ArrowError::IoError(format!(
-                    "Cannot find a dictionary batch with dict id: {dict_id}"
-                ))
-            })?;
-            node_index += 1;
-            buffer_index += 2;
+            let value_array =
+                reader.dictionaries_by_id.get(&dict_id).ok_or_else(|| {
+                    ArrowError::IoError(format!(
+                        "Cannot find a dictionary batch with dict id: 
{dict_id}"
+                    ))
+                })?;
 
             create_dictionary_array(
                 index_node,
                 data_type,
                 &index_buffers,
                 value_array.clone(),
-            )?
+            )
         }
         Union(fields, mode) => {
-            let union_node = nodes.get(node_index);
-            node_index += 1;
-
+            let union_node = reader.next_node(field)?;
             let len = union_node.length() as usize;
 
             // In V4, union types has validity bitmap
             // In V5 and later, union types have no validity bitmap
-            if metadata < &crate::MetadataVersion::V5 {
-                read_buffer(buffers.get(buffer_index), data, 
compression_codec)?;
-                buffer_index += 1;
+            if reader.version < MetadataVersion::V5 {
+                reader.next_buffer()?;
             }
 
-            let type_ids: Buffer =
-                read_buffer(buffers.get(buffer_index), data, 
compression_codec)?[..len]
-                    .into();
-
-            buffer_index += 1;
+            let type_ids: Buffer = reader.next_buffer()?[..len].into();
 
             let value_offsets = match mode {
                 UnionMode::Dense => {
-                    let buffer =
-                        read_buffer(buffers.get(buffer_index), data, 
compression_codec)?;
-                    buffer_index += 1;
+                    let buffer = reader.next_buffer()?;
                     Some(buffer[..len * 4].into())
                 }
                 UnionMode::Sparse => None,
@@ -296,30 +181,16 @@ fn create_array(
             let mut ids = Vec::with_capacity(fields.len());
 
             for (id, field) in fields.iter() {
-                let triple = create_array(
-                    nodes,
-                    field,
-                    data,
-                    buffers,
-                    dictionaries_by_id,
-                    node_index,
-                    buffer_index,
-                    compression_codec,
-                    metadata,
-                )?;
-
-                node_index = triple.1;
-                buffer_index = triple.2;
-
-                children.push((field.as_ref().clone(), triple.0));
+                let child = create_array(reader, field)?;
+                children.push((field.as_ref().clone(), child));
                 ids.push(id);
             }
 
             let array = UnionArray::try_new(&ids, type_ids, value_offsets, 
children)?;
-            Arc::new(array)
+            Ok(Arc::new(array))
         }
         Null => {
-            let node = nodes.get(node_index);
+            let node = reader.next_node(field)?;
             let length = node.length();
             let null_count = node.null_count();
 
@@ -334,125 +205,21 @@ fn create_array(
                 .offset(0)
                 .build()
                 .unwrap();
-            node_index += 1;
             // no buffer increases
-            make_array(data)
-        }
-        _ => {
-            if nodes.len() <= node_index {
-                return Err(ArrowError::IoError(format!(
-                    "Invalid data for schema. {} refers to node index {} but 
only {} in schema",
-                    field, node_index, nodes.len()
-                )));
-            }
-            let array = create_primitive_array(
-                nodes.get(node_index),
-                data_type,
-                &[
-                    read_buffer(buffers.get(buffer_index), data, 
compression_codec)?,
-                    read_buffer(buffers.get(buffer_index + 1), data, 
compression_codec)?,
-                ],
-            )?;
-            node_index += 1;
-            buffer_index += 2;
-            array
-        }
-    };
-    Ok((array, node_index, buffer_index))
-}
-
-/// Skip fields based on data types to advance `node_index` and `buffer_index`.
-/// This function should be called when doing projection in fn 
`read_record_batch`.
-/// The advancement logic references fn `create_array`.
-fn skip_field(
-    data_type: &DataType,
-    mut node_index: usize,
-    mut buffer_index: usize,
-) -> Result<(usize, usize), ArrowError> {
-    match data_type {
-        Utf8 | Binary | LargeBinary | LargeUtf8 => {
-            node_index += 1;
-            buffer_index += 3;
-        }
-        FixedSizeBinary(_) => {
-            node_index += 1;
-            buffer_index += 2;
-        }
-        List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, 
_) => {
-            node_index += 1;
-            buffer_index += 2;
-            let tuple = skip_field(list_field.data_type(), node_index, 
buffer_index)?;
-            node_index = tuple.0;
-            buffer_index = tuple.1;
+            Ok(Arc::new(NullArray::from(data)))
         }
-        FixedSizeList(ref list_field, _) => {
-            node_index += 1;
-            buffer_index += 1;
-            let tuple = skip_field(list_field.data_type(), node_index, 
buffer_index)?;
-            node_index = tuple.0;
-            buffer_index = tuple.1;
-        }
-        Struct(struct_fields) => {
-            node_index += 1;
-            buffer_index += 1;
-
-            // skip for each field
-            for struct_field in struct_fields {
-                let tuple =
-                    skip_field(struct_field.data_type(), node_index, 
buffer_index)?;
-                node_index = tuple.0;
-                buffer_index = tuple.1;
-            }
-        }
-        RunEndEncoded(run_ends_field, values_field) => {
-            node_index += 1;
-
-            let tuple = skip_field(run_ends_field.data_type(), node_index, 
buffer_index)?;
-            node_index = tuple.0;
-            buffer_index = tuple.1;
-
-            let tuple = skip_field(values_field.data_type(), node_index, 
buffer_index)?;
-            node_index = tuple.0;
-            buffer_index = tuple.1;
-        }
-        Dictionary(_, _) => {
-            node_index += 1;
-            buffer_index += 2;
-        }
-        Union(fields, mode) => {
-            node_index += 1;
-            buffer_index += 1;
-
-            match mode {
-                UnionMode::Dense => {
-                    buffer_index += 1;
-                }
-                UnionMode::Sparse => {}
-            };
-
-            for (_, field) in fields.iter() {
-                let tuple = skip_field(field.data_type(), node_index, 
buffer_index)?;
-
-                node_index = tuple.0;
-                buffer_index = tuple.1;
-            }
-        }
-        Null => {
-            node_index += 1;
-            // no buffer increases
-        }
-        _ => {
-            node_index += 1;
-            buffer_index += 2;
-        }
-    };
-    Ok((node_index, buffer_index))
+        _ => create_primitive_array(
+            reader.next_node(field)?,
+            data_type,
+            &[reader.next_buffer()?, reader.next_buffer()?],
+        ),
+    }
 }
 
 /// Reads the correct number of buffers based on data type and null_count, and 
creates a
 /// primitive array ref
 fn create_primitive_array(
-    field_node: &crate::FieldNode,
+    field_node: &FieldNode,
     data_type: &DataType,
     buffers: &[Buffer],
 ) -> Result<ArrayRef, ArrowError> {
@@ -628,6 +395,100 @@ fn create_dictionary_array(
     }
 }
 
+/// State for decoding arrays from an encoded [`RecordBatch`]
+struct ArrayReader<'a> {
+    /// Decoded dictionaries indexed by dictionary id
+    dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
+    /// Optional compression codec
+    compression: Option<CompressionCodec>,
+    /// The format version
+    version: MetadataVersion,
+    /// The raw data buffer
+    data: &'a Buffer,
+    /// The fields comprising this array
+    nodes: VectorIter<'a, FieldNode>,
+    /// The buffers comprising this array
+    buffers: VectorIter<'a, crate::Buffer>,
+}
+
+impl<'a> ArrayReader<'a> {
+    fn next_buffer(&mut self) -> Result<Buffer, ArrowError> {
+        read_buffer(self.buffers.next().unwrap(), self.data, self.compression)
+    }
+
+    fn skip_buffer(&mut self) {
+        self.buffers.next().unwrap();
+    }
+
+    fn next_node(&mut self, field: &Field) -> Result<&'a FieldNode, 
ArrowError> {
+        self.nodes.next().ok_or_else(|| {
+            ArrowError::IoError(format!(
+                "Invalid data for schema. {} refers to node not found in 
schema",
+                field
+            ))
+        })
+    }
+
+    fn skip_field(&mut self, field: &Field) -> Result<(), ArrowError> {
+        self.next_node(field)?;
+
+        match field.data_type() {
+            Utf8 | Binary | LargeBinary | LargeUtf8 => {
+                for _ in 0..3 {
+                    self.skip_buffer()
+                }
+            }
+            FixedSizeBinary(_) => {
+                self.skip_buffer();
+                self.skip_buffer();
+            }
+            List(list_field) | LargeList(list_field) | Map(list_field, _) => {
+                self.skip_buffer();
+                self.skip_buffer();
+                self.skip_field(list_field)?;
+            }
+            FixedSizeList(list_field, _) => {
+                self.skip_buffer();
+                self.skip_field(list_field)?;
+            }
+            Struct(struct_fields) => {
+                self.skip_buffer();
+
+                // skip for each field
+                for struct_field in struct_fields {
+                    self.skip_field(struct_field)?
+                }
+            }
+            RunEndEncoded(run_ends_field, values_field) => {
+                self.skip_field(run_ends_field)?;
+                self.skip_field(values_field)?;
+            }
+            Dictionary(_, _) => {
+                self.skip_buffer(); // Nulls
+                self.skip_buffer(); // Indices
+            }
+            Union(fields, mode) => {
+                self.skip_buffer(); // Nulls
+
+                match mode {
+                    UnionMode::Dense => self.skip_buffer(),
+                    UnionMode::Sparse => {}
+                };
+
+                for (_, field) in fields.iter() {
+                    self.skip_field(field)?
+                }
+            }
+            Null => {} // No buffer increases
+            _ => {
+                self.skip_buffer();
+                self.skip_buffer();
+            }
+        };
+        Ok(())
+    }
+}
+
 /// Creates a record batch from binary data using the `crate::RecordBatch` 
indexes and the `Schema`
 pub fn read_record_batch(
     buf: &Buffer,
@@ -635,7 +496,7 @@ pub fn read_record_batch(
     schema: SchemaRef,
     dictionaries_by_id: &HashMap<i64, ArrayRef>,
     projection: Option<&[usize]>,
-    metadata: &crate::MetadataVersion,
+    metadata: &MetadataVersion,
 ) -> Result<RecordBatch, ArrowError> {
     let buffers = batch.buffers().ok_or_else(|| {
         ArrowError::IoError("Unable to get buffers from IPC 
RecordBatch".to_string())
@@ -644,13 +505,18 @@ pub fn read_record_batch(
         ArrowError::IoError("Unable to get field nodes from IPC 
RecordBatch".to_string())
     })?;
     let batch_compression = batch.compression();
-    let compression_codec: Option<CompressionCodec> = batch_compression
+    let compression = batch_compression
         .map(|batch_compression| batch_compression.codec().try_into())
         .transpose()?;
 
-    // keep track of buffer and node index, the functions that create arrays 
mutate these
-    let mut buffer_index = 0;
-    let mut node_index = 0;
+    let mut reader = ArrayReader {
+        dictionaries_by_id,
+        compression,
+        version: *metadata,
+        data: buf,
+        nodes: field_nodes.iter(),
+        buffers: buffers.iter(),
+    };
 
     let options = RecordBatchOptions::new().with_row_count(Some(batch.length() 
as usize));
 
@@ -660,26 +526,10 @@ pub fn read_record_batch(
         for (idx, field) in schema.fields().iter().enumerate() {
             // Create array for projected field
             if let Some(proj_idx) = projection.iter().position(|p| p == &idx) {
-                let triple = create_array(
-                    field_nodes,
-                    field,
-                    buf,
-                    buffers,
-                    dictionaries_by_id,
-                    node_index,
-                    buffer_index,
-                    &compression_codec,
-                    metadata,
-                )?;
-                node_index = triple.1;
-                buffer_index = triple.2;
-                arrays.push((proj_idx, triple.0));
+                let child = create_array(&mut reader, field)?;
+                arrays.push((proj_idx, child));
             } else {
-                // Skip field.
-                // This must be called to advance `node_index` and 
`buffer_index`.
-                let tuple = skip_field(field.data_type(), node_index, 
buffer_index)?;
-                node_index = tuple.0;
-                buffer_index = tuple.1;
+                reader.skip_field(field)?;
             }
         }
         arrays.sort_by_key(|t| t.0);
@@ -689,25 +539,13 @@ pub fn read_record_batch(
             &options,
         )
     } else {
-        let mut arrays = vec![];
+        let mut children = vec![];
         // keep track of index as lists require more than one node
         for field in schema.fields() {
-            let triple = create_array(
-                field_nodes,
-                field,
-                buf,
-                buffers,
-                dictionaries_by_id,
-                node_index,
-                buffer_index,
-                &compression_codec,
-                metadata,
-            )?;
-            node_index = triple.1;
-            buffer_index = triple.2;
-            arrays.push(triple.0);
+            let child = create_array(&mut reader, field)?;
+            children.push(child);
         }
-        RecordBatch::try_new_with_options(schema, arrays, &options)
+        RecordBatch::try_new_with_options(schema, children, &options)
     }
 }
 

Reply via email to