itsjunetime commented on code in PR #6690:
URL: https://github.com/apache/arrow-rs/pull/6690#discussion_r1832863993


##########
arrow-ipc/src/writer.rs:
##########
@@ -1462,280 +1391,508 @@ fn get_list_array_buffers<O: OffsetSizeTrait>(data: 
&ArrayData) -> (Buffer, Arra
         );
     }
 
-    let (offsets, original_start_offset, len) = 
reencode_offsets::<O>(&data.buffers()[0], data);
+    let (offsets, original_start_offset, len) = reencode_offsets::<O>(data);
     let child_data = data.child_data()[0].slice(original_start_offset, len);
     (offsets, child_data)
 }
 
-/// Write array data to a vector of bytes
-#[allow(clippy::too_many_arguments)]
-fn write_array_data(
-    array_data: &ArrayData,
-    buffers: &mut Vec<crate::Buffer>,
-    arrow_data: &mut Vec<u8>,
-    nodes: &mut Vec<crate::FieldNode>,
-    offset: i64,
-    num_rows: usize,
-    null_count: usize,
-    compression_codec: Option<CompressionCodec>,
+const DEFAULT_ALIGNMENT: u8 = 64;
+const PADDING: [u8; DEFAULT_ALIGNMENT as usize] = [0; DEFAULT_ALIGNMENT as 
usize];
+
+/// Calculate an alignment boundary and return the number of bytes needed to 
pad to the alignment boundary
+#[inline]
+fn pad_to_alignment(alignment: u8, len: usize) -> usize {
+    let a = usize::from(alignment - 1);
+    ((len + a) & !a) - len
+}
+
+fn chunked_encoded_batch_bytes(
+    batch: &RecordBatch,
     write_options: &IpcWriteOptions,
-) -> Result<i64, ArrowError> {
-    let mut offset = offset;
-    if !matches!(array_data.data_type(), DataType::Null) {
-        nodes.push(crate::FieldNode::new(num_rows as i64, null_count as i64));
-    } else {
-        // NullArray's null_count equals to len, but the `null_count` passed 
in is from ArrayData
-        // where null_count is always 0.
-        nodes.push(crate::FieldNode::new(num_rows as i64, num_rows as i64));
-    }
-    if has_validity_bitmap(array_data.data_type(), write_options) {
-        // write null buffer if exists
-        let null_buffer = match array_data.nulls() {
-            None => {
-                // create a buffer and fill it with valid bits
-                let num_bytes = bit_util::ceil(num_rows, 8);
-                let buffer = MutableBuffer::new(num_bytes);
-                let buffer = buffer.with_bitset(num_bytes, true);
-                buffer.into()
-            }
-            Some(buffer) => buffer.inner().sliced(),
+    max_flight_data_size: usize,
+) -> Result<Vec<EncodedData>, ArrowError> {
+    encode_array_datas(
+        &batch
+            .columns()
+            .iter()
+            .map(ArrayRef::to_data)
+            .collect::<Vec<_>>(),
+        batch.num_rows(),
+        |_, offset| offset.as_union_value(),
+        MessageHeader::RecordBatch,
+        max_flight_data_size,
+        write_options,
+    )
+}
+
+fn get_encoded_arr_batch_size<AD: Borrow<ArrayData>>(
+    iter: impl IntoIterator<Item = AD>,
+    write_options: &IpcWriteOptions,
+) -> Result<usize, ArrowError> {
+    iter.into_iter()
+        .map(|arr| {
+            let arr = arr.borrow();
+            
arr.get_slice_memory_size_with_alignment(Some(write_options.alignment))
+                .map(|size| {
+                    let didnt_count_nulls = arr.nulls().is_none();
+                    let will_write_nulls = 
has_validity_bitmap(arr.data_type(), write_options);
+
+                    if will_write_nulls && didnt_count_nulls {
+                        let null_len = bit_util::ceil(arr.len(), 8);
+                        size + null_len + 
pad_to_alignment(write_options.alignment, null_len)
+                    } else {
+                        size
+                    }
+                })
+        })
+        .sum()
+}
+
+fn encode_array_datas(
+    arr_datas: &[ArrayData],
+    n_rows: usize,
+    encode_root: impl Fn(
+        &mut FlatBufferSizeTracker,
+        WIPOffset<crate::gen::Message::RecordBatch>,
+    ) -> WIPOffset<UnionWIPOffset>,
+    header_type: MessageHeader,
+    mut max_msg_size: usize,
+    write_options: &IpcWriteOptions,
+) -> Result<Vec<EncodedData>, ArrowError> {
+    let mut fbb = FlatBufferSizeTracker::for_dry_run(arr_datas.len());
+    fbb.encode_array_datas(
+        arr_datas,
+        n_rows as i64,
+        &encode_root,
+        header_type,
+        write_options,
+    )?;
+
+    let header_len = fbb.fbb.finished_data().len();
+    max_msg_size = max_msg_size.saturating_sub(header_len).max(1);
+
+    let total_size = get_encoded_arr_batch_size(arr_datas.iter(), 
write_options)?;
+
+    let n_batches = bit_util::ceil(total_size, max_msg_size);
+    let mut out = Vec::with_capacity(n_batches);
+
+    let mut offset = 0;
+    while offset < n_rows.max(1) {
+        let slice_arrays = |len: usize| {
+            arr_datas.iter().map(move |arr| {
+                if len >= arr.len() {
+                    arr.clone()
+                } else {
+                    arr.slice(offset, len)
+                }
+            })
         };
 
-        offset = write_buffer(
-            null_buffer.as_slice(),
-            buffers,
-            arrow_data,
-            offset,
-            compression_codec,
-            write_options.alignment,
-        )?;
-    }
+        let rows_left = n_rows - offset;
+        // TODO? maybe this could be more efficient by continually 
approximating the maximum number
+        // of rows based on (size / n_rows) of the current ArrayData slice 
until we've found the
+        // maximum that can fit? e.g. 'oh, it's 200 bytes and 10 rows, so each 
row is probably 20
+        // bytes - let's do (max_size / 20) rows and see if that fits'
+        let length = (1..=rows_left)
+            .find(|len| {
+                // If we've exhausted the available length of the array datas, 
then just return -
+                // we've got it.
+                if offset + len > n_rows {
+                    return true;
+                }
 
-    let data_type = array_data.data_type();
-    if matches!(data_type, DataType::Binary | DataType::Utf8) {
-        let (offsets, values) = get_byte_array_buffers::<i32>(array_data);
-        for buffer in [offsets, values] {
-            offset = write_buffer(
-                buffer.as_slice(),
-                buffers,
-                arrow_data,
-                offset,
-                compression_codec,
-                write_options.alignment,
+                // we can unwrap this here b/c this only errors on malformed 
buffer-type/data-type
+                // combinations, and if any of these arrays had that, this 
function would've
+                // already short-circuited on an earlier call of this function
+                get_encoded_arr_batch_size(slice_arrays(*len), 
write_options).unwrap()
+                    > max_msg_size
+            })
+            // If no rows fit in the given max size, we want to try to get the 
data across anyways,
+            // so that just means doing a single row. Calling `max(2)` is how 
we ensure that - if
+            // the very first item would go over the max size, giving us a 
length of 0, we want to
+            // set this to `2` so that taking away 1 leaves us with one row to 
encode.
+            .map(|len| len.max(2) - 1)
+            // If all rows can comfortably fit in this given size, then just 
get them all
+            .unwrap_or(rows_left);
+
+        // We could get into a situtation where we were given all 0-row arrays 
to be sent over
+        // flight - we do need to send a flight message to show that there is 
no data, but we also
+        // can't have `length` be 0 at this point because it could also be 
that all rows are too
+        // large to send with the provided limits and so we just want to try 
to send one now
+        // anyways, so the checks in this fn are just how we cover our bases 
there.
+        let new_arrs = slice_arrays(length).collect::<Vec<_>>();
+
+        // If we've got more than one row to encode or if we have 0 rows to 
encode but we haven't
+        // encoded anything yet, then continue with encoding. We don't need to 
do encoding, though,
+        // if we've already encoded some rows and there's no rows left
+        if length != 0 || offset == 0 {
+            fbb.reset_for_real_run();
+            fbb.encode_array_datas(
+                &new_arrs,
+                length as i64,
+                &encode_root,
+                header_type,
+                write_options,
             )?;
+
+            let finished_data = fbb.fbb.finished_data();
+
+            out.push(EncodedData {
+                ipc_message: finished_data.to_vec(),
+                arrow_data: fbb.arrow_data.clone(),
+            });
         }
-    } else if matches!(data_type, DataType::BinaryView | DataType::Utf8View) {
-        // Slicing the views buffer is safe and easy,
-        // but pruning unneeded data buffers is much more nuanced since it's 
complicated to prove that no views reference the pruned buffers
-        //
-        // Current implementation just serialize the raw arrays as given and 
not try to optimize anything.
-        // If users wants to "compact" the arrays prior to sending them over 
IPC,
-        // they should consider the gc API suggested in #5513
-        for buffer in array_data.buffers() {
-            offset = write_buffer(
-                buffer.as_slice(),
-                buffers,
-                arrow_data,
-                offset,
-                compression_codec,
-                write_options.alignment,
-            )?;
+
+        // If length == 0, that means they gave us ArrayData with no rows, so 
a single iteration is
+        // always sufficient.
+        if length == 0 {
+            break;
         }
-    } else if matches!(data_type, DataType::LargeBinary | DataType::LargeUtf8) 
{
-        let (offsets, values) = get_byte_array_buffers::<i64>(array_data);
-        for buffer in [offsets, values] {
-            offset = write_buffer(
-                buffer.as_slice(),
-                buffers,
-                arrow_data,
-                offset,
-                compression_codec,
-                write_options.alignment,
+
+        offset += length;
+    }
+
+    Ok(out)
+}
+
+/// A struct to help ensure that the size of encoded flight messages never 
goes over a provided
+/// limit (except in ridiculous cases like a limit of 1 byte). The way it does 
so is by first
+/// running through a provided slice of [`ArrayData`], producing an IPC header 
message for that
+/// slice, and then subtracting the size of that generated header from the 
message limit it has
+/// been given. Because IPC header message sizes don't change due to a 
different amount of rows,
+/// this header size will stay consistent throughout the entire time that we 
have to transmit a
+/// chunk of rows, so we can just subtract it from the overall limit and use 
that to check
+/// different slices of `ArrayData` against to know how many to transmit each 
time.
+///
+/// This whole process is done in [`encode_array_datas()`] above
+#[derive(Default)]
+struct FlatBufferSizeTracker<'fbb> {
+    // the builder and backing flatbuffer that we use to write the arrow data 
into.
+    fbb: FlatBufferBuilder<'fbb>,
+    // tracks the data in `arrow_data` - `buffers` contains the offsets and 
length of different
+    // buffers encoded within the big chunk that is `arrow_data`.
+    buffers: Vec<crate::Buffer>,
+    // the raw array data that we need to send across the wire
+    arrow_data: Vec<u8>,
+    nodes: Vec<crate::FieldNode>,
+    dry_run: bool,
+}
+
+impl<'fbb> FlatBufferSizeTracker<'fbb> {
+    /// Preferred initializer, as this should always be used with a dry-run 
before a real run to
+    /// figure out the size of the IPC header.
+    #[must_use]
+    fn for_dry_run(capacity: usize) -> Self {
+        Self {
+            dry_run: true,
+            buffers: Vec::with_capacity(capacity),
+            nodes: Vec::with_capacity(capacity),
+            ..Self::default()
+        }
+    }
+
+    /// Should be called in-between calls to `encode_array_datas` to ensure we 
don't accidentally
+    /// keep & encode old data each time.
+    fn reset_for_real_run(&mut self) {
+        self.fbb.reset();
+        self.buffers.clear();
+        self.arrow_data.clear();
+        self.nodes.clear();
+        self.dry_run = false;
+
+        // this helps us avoid completely re-allocating the buffers by just 
creating a new `Self`.
+        // So everything should be allocated correctly now besides arrow_data. 
If we're calling
+        // this after only a dry run, `arrow_data` shouldn't have anything 
written into it, but
+        // we call this after every real run loop, so we still need to clear 
it.
+    }
+
+    fn encode_array_datas(
+        &mut self,
+        arr_datas: &[ArrayData],
+        n_rows: i64,
+        encode_root: impl FnOnce(
+            &mut FlatBufferSizeTracker,
+            WIPOffset<crate::gen::Message::RecordBatch>,
+        ) -> WIPOffset<UnionWIPOffset>,
+        header_type: MessageHeader,
+        write_options: &IpcWriteOptions,
+    ) -> Result<(), ArrowError> {
+        let batch_compression_type = write_options.batch_compression_type;
+
+        let compression = batch_compression_type.map(|compression_type| {
+            let mut builder = BodyCompressionBuilder::new(&mut self.fbb);
+            builder.add_method(BodyCompressionMethod::BUFFER);
+            builder.add_codec(compression_type);
+            builder.finish()
+        });
+
+        let mut variadic_buffer_counts = Vec::<i64>::default();
+        let mut offset = 0;
+
+        for array in arr_datas {
+            self.write_array_data(
+                array,
+                &mut offset,
+                array.len(),
+                array.null_count(),
+                write_options,
             )?;
+
+            append_variadic_buffer_counts(&mut variadic_buffer_counts, array);
         }
-    } else if DataType::is_numeric(data_type)
-        || DataType::is_temporal(data_type)
-        || matches!(
-            array_data.data_type(),
-            DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _)
-        )
-    {
-        // Truncate values
-        assert_eq!(array_data.buffers().len(), 1);
-
-        let buffer = &array_data.buffers()[0];
-        let layout = layout(data_type);
-        let spec = &layout.buffers[0];
-
-        let byte_width = get_buffer_element_width(spec);
-        let min_length = array_data.len() * byte_width;
-        let buffer_slice = if buffer_need_truncate(array_data.offset(), 
buffer, spec, min_length) {
-            let byte_offset = array_data.offset() * byte_width;
-            let buffer_length = min(min_length, buffer.len() - byte_offset);
-            &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)]
-        } else {
-            buffer.as_slice()
+
+        // pad the tail of the body data
+        let pad_len = pad_to_alignment(write_options.alignment, 
self.arrow_data.len());
+        self.arrow_data.extend_from_slice(&PADDING[..pad_len]);
+
+        let buffers = self.fbb.create_vector(&self.buffers);
+        let nodes = self.fbb.create_vector(&self.nodes);
+        let variadic_buffer = (!variadic_buffer_counts.is_empty())
+            .then(|| self.fbb.create_vector(&variadic_buffer_counts));
+
+        let root = {
+            let mut builder = RecordBatchBuilder::new(&mut self.fbb);
+
+            builder.add_length(n_rows);
+            builder.add_nodes(nodes);
+            builder.add_buffers(buffers);
+            if let Some(c) = compression {
+                builder.add_compression(c);
+            }
+            if let Some(v) = variadic_buffer {
+                builder.add_variadicBufferCounts(v);
+            }
+
+            builder.finish()
         };
-        offset = write_buffer(
-            buffer_slice,
-            buffers,
-            arrow_data,
-            offset,
-            compression_codec,
-            write_options.alignment,
-        )?;
-    } else if matches!(data_type, DataType::Boolean) {
-        // Bools are special because the payload (= 1 bit) is smaller than the 
physical container elements (= bytes).
-        // The array data may not start at the physical boundary of the 
underlying buffer, so we need to shift bits around.
-        assert_eq!(array_data.buffers().len(), 1);
-
-        let buffer = &array_data.buffers()[0];
-        let buffer = buffer.bit_slice(array_data.offset(), array_data.len());
-        offset = write_buffer(
-            &buffer,
-            buffers,
-            arrow_data,
-            offset,
-            compression_codec,
-            write_options.alignment,
-        )?;
-    } else if matches!(
-        data_type,
-        DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _)
-    ) {
-        assert_eq!(array_data.buffers().len(), 1);
-        assert_eq!(array_data.child_data().len(), 1);
-
-        // Truncate offsets and the child data to avoid writing unnecessary 
data
-        let (offsets, sliced_child_data) = match data_type {
-            DataType::List(_) => get_list_array_buffers::<i32>(array_data),
-            DataType::Map(_, _) => get_list_array_buffers::<i32>(array_data),
-            DataType::LargeList(_) => 
get_list_array_buffers::<i64>(array_data),
-            _ => unreachable!(),
+
+        let root = encode_root(self, root);
+
+        let arrow_len = self.arrow_data.len() as i64;
+        let msg = {
+            let mut builder = MessageBuilder::new(&mut self.fbb);
+            builder.add_version(write_options.metadata_version);
+            builder.add_header_type(header_type);
+            builder.add_bodyLength(arrow_len);
+
+            builder.add_header(root);
+            builder.finish()
         };
-        offset = write_buffer(
-            offsets.as_slice(),
-            buffers,
-            arrow_data,
-            offset,
-            compression_codec,
-            write_options.alignment,
-        )?;
-        offset = write_array_data(
-            &sliced_child_data,
-            buffers,
-            arrow_data,
-            nodes,
-            offset,
-            sliced_child_data.len(),
-            sliced_child_data.null_count(),
-            compression_codec,
-            write_options,
-        )?;
-        return Ok(offset);
-    } else {
-        for buffer in array_data.buffers() {
-            offset = write_buffer(
-                buffer,
-                buffers,
-                arrow_data,
+
+        self.fbb.finish(msg, None);
+        Ok(())
+    }
+
+    fn write_array_data(
+        &mut self,
+        array_data: &ArrayData,
+        offset: &mut i64,
+        num_rows: usize,
+        null_count: usize,
+        write_options: &IpcWriteOptions,
+    ) -> Result<(), ArrowError> {
+        let compression_codec: Option<CompressionCodec> = write_options
+            .batch_compression_type
+            .map(TryInto::try_into)
+            .transpose()?;
+
+        // NullArray's null_count equals to len, but the `null_count` passed 
in is from ArrayData
+        // where null_count is always 0.
+        self.nodes.push(crate::FieldNode::new(
+            num_rows as i64,
+            match array_data.data_type() {
+                DataType::Null => num_rows,
+                _ => null_count,
+            } as i64,
+        ));
+
+        if has_validity_bitmap(array_data.data_type(), write_options) {
+            // write null buffer if exists
+            let null_buffer = match array_data.nulls() {
+                None => {
+                    let num_bytes = bit_util::ceil(num_rows, 8);
+                    // create a buffer and fill it with valid bits
+                    MutableBuffer::new(num_bytes)
+                        .with_bitset(num_bytes, true)
+                        .into()
+                }
+                Some(buffer) => buffer.inner().sliced(),
+            };
+
+            self.write_buffer(
+                &null_buffer,
                 offset,
                 compression_codec,
                 write_options.alignment,
             )?;
         }
-    }
 
-    match array_data.data_type() {
-        DataType::Dictionary(_, _) => {}
-        DataType::RunEndEncoded(_, _) => {
-            // unslice the run encoded array.
-            let arr = unslice_run_array(array_data.clone())?;
-            // recursively write out nested structures
-            for data_ref in arr.child_data() {
-                // write the nested data (e.g list data)
-                offset = write_array_data(
-                    data_ref,
-                    buffers,
-                    arrow_data,
-                    nodes,
+        let mut write_byte_array_byffers = |(offsets, values): (Buffer, 
Buffer)| {
+            for buffer in [offsets, values] {
+                self.write_buffer(&buffer, offset, compression_codec, 
write_options.alignment)?;
+            }
+            Ok::<_, ArrowError>(())
+        };
+
+        match array_data.data_type() {
+            DataType::Binary | DataType::Utf8 => {
+                
write_byte_array_byffers(get_byte_array_buffers::<i32>(array_data))?
+            }
+            DataType::LargeBinary | DataType::LargeUtf8 => {
+                
write_byte_array_byffers(get_byte_array_buffers::<i64>(array_data))?
+            }
+            dt if DataType::is_numeric(dt)
+                || DataType::is_temporal(dt)
+                || matches!(
+                    dt,
+                    DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _)
+                ) =>
+            {
+                // Truncate values
+                let [buffer] = array_data.buffers() else {
+                    panic!("Temporal, Numeric, FixedSizeBinary, and Dictionary 
data types must contain only one buffer");
+                };
+
+                let layout = layout(dt);
+                let spec = &layout.buffers[0];
+
+                let byte_width = get_buffer_element_width(spec);
+                let min_length = array_data.len() * byte_width;
+                let mut buffer_slice = buffer.as_slice();
+
+                if buffer_need_truncate(array_data.offset(), buffer, spec, 
min_length) {
+                    let byte_offset = array_data.offset() * byte_width;
+                    let buffer_length = min(min_length, buffer.len() - 
byte_offset);
+                    buffer_slice = &buffer_slice[byte_offset..(byte_offset + 
buffer_length)];
+                }
+
+                self.write_buffer(
+                    buffer_slice,
                     offset,
-                    data_ref.len(),
-                    data_ref.null_count(),
                     compression_codec,
+                    write_options.alignment,
+                )?;
+            }
+            DataType::Boolean => {
+                // Bools are special because the payload (= 1 bit) is smaller 
than the physical container elements (= bytes).
+                // The array data may not start at the physical boundary of 
the underlying buffer,
+                // so we need to shift bits around.
+                let [single_buf] = array_data.buffers() else {
+                    panic!("ArrayData of type Boolean should only contain 1 
buffer");
+                };
+
+                let buffer = &single_buf.bit_slice(array_data.offset(), 
array_data.len());
+                self.write_buffer(buffer, offset, compression_codec, 
write_options.alignment)?;
+            }
+            dt @ (DataType::List(_) | DataType::LargeList(_) | 
DataType::Map(_, _)) => {
+                assert_eq!(array_data.buffers().len(), 1);
+                assert_eq!(array_data.child_data().len(), 1);
+
+                // Truncate offsets and the child data to avoid writing 
unnecessary data
+                let (offsets, sliced_child_data) = match dt {
+                    DataType::List(_) | DataType::Map(_, _) => {
+                        get_list_array_buffers::<i32>(array_data)
+                    }
+                    DataType::LargeList(_) => 
get_list_array_buffers::<i64>(array_data),
+                    _ => unreachable!(),
+                };
+                self.write_buffer(&offsets, offset, compression_codec, 
write_options.alignment)?;
+                self.write_array_data(
+                    &sliced_child_data,
+                    offset,
+                    sliced_child_data.len(),
+                    sliced_child_data.null_count(),
                     write_options,
                 )?;
+                return Ok(());
+            }
+            _ => {
+                // This accommodates for even the `View` types (e.g. 
BinaryView and Utf8View):
+                // Slicing the views buffer is safe and easy,
+                // but pruning unneeded data buffers is much more nuanced 
since it's complicated
+                // to prove that no views reference the pruned buffers
+                //
+                // Current implementation just serialize the raw arrays as 
given and not try to optimize anything.
+                // If users wants to "compact" the arrays prior to sending 
them over IPC,
+                // they should consider the gc API suggested in #5513
+                for buffer in array_data.buffers() {
+                    self.write_buffer(buffer, offset, compression_codec, 
write_options.alignment)?;
+                }
             }
         }
-        _ => {
-            // recursively write out nested structures
-            for data_ref in array_data.child_data() {
-                // write the nested data (e.g list data)
-                offset = write_array_data(
+
+        let mut write_arr = |arr: &ArrayData| {
+            for data_ref in arr.child_data() {
+                self.write_array_data(
                     data_ref,
-                    buffers,
-                    arrow_data,
-                    nodes,
                     offset,
                     data_ref.len(),
                     data_ref.null_count(),
-                    compression_codec,
                     write_options,
                 )?;
             }
+            Ok::<_, ArrowError>(())
+        };
+
+        match array_data.data_type() {
+            DataType::Dictionary(_, _) => Ok(()),

Review Comment:
   This is something that I wasn't able to figure out about this encoding 
process - it seems we don't write the `child_data` for dictionaries into the 
encoded message, but that's where all the values of the dictionary are. Without 
this, we only have the keys written. Does anyone know why this is?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to