This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new a60ad5ffc perf: Improve performance of native row-to-columnar
transition used by JVM shuffle (#3289)
a60ad5ffc is described below
commit a60ad5ffcf41a405bb785d0cf0fc9a269d730212
Author: Andy Grove <[email protected]>
AuthorDate: Tue Mar 10 17:43:12 2026 -0600
perf: Improve performance of native row-to-columnar transition used by JVM
shuffle (#3289)
---
native/core/Cargo.toml | 4 +
native/core/benches/array_element_append.rs | 272 +++++++
native/core/src/execution/jni_api.rs | 2 +-
.../src/execution/shuffle/spark_unsafe/list.rs | 305 ++++++--
.../core/src/execution/shuffle/spark_unsafe/mod.rs | 2 +-
.../core/src/execution/shuffle/spark_unsafe/row.rs | 864 ++++++++++++++++++---
6 files changed, 1264 insertions(+), 185 deletions(-)
diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml
index 37be0f282..9c4ec9775 100644
--- a/native/core/Cargo.toml
+++ b/native/core/Cargo.toml
@@ -134,3 +134,7 @@ harness = false
[[bench]]
name = "parquet_decode"
harness = false
+
+[[bench]]
+name = "array_element_append"
+harness = false
diff --git a/native/core/benches/array_element_append.rs
b/native/core/benches/array_element_append.rs
new file mode 100644
index 000000000..2c46f9ba1
--- /dev/null
+++ b/native/core/benches/array_element_append.rs
@@ -0,0 +1,272 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Micro-benchmarks for SparkUnsafeArray element iteration.
+//!
+//! This tests the low-level `append_to_builder` function which converts
+//! SparkUnsafeArray elements to Arrow array builders. This is the inner loop
+//! used when processing List/Array columns in JVM shuffle.
+
+use arrow::array::builder::{
+ Date32Builder, Float64Builder, Int32Builder, Int64Builder,
TimestampMicrosecondBuilder,
+};
+use arrow::datatypes::{DataType, TimeUnit};
+use comet::execution::shuffle::spark_unsafe::list::{append_to_builder,
SparkUnsafeArray};
+use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
+
+const NUM_ELEMENTS: usize = 10000;
+
+/// Create a SparkUnsafeArray in memory with i32 elements.
+/// Layout:
+/// - 8 bytes: num_elements (i64)
+/// - null bitset: 8 bytes per 64 elements
+/// - element data: 4 bytes per element (i32)
+fn create_spark_unsafe_array_i32(num_elements: usize, with_nulls: bool) ->
Vec<u8> {
+ // Header size: 8 (num_elements) + ceil(num_elements/64) * 8 (null bitset)
+ let null_bitset_words = num_elements.div_ceil(64);
+ let header_size = 8 + null_bitset_words * 8;
+ let data_size = num_elements * 4; // i32 = 4 bytes
+ let total_size = header_size + data_size;
+
+ let mut buffer = vec![0u8; total_size];
+
+ // Write num_elements
+ buffer[0..8].copy_from_slice(&(num_elements as i64).to_le_bytes());
+
+ // Write null bitset (set every 10th element as null if with_nulls)
+ if with_nulls {
+ for i in (0..num_elements).step_by(10) {
+ let word_idx = i / 64;
+ let bit_idx = i % 64;
+ let word_offset = 8 + word_idx * 8;
+ let current_word =
+ i64::from_le_bytes(buffer[word_offset..word_offset +
8].try_into().unwrap());
+ let new_word = current_word | (1i64 << bit_idx);
+ buffer[word_offset..word_offset +
8].copy_from_slice(&new_word.to_le_bytes());
+ }
+ }
+
+ // Write element data
+ for i in 0..num_elements {
+ let offset = header_size + i * 4;
+ buffer[offset..offset + 4].copy_from_slice(&(i as i32).to_le_bytes());
+ }
+
+ buffer
+}
+
+/// Create a SparkUnsafeArray in memory with i64 elements.
+fn create_spark_unsafe_array_i64(num_elements: usize, with_nulls: bool) ->
Vec<u8> {
+ let null_bitset_words = num_elements.div_ceil(64);
+ let header_size = 8 + null_bitset_words * 8;
+ let data_size = num_elements * 8; // i64 = 8 bytes
+ let total_size = header_size + data_size;
+
+ let mut buffer = vec![0u8; total_size];
+
+ // Write num_elements
+ buffer[0..8].copy_from_slice(&(num_elements as i64).to_le_bytes());
+
+ // Write null bitset
+ if with_nulls {
+ for i in (0..num_elements).step_by(10) {
+ let word_idx = i / 64;
+ let bit_idx = i % 64;
+ let word_offset = 8 + word_idx * 8;
+ let current_word =
+ i64::from_le_bytes(buffer[word_offset..word_offset +
8].try_into().unwrap());
+ let new_word = current_word | (1i64 << bit_idx);
+ buffer[word_offset..word_offset +
8].copy_from_slice(&new_word.to_le_bytes());
+ }
+ }
+
+ // Write element data
+ for i in 0..num_elements {
+ let offset = header_size + i * 8;
+ buffer[offset..offset + 8].copy_from_slice(&(i as i64).to_le_bytes());
+ }
+
+ buffer
+}
+
+/// Create a SparkUnsafeArray in memory with f64 elements.
+fn create_spark_unsafe_array_f64(num_elements: usize, with_nulls: bool) ->
Vec<u8> {
+ let null_bitset_words = num_elements.div_ceil(64);
+ let header_size = 8 + null_bitset_words * 8;
+ let data_size = num_elements * 8; // f64 = 8 bytes
+ let total_size = header_size + data_size;
+
+ let mut buffer = vec![0u8; total_size];
+
+ // Write num_elements
+ buffer[0..8].copy_from_slice(&(num_elements as i64).to_le_bytes());
+
+ // Write null bitset
+ if with_nulls {
+ for i in (0..num_elements).step_by(10) {
+ let word_idx = i / 64;
+ let bit_idx = i % 64;
+ let word_offset = 8 + word_idx * 8;
+ let current_word =
+ i64::from_le_bytes(buffer[word_offset..word_offset +
8].try_into().unwrap());
+ let new_word = current_word | (1i64 << bit_idx);
+ buffer[word_offset..word_offset +
8].copy_from_slice(&new_word.to_le_bytes());
+ }
+ }
+
+ // Write element data
+ for i in 0..num_elements {
+ let offset = header_size + i * 8;
+ buffer[offset..offset + 8].copy_from_slice(&(i as f64).to_le_bytes());
+ }
+
+ buffer
+}
+
+fn benchmark_array_conversion(c: &mut Criterion) {
+ let mut group = c.benchmark_group("spark_unsafe_array_to_arrow");
+
+ // Benchmark i32 array conversion
+ for with_nulls in [false, true] {
+ let buffer = create_spark_unsafe_array_i32(NUM_ELEMENTS, with_nulls);
+ let array = SparkUnsafeArray::new(buffer.as_ptr() as i64);
+ let null_str = if with_nulls { "with_nulls" } else { "no_nulls" };
+
+ group.bench_with_input(
+ BenchmarkId::new("i32", null_str),
+ &(&array, &buffer),
+ |b, (array, _buffer)| {
+ b.iter(|| {
+ let mut builder =
Int32Builder::with_capacity(NUM_ELEMENTS);
+ if with_nulls {
+ append_to_builder::<true>(&DataType::Int32, &mut
builder, array).unwrap();
+ } else {
+ append_to_builder::<false>(&DataType::Int32, &mut
builder, array).unwrap();
+ }
+ builder.finish()
+ });
+ },
+ );
+ }
+
+ // Benchmark i64 array conversion
+ for with_nulls in [false, true] {
+ let buffer = create_spark_unsafe_array_i64(NUM_ELEMENTS, with_nulls);
+ let array = SparkUnsafeArray::new(buffer.as_ptr() as i64);
+ let null_str = if with_nulls { "with_nulls" } else { "no_nulls" };
+
+ group.bench_with_input(
+ BenchmarkId::new("i64", null_str),
+ &(&array, &buffer),
+ |b, (array, _buffer)| {
+ b.iter(|| {
+ let mut builder =
Int64Builder::with_capacity(NUM_ELEMENTS);
+ if with_nulls {
+ append_to_builder::<true>(&DataType::Int64, &mut
builder, array).unwrap();
+ } else {
+ append_to_builder::<false>(&DataType::Int64, &mut
builder, array).unwrap();
+ }
+ builder.finish()
+ });
+ },
+ );
+ }
+
+ // Benchmark f64 array conversion
+ for with_nulls in [false, true] {
+ let buffer = create_spark_unsafe_array_f64(NUM_ELEMENTS, with_nulls);
+ let array = SparkUnsafeArray::new(buffer.as_ptr() as i64);
+ let null_str = if with_nulls { "with_nulls" } else { "no_nulls" };
+
+ group.bench_with_input(
+ BenchmarkId::new("f64", null_str),
+ &(&array, &buffer),
+ |b, (array, _buffer)| {
+ b.iter(|| {
+ let mut builder =
Float64Builder::with_capacity(NUM_ELEMENTS);
+ if with_nulls {
+ append_to_builder::<true>(&DataType::Float64, &mut
builder, array).unwrap();
+ } else {
+ append_to_builder::<false>(&DataType::Float64, &mut
builder, array)
+ .unwrap();
+ }
+ builder.finish()
+ });
+ },
+ );
+ }
+
+ // Benchmark date32 array conversion (same memory layout as i32)
+ for with_nulls in [false, true] {
+ let buffer = create_spark_unsafe_array_i32(NUM_ELEMENTS, with_nulls);
+ let array = SparkUnsafeArray::new(buffer.as_ptr() as i64);
+ let null_str = if with_nulls { "with_nulls" } else { "no_nulls" };
+
+ group.bench_with_input(
+ BenchmarkId::new("date32", null_str),
+ &(&array, &buffer),
+ |b, (array, _buffer)| {
+ b.iter(|| {
+ let mut builder =
Date32Builder::with_capacity(NUM_ELEMENTS);
+ if with_nulls {
+ append_to_builder::<true>(&DataType::Date32, &mut
builder, array).unwrap();
+ } else {
+ append_to_builder::<false>(&DataType::Date32, &mut
builder, array).unwrap();
+ }
+ builder.finish()
+ });
+ },
+ );
+ }
+
+ // Benchmark timestamp array conversion (same memory layout as i64)
+ for with_nulls in [false, true] {
+ let buffer = create_spark_unsafe_array_i64(NUM_ELEMENTS, with_nulls);
+ let array = SparkUnsafeArray::new(buffer.as_ptr() as i64);
+ let null_str = if with_nulls { "with_nulls" } else { "no_nulls" };
+
+ group.bench_with_input(
+ BenchmarkId::new("timestamp", null_str),
+ &(&array, &buffer),
+ |b, (array, _buffer)| {
+ b.iter(|| {
+ let mut builder =
TimestampMicrosecondBuilder::with_capacity(NUM_ELEMENTS);
+ let dt = DataType::Timestamp(TimeUnit::Microsecond, None);
+ if with_nulls {
+ append_to_builder::<true>(&dt, &mut builder,
array).unwrap();
+ } else {
+ append_to_builder::<false>(&dt, &mut builder,
array).unwrap();
+ }
+ builder.finish()
+ });
+ },
+ );
+ }
+
+ group.finish();
+}
+
+fn config() -> Criterion {
+ Criterion::default()
+}
+
+criterion_group! {
+ name = benches;
+ config = config();
+ targets = benchmark_array_conversion
+}
+criterion_main!(benches);
diff --git a/native/core/src/execution/jni_api.rs
b/native/core/src/execution/jni_api.rs
index 4d039c5c2..361deae18 100644
--- a/native/core/src/execution/jni_api.rs
+++ b/native/core/src/execution/jni_api.rs
@@ -547,7 +547,6 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
let physical_plan_time = start.elapsed();
exec_context.plan_creation_time += physical_plan_time;
- exec_context.root_op = Some(Arc::clone(&root_op));
exec_context.scans = scans;
if exec_context.explain_native {
@@ -602,6 +601,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
} else {
exec_context.stream = Some(stream);
}
+ exec_context.root_op = Some(root_op);
} else {
// Pull input batches
pull_input_batches(exec_context)?;
diff --git a/native/core/src/execution/shuffle/spark_unsafe/list.rs
b/native/core/src/execution/shuffle/spark_unsafe/list.rs
index 9e58c71d3..72610d2d8 100644
--- a/native/core/src/execution/shuffle/spark_unsafe/list.rs
+++ b/native/core/src/execution/shuffle/spark_unsafe/list.rs
@@ -32,6 +32,59 @@ use arrow::array::{
};
use arrow::datatypes::{DataType, TimeUnit};
+/// Generates bulk append methods for primitive types in SparkUnsafeArray.
+///
+/// # Safety invariants for all generated methods:
+/// - `element_offset` points to contiguous element data of length
`num_elements`
+/// - `null_bitset_ptr()` returns a pointer to `ceil(num_elements/64)` i64
words
+/// - These invariants are guaranteed by the SparkUnsafeArray layout from the
JVM
+macro_rules! impl_append_to_builder {
+ ($method_name:ident, $builder_type:ty, $element_type:ty) => {
+ pub(crate) fn $method_name<const NULLABLE: bool>(&self, builder: &mut
$builder_type) {
+ let num_elements = self.num_elements;
+ if num_elements == 0 {
+ return;
+ }
+
+ if NULLABLE {
+ let mut ptr = self.element_offset as *const $element_type;
+ let null_words = self.null_bitset_ptr();
+ debug_assert!(!null_words.is_null(), "null_bitset_ptr is
null");
+ debug_assert!(!ptr.is_null(), "element_offset pointer is
null");
+ for idx in 0..num_elements {
+ // SAFETY: null_words has ceil(num_elements/64) words, idx
< num_elements
+ let is_null = unsafe { Self::is_null_in_bitset(null_words,
idx) };
+
+ if is_null {
+ builder.append_null();
+ } else {
+ // SAFETY: ptr is within element data bounds
+ builder.append_value(unsafe { ptr.read_unaligned() });
+ }
+ // SAFETY: ptr stays within bounds, iterating num_elements
times
+ ptr = unsafe { ptr.add(1) };
+ }
+ } else {
+ // SAFETY: element_offset points to contiguous data of length
num_elements
+ debug_assert!(self.element_offset != 0, "element_offset is
null");
+ let ptr = self.element_offset as *const $element_type;
+ // Use bulk copy when data is properly aligned, fall back to
+ // per-element unaligned reads otherwise
+ if (ptr as
usize).is_multiple_of(std::mem::align_of::<$element_type>()) {
+ let slice = unsafe { std::slice::from_raw_parts(ptr,
num_elements) };
+ builder.append_slice(slice);
+ } else {
+ let mut ptr = ptr;
+ for _ in 0..num_elements {
+ builder.append_value(unsafe { ptr.read_unaligned() });
+ ptr = unsafe { ptr.add(1) };
+ }
+ }
+ }
+ }
+ };
+}
+
pub struct SparkUnsafeArray {
row_addr: i64,
num_elements: usize,
@@ -39,10 +92,12 @@ pub struct SparkUnsafeArray {
}
impl SparkUnsafeObject for SparkUnsafeArray {
+ #[inline]
fn get_row_addr(&self) -> i64 {
self.row_addr
}
+ #[inline]
fn get_element_offset(&self, index: usize, element_size: usize) -> *const
u8 {
(self.element_offset + (index * element_size) as i64) as *const u8
}
@@ -100,6 +155,183 @@ impl SparkUnsafeArray {
(word & mask) != 0
}
}
+
+ /// Returns the null bitset pointer (starts at row_addr + 8).
+ #[inline]
+ fn null_bitset_ptr(&self) -> *const i64 {
+ (self.row_addr + 8) as *const i64
+ }
+
+ /// Checks whether the null bit at `idx` is set in the given null bitset
pointer.
+ ///
+ /// # Safety
+ /// `null_words` must point to at least `ceil((idx+1)/64)` i64 words.
+ #[inline]
+ unsafe fn is_null_in_bitset(null_words: *const i64, idx: usize) -> bool {
+ let word_idx = idx >> 6;
+ let bit_idx = idx & 0x3f;
+ (null_words.add(word_idx).read_unaligned() & (1i64 << bit_idx)) != 0
+ }
+
+ impl_append_to_builder!(append_ints_to_builder, Int32Builder, i32);
+ impl_append_to_builder!(append_longs_to_builder, Int64Builder, i64);
+ impl_append_to_builder!(append_shorts_to_builder, Int16Builder, i16);
+ impl_append_to_builder!(append_bytes_to_builder, Int8Builder, i8);
+ impl_append_to_builder!(append_floats_to_builder, Float32Builder, f32);
+ impl_append_to_builder!(append_doubles_to_builder, Float64Builder, f64);
+
+ /// Bulk append boolean values to builder.
+ /// Booleans are stored as 1 byte each in SparkUnsafeArray, requiring
special handling.
+ pub(crate) fn append_booleans_to_builder<const NULLABLE: bool>(
+ &self,
+ builder: &mut BooleanBuilder,
+ ) {
+ let num_elements = self.num_elements;
+ if num_elements == 0 {
+ return;
+ }
+
+ let mut ptr = self.element_offset as *const u8;
+ debug_assert!(
+ !ptr.is_null(),
+ "append_booleans: element_offset pointer is null"
+ );
+
+ if NULLABLE {
+ let null_words = self.null_bitset_ptr();
+ debug_assert!(
+ !null_words.is_null(),
+ "append_booleans: null_bitset_ptr is null"
+ );
+ for idx in 0..num_elements {
+ // SAFETY: null_words has ceil(num_elements/64) words, idx <
num_elements
+ let is_null = unsafe { Self::is_null_in_bitset(null_words,
idx) };
+
+ if is_null {
+ builder.append_null();
+ } else {
+ // SAFETY: ptr is within element data bounds
+ builder.append_value(unsafe { *ptr != 0 });
+ }
+ // SAFETY: ptr stays within bounds, iterating num_elements
times
+ ptr = unsafe { ptr.add(1) };
+ }
+ } else {
+ for _ in 0..num_elements {
+ // SAFETY: ptr is within element data bounds
+ builder.append_value(unsafe { *ptr != 0 });
+ ptr = unsafe { ptr.add(1) };
+ }
+ }
+ }
+
+ /// Bulk append timestamp values to builder (stored as i64 microseconds).
+ pub(crate) fn append_timestamps_to_builder<const NULLABLE: bool>(
+ &self,
+ builder: &mut TimestampMicrosecondBuilder,
+ ) {
+ let num_elements = self.num_elements;
+ if num_elements == 0 {
+ return;
+ }
+
+ if NULLABLE {
+ let mut ptr = self.element_offset as *const i64;
+ let null_words = self.null_bitset_ptr();
+ debug_assert!(
+ !null_words.is_null(),
+ "append_timestamps: null_bitset_ptr is null"
+ );
+ debug_assert!(
+ !ptr.is_null(),
+ "append_timestamps: element_offset pointer is null"
+ );
+ for idx in 0..num_elements {
+ // SAFETY: null_words has ceil(num_elements/64) words, idx <
num_elements
+ let is_null = unsafe { Self::is_null_in_bitset(null_words,
idx) };
+
+ if is_null {
+ builder.append_null();
+ } else {
+ // SAFETY: ptr is within element data bounds
+ builder.append_value(unsafe { ptr.read_unaligned() });
+ }
+ // SAFETY: ptr stays within bounds, iterating num_elements
times
+ ptr = unsafe { ptr.add(1) };
+ }
+ } else {
+ // SAFETY: element_offset points to contiguous i64 data of length
num_elements
+ debug_assert!(
+ self.element_offset != 0,
+ "append_timestamps: element_offset is null"
+ );
+ let ptr = self.element_offset as *const i64;
+ if (ptr as usize).is_multiple_of(std::mem::align_of::<i64>()) {
+ let slice = unsafe { std::slice::from_raw_parts(ptr,
num_elements) };
+ builder.append_slice(slice);
+ } else {
+ let mut ptr = ptr;
+ for _ in 0..num_elements {
+ builder.append_value(unsafe { ptr.read_unaligned() });
+ ptr = unsafe { ptr.add(1) };
+ }
+ }
+ }
+ }
+
+ /// Bulk append date values to builder (stored as i32 days since epoch).
+ pub(crate) fn append_dates_to_builder<const NULLABLE: bool>(
+ &self,
+ builder: &mut Date32Builder,
+ ) {
+ let num_elements = self.num_elements;
+ if num_elements == 0 {
+ return;
+ }
+
+ if NULLABLE {
+ let mut ptr = self.element_offset as *const i32;
+ let null_words = self.null_bitset_ptr();
+ debug_assert!(
+ !null_words.is_null(),
+ "append_dates: null_bitset_ptr is null"
+ );
+ debug_assert!(
+ !ptr.is_null(),
+ "append_dates: element_offset pointer is null"
+ );
+ for idx in 0..num_elements {
+ // SAFETY: null_words has ceil(num_elements/64) words, idx <
num_elements
+ let is_null = unsafe { Self::is_null_in_bitset(null_words,
idx) };
+
+ if is_null {
+ builder.append_null();
+ } else {
+ // SAFETY: ptr is within element data bounds
+ builder.append_value(unsafe { ptr.read_unaligned() });
+ }
+ // SAFETY: ptr stays within bounds, iterating num_elements
times
+ ptr = unsafe { ptr.add(1) };
+ }
+ } else {
+ // SAFETY: element_offset points to contiguous i32 data of length
num_elements
+ debug_assert!(
+ self.element_offset != 0,
+ "append_dates: element_offset is null"
+ );
+ let ptr = self.element_offset as *const i32;
+ if (ptr as usize).is_multiple_of(std::mem::align_of::<i32>()) {
+ let slice = unsafe { std::slice::from_raw_parts(ptr,
num_elements) };
+ builder.append_slice(slice);
+ } else {
+ let mut ptr = ptr;
+ for _ in 0..num_elements {
+ builder.append_value(unsafe { ptr.read_unaligned() });
+ ptr = unsafe { ptr.add(1) };
+ }
+ }
+ }
+ }
}
pub fn append_to_builder<const NULLABLE: bool>(
@@ -122,77 +354,40 @@ pub fn append_to_builder<const NULLABLE: bool>(
match data_type {
DataType::Boolean => {
- add_values!(
- BooleanBuilder,
- |builder: &mut BooleanBuilder, values: &SparkUnsafeArray, idx:
usize| builder
- .append_value(values.get_boolean(idx)),
- |builder: &mut BooleanBuilder| builder.append_null()
- );
+ let builder = downcast_builder_ref!(BooleanBuilder, builder);
+ array.append_booleans_to_builder::<NULLABLE>(builder);
}
DataType::Int8 => {
- add_values!(
- Int8Builder,
- |builder: &mut Int8Builder, values: &SparkUnsafeArray, idx:
usize| builder
- .append_value(values.get_byte(idx)),
- |builder: &mut Int8Builder| builder.append_null()
- );
+ let builder = downcast_builder_ref!(Int8Builder, builder);
+ array.append_bytes_to_builder::<NULLABLE>(builder);
}
DataType::Int16 => {
- add_values!(
- Int16Builder,
- |builder: &mut Int16Builder, values: &SparkUnsafeArray, idx:
usize| builder
- .append_value(values.get_short(idx)),
- |builder: &mut Int16Builder| builder.append_null()
- );
+ let builder = downcast_builder_ref!(Int16Builder, builder);
+ array.append_shorts_to_builder::<NULLABLE>(builder);
}
DataType::Int32 => {
- add_values!(
- Int32Builder,
- |builder: &mut Int32Builder, values: &SparkUnsafeArray, idx:
usize| builder
- .append_value(values.get_int(idx)),
- |builder: &mut Int32Builder| builder.append_null()
- );
+ let builder = downcast_builder_ref!(Int32Builder, builder);
+ array.append_ints_to_builder::<NULLABLE>(builder);
}
DataType::Int64 => {
- add_values!(
- Int64Builder,
- |builder: &mut Int64Builder, values: &SparkUnsafeArray, idx:
usize| builder
- .append_value(values.get_long(idx)),
- |builder: &mut Int64Builder| builder.append_null()
- );
+ let builder = downcast_builder_ref!(Int64Builder, builder);
+ array.append_longs_to_builder::<NULLABLE>(builder);
}
DataType::Float32 => {
- add_values!(
- Float32Builder,
- |builder: &mut Float32Builder, values: &SparkUnsafeArray, idx:
usize| builder
- .append_value(values.get_float(idx)),
- |builder: &mut Float32Builder| builder.append_null()
- );
+ let builder = downcast_builder_ref!(Float32Builder, builder);
+ array.append_floats_to_builder::<NULLABLE>(builder);
}
DataType::Float64 => {
- add_values!(
- Float64Builder,
- |builder: &mut Float64Builder, values: &SparkUnsafeArray, idx:
usize| builder
- .append_value(values.get_double(idx)),
- |builder: &mut Float64Builder| builder.append_null()
- );
+ let builder = downcast_builder_ref!(Float64Builder, builder);
+ array.append_doubles_to_builder::<NULLABLE>(builder);
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
- add_values!(
- TimestampMicrosecondBuilder,
- |builder: &mut TimestampMicrosecondBuilder,
- values: &SparkUnsafeArray,
- idx: usize| builder.append_value(values.get_timestamp(idx)),
- |builder: &mut TimestampMicrosecondBuilder|
builder.append_null()
- );
+ let builder = downcast_builder_ref!(TimestampMicrosecondBuilder,
builder);
+ array.append_timestamps_to_builder::<NULLABLE>(builder);
}
DataType::Date32 => {
- add_values!(
- Date32Builder,
- |builder: &mut Date32Builder, values: &SparkUnsafeArray, idx:
usize| builder
- .append_value(values.get_date(idx)),
- |builder: &mut Date32Builder| builder.append_null()
- );
+ let builder = downcast_builder_ref!(Date32Builder, builder);
+ array.append_dates_to_builder::<NULLABLE>(builder);
}
DataType::Binary => {
add_values!(
diff --git a/native/core/src/execution/shuffle/spark_unsafe/mod.rs
b/native/core/src/execution/shuffle/spark_unsafe/mod.rs
index b052df29b..6390a0f23 100644
--- a/native/core/src/execution/shuffle/spark_unsafe/mod.rs
+++ b/native/core/src/execution/shuffle/spark_unsafe/mod.rs
@@ -15,6 +15,6 @@
// specific language governing permissions and limitations
// under the License.
-mod list;
+pub mod list;
mod map;
pub mod row;
diff --git a/native/core/src/execution/shuffle/spark_unsafe/row.rs
b/native/core/src/execution/shuffle/spark_unsafe/row.rs
index 7962caace..6b41afae8 100644
--- a/native/core/src/execution/shuffle/spark_unsafe/row.rs
+++ b/native/core/src/execution/shuffle/spark_unsafe/row.rs
@@ -88,6 +88,7 @@ pub trait SparkUnsafeObject {
}
/// Returns boolean value at the given index of the object.
+ #[inline]
fn get_boolean(&self, index: usize) -> bool {
let addr = self.get_element_offset(index, 1);
// SAFETY: addr points to valid element data within the
UnsafeRow/UnsafeArray region.
@@ -100,6 +101,7 @@ pub trait SparkUnsafeObject {
}
/// Returns byte value at the given index of the object.
+ #[inline]
fn get_byte(&self, index: usize) -> i8 {
let addr = self.get_element_offset(index, 1);
// SAFETY: addr points to valid element data (1 byte) within the
row/array region.
@@ -109,6 +111,7 @@ pub trait SparkUnsafeObject {
}
/// Returns short value at the given index of the object.
+ #[inline]
fn get_short(&self, index: usize) -> i16 {
let addr = self.get_element_offset(index, 2);
// SAFETY: addr points to valid element data (2 bytes) within the
row/array region.
@@ -118,6 +121,7 @@ pub trait SparkUnsafeObject {
}
/// Returns integer value at the given index of the object.
+ #[inline]
fn get_int(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the
row/array region.
@@ -127,6 +131,7 @@ pub trait SparkUnsafeObject {
}
/// Returns long value at the given index of the object.
+ #[inline]
fn get_long(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the
row/array region.
@@ -136,6 +141,7 @@ pub trait SparkUnsafeObject {
}
/// Returns float value at the given index of the object.
+ #[inline]
fn get_float(&self, index: usize) -> f32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the
row/array region.
@@ -145,6 +151,7 @@ pub trait SparkUnsafeObject {
}
/// Returns double value at the given index of the object.
+ #[inline]
fn get_double(&self, index: usize) -> f64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the
row/array region.
@@ -184,6 +191,7 @@ pub trait SparkUnsafeObject {
}
/// Returns date value at the given index of the object.
+ #[inline]
fn get_date(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the
row/array region.
@@ -193,6 +201,7 @@ pub trait SparkUnsafeObject {
}
/// Returns timestamp value at the given index of the object.
+ #[inline]
fn get_timestamp(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the
row/array region.
@@ -325,6 +334,7 @@ impl SparkUnsafeRow {
pub fn set_not_null_at(&mut self, index: usize) {
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures
index < num_fields.
+ // word_offset is within the bitset region since (index >> 6) << 3 <
bitset size.
// Writing is safe because we have mutable access and the memory is
owned by the JVM.
debug_assert!(self.row_addr != -1, "set_not_null_at: row not
initialized");
unsafe {
@@ -337,11 +347,32 @@ impl SparkUnsafeRow {
}
macro_rules! downcast_builder_ref {
- ($builder_type:ty, $builder:expr) => {
+ ($builder_type:ty, $builder:expr) => {{
+ let actual_type_id = $builder.as_any().type_id();
$builder
.as_any_mut()
.downcast_mut::<$builder_type>()
- .expect(stringify!($builder_type))
+ .ok_or_else(|| {
+ CometError::Internal(format!(
+ "Failed to downcast builder: expected {}, got {:?}",
+ stringify!($builder_type),
+ actual_type_id
+ ))
+ })?
+ }};
+}
+
+macro_rules! get_field_builder {
+ ($struct_builder:expr, $builder_type:ty, $idx:expr) => {
+ $struct_builder
+ .field_builder::<$builder_type>($idx)
+ .ok_or_else(|| {
+ CometError::Internal(format!(
+ "Failed to get field builder at index {}: expected {}",
+ $idx,
+ stringify!($builder_type)
+ ))
+ })?
};
}
@@ -364,7 +395,7 @@ pub(super) fn append_field(
/// A macro for generating code of appending value into field builder of
Arrow struct builder.
macro_rules! append_field_to_builder {
($builder_type:ty, $accessor:expr) => {{
- let field_builder =
struct_builder.field_builder::<$builder_type>(idx).unwrap();
+ let field_builder = get_field_builder!(struct_builder,
$builder_type, idx);
if row.is_null_row() {
// The row is null.
@@ -437,7 +468,7 @@ pub(super) fn append_field(
}
DataType::Struct(fields) => {
// Appending value into struct field builder of Arrow struct
builder.
- let field_builder =
struct_builder.field_builder::<StructBuilder>(idx).unwrap();
+ let field_builder = get_field_builder!(struct_builder,
StructBuilder, idx);
let nested_row = if row.is_null_row() || row.is_null_at(idx) {
// The row is null, or the field in the row is null, i.e., a
null nested row.
@@ -454,9 +485,11 @@ pub(super) fn append_field(
}
}
DataType::Map(field, _) => {
- let field_builder = struct_builder
- .field_builder::<MapBuilder<Box<dyn ArrayBuilder>, Box<dyn
ArrayBuilder>>>(idx)
- .unwrap();
+ let field_builder = get_field_builder!(
+ struct_builder,
+ MapBuilder<Box<dyn ArrayBuilder>, Box<dyn ArrayBuilder>>,
+ idx
+ );
if row.is_null_row() {
// The row is null.
@@ -474,9 +507,8 @@ pub(super) fn append_field(
}
}
DataType::List(field) => {
- let field_builder = struct_builder
- .field_builder::<ListBuilder<Box<dyn ArrayBuilder>>>(idx)
- .unwrap();
+ let field_builder =
+ get_field_builder!(struct_builder, ListBuilder<Box<dyn
ArrayBuilder>>, idx);
if row.is_null_row() {
// The row is null.
@@ -501,7 +533,667 @@ pub(super) fn append_field(
Ok(())
}
+/// Appends nested struct fields to the struct builder using field-major order.
+/// This is a helper function for processing nested struct fields recursively.
+///
+/// Unlike `append_struct_fields_field_major`, this function takes slices of
row addresses,
+/// sizes, and null flags directly, without needing to navigate from a parent
row.
+#[allow(clippy::redundant_closure_call)]
+fn append_nested_struct_fields_field_major(
+ row_addresses: &[jlong],
+ row_sizes: &[jint],
+ struct_is_null: &[bool],
+ struct_builder: &mut StructBuilder,
+ fields: &arrow::datatypes::Fields,
+) -> Result<(), CometError> {
+ let num_rows = row_addresses.len();
+ let mut row = SparkUnsafeRow::new_with_num_fields(fields.len());
+
+ // Helper macro for processing primitive fields
+ macro_rules! process_field {
+ ($builder_type:ty, $field_idx:expr, $get_value:expr) => {{
+ let field_builder = get_field_builder!(struct_builder,
$builder_type, $field_idx);
+
+ for row_idx in 0..num_rows {
+ if struct_is_null[row_idx] {
+ // Struct is null, field is also null
+ field_builder.append_null();
+ } else {
+ let row_addr = row_addresses[row_idx];
+ let row_size = row_sizes[row_idx];
+ row.point_to(row_addr, row_size);
+
+ if row.is_null_at($field_idx) {
+ field_builder.append_null();
+ } else {
+ field_builder.append_value($get_value(&row,
$field_idx));
+ }
+ }
+ }
+ }};
+ }
+
+ // Process each field across all rows
+ for (field_idx, field) in fields.iter().enumerate() {
+ match field.data_type() {
+ DataType::Boolean => {
+ process_field!(BooleanBuilder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_boolean(idx));
+ }
+ DataType::Int8 => {
+ process_field!(Int8Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_byte(idx));
+ }
+ DataType::Int16 => {
+ process_field!(Int16Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_short(idx));
+ }
+ DataType::Int32 => {
+ process_field!(Int32Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_int(idx));
+ }
+ DataType::Int64 => {
+ process_field!(Int64Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_long(idx));
+ }
+ DataType::Float32 => {
+ process_field!(Float32Builder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_float(idx));
+ }
+ DataType::Float64 => {
+ process_field!(Float64Builder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_double(idx));
+ }
+ DataType::Date32 => {
+ process_field!(Date32Builder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_date(idx));
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ process_field!(
+ TimestampMicrosecondBuilder,
+ field_idx,
+ |row: &SparkUnsafeRow, idx| row.get_timestamp(idx)
+ );
+ }
+ DataType::Binary => {
+ let field_builder = get_field_builder!(struct_builder,
BinaryBuilder, field_idx);
+
+ for row_idx in 0..num_rows {
+ if struct_is_null[row_idx] {
+ field_builder.append_null();
+ } else {
+ let row_addr = row_addresses[row_idx];
+ let row_size = row_sizes[row_idx];
+ row.point_to(row_addr, row_size);
+
+ if row.is_null_at(field_idx) {
+ field_builder.append_null();
+ } else {
+
field_builder.append_value(row.get_binary(field_idx));
+ }
+ }
+ }
+ }
+ DataType::Utf8 => {
+ let field_builder = get_field_builder!(struct_builder,
StringBuilder, field_idx);
+
+ for row_idx in 0..num_rows {
+ if struct_is_null[row_idx] {
+ field_builder.append_null();
+ } else {
+ let row_addr = row_addresses[row_idx];
+ let row_size = row_sizes[row_idx];
+ row.point_to(row_addr, row_size);
+
+ if row.is_null_at(field_idx) {
+ field_builder.append_null();
+ } else {
+
field_builder.append_value(row.get_string(field_idx));
+ }
+ }
+ }
+ }
+ DataType::Decimal128(p, _) => {
+ let p = *p;
+ let field_builder =
+ get_field_builder!(struct_builder, Decimal128Builder,
field_idx);
+
+ for row_idx in 0..num_rows {
+ if struct_is_null[row_idx] {
+ field_builder.append_null();
+ } else {
+ let row_addr = row_addresses[row_idx];
+ let row_size = row_sizes[row_idx];
+ row.point_to(row_addr, row_size);
+
+ if row.is_null_at(field_idx) {
+ field_builder.append_null();
+ } else {
+
field_builder.append_value(row.get_decimal(field_idx, p));
+ }
+ }
+ }
+ }
+ DataType::Struct(nested_fields) => {
+ let nested_builder = get_field_builder!(struct_builder,
StructBuilder, field_idx);
+
+ // Collect nested struct addresses and sizes in one pass,
building validity
+ let mut nested_addresses: Vec<jlong> =
Vec::with_capacity(num_rows);
+ let mut nested_sizes: Vec<jint> = Vec::with_capacity(num_rows);
+ let mut nested_is_null: Vec<bool> =
Vec::with_capacity(num_rows);
+
+ for row_idx in 0..num_rows {
+ if struct_is_null[row_idx] {
+ // Parent struct is null, nested struct is also null
+ nested_builder.append_null();
+ nested_is_null.push(true);
+ nested_addresses.push(0);
+ nested_sizes.push(0);
+ } else {
+ let row_addr = row_addresses[row_idx];
+ let row_size = row_sizes[row_idx];
+ row.point_to(row_addr, row_size);
+
+ if row.is_null_at(field_idx) {
+ nested_builder.append_null();
+ nested_is_null.push(true);
+ nested_addresses.push(0);
+ nested_sizes.push(0);
+ } else {
+ nested_builder.append(true);
+ nested_is_null.push(false);
+ // Get nested struct address and size
+ let nested_row = row.get_struct(field_idx,
nested_fields.len());
+ nested_addresses.push(nested_row.get_row_addr());
+ nested_sizes.push(nested_row.get_row_size());
+ }
+ }
+ }
+
+ // Recursively process nested struct fields in field-major
order
+ append_nested_struct_fields_field_major(
+ &nested_addresses,
+ &nested_sizes,
+ &nested_is_null,
+ nested_builder,
+ nested_fields,
+ )?;
+ }
+ // For list and map, fall back to append_field since they have
variable-length elements
+ dt @ (DataType::List(_) | DataType::Map(_, _)) => {
+ for row_idx in 0..num_rows {
+ if struct_is_null[row_idx] {
+ let null_row = SparkUnsafeRow::default();
+ append_field(dt, struct_builder, &null_row,
field_idx)?;
+ } else {
+ let row_addr = row_addresses[row_idx];
+ let row_size = row_sizes[row_idx];
+ row.point_to(row_addr, row_size);
+ append_field(dt, struct_builder, &row, field_idx)?;
+ }
+ }
+ }
+ _ => {
+ unreachable!(
+ "Unsupported data type of struct field: {:?}",
+ field.data_type()
+ )
+ }
+ }
+ }
+
+ Ok(())
+}
+
+/// Reads row address and size from JVM-provided pointer arrays and points the
row to that data.
+///
+/// # Safety
+/// Caller must ensure row_addresses_ptr and row_sizes_ptr are valid for index
i.
+/// This is guaranteed when called from append_columns with indices in
[row_start, row_end).
+macro_rules! read_row_at {
+ ($row:expr, $row_addresses_ptr:expr, $row_sizes_ptr:expr, $i:expr) => {{
+ // SAFETY: Caller guarantees pointers are valid for this index (see
macro doc)
+ debug_assert!(
+ !$row_addresses_ptr.is_null(),
+ "read_row_at: null row_addresses_ptr"
+ );
+ debug_assert!(!$row_sizes_ptr.is_null(), "read_row_at: null
row_sizes_ptr");
+ let row_addr = unsafe { *$row_addresses_ptr.add($i) };
+ let row_size = unsafe { *$row_sizes_ptr.add($i) };
+ $row.point_to(row_addr, row_size);
+ }};
+}
+
+/// Appends a batch of list values to the list builder with a single type
dispatch.
+/// This moves type dispatch from O(rows) to O(1), significantly improving
performance
+/// for large batches.
+#[allow(clippy::too_many_arguments)]
+fn append_list_column_batch(
+ row_addresses_ptr: *mut jlong,
+ row_sizes_ptr: *mut jint,
+ row_start: usize,
+ row_end: usize,
+ schema: &[DataType],
+ column_idx: usize,
+ element_type: &DataType,
+ list_builder: &mut ListBuilder<Box<dyn ArrayBuilder>>,
+) -> Result<(), CometError> {
+ let mut row = SparkUnsafeRow::new(schema);
+
+ // Helper macro for primitive element types - gets builder fresh each
iteration
+ // to avoid borrow conflicts with list_builder.append()
+ macro_rules! process_primitive_lists {
+ ($builder_type:ty, $append_fn:ident) => {{
+ for i in row_start..row_end {
+ read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i);
+
+ if row.is_null_at(column_idx) {
+ list_builder.append_null();
+ } else {
+ let array = row.get_array(column_idx);
+ // Get values builder fresh each iteration to avoid borrow
conflict
+ let values_builder = list_builder
+ .values()
+ .as_any_mut()
+ .downcast_mut::<$builder_type>()
+ .expect(stringify!($builder_type));
+ array.$append_fn::<true>(values_builder);
+ list_builder.append(true);
+ }
+ }
+ }};
+ }
+
+ match element_type {
+ DataType::Boolean => {
+ process_primitive_lists!(BooleanBuilder,
append_booleans_to_builder);
+ }
+ DataType::Int8 => {
+ process_primitive_lists!(Int8Builder, append_bytes_to_builder);
+ }
+ DataType::Int16 => {
+ process_primitive_lists!(Int16Builder, append_shorts_to_builder);
+ }
+ DataType::Int32 => {
+ process_primitive_lists!(Int32Builder, append_ints_to_builder);
+ }
+ DataType::Int64 => {
+ process_primitive_lists!(Int64Builder, append_longs_to_builder);
+ }
+ DataType::Float32 => {
+ process_primitive_lists!(Float32Builder, append_floats_to_builder);
+ }
+ DataType::Float64 => {
+ process_primitive_lists!(Float64Builder,
append_doubles_to_builder);
+ }
+ DataType::Date32 => {
+ process_primitive_lists!(Date32Builder, append_dates_to_builder);
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ process_primitive_lists!(TimestampMicrosecondBuilder,
append_timestamps_to_builder);
+ }
+ // For complex element types, fall back to per-row dispatch
+ _ => {
+ for i in row_start..row_end {
+ read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i);
+
+ if row.is_null_at(column_idx) {
+ list_builder.append_null();
+ } else {
+ append_list_element(element_type, list_builder,
&row.get_array(column_idx))?;
+ }
+ }
+ }
+ }
+
+ Ok(())
+}
+
+/// Appends a batch of map values to the map builder with a single type
dispatch.
+/// This moves type dispatch from O(rows × 2) to O(2), improving performance
for maps.
+#[allow(clippy::too_many_arguments)]
+fn append_map_column_batch(
+ row_addresses_ptr: *mut jlong,
+ row_sizes_ptr: *mut jint,
+ row_start: usize,
+ row_end: usize,
+ schema: &[DataType],
+ column_idx: usize,
+ field: &arrow::datatypes::FieldRef,
+ map_builder: &mut MapBuilder<Box<dyn ArrayBuilder>, Box<dyn ArrayBuilder>>,
+) -> Result<(), CometError> {
+ let mut row = SparkUnsafeRow::new(schema);
+ let (key_field, value_field, _) = get_map_key_value_fields(field)?;
+ let key_type = key_field.data_type();
+ let value_type = value_field.data_type();
+
+ // Helper macro for processing maps with primitive key/value types
+ // Uses scoped borrows to avoid borrow checker conflicts
+ macro_rules! process_primitive_maps {
+ ($key_builder:ty, $key_append:ident, $val_builder:ty,
$val_append:ident) => {{
+ for i in row_start..row_end {
+ read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i);
+
+ if row.is_null_at(column_idx) {
+ map_builder.append(false)?;
+ } else {
+ let map = row.get_map(column_idx);
+ // Process keys in a scope so borrow ends
+ {
+ let keys_builder = map_builder
+ .keys()
+ .as_any_mut()
+ .downcast_mut::<$key_builder>()
+ .expect(stringify!($key_builder));
+ map.keys.$key_append::<false>(keys_builder);
+ }
+ // Process values in a scope so borrow ends
+ {
+ let values_builder = map_builder
+ .values()
+ .as_any_mut()
+ .downcast_mut::<$val_builder>()
+ .expect(stringify!($val_builder));
+ map.values.$val_append::<true>(values_builder);
+ }
+ map_builder.append(true)?;
+ }
+ }
+ }};
+ }
+
+ // Optimize common map type combinations
+ match (key_type, value_type) {
+ // Map<Int64, Int64>
+ (DataType::Int64, DataType::Int64) => {
+ process_primitive_maps!(
+ Int64Builder,
+ append_longs_to_builder,
+ Int64Builder,
+ append_longs_to_builder
+ );
+ }
+ // Map<Int64, Float64>
+ (DataType::Int64, DataType::Float64) => {
+ process_primitive_maps!(
+ Int64Builder,
+ append_longs_to_builder,
+ Float64Builder,
+ append_doubles_to_builder
+ );
+ }
+ // Map<Int32, Int32>
+ (DataType::Int32, DataType::Int32) => {
+ process_primitive_maps!(
+ Int32Builder,
+ append_ints_to_builder,
+ Int32Builder,
+ append_ints_to_builder
+ );
+ }
+ // Map<Int32, Int64>
+ (DataType::Int32, DataType::Int64) => {
+ process_primitive_maps!(
+ Int32Builder,
+ append_ints_to_builder,
+ Int64Builder,
+ append_longs_to_builder
+ );
+ }
+ // For other types, fall back to per-row dispatch
+ _ => {
+ for i in row_start..row_end {
+ read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i);
+
+ if row.is_null_at(column_idx) {
+ map_builder.append(false)?;
+ } else {
+ append_map_elements(field, map_builder,
&row.get_map(column_idx))?;
+ }
+ }
+ }
+ }
+
+ Ok(())
+}
+
+/// Appends struct fields to the struct builder using field-major order.
+/// This processes one field at a time across all rows, which moves type
dispatch
+/// outside the row loop (O(fields) dispatches instead of O(rows × fields)).
+#[allow(clippy::redundant_closure_call, clippy::too_many_arguments)]
+fn append_struct_fields_field_major(
+ row_addresses_ptr: *mut jlong,
+ row_sizes_ptr: *mut jint,
+ row_start: usize,
+ row_end: usize,
+ parent_row: &mut SparkUnsafeRow,
+ column_idx: usize,
+ struct_builder: &mut StructBuilder,
+ fields: &arrow::datatypes::Fields,
+) -> Result<(), CometError> {
+ let num_rows = row_end - row_start;
+ let num_fields = fields.len();
+
+ // First pass: Build struct validity and collect which structs are null
+ // We use a Vec<bool> for simplicity; could use a bitset for better memory
+ let mut struct_is_null = Vec::with_capacity(num_rows);
+
+ for i in row_start..row_end {
+ read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i);
+
+ let is_null = parent_row.is_null_at(column_idx);
+ struct_is_null.push(is_null);
+
+ if is_null {
+ struct_builder.append_null();
+ } else {
+ struct_builder.append(true);
+ }
+ }
+
+ // Helper macro for processing primitive fields
+ macro_rules! process_field {
+ ($builder_type:ty, $field_idx:expr, $get_value:expr) => {{
+ let field_builder = get_field_builder!(struct_builder,
$builder_type, $field_idx);
+
+ for (row_idx, i) in (row_start..row_end).enumerate() {
+ if struct_is_null[row_idx] {
+ // Struct is null, field is also null
+ field_builder.append_null();
+ } else {
+ read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr,
i);
+ let nested_row = parent_row.get_struct(column_idx,
num_fields);
+
+ if nested_row.is_null_at($field_idx) {
+ field_builder.append_null();
+ } else {
+ field_builder.append_value($get_value(&nested_row,
$field_idx));
+ }
+ }
+ }
+ }};
+ }
+
+ // Second pass: Process each field across all rows
+ for (field_idx, field) in fields.iter().enumerate() {
+ match field.data_type() {
+ DataType::Boolean => {
+ process_field!(BooleanBuilder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_boolean(idx));
+ }
+ DataType::Int8 => {
+ process_field!(Int8Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_byte(idx));
+ }
+ DataType::Int16 => {
+ process_field!(Int16Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_short(idx));
+ }
+ DataType::Int32 => {
+ process_field!(Int32Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_int(idx));
+ }
+ DataType::Int64 => {
+ process_field!(Int64Builder, field_idx, |row: &SparkUnsafeRow,
idx| row
+ .get_long(idx));
+ }
+ DataType::Float32 => {
+ process_field!(Float32Builder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_float(idx));
+ }
+ DataType::Float64 => {
+ process_field!(Float64Builder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_double(idx));
+ }
+ DataType::Date32 => {
+ process_field!(Date32Builder, field_idx, |row:
&SparkUnsafeRow, idx| row
+ .get_date(idx));
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, _) => {
+ process_field!(
+ TimestampMicrosecondBuilder,
+ field_idx,
+ |row: &SparkUnsafeRow, idx| row.get_timestamp(idx)
+ );
+ }
+ DataType::Binary => {
+ let field_builder = get_field_builder!(struct_builder,
BinaryBuilder, field_idx);
+
+ for (row_idx, i) in (row_start..row_end).enumerate() {
+ if struct_is_null[row_idx] {
+ field_builder.append_null();
+ } else {
+ read_row_at!(parent_row, row_addresses_ptr,
row_sizes_ptr, i);
+ let nested_row = parent_row.get_struct(column_idx,
num_fields);
+
+ if nested_row.is_null_at(field_idx) {
+ field_builder.append_null();
+ } else {
+
field_builder.append_value(nested_row.get_binary(field_idx));
+ }
+ }
+ }
+ }
+ DataType::Utf8 => {
+ let field_builder = get_field_builder!(struct_builder,
StringBuilder, field_idx);
+
+ for (row_idx, i) in (row_start..row_end).enumerate() {
+ if struct_is_null[row_idx] {
+ field_builder.append_null();
+ } else {
+ read_row_at!(parent_row, row_addresses_ptr,
row_sizes_ptr, i);
+ let nested_row = parent_row.get_struct(column_idx,
num_fields);
+
+ if nested_row.is_null_at(field_idx) {
+ field_builder.append_null();
+ } else {
+
field_builder.append_value(nested_row.get_string(field_idx));
+ }
+ }
+ }
+ }
+ DataType::Decimal128(p, _) => {
+ let p = *p;
+ let field_builder =
+ get_field_builder!(struct_builder, Decimal128Builder,
field_idx);
+
+ for (row_idx, i) in (row_start..row_end).enumerate() {
+ if struct_is_null[row_idx] {
+ field_builder.append_null();
+ } else {
+ read_row_at!(parent_row, row_addresses_ptr,
row_sizes_ptr, i);
+ let nested_row = parent_row.get_struct(column_idx,
num_fields);
+
+ if nested_row.is_null_at(field_idx) {
+ field_builder.append_null();
+ } else {
+
field_builder.append_value(nested_row.get_decimal(field_idx, p));
+ }
+ }
+ }
+ }
+ // For nested structs, apply field-major processing recursively
+ DataType::Struct(nested_fields) => {
+ let nested_builder = get_field_builder!(struct_builder,
StructBuilder, field_idx);
+
+ // Collect nested struct addresses and sizes in one pass,
building validity
+ let mut nested_addresses: Vec<jlong> =
Vec::with_capacity(num_rows);
+ let mut nested_sizes: Vec<jint> = Vec::with_capacity(num_rows);
+ let mut nested_is_null: Vec<bool> =
Vec::with_capacity(num_rows);
+
+ for (row_idx, i) in (row_start..row_end).enumerate() {
+ if struct_is_null[row_idx] {
+ // Parent struct is null, nested struct is also null
+ nested_builder.append_null();
+ nested_is_null.push(true);
+ nested_addresses.push(0);
+ nested_sizes.push(0);
+ } else {
+ read_row_at!(parent_row, row_addresses_ptr,
row_sizes_ptr, i);
+ let parent_struct = parent_row.get_struct(column_idx,
num_fields);
+
+ if parent_struct.is_null_at(field_idx) {
+ nested_builder.append_null();
+ nested_is_null.push(true);
+ nested_addresses.push(0);
+ nested_sizes.push(0);
+ } else {
+ nested_builder.append(true);
+ nested_is_null.push(false);
+ // Get nested struct address and size
+ let nested_row =
+ parent_struct.get_struct(field_idx,
nested_fields.len());
+ nested_addresses.push(nested_row.get_row_addr());
+ nested_sizes.push(nested_row.get_row_size());
+ }
+ }
+ }
+
+ // Recursively process nested struct fields in field-major
order
+ append_nested_struct_fields_field_major(
+ &nested_addresses,
+ &nested_sizes,
+ &nested_is_null,
+ nested_builder,
+ nested_fields,
+ )?;
+ }
+ // For list and map, fall back to append_field since they have
variable-length elements
+ dt @ (DataType::List(_) | DataType::Map(_, _)) => {
+ for (row_idx, i) in (row_start..row_end).enumerate() {
+ if struct_is_null[row_idx] {
+ let null_row = SparkUnsafeRow::default();
+ append_field(dt, struct_builder, &null_row,
field_idx)?;
+ } else {
+ read_row_at!(parent_row, row_addresses_ptr,
row_sizes_ptr, i);
+ let nested_row = parent_row.get_struct(column_idx,
num_fields);
+ append_field(dt, struct_builder, &nested_row,
field_idx)?;
+ }
+ }
+ }
+ _ => {
+ unreachable!(
+ "Unsupported data type of struct field: {:?}",
+ field.data_type()
+ )
+ }
+ }
+ }
+
+ Ok(())
+}
+
/// Appends column of top rows to the given array builder.
+///
+/// # Safety
+///
+/// The caller must ensure:
+/// - `row_addresses_ptr` points to an array of at least `row_end` jlong values
+/// - `row_sizes_ptr` points to an array of at least `row_end` jint values
+/// - Each address in `row_addresses_ptr[row_start..row_end]` points to valid
Spark UnsafeRow data
+/// - The memory remains valid for the duration of this function call
+///
+/// These invariants are guaranteed when called from JNI with arrays provided
by the JVM.
#[allow(clippy::redundant_closure_call, clippy::too_many_arguments)]
fn append_columns(
row_addresses_ptr: *mut jlong,
@@ -523,23 +1215,7 @@ fn append_columns(
let mut row = SparkUnsafeRow::new(schema);
for i in row_start..row_end {
- // SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays
with at least
- // row_end elements. i is in [row_start, row_end) so the
offset is in bounds.
- debug_assert!(
- !row_addresses_ptr.is_null(),
- "append_columns: null row_addresses_ptr"
- );
- debug_assert!(
- !row_sizes_ptr.is_null(),
- "append_columns: null row_sizes_ptr"
- );
- debug_assert!(
- i < row_end,
- "append_columns: index {i} out of bounds
(row_end={row_end})"
- );
- let row_addr = unsafe { *row_addresses_ptr.add(i) };
- let row_size = unsafe { *row_sizes_ptr.add(i) };
- row.point_to(row_addr, row_size);
+ read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i);
let is_null = row.is_null_at(column_idx);
@@ -664,75 +1340,31 @@ fn append_columns(
MapBuilder<Box<dyn ArrayBuilder>, Box<dyn ArrayBuilder>>,
builder
);
- let mut row = SparkUnsafeRow::new(schema);
-
- for i in row_start..row_end {
- // SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays
with at least
- // row_end elements. i is in [row_start, row_end) so the
offset is in bounds.
- debug_assert!(
- !row_addresses_ptr.is_null(),
- "append_columns: null row_addresses_ptr"
- );
- debug_assert!(
- !row_sizes_ptr.is_null(),
- "append_columns: null row_sizes_ptr"
- );
- debug_assert!(
- i < row_end,
- "append_columns: index {i} out of bounds
(row_end={row_end})"
- );
- let row_addr = unsafe { *row_addresses_ptr.add(i) };
- let row_size = unsafe { *row_sizes_ptr.add(i) };
- row.point_to(row_addr, row_size);
-
- let is_null = row.is_null_at(column_idx);
-
- if is_null {
- // The map is null.
- // Append a null value to the map builder.
- map_builder.append(false)?;
- } else {
- append_map_elements(field, map_builder,
&row.get_map(column_idx))?
- }
- }
+ // Use batched processing for better performance
+ append_map_column_batch(
+ row_addresses_ptr,
+ row_sizes_ptr,
+ row_start,
+ row_end,
+ schema,
+ column_idx,
+ field,
+ map_builder,
+ )?;
}
DataType::List(field) => {
let list_builder = downcast_builder_ref!(ListBuilder<Box<dyn
ArrayBuilder>>, builder);
- let mut row = SparkUnsafeRow::new(schema);
-
- for i in row_start..row_end {
- // SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays
with at least
- // row_end elements. i is in [row_start, row_end) so the
offset is in bounds.
- debug_assert!(
- !row_addresses_ptr.is_null(),
- "append_columns: null row_addresses_ptr"
- );
- debug_assert!(
- !row_sizes_ptr.is_null(),
- "append_columns: null row_sizes_ptr"
- );
- debug_assert!(
- i < row_end,
- "append_columns: index {i} out of bounds
(row_end={row_end})"
- );
- let row_addr = unsafe { *row_addresses_ptr.add(i) };
- let row_size = unsafe { *row_sizes_ptr.add(i) };
- row.point_to(row_addr, row_size);
-
- let is_null = row.is_null_at(column_idx);
-
- if is_null {
- // The list is null.
- // Append a null value to the list builder.
- list_builder.append_null();
- } else {
- append_list_element(
- field.data_type(),
- list_builder,
- &row.get_array(column_idx),
- )?
- }
- }
+ // Use batched processing for better performance
+ append_list_column_batch(
+ row_addresses_ptr,
+ row_sizes_ptr,
+ row_start,
+ row_end,
+ schema,
+ column_idx,
+ field.data_type(),
+ list_builder,
+ )?;
}
DataType::Struct(fields) => {
let struct_builder = builder
@@ -741,41 +1373,17 @@ fn append_columns(
.expect("StructBuilder");
let mut row = SparkUnsafeRow::new(schema);
- for i in row_start..row_end {
- // SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays
with at least
- // row_end elements. i is in [row_start, row_end) so the
offset is in bounds.
- debug_assert!(
- !row_addresses_ptr.is_null(),
- "append_columns: null row_addresses_ptr"
- );
- debug_assert!(
- !row_sizes_ptr.is_null(),
- "append_columns: null row_sizes_ptr"
- );
- debug_assert!(
- i < row_end,
- "append_columns: index {i} out of bounds
(row_end={row_end})"
- );
- let row_addr = unsafe { *row_addresses_ptr.add(i) };
- let row_size = unsafe { *row_sizes_ptr.add(i) };
- row.point_to(row_addr, row_size);
-
- let is_null = row.is_null_at(column_idx);
-
- let nested_row = if is_null {
- // The struct is null.
- // Append a null value to the struct builder and field
builders.
- struct_builder.append_null();
- SparkUnsafeRow::default()
- } else {
- struct_builder.append(true);
- row.get_struct(column_idx, fields.len())
- };
-
- for (idx, field) in fields.into_iter().enumerate() {
- append_field(field.data_type(), struct_builder,
&nested_row, idx)?;
- }
- }
+ // Use field-major processing to avoid per-row type dispatch
+ append_struct_fields_field_major(
+ row_addresses_ptr,
+ row_sizes_ptr,
+ row_start,
+ row_end,
+ &mut row,
+ column_idx,
+ struct_builder,
+ fields,
+ )?;
}
_ => {
unreachable!("Unsupported data type of column: {:?}", dt)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]