This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 51653dd68 feat: Native columnar to row conversion (Phase 2) (#3266)
51653dd68 is described below
commit 51653dd68e6b3e51b1396a72139abc402026695a
Author: Andy Grove <[email protected]>
AuthorDate: Mon Jan 26 18:58:42 2026 -0700
feat: Native columnar to row conversion (Phase 2) (#3266)
---
.../org/apache/spark/sql/comet/util/Utils.scala | 4 +-
native/core/src/execution/columnar_to_row.rs | 1407 +++++++++-----------
.../rules/EliminateRedundantTransitions.scala | 32 +-
.../sql/comet/CometNativeColumnarToRowExec.scala | 138 +-
.../org/apache/comet/CometExpressionSuite.scala | 10 +-
.../org/apache/comet/exec/CometExecSuite.scala | 31 +-
.../scala/org/apache/spark/sql/CometTestBase.scala | 1 +
.../apache/spark/sql/comet/CometPlanChecker.scala | 2 +-
8 files changed, 828 insertions(+), 797 deletions(-)
diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
index cb0a51944..7662b219c 100644
--- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
+++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
@@ -26,7 +26,7 @@ import java.nio.channels.Channels
import scala.jdk.CollectionConverters._
import org.apache.arrow.c.CDataDictionaryProvider
-import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector,
DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector,
IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector,
TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot}
+import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector,
DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector,
IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector,
TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector,
VarCharVector, VectorSchemaRoot}
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.arrow.vector.ipc.ArrowStreamWriter
@@ -288,7 +288,7 @@ object Utils extends CometTypeShim {
_: BigIntVector | _: Float4Vector | _: Float8Vector | _:
VarCharVector |
_: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _:
VarBinaryVector |
_: FixedSizeBinaryVector | _: TimeStampMicroVector | _: StructVector
| _: ListVector |
- _: MapVector) =>
+ _: MapVector | _: NullVector) =>
v.asInstanceOf[FieldVector]
case _ =>
throw new SparkException(s"Unsupported Arrow Vector for $reason:
${valueVector.getClass}")
diff --git a/native/core/src/execution/columnar_to_row.rs
b/native/core/src/execution/columnar_to_row.rs
index 78ab7637e..66b53af2b 100644
--- a/native/core/src/execution/columnar_to_row.rs
+++ b/native/core/src/execution/columnar_to_row.rs
@@ -41,16 +41,117 @@ use arrow::array::types::{
UInt64Type, UInt8Type,
};
use arrow::array::*;
+use arrow::compute::{cast_with_options, CastOptions};
use arrow::datatypes::{ArrowNativeType, DataType, TimeUnit};
use std::sync::Arc;
/// Maximum digits for decimal that can fit in a long (8 bytes).
const MAX_LONG_DIGITS: u8 = 18;
+/// Helper macro for downcasting arrays with consistent error messages.
+macro_rules! downcast_array {
+ ($array:expr, $array_type:ty) => {
+ $array
+ .as_any()
+ .downcast_ref::<$array_type>()
+ .ok_or_else(|| {
+ CometError::Internal(format!(
+ "Failed to downcast to {}, actual type: {:?}",
+ stringify!($array_type),
+ $array.data_type()
+ ))
+ })
+ };
+}
+
+/// Macro to implement is_null for typed array enums.
+/// Generates a complete match expression for all variants that have an array
as first field.
+macro_rules! impl_is_null {
+ ($self:expr, $row_idx:expr, [$($variant:ident),+ $(,)?]) => {
+ match $self {
+ $(Self::$variant(arr, ..) => arr.is_null($row_idx),)+
+ }
+ };
+ // Version with special handling for Null variant
+ ($self:expr, $row_idx:expr, null_always_true, [$($variant:ident),+ $(,)?])
=> {
+ match $self {
+ Self::Null => true,
+ $(Self::$variant(arr, ..) => arr.is_null($row_idx),)+
+ }
+ };
+}
+
+/// Macro to generate TypedElements::from_array match arms for primitive types.
+macro_rules! typed_elements_from_primitive {
+ ($array:expr, $element_type:expr, $(($dt:pat, $variant:ident,
$arr_type:ty)),+ $(,)?) => {
+ match $element_type {
+ $(
+ $dt => {
+ if let Some(arr) =
$array.as_any().downcast_ref::<$arr_type>() {
+ return TypedElements::$variant(arr);
+ }
+ }
+ )+
+ _ => {}
+ }
+ };
+}
+
+/// Macro for write_column_fixed_width arms - handles downcast + loop pattern.
+macro_rules! write_fixed_column_primitive {
+ ($self:expr, $array:expr, $row_size:expr, $field_offset:expr,
$num_rows:expr,
+ $arr_type:ty, $to_i64:expr) => {{
+ let arr = downcast_array!($array, $arr_type)?;
+ for row_idx in 0..$num_rows {
+ if !arr.is_null(row_idx) {
+ let offset = row_idx * $row_size + $field_offset;
+ let value: i64 = $to_i64(arr.value(row_idx));
+ $self.buffer[offset..offset +
8].copy_from_slice(&value.to_le_bytes());
+ }
+ }
+ Ok(())
+ }};
+}
+
+/// Macro for get_field_value arms - handles downcast + value extraction.
+macro_rules! get_field_value_primitive {
+ ($array:expr, $row_idx:expr, $arr_type:ty, $to_i64:expr) => {{
+ let arr = downcast_array!($array, $arr_type)?;
+ Ok($to_i64(arr.value($row_idx)))
+ }};
+}
+
+/// Macro for write_struct_to_buffer fixed-width field extraction.
+macro_rules! extract_fixed_value {
+ ($column:expr, $row_idx:expr, $(($dt:pat, $arr_type:ty, $to_i64:expr)),+
$(,)?) => {
+ match $column.data_type() {
+ $(
+ $dt => {
+ let arr = downcast_array!($column, $arr_type)?;
+ Some($to_i64(arr.value($row_idx)))
+ }
+ )+
+ _ => None,
+ }
+ };
+}
+
+/// Writes bytes to buffer with 8-byte alignment padding.
+/// Returns the unpadded length.
+#[inline]
+fn write_bytes_padded(buffer: &mut Vec<u8>, bytes: &[u8]) -> usize {
+ let len = bytes.len();
+ buffer.extend_from_slice(bytes);
+ let padding = round_up_to_8(len) - len;
+ buffer.extend(std::iter::repeat_n(0u8, padding));
+ len
+}
+
/// Pre-downcast array reference to avoid type dispatch in inner loops.
/// This enum holds references to concrete array types, allowing direct access
/// without repeated downcast_ref calls.
enum TypedArray<'a> {
+ Null,
Boolean(&'a BooleanArray),
Int8(&'a Int8Array),
Int16(&'a Int16Array),
@@ -65,6 +166,7 @@ enum TypedArray<'a> {
LargeString(&'a LargeStringArray),
Binary(&'a BinaryArray),
LargeBinary(&'a LargeBinaryArray),
+ FixedSizeBinary(&'a FixedSizeBinaryArray),
Struct(
&'a StructArray,
arrow::datatypes::Fields,
@@ -78,119 +180,46 @@ enum TypedArray<'a> {
impl<'a> TypedArray<'a> {
/// Pre-downcast an ArrayRef to a TypedArray.
- fn from_array(array: &'a ArrayRef, schema_type: &DataType) ->
CometResult<Self> {
+ fn from_array(array: &'a ArrayRef) -> CometResult<Self> {
let actual_type = array.data_type();
match actual_type {
- DataType::Boolean => Ok(TypedArray::Boolean(
- array
- .as_any()
- .downcast_ref::<BooleanArray>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
BooleanArray".to_string())
- })?,
- )),
- DataType::Int8 => Ok(TypedArray::Int8(
- array.as_any().downcast_ref::<Int8Array>().ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Int8Array".to_string())
- })?,
- )),
- DataType::Int16 => Ok(TypedArray::Int16(
- array.as_any().downcast_ref::<Int16Array>().ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Int16Array".to_string())
- })?,
- )),
- DataType::Int32 => Ok(TypedArray::Int32(
- array.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Int32Array".to_string())
- })?,
- )),
- DataType::Int64 => Ok(TypedArray::Int64(
- array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Int64Array".to_string())
- })?,
- )),
- DataType::Float32 => Ok(TypedArray::Float32(
- array
- .as_any()
- .downcast_ref::<Float32Array>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Float32Array".to_string())
- })?,
- )),
- DataType::Float64 => Ok(TypedArray::Float64(
- array
- .as_any()
- .downcast_ref::<Float64Array>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Float64Array".to_string())
- })?,
- )),
- DataType::Date32 => Ok(TypedArray::Date32(
- array
- .as_any()
- .downcast_ref::<Date32Array>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Date32Array".to_string())
- })?,
- )),
+ DataType::Null => {
+ // Verify the array is actually a NullArray, but we don't need
to store the reference
+ // since all values are null by definition
+ downcast_array!(array, NullArray)?;
+ Ok(TypedArray::Null)
+ }
+ DataType::Boolean => Ok(TypedArray::Boolean(downcast_array!(array,
BooleanArray)?)),
+ DataType::Int8 => Ok(TypedArray::Int8(downcast_array!(array,
Int8Array)?)),
+ DataType::Int16 => Ok(TypedArray::Int16(downcast_array!(array,
Int16Array)?)),
+ DataType::Int32 => Ok(TypedArray::Int32(downcast_array!(array,
Int32Array)?)),
+ DataType::Int64 => Ok(TypedArray::Int64(downcast_array!(array,
Int64Array)?)),
+ DataType::Float32 => Ok(TypedArray::Float32(downcast_array!(array,
Float32Array)?)),
+ DataType::Float64 => Ok(TypedArray::Float64(downcast_array!(array,
Float64Array)?)),
+ DataType::Date32 => Ok(TypedArray::Date32(downcast_array!(array,
Date32Array)?)),
DataType::Timestamp(TimeUnit::Microsecond, _) =>
Ok(TypedArray::TimestampMicro(
- array
- .as_any()
- .downcast_ref::<TimestampMicrosecondArray>()
- .ok_or_else(|| {
- CometError::Internal(
- "Failed to downcast to
TimestampMicrosecondArray".to_string(),
- )
- })?,
+ downcast_array!(array, TimestampMicrosecondArray)?,
)),
DataType::Decimal128(p, _) => Ok(TypedArray::Decimal128(
- array
- .as_any()
- .downcast_ref::<Decimal128Array>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Decimal128Array".to_string())
- })?,
+ downcast_array!(array, Decimal128Array)?,
*p,
)),
- DataType::Utf8 => Ok(TypedArray::String(
- array
- .as_any()
- .downcast_ref::<StringArray>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
StringArray".to_string())
- })?,
- )),
- DataType::LargeUtf8 => Ok(TypedArray::LargeString(
- array
- .as_any()
- .downcast_ref::<LargeStringArray>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
LargeStringArray".to_string())
- })?,
- )),
- DataType::Binary => Ok(TypedArray::Binary(
- array
- .as_any()
- .downcast_ref::<BinaryArray>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
BinaryArray".to_string())
- })?,
- )),
- DataType::LargeBinary => Ok(TypedArray::LargeBinary(
- array
- .as_any()
- .downcast_ref::<LargeBinaryArray>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
LargeBinaryArray".to_string())
- })?,
- )),
+ DataType::Utf8 => Ok(TypedArray::String(downcast_array!(array,
StringArray)?)),
+ DataType::LargeUtf8 => Ok(TypedArray::LargeString(downcast_array!(
+ array,
+ LargeStringArray
+ )?)),
+ DataType::Binary => Ok(TypedArray::Binary(downcast_array!(array,
BinaryArray)?)),
+ DataType::LargeBinary =>
Ok(TypedArray::LargeBinary(downcast_array!(
+ array,
+ LargeBinaryArray
+ )?)),
+ DataType::FixedSizeBinary(_) =>
Ok(TypedArray::FixedSizeBinary(downcast_array!(
+ array,
+ FixedSizeBinaryArray
+ )?)),
DataType::Struct(fields) => {
- let struct_arr = array
- .as_any()
- .downcast_ref::<StructArray>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
StructArray".to_string())
- })?;
+ let struct_arr = downcast_array!(array, StructArray)?;
// Pre-downcast all struct fields once
let typed_fields: Vec<TypedElements> = fields
.iter()
@@ -202,27 +231,18 @@ impl<'a> TypedArray<'a> {
Ok(TypedArray::Struct(struct_arr, fields.clone(),
typed_fields))
}
DataType::List(field) => Ok(TypedArray::List(
- array.as_any().downcast_ref::<ListArray>().ok_or_else(|| {
- CometError::Internal("Failed to downcast to
ListArray".to_string())
- })?,
+ downcast_array!(array, ListArray)?,
Arc::clone(field),
)),
DataType::LargeList(field) => Ok(TypedArray::LargeList(
- array
- .as_any()
- .downcast_ref::<LargeListArray>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
LargeListArray".to_string())
- })?,
+ downcast_array!(array, LargeListArray)?,
Arc::clone(field),
)),
DataType::Map(field, _) => Ok(TypedArray::Map(
- array.as_any().downcast_ref::<MapArray>().ok_or_else(|| {
- CometError::Internal("Failed to downcast to
MapArray".to_string())
- })?,
+ downcast_array!(array, MapArray)?,
Arc::clone(field),
)),
- DataType::Dictionary(_, _) => Ok(TypedArray::Dictionary(array,
schema_type.clone())),
+ DataType::Dictionary(_, _) => Ok(TypedArray::Dictionary(array,
actual_type.clone())),
_ => Err(CometError::Internal(format!(
"Unsupported data type for pre-downcast: {:?}",
actual_type
@@ -233,27 +253,33 @@ impl<'a> TypedArray<'a> {
/// Check if the value at the given index is null.
#[inline]
fn is_null(&self, row_idx: usize) -> bool {
- match self {
- TypedArray::Boolean(arr) => arr.is_null(row_idx),
- TypedArray::Int8(arr) => arr.is_null(row_idx),
- TypedArray::Int16(arr) => arr.is_null(row_idx),
- TypedArray::Int32(arr) => arr.is_null(row_idx),
- TypedArray::Int64(arr) => arr.is_null(row_idx),
- TypedArray::Float32(arr) => arr.is_null(row_idx),
- TypedArray::Float64(arr) => arr.is_null(row_idx),
- TypedArray::Date32(arr) => arr.is_null(row_idx),
- TypedArray::TimestampMicro(arr) => arr.is_null(row_idx),
- TypedArray::Decimal128(arr, _) => arr.is_null(row_idx),
- TypedArray::String(arr) => arr.is_null(row_idx),
- TypedArray::LargeString(arr) => arr.is_null(row_idx),
- TypedArray::Binary(arr) => arr.is_null(row_idx),
- TypedArray::LargeBinary(arr) => arr.is_null(row_idx),
- TypedArray::Struct(arr, _, _) => arr.is_null(row_idx),
- TypedArray::List(arr, _) => arr.is_null(row_idx),
- TypedArray::LargeList(arr, _) => arr.is_null(row_idx),
- TypedArray::Map(arr, _) => arr.is_null(row_idx),
- TypedArray::Dictionary(arr, _) => arr.is_null(row_idx),
- }
+ impl_is_null!(
+ self,
+ row_idx,
+ null_always_true,
+ [
+ Boolean,
+ Int8,
+ Int16,
+ Int32,
+ Int64,
+ Float32,
+ Float64,
+ Date32,
+ TimestampMicro,
+ Decimal128,
+ String,
+ LargeString,
+ Binary,
+ LargeBinary,
+ FixedSizeBinary,
+ Struct,
+ List,
+ LargeList,
+ Map,
+ Dictionary
+ ]
+ )
}
/// Get the fixed-width value as i64 (for types that fit in 8 bytes).
@@ -291,7 +317,8 @@ impl<'a> TypedArray<'a> {
#[inline]
fn is_variable_length(&self) -> bool {
match self {
- TypedArray::Boolean(_)
+ TypedArray::Null
+ | TypedArray::Boolean(_)
| TypedArray::Int8(_)
| TypedArray::Int16(_)
| TypedArray::Int32(_)
@@ -309,44 +336,17 @@ impl<'a> TypedArray<'a> {
fn write_variable_to_buffer(&self, buffer: &mut Vec<u8>, row_idx: usize)
-> CometResult<usize> {
match self {
TypedArray::String(arr) => {
- let bytes = arr.value(row_idx).as_bytes();
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ Ok(write_bytes_padded(buffer, arr.value(row_idx).as_bytes()))
}
TypedArray::LargeString(arr) => {
- let bytes = arr.value(row_idx).as_bytes();
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
- }
- TypedArray::Binary(arr) => {
- let bytes = arr.value(row_idx);
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
- }
- TypedArray::LargeBinary(arr) => {
- let bytes = arr.value(row_idx);
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ Ok(write_bytes_padded(buffer, arr.value(row_idx).as_bytes()))
}
+ TypedArray::Binary(arr) => Ok(write_bytes_padded(buffer,
arr.value(row_idx))),
+ TypedArray::LargeBinary(arr) => Ok(write_bytes_padded(buffer,
arr.value(row_idx))),
+ TypedArray::FixedSizeBinary(arr) => Ok(write_bytes_padded(buffer,
arr.value(row_idx))),
TypedArray::Decimal128(arr, precision) if *precision >
MAX_LONG_DIGITS => {
let bytes = i128_to_spark_decimal_bytes(arr.value(row_idx));
- let len = bytes.len();
- buffer.extend_from_slice(&bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ Ok(write_bytes_padded(buffer, &bytes))
}
TypedArray::Struct(arr, fields, typed_fields) => {
write_struct_to_buffer_typed(buffer, arr, row_idx, fields,
typed_fields)
@@ -394,6 +394,7 @@ enum TypedElements<'a> {
LargeString(&'a LargeStringArray),
Binary(&'a BinaryArray),
LargeBinary(&'a LargeBinaryArray),
+ FixedSizeBinary(&'a FixedSizeBinaryArray),
// For nested types, fall back to ArrayRef
Other(&'a ArrayRef, DataType),
}
@@ -401,47 +402,26 @@ enum TypedElements<'a> {
impl<'a> TypedElements<'a> {
/// Create from an ArrayRef and element type.
fn from_array(array: &'a ArrayRef, element_type: &DataType) -> Self {
+ // Try primitive types first using macro
+ typed_elements_from_primitive!(
+ array,
+ element_type,
+ (DataType::Boolean, Boolean, BooleanArray),
+ (DataType::Int8, Int8, Int8Array),
+ (DataType::Int16, Int16, Int16Array),
+ (DataType::Int32, Int32, Int32Array),
+ (DataType::Int64, Int64, Int64Array),
+ (DataType::Float32, Float32, Float32Array),
+ (DataType::Float64, Float64, Float64Array),
+ (DataType::Date32, Date32, Date32Array),
+ (DataType::Utf8, String, StringArray),
+ (DataType::LargeUtf8, LargeString, LargeStringArray),
+ (DataType::Binary, Binary, BinaryArray),
+ (DataType::LargeBinary, LargeBinary, LargeBinaryArray),
+ );
+
+ // Handle special cases that need extra processing
match element_type {
- DataType::Boolean => {
- if let Some(arr) =
array.as_any().downcast_ref::<BooleanArray>() {
- return TypedElements::Boolean(arr);
- }
- }
- DataType::Int8 => {
- if let Some(arr) = array.as_any().downcast_ref::<Int8Array>() {
- return TypedElements::Int8(arr);
- }
- }
- DataType::Int16 => {
- if let Some(arr) = array.as_any().downcast_ref::<Int16Array>()
{
- return TypedElements::Int16(arr);
- }
- }
- DataType::Int32 => {
- if let Some(arr) = array.as_any().downcast_ref::<Int32Array>()
{
- return TypedElements::Int32(arr);
- }
- }
- DataType::Int64 => {
- if let Some(arr) = array.as_any().downcast_ref::<Int64Array>()
{
- return TypedElements::Int64(arr);
- }
- }
- DataType::Float32 => {
- if let Some(arr) =
array.as_any().downcast_ref::<Float32Array>() {
- return TypedElements::Float32(arr);
- }
- }
- DataType::Float64 => {
- if let Some(arr) =
array.as_any().downcast_ref::<Float64Array>() {
- return TypedElements::Float64(arr);
- }
- }
- DataType::Date32 => {
- if let Some(arr) =
array.as_any().downcast_ref::<Date32Array>() {
- return TypedElements::Date32(arr);
- }
- }
DataType::Timestamp(TimeUnit::Microsecond, _) => {
if let Some(arr) =
array.as_any().downcast_ref::<TimestampMicrosecondArray>() {
return TypedElements::TimestampMicro(arr);
@@ -452,24 +432,9 @@ impl<'a> TypedElements<'a> {
return TypedElements::Decimal128(arr, *p);
}
}
- DataType::Utf8 => {
- if let Some(arr) =
array.as_any().downcast_ref::<StringArray>() {
- return TypedElements::String(arr);
- }
- }
- DataType::LargeUtf8 => {
- if let Some(arr) =
array.as_any().downcast_ref::<LargeStringArray>() {
- return TypedElements::LargeString(arr);
- }
- }
- DataType::Binary => {
- if let Some(arr) =
array.as_any().downcast_ref::<BinaryArray>() {
- return TypedElements::Binary(arr);
- }
- }
- DataType::LargeBinary => {
- if let Some(arr) =
array.as_any().downcast_ref::<LargeBinaryArray>() {
- return TypedElements::LargeBinary(arr);
+ DataType::FixedSizeBinary(_) => {
+ if let Some(arr) =
array.as_any().downcast_ref::<FixedSizeBinaryArray>() {
+ return TypedElements::FixedSizeBinary(arr);
}
}
_ => {}
@@ -510,23 +475,28 @@ impl<'a> TypedElements<'a> {
/// Check if value at given index is null.
#[inline]
fn is_null_at(&self, idx: usize) -> bool {
- match self {
- TypedElements::Boolean(arr) => arr.is_null(idx),
- TypedElements::Int8(arr) => arr.is_null(idx),
- TypedElements::Int16(arr) => arr.is_null(idx),
- TypedElements::Int32(arr) => arr.is_null(idx),
- TypedElements::Int64(arr) => arr.is_null(idx),
- TypedElements::Float32(arr) => arr.is_null(idx),
- TypedElements::Float64(arr) => arr.is_null(idx),
- TypedElements::Date32(arr) => arr.is_null(idx),
- TypedElements::TimestampMicro(arr) => arr.is_null(idx),
- TypedElements::Decimal128(arr, _) => arr.is_null(idx),
- TypedElements::String(arr) => arr.is_null(idx),
- TypedElements::LargeString(arr) => arr.is_null(idx),
- TypedElements::Binary(arr) => arr.is_null(idx),
- TypedElements::LargeBinary(arr) => arr.is_null(idx),
- TypedElements::Other(arr, _) => arr.is_null(idx),
- }
+ impl_is_null!(
+ self,
+ idx,
+ [
+ Boolean,
+ Int8,
+ Int16,
+ Int32,
+ Int64,
+ Float32,
+ Float64,
+ Date32,
+ TimestampMicro,
+ Decimal128,
+ String,
+ LargeString,
+ Binary,
+ LargeBinary,
+ FixedSizeBinary,
+ Other
+ ]
+ )
}
/// Check if this is a fixed-width type (value fits in 8-byte slot).
@@ -572,55 +542,21 @@ impl<'a> TypedElements<'a> {
}
/// Write variable-length data to buffer. Returns length written (0 for
fixed-width).
- fn write_variable_value(
- &self,
- buffer: &mut Vec<u8>,
- idx: usize,
- base_offset: usize,
- ) -> CometResult<usize> {
+ fn write_variable_value(&self, buffer: &mut Vec<u8>, idx: usize) ->
CometResult<usize> {
match self {
- TypedElements::String(arr) => {
- let bytes = arr.value(idx).as_bytes();
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
- }
+ TypedElements::String(arr) => Ok(write_bytes_padded(buffer,
arr.value(idx).as_bytes())),
TypedElements::LargeString(arr) => {
- let bytes = arr.value(idx).as_bytes();
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
- }
- TypedElements::Binary(arr) => {
- let bytes = arr.value(idx);
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
- }
- TypedElements::LargeBinary(arr) => {
- let bytes = arr.value(idx);
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ Ok(write_bytes_padded(buffer, arr.value(idx).as_bytes()))
}
+ TypedElements::Binary(arr) => Ok(write_bytes_padded(buffer,
arr.value(idx))),
+ TypedElements::LargeBinary(arr) => Ok(write_bytes_padded(buffer,
arr.value(idx))),
+ TypedElements::FixedSizeBinary(arr) =>
Ok(write_bytes_padded(buffer, arr.value(idx))),
TypedElements::Decimal128(arr, precision) if *precision >
MAX_LONG_DIGITS => {
let bytes = i128_to_spark_decimal_bytes(arr.value(idx));
- let len = bytes.len();
- buffer.extend_from_slice(&bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ Ok(write_bytes_padded(buffer, &bytes))
}
TypedElements::Other(arr, element_type) => {
- write_nested_variable_to_buffer(buffer, element_type, arr,
idx, base_offset)
+ write_nested_variable_to_buffer(buffer, element_type, arr, idx)
}
_ => Ok(0), // Fixed-width types
}
@@ -771,11 +707,7 @@ impl<'a> TypedElements<'a> {
set_null_bit(buffer, null_bitset_start, i);
} else {
let bytes =
i128_to_spark_decimal_bytes(arr.value(src_idx));
- let len = bytes.len();
- buffer.extend_from_slice(&bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
-
+ let len = write_bytes_padded(buffer, &bytes);
let data_offset = buffer.len() - round_up_to_8(len) -
array_start;
let offset_and_len = ((data_offset as i64) << 32) |
(len as i64);
let slot_offset = elements_start + i * 8;
@@ -790,12 +722,7 @@ impl<'a> TypedElements<'a> {
if arr.is_null(src_idx) {
set_null_bit(buffer, null_bitset_start, i);
} else {
- let bytes = arr.value(src_idx).as_bytes();
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
-
+ let len = write_bytes_padded(buffer,
arr.value(src_idx).as_bytes());
let data_offset = buffer.len() - round_up_to_8(len) -
array_start;
let offset_and_len = ((data_offset as i64) << 32) |
(len as i64);
let slot_offset = elements_start + i * 8;
@@ -810,12 +737,7 @@ impl<'a> TypedElements<'a> {
if arr.is_null(src_idx) {
set_null_bit(buffer, null_bitset_start, i);
} else {
- let bytes = arr.value(src_idx).as_bytes();
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
-
+ let len = write_bytes_padded(buffer,
arr.value(src_idx).as_bytes());
let data_offset = buffer.len() - round_up_to_8(len) -
array_start;
let offset_and_len = ((data_offset as i64) << 32) |
(len as i64);
let slot_offset = elements_start + i * 8;
@@ -830,12 +752,7 @@ impl<'a> TypedElements<'a> {
if arr.is_null(src_idx) {
set_null_bit(buffer, null_bitset_start, i);
} else {
- let bytes = arr.value(src_idx);
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
-
+ let len = write_bytes_padded(buffer,
arr.value(src_idx));
let data_offset = buffer.len() - round_up_to_8(len) -
array_start;
let offset_and_len = ((data_offset as i64) << 32) |
(len as i64);
let slot_offset = elements_start + i * 8;
@@ -850,12 +767,7 @@ impl<'a> TypedElements<'a> {
if arr.is_null(src_idx) {
set_null_bit(buffer, null_bitset_start, i);
} else {
- let bytes = arr.value(src_idx);
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
-
+ let len = write_bytes_padded(buffer,
arr.value(src_idx));
let data_offset = buffer.len() - round_up_to_8(len) -
array_start;
let offset_and_len = ((data_offset as i64) << 32) |
(len as i64);
let slot_offset = elements_start + i * 8;
@@ -872,13 +784,8 @@ impl<'a> TypedElements<'a> {
set_null_bit(buffer, null_bitset_start, i);
} else {
let slot_offset = elements_start + i * element_size;
- let var_len = write_nested_variable_to_buffer(
- buffer,
- element_type,
- arr,
- src_idx,
- array_start,
- )?;
+ let var_len =
+ write_nested_variable_to_buffer(buffer,
element_type, arr, src_idx)?;
if var_len > 0 {
let padded_len = round_up_to_8(var_len);
@@ -1035,6 +942,16 @@ impl ColumnarToRowContext {
)));
}
+ // Unpack any dictionary arrays to their underlying value type
+ // This is needed because Parquet may return dictionary-encoded arrays
+ // even when the schema expects a specific type like Decimal128
+ let arrays: Vec<ArrayRef> = arrays
+ .iter()
+ .zip(self.schema.iter())
+ .map(|(arr, schema_type)| Self::maybe_cast_to_schema_type(arr,
schema_type))
+ .collect::<CometResult<Vec<_>>>()?;
+ let arrays = arrays.as_slice();
+
// Clear previous data
self.buffer.clear();
self.offsets.clear();
@@ -1052,8 +969,7 @@ impl ColumnarToRowContext {
// Pre-downcast all arrays to avoid type dispatch in inner loop
let typed_arrays: Vec<TypedArray> = arrays
.iter()
- .zip(self.schema.iter())
- .map(|(arr, dt)| TypedArray::from_array(arr, dt))
+ .map(TypedArray::from_array)
.collect::<CometResult<Vec<_>>>()?;
// Pre-compute variable-length column indices (once per batch, not per
row)
@@ -1079,6 +995,83 @@ impl ColumnarToRowContext {
Ok((self.buffer.as_ptr(), &self.offsets, &self.lengths))
}
+ /// Casts an array to match the expected schema type if needed.
+ /// This handles cases where:
+ /// 1. Parquet returns dictionary-encoded arrays but the schema expects a
non-dictionary type
+ /// 2. Parquet returns NullArray when all values are null, but the schema
expects a typed array
+ /// 3. Parquet returns Int32/Int64 for small-precision decimals but schema
expects Decimal128
+ fn maybe_cast_to_schema_type(
+ array: &ArrayRef,
+ schema_type: &DataType,
+ ) -> CometResult<ArrayRef> {
+ let actual_type = array.data_type();
+
+ // If types already match, no cast needed
+ if actual_type == schema_type {
+ return Ok(Arc::clone(array));
+ }
+
+ match (actual_type, schema_type) {
+ (DataType::Dictionary(_, _), schema)
+ if !matches!(schema, DataType::Dictionary(_, _)) =>
+ {
+ // Unpack dictionary if the schema type is not a dictionary
+ let options = CastOptions::default();
+ cast_with_options(array, schema_type, &options).map_err(|e| {
+ CometError::Internal(format!(
+ "Failed to unpack dictionary array from {:?} to {:?}:
{}",
+ actual_type, schema_type, e
+ ))
+ })
+ }
+ (DataType::Null, _) => {
+ // Cast NullArray to the expected schema type
+ // This happens when all values in a column are null
+ let options = CastOptions::default();
+ cast_with_options(array, schema_type, &options).map_err(|e| {
+ CometError::Internal(format!(
+ "Failed to cast NullArray to {:?}: {}",
+ schema_type, e
+ ))
+ })
+ }
+ (DataType::Int32, DataType::Decimal128(precision, scale)) => {
+ // Parquet stores small-precision decimals as Int32 for
efficiency.
+ // When COMET_USE_DECIMAL_128 is false, BatchReader produces
these types.
+ // The Int32 value is already scaled (e.g., -1 means -0.01 for
scale 2).
+ // We need to reinterpret (not cast) to Decimal128 preserving
the value.
+ let int_array =
array.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
+ CometError::Internal("Failed to downcast to
Int32Array".to_string())
+ })?;
+ let decimal_array: Decimal128Array = int_array
+ .iter()
+ .map(|v| v.map(|x| x as i128))
+ .collect::<Decimal128Array>()
+ .with_precision_and_scale(*precision, *scale)
+ .map_err(|e| {
+ CometError::Internal(format!("Invalid decimal
precision/scale: {}", e))
+ })?;
+ Ok(Arc::new(decimal_array))
+ }
+ (DataType::Int64, DataType::Decimal128(precision, scale)) => {
+ // Same as Int32 but for medium-precision decimals stored as
Int64.
+ let int_array =
array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
+ CometError::Internal("Failed to downcast to
Int64Array".to_string())
+ })?;
+ let decimal_array: Decimal128Array = int_array
+ .iter()
+ .map(|v| v.map(|x| x as i128))
+ .collect::<Decimal128Array>()
+ .with_precision_and_scale(*precision, *scale)
+ .map_err(|e| {
+ CometError::Internal(format!("Invalid decimal
precision/scale: {}", e))
+ })?;
+ Ok(Arc::new(decimal_array))
+ }
+ _ => Ok(Arc::clone(array)),
+ }
+ }
+
/// Fast path for schemas with only fixed-width columns.
/// Pre-allocates entire buffer and processes more efficiently.
fn convert_fixed_width(
@@ -1153,153 +1146,104 @@ impl ColumnarToRowContext {
// Write non-null values using type-specific fast paths
match data_type {
DataType::Boolean => {
- let arr = array
- .as_any()
- .downcast_ref::<BooleanArray>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
BooleanArray".to_string())
- })?;
+ // Boolean is special: writes single byte, not 8-byte i64
+ let arr = downcast_array!(array, BooleanArray)?;
for row_idx in 0..num_rows {
if !arr.is_null(row_idx) {
let offset = row_idx * row_size + field_offset_in_row;
self.buffer[offset] = if arr.value(row_idx) { 1 } else
{ 0 };
}
}
+ Ok(())
}
- DataType::Int8 => {
- let arr =
array.as_any().downcast_ref::<Int8Array>().ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Int8Array".to_string())
- })?;
- for row_idx in 0..num_rows {
- if !arr.is_null(row_idx) {
- let offset = row_idx * row_size + field_offset_in_row;
- self.buffer[offset..offset + 8]
- .copy_from_slice(&(arr.value(row_idx) as
i64).to_le_bytes());
- }
- }
- }
- DataType::Int16 => {
- let arr =
array.as_any().downcast_ref::<Int16Array>().ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Int16Array".to_string())
- })?;
- for row_idx in 0..num_rows {
- if !arr.is_null(row_idx) {
- let offset = row_idx * row_size + field_offset_in_row;
- self.buffer[offset..offset + 8]
- .copy_from_slice(&(arr.value(row_idx) as
i64).to_le_bytes());
- }
- }
- }
- DataType::Int32 => {
- let arr =
array.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Int32Array".to_string())
- })?;
- for row_idx in 0..num_rows {
- if !arr.is_null(row_idx) {
- let offset = row_idx * row_size + field_offset_in_row;
- self.buffer[offset..offset + 8]
- .copy_from_slice(&(arr.value(row_idx) as
i64).to_le_bytes());
- }
- }
- }
- DataType::Int64 => {
- let arr =
array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Int64Array".to_string())
- })?;
- for row_idx in 0..num_rows {
- if !arr.is_null(row_idx) {
- let offset = row_idx * row_size + field_offset_in_row;
- self.buffer[offset..offset + 8]
-
.copy_from_slice(&arr.value(row_idx).to_le_bytes());
- }
- }
- }
- DataType::Float32 => {
- let arr = array
- .as_any()
- .downcast_ref::<Float32Array>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Float32Array".to_string())
- })?;
- for row_idx in 0..num_rows {
- if !arr.is_null(row_idx) {
- let offset = row_idx * row_size + field_offset_in_row;
- self.buffer[offset..offset + 8]
- .copy_from_slice(&(arr.value(row_idx).to_bits() as
i64).to_le_bytes());
- }
- }
- }
- DataType::Float64 => {
- let arr = array
- .as_any()
- .downcast_ref::<Float64Array>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Float64Array".to_string())
- })?;
- for row_idx in 0..num_rows {
- if !arr.is_null(row_idx) {
- let offset = row_idx * row_size + field_offset_in_row;
- self.buffer[offset..offset + 8]
- .copy_from_slice(&(arr.value(row_idx).to_bits() as
i64).to_le_bytes());
- }
- }
- }
- DataType::Date32 => {
- let arr = array
- .as_any()
- .downcast_ref::<Date32Array>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Date32Array".to_string())
- })?;
- for row_idx in 0..num_rows {
- if !arr.is_null(row_idx) {
- let offset = row_idx * row_size + field_offset_in_row;
- self.buffer[offset..offset + 8]
- .copy_from_slice(&(arr.value(row_idx) as
i64).to_le_bytes());
- }
- }
- }
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- let arr = array
- .as_any()
- .downcast_ref::<TimestampMicrosecondArray>()
- .ok_or_else(|| {
- CometError::Internal(
- "Failed to downcast to
TimestampMicrosecondArray".to_string(),
- )
- })?;
- for row_idx in 0..num_rows {
- if !arr.is_null(row_idx) {
- let offset = row_idx * row_size + field_offset_in_row;
- self.buffer[offset..offset + 8]
-
.copy_from_slice(&arr.value(row_idx).to_le_bytes());
- }
- }
- }
+ DataType::Int8 => write_fixed_column_primitive!(
+ self,
+ array,
+ row_size,
+ field_offset_in_row,
+ num_rows,
+ Int8Array,
+ |v: i8| v as i64
+ ),
+ DataType::Int16 => write_fixed_column_primitive!(
+ self,
+ array,
+ row_size,
+ field_offset_in_row,
+ num_rows,
+ Int16Array,
+ |v: i16| v as i64
+ ),
+ DataType::Int32 => write_fixed_column_primitive!(
+ self,
+ array,
+ row_size,
+ field_offset_in_row,
+ num_rows,
+ Int32Array,
+ |v: i32| v as i64
+ ),
+ DataType::Int64 => write_fixed_column_primitive!(
+ self,
+ array,
+ row_size,
+ field_offset_in_row,
+ num_rows,
+ Int64Array,
+ |v: i64| v
+ ),
+ DataType::Float32 => write_fixed_column_primitive!(
+ self,
+ array,
+ row_size,
+ field_offset_in_row,
+ num_rows,
+ Float32Array,
+ |v: f32| v.to_bits() as i64
+ ),
+ DataType::Float64 => write_fixed_column_primitive!(
+ self,
+ array,
+ row_size,
+ field_offset_in_row,
+ num_rows,
+ Float64Array,
+ |v: f64| v.to_bits() as i64
+ ),
+ DataType::Date32 => write_fixed_column_primitive!(
+ self,
+ array,
+ row_size,
+ field_offset_in_row,
+ num_rows,
+ Date32Array,
+ |v: i32| v as i64
+ ),
+ DataType::Timestamp(TimeUnit::Microsecond, _) =>
write_fixed_column_primitive!(
+ self,
+ array,
+ row_size,
+ field_offset_in_row,
+ num_rows,
+ TimestampMicrosecondArray,
+ |v: i64| v
+ ),
DataType::Decimal128(precision, _) if *precision <=
MAX_LONG_DIGITS => {
- let arr = array
- .as_any()
- .downcast_ref::<Decimal128Array>()
- .ok_or_else(|| {
- CometError::Internal("Failed to downcast to
Decimal128Array".to_string())
- })?;
- for row_idx in 0..num_rows {
- if !arr.is_null(row_idx) {
- let offset = row_idx * row_size + field_offset_in_row;
- self.buffer[offset..offset + 8]
- .copy_from_slice(&(arr.value(row_idx) as
i64).to_le_bytes());
- }
- }
- }
- _ => {
- return Err(CometError::Internal(format!(
- "Unexpected non-fixed-width type in fast path: {:?}",
- data_type
- )));
+ write_fixed_column_primitive!(
+ self,
+ array,
+ row_size,
+ field_offset_in_row,
+ num_rows,
+ Decimal128Array,
+ |v: i128| v as i64
+ )
}
+ _ => Err(CometError::Internal(format!(
+ "Unexpected non-fixed-width type in fast path: {:?}",
+ data_type
+ ))),
}
-
- Ok(())
}
/// Writes a complete row using pre-downcast TypedArrays.
@@ -1386,112 +1330,31 @@ fn get_field_value(data_type: &DataType, array:
&ArrayRef, row_idx: usize) -> Co
match actual_type {
DataType::Boolean => {
- let arr = array
- .as_any()
- .downcast_ref::<BooleanArray>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to BooleanArray for type {:?}",
- actual_type
- ))
- })?;
+ let arr = downcast_array!(array, BooleanArray)?;
Ok(if arr.value(row_idx) { 1i64 } else { 0i64 })
}
- DataType::Int8 => {
- let arr = array.as_any().downcast_ref::<Int8Array>().ok_or_else(||
{
- CometError::Internal(format!(
- "Failed to downcast to Int8Array for type {:?}",
- actual_type
- ))
- })?;
- Ok(arr.value(row_idx) as i64)
- }
+ DataType::Int8 => get_field_value_primitive!(array, row_idx,
Int8Array, |v: i8| v as i64),
DataType::Int16 => {
- let arr =
array.as_any().downcast_ref::<Int16Array>().ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to Int16Array for type {:?}",
- actual_type
- ))
- })?;
- Ok(arr.value(row_idx) as i64)
+ get_field_value_primitive!(array, row_idx, Int16Array, |v: i16| v
as i64)
}
DataType::Int32 => {
- let arr =
array.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to Int32Array for type {:?}",
- actual_type
- ))
- })?;
- Ok(arr.value(row_idx) as i64)
- }
- DataType::Int64 => {
- let arr =
array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to Int64Array for type {:?}",
- actual_type
- ))
- })?;
- Ok(arr.value(row_idx))
+ get_field_value_primitive!(array, row_idx, Int32Array, |v: i32| v
as i64)
}
+ DataType::Int64 => get_field_value_primitive!(array, row_idx,
Int64Array, |v: i64| v),
DataType::Float32 => {
- let arr = array
- .as_any()
- .downcast_ref::<Float32Array>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to Float32Array for type {:?}",
- actual_type
- ))
- })?;
- Ok(arr.value(row_idx).to_bits() as i64)
+ get_field_value_primitive!(array, row_idx, Float32Array, |v: f32|
v.to_bits() as i64)
}
DataType::Float64 => {
- let arr = array
- .as_any()
- .downcast_ref::<Float64Array>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to Float64Array for type {:?}",
- actual_type
- ))
- })?;
- Ok(arr.value(row_idx).to_bits() as i64)
+ get_field_value_primitive!(array, row_idx, Float64Array, |v: f64|
v.to_bits() as i64)
}
DataType::Date32 => {
- let arr = array
- .as_any()
- .downcast_ref::<Date32Array>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to Date32Array for type {:?}",
- actual_type
- ))
- })?;
- Ok(arr.value(row_idx) as i64)
+ get_field_value_primitive!(array, row_idx, Date32Array, |v: i32| v
as i64)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
- let arr = array
- .as_any()
- .downcast_ref::<TimestampMicrosecondArray>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to TimestampMicrosecondArray for
type {:?}",
- actual_type
- ))
- })?;
- Ok(arr.value(row_idx))
+ get_field_value_primitive!(array, row_idx,
TimestampMicrosecondArray, |v: i64| v)
}
DataType::Decimal128(precision, _) if *precision <= MAX_LONG_DIGITS =>
{
- let arr = array
- .as_any()
- .downcast_ref::<Decimal128Array>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to Decimal128Array for type {:?}",
- actual_type
- ))
- })?;
- Ok(arr.value(row_idx) as i64)
+ get_field_value_primitive!(array, row_idx, Decimal128Array, |v:
i128| v as i64)
}
// Variable-length types use placeholder (will be overwritten by
get_variable_length_data)
DataType::Utf8
@@ -1605,72 +1468,26 @@ fn write_dictionary_to_buffer_with_key<K:
ArrowDictionaryKeyType>(
match value_type {
DataType::Utf8 => {
- let string_values = values
- .as_any()
- .downcast_ref::<StringArray>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast dictionary values to StringArray,
actual type: {:?}",
- values.data_type()
- ))
- })?;
- let bytes = string_values.value(key_idx).as_bytes();
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ let string_values = downcast_array!(values, StringArray)?;
+ Ok(write_bytes_padded(
+ buffer,
+ string_values.value(key_idx).as_bytes(),
+ ))
}
DataType::LargeUtf8 => {
- let string_values = values
- .as_any()
- .downcast_ref::<LargeStringArray>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast dictionary values to
LargeStringArray, actual type: {:?}",
- values.data_type()
- ))
- })?;
- let bytes = string_values.value(key_idx).as_bytes();
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ let string_values = downcast_array!(values, LargeStringArray)?;
+ Ok(write_bytes_padded(
+ buffer,
+ string_values.value(key_idx).as_bytes(),
+ ))
}
DataType::Binary => {
- let binary_values = values
- .as_any()
- .downcast_ref::<BinaryArray>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast dictionary values to BinaryArray,
actual type: {:?}",
- values.data_type()
- ))
- })?;
- let bytes = binary_values.value(key_idx);
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ let binary_values = downcast_array!(values, BinaryArray)?;
+ Ok(write_bytes_padded(buffer, binary_values.value(key_idx)))
}
DataType::LargeBinary => {
- let binary_values = values
- .as_any()
- .downcast_ref::<LargeBinaryArray>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast dictionary values to
LargeBinaryArray, actual type: {:?}",
- values.data_type()
- ))
- })?;
- let bytes = binary_values.value(key_idx);
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ let binary_values = downcast_array!(values, LargeBinaryArray)?;
+ Ok(write_bytes_padded(buffer, binary_values.value(key_idx)))
}
_ => Err(CometError::Internal(format!(
"Unsupported dictionary value type for direct buffer write: {:?}",
@@ -1781,7 +1598,7 @@ fn write_struct_to_buffer_typed(
buffer[field_offset..field_offset +
8].copy_from_slice(&value.to_le_bytes());
} else {
// Variable-length field - use pre-downcast writer
- let var_len = typed_field.write_variable_value(buffer,
row_idx, struct_start)?;
+ let var_len = typed_field.write_variable_value(buffer,
row_idx)?;
if var_len > 0 {
let padded_len = round_up_to_8(var_len);
let data_offset = buffer.len() - padded_len - struct_start;
@@ -1829,51 +1646,35 @@ fn write_struct_to_buffer(
let field_offset = struct_start + nested_bitset_width + field_idx
* 8;
// Inline type dispatch for fixed-width types (most common case)
- let value = match data_type {
- DataType::Boolean => {
- let arr =
column.as_any().downcast_ref::<BooleanArray>().unwrap();
- Some(if arr.value(row_idx) { 1i64 } else { 0i64 })
- }
- DataType::Int8 => {
- let arr =
column.as_any().downcast_ref::<Int8Array>().unwrap();
- Some(arr.value(row_idx) as i64)
- }
- DataType::Int16 => {
- let arr =
column.as_any().downcast_ref::<Int16Array>().unwrap();
- Some(arr.value(row_idx) as i64)
- }
- DataType::Int32 => {
- let arr =
column.as_any().downcast_ref::<Int32Array>().unwrap();
- Some(arr.value(row_idx) as i64)
- }
- DataType::Int64 => {
- let arr =
column.as_any().downcast_ref::<Int64Array>().unwrap();
- Some(arr.value(row_idx))
- }
- DataType::Float32 => {
- let arr =
column.as_any().downcast_ref::<Float32Array>().unwrap();
- Some((arr.value(row_idx).to_bits() as i32) as i64)
- }
- DataType::Float64 => {
- let arr =
column.as_any().downcast_ref::<Float64Array>().unwrap();
- Some(arr.value(row_idx).to_bits() as i64)
- }
- DataType::Date32 => {
- let arr =
column.as_any().downcast_ref::<Date32Array>().unwrap();
- Some(arr.value(row_idx) as i64)
- }
- DataType::Timestamp(TimeUnit::Microsecond, _) => {
- let arr = column
- .as_any()
- .downcast_ref::<TimestampMicrosecondArray>()
- .unwrap();
- Some(arr.value(row_idx))
- }
- DataType::Decimal128(p, _) if *p <= MAX_LONG_DIGITS => {
- let arr =
column.as_any().downcast_ref::<Decimal128Array>().unwrap();
+ let value: Option<i64> = extract_fixed_value!(
+ column,
+ row_idx,
+ (DataType::Boolean, BooleanArray, |v: bool| if v {
+ 1i64
+ } else {
+ 0i64
+ }),
+ (DataType::Int8, Int8Array, |v: i8| v as i64),
+ (DataType::Int16, Int16Array, |v: i16| v as i64),
+ (DataType::Int32, Int32Array, |v: i32| v as i64),
+ (DataType::Int64, Int64Array, |v: i64| v),
+ (DataType::Float32, Float32Array, |v: f32| v.to_bits() as i64),
+ (DataType::Float64, Float64Array, |v: f64| v.to_bits() as i64),
+ (DataType::Date32, Date32Array, |v: i32| v as i64),
+ (
+ DataType::Timestamp(TimeUnit::Microsecond, _),
+ TimestampMicrosecondArray,
+ |v: i64| v
+ ),
+ );
+ // Handle Decimal128 with precision guard separately
+ let value: Option<i64> = match (value, data_type) {
+ (Some(v), _) => Some(v),
+ (None, DataType::Decimal128(p, _)) if *p <= MAX_LONG_DIGITS =>
{
+ let arr = downcast_array!(column, Decimal128Array)?;
Some(arr.value(row_idx) as i64)
}
- _ => None, // Variable-length type
+ _ => None,
};
if let Some(v) = value {
@@ -1881,13 +1682,7 @@ fn write_struct_to_buffer(
buffer[field_offset..field_offset +
8].copy_from_slice(&v.to_le_bytes());
} else {
// Variable-length field
- let var_len = write_nested_variable_to_buffer(
- buffer,
- data_type,
- column,
- row_idx,
- struct_start,
- )?;
+ let var_len = write_nested_variable_to_buffer(buffer,
data_type, column, row_idx)?;
if var_len > 0 {
let padded_len = round_up_to_8(var_len);
let data_offset = buffer.len() - padded_len - struct_start;
@@ -2016,136 +1811,45 @@ fn write_nested_variable_to_buffer(
data_type: &DataType,
array: &ArrayRef,
row_idx: usize,
- _base_offset: usize,
) -> CometResult<usize> {
let actual_type = array.data_type();
match actual_type {
DataType::Utf8 => {
- let arr = array
- .as_any()
- .downcast_ref::<StringArray>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to StringArray for type {:?}",
- actual_type
- ))
- })?;
- let bytes = arr.value(row_idx).as_bytes();
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ let arr = downcast_array!(array, StringArray)?;
+ Ok(write_bytes_padded(buffer, arr.value(row_idx).as_bytes()))
}
DataType::LargeUtf8 => {
- let arr = array
- .as_any()
- .downcast_ref::<LargeStringArray>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to LargeStringArray for type {:?}",
- actual_type
- ))
- })?;
- let bytes = arr.value(row_idx).as_bytes();
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ let arr = downcast_array!(array, LargeStringArray)?;
+ Ok(write_bytes_padded(buffer, arr.value(row_idx).as_bytes()))
}
DataType::Binary => {
- let arr = array
- .as_any()
- .downcast_ref::<BinaryArray>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to BinaryArray for type {:?}",
- actual_type
- ))
- })?;
- let bytes = arr.value(row_idx);
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ let arr = downcast_array!(array, BinaryArray)?;
+ Ok(write_bytes_padded(buffer, arr.value(row_idx)))
}
DataType::LargeBinary => {
- let arr = array
- .as_any()
- .downcast_ref::<LargeBinaryArray>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to LargeBinaryArray for type {:?}",
- actual_type
- ))
- })?;
- let bytes = arr.value(row_idx);
- let len = bytes.len();
- buffer.extend_from_slice(bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ let arr = downcast_array!(array, LargeBinaryArray)?;
+ Ok(write_bytes_padded(buffer, arr.value(row_idx)))
}
DataType::Decimal128(precision, _) if *precision > MAX_LONG_DIGITS => {
- let arr = array
- .as_any()
- .downcast_ref::<Decimal128Array>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to Decimal128Array for type {:?}",
- actual_type
- ))
- })?;
+ let arr = downcast_array!(array, Decimal128Array)?;
let bytes = i128_to_spark_decimal_bytes(arr.value(row_idx));
- let len = bytes.len();
- buffer.extend_from_slice(&bytes);
- let padding = round_up_to_8(len) - len;
- buffer.extend(std::iter::repeat_n(0u8, padding));
- Ok(len)
+ Ok(write_bytes_padded(buffer, &bytes))
}
DataType::Struct(fields) => {
- let struct_array = array
- .as_any()
- .downcast_ref::<StructArray>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to StructArray for type {:?}",
- actual_type
- ))
- })?;
+ let struct_array = downcast_array!(array, StructArray)?;
write_struct_to_buffer(buffer, struct_array, row_idx, fields)
}
DataType::List(field) => {
- let list_array =
array.as_any().downcast_ref::<ListArray>().ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to ListArray for type {:?}",
- actual_type
- ))
- })?;
+ let list_array = downcast_array!(array, ListArray)?;
write_list_to_buffer(buffer, list_array, row_idx, field)
}
DataType::LargeList(field) => {
- let list_array = array
- .as_any()
- .downcast_ref::<LargeListArray>()
- .ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to LargeListArray for type {:?}",
- actual_type
- ))
- })?;
+ let list_array = downcast_array!(array, LargeListArray)?;
write_large_list_to_buffer(buffer, list_array, row_idx, field)
}
DataType::Map(field, _) => {
- let map_array =
array.as_any().downcast_ref::<MapArray>().ok_or_else(|| {
- CometError::Internal(format!(
- "Failed to downcast to MapArray for type {:?}",
- actual_type
- ))
- })?;
+ let map_array = downcast_array!(array, MapArray)?;
write_map_to_buffer(buffer, map_array, row_idx, field)
}
DataType::Dictionary(key_type, value_type) => {
@@ -2748,4 +2452,163 @@ mod tests {
assert_eq!(value, i as i32, "element {} should be {}", i, i);
}
}
+
+ #[test]
+ fn test_convert_fixed_size_binary_array() {
+ // FixedSizeBinary(3) - each value is exactly 3 bytes
+ let schema = vec![DataType::FixedSizeBinary(3)];
+ let mut ctx = ColumnarToRowContext::new(schema, 100);
+
+ let array: ArrayRef = Arc::new(FixedSizeBinaryArray::from(vec![
+ Some(&[1u8, 2, 3][..]),
+ Some(&[4u8, 5, 6][..]),
+ None, // Test null handling
+ ]));
+ let arrays = vec![array];
+
+ let (ptr, offsets, lengths) = ctx.convert(&arrays, 3).unwrap();
+
+ assert!(!ptr.is_null());
+ assert_eq!(offsets.len(), 3);
+ assert_eq!(lengths.len(), 3);
+
+ // Row 0: 8 (bitset) + 8 (field slot) + 8 (aligned 3-byte data) = 24
+ // Row 1: 8 (bitset) + 8 (field slot) + 8 (aligned 3-byte data) = 24
+ // Row 2: 8 (bitset) + 8 (field slot) = 16 (null, no variable data)
+ assert_eq!(lengths[0], 24);
+ assert_eq!(lengths[1], 24);
+ assert_eq!(lengths[2], 16);
+
+ // Verify the data is correct for non-null rows
+ unsafe {
+ let row0 =
+ std::slice::from_raw_parts(ptr.add(offsets[0] as usize),
lengths[0] as usize);
+ // Variable data starts at offset 16 (8 bitset + 8 field slot)
+ assert_eq!(&row0[16..19], &[1u8, 2, 3]);
+
+ let row1 =
+ std::slice::from_raw_parts(ptr.add(offsets[1] as usize),
lengths[1] as usize);
+ assert_eq!(&row1[16..19], &[4u8, 5, 6]);
+ }
+ }
+
+ #[test]
+ fn test_convert_dictionary_decimal_array() {
+ // Test that dictionary-encoded decimals are correctly unpacked and
converted
+ // This tests the fix for casting to schema_type instead of value_type
+ use arrow::datatypes::Int8Type;
+
+ // Create a dictionary array with Decimal128 values
+ // Values: [-0.01, -0.02, -0.03] represented as [-1, -2, -3] with
scale 2
+ let values = Decimal128Array::from(vec![-1i128, -2, -3])
+ .with_precision_and_scale(5, 2)
+ .unwrap();
+
+ // Keys: [0, 1, 2, 0, 1, 2] - each value appears twice
+ let keys = Int8Array::from(vec![0i8, 1, 2, 0, 1, 2]);
+
+ let dict_array: ArrayRef =
+ Arc::new(DictionaryArray::<Int8Type>::try_new(keys,
Arc::new(values)).unwrap());
+
+ // Schema expects Decimal128(5, 2) - not a dictionary type
+ let schema = vec![DataType::Decimal128(5, 2)];
+ let mut ctx = ColumnarToRowContext::new(schema, 100);
+
+ let arrays = vec![dict_array];
+ let (ptr, offsets, lengths) = ctx.convert(&arrays, 6).unwrap();
+
+ assert!(!ptr.is_null());
+ assert_eq!(offsets.len(), 6);
+ assert_eq!(lengths.len(), 6);
+
+ // Verify the decimal values are correct (not doubled or otherwise
corrupted)
+ // Fixed-width decimal is stored directly in the 8-byte field slot
+ unsafe {
+ for (i, expected) in [-1i64, -2, -3, -1, -2,
-3].iter().enumerate() {
+ let row =
+ std::slice::from_raw_parts(ptr.add(offsets[i] as usize),
lengths[i] as usize);
+ // Field value starts at offset 8 (after null bitset)
+ let value = i64::from_le_bytes(row[8..16].try_into().unwrap());
+ assert_eq!(
+ value, *expected,
+ "Row {} should have value {}, got {}",
+ i, expected, value
+ );
+ }
+ }
+ }
+
+ #[test]
+ fn test_convert_int32_to_decimal128() {
+ // Test that Int32 arrays are correctly cast to Decimal128 when schema
expects Decimal128.
+ // This can happen when COMET_USE_DECIMAL_128 is false and the parquet
reader produces
+ // Int32 for small-precision decimals.
+
+ // Create an Int32 array representing decimals: [-1, -2, -3] which at
scale 2 means
+ // [-0.01, -0.02, -0.03]
+ let int_array: ArrayRef = Arc::new(Int32Array::from(vec![-1i32, -2,
-3]));
+
+ // Schema expects Decimal128(5, 2)
+ let schema = vec![DataType::Decimal128(5, 2)];
+ let mut ctx = ColumnarToRowContext::new(schema, 100);
+
+ let arrays = vec![int_array];
+ let (ptr, offsets, lengths) = ctx.convert(&arrays, 3).unwrap();
+
+ assert!(!ptr.is_null());
+ assert_eq!(offsets.len(), 3);
+ assert_eq!(lengths.len(), 3);
+
+ // Verify the decimal values are correct after casting
+ // Fixed-width decimal is stored directly in the 8-byte field slot
+ unsafe {
+ for (i, expected) in [-1i64, -2, -3].iter().enumerate() {
+ let row =
+ std::slice::from_raw_parts(ptr.add(offsets[i] as usize),
lengths[i] as usize);
+ // Field value starts at offset 8 (after null bitset)
+ let value = i64::from_le_bytes(row[8..16].try_into().unwrap());
+ assert_eq!(
+ value, *expected,
+ "Row {} should have value {}, got {}",
+ i, expected, value
+ );
+ }
+ }
+ }
+
+ #[test]
+ fn test_convert_int64_to_decimal128() {
+ // Test that Int64 arrays are correctly cast to Decimal128 when schema
expects Decimal128.
+ // This can happen when COMET_USE_DECIMAL_128 is false and the parquet
reader produces
+ // Int64 for medium-precision decimals.
+
+ // Create an Int64 array representing decimals
+ let int_array: ArrayRef = Arc::new(Int64Array::from(vec![-100i64,
-200, -300]));
+
+ // Schema expects Decimal128(10, 2)
+ let schema = vec![DataType::Decimal128(10, 2)];
+ let mut ctx = ColumnarToRowContext::new(schema, 100);
+
+ let arrays = vec![int_array];
+ let (ptr, offsets, lengths) = ctx.convert(&arrays, 3).unwrap();
+
+ assert!(!ptr.is_null());
+ assert_eq!(offsets.len(), 3);
+ assert_eq!(lengths.len(), 3);
+
+ // Verify the decimal values are correct after casting
+ unsafe {
+ for (i, expected) in [-100i64, -200, -300].iter().enumerate() {
+ let row =
+ std::slice::from_raw_parts(ptr.add(offsets[i] as usize),
lengths[i] as usize);
+ // Field value starts at offset 8 (after null bitset)
+ let value = i64::from_le_bytes(row[8..16].try_into().unwrap());
+ assert_eq!(
+ value, *expected,
+ "Row {} should have value {}, got {}",
+ i, expected, value
+ );
+ }
+ }
+ }
}
diff --git
a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala
b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala
index 7402a8324..d1c3b0767 100644
---
a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala
+++
b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala
@@ -22,13 +22,14 @@ package org.apache.comet.rules
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.sideBySide
-import org.apache.spark.sql.comet.{CometCollectLimitExec,
CometColumnarToRowExec, CometNativeColumnarToRowExec, CometNativeWriteExec,
CometPlan, CometSparkToColumnarExec}
+import org.apache.spark.sql.comet.{CometBatchScanExec, CometCollectLimitExec,
CometColumnarToRowExec, CometNativeColumnarToRowExec, CometNativeWriteExec,
CometPlan, CometScanExec, CometSparkToColumnarExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec,
SparkPlan}
import org.apache.spark.sql.execution.adaptive.QueryStageExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.comet.CometConf
+import org.apache.comet.parquet.CometParquetScan
// This rule is responsible for eliminating redundant transitions between
row-based and
// columnar-based operators for Comet. Currently, three potential redundant
transitions are:
@@ -139,7 +140,8 @@ case class EliminateRedundantTransitions(session:
SparkSession) extends Rule[Spa
private def createColumnarToRowExec(child: SparkPlan): SparkPlan = {
val schema = child.schema
val useNative = CometConf.COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED.get() &&
- CometNativeColumnarToRowExec.supportsSchema(schema)
+ CometNativeColumnarToRowExec.supportsSchema(schema) &&
+ !hasScanUsingMutableBuffers(child)
if (useNative) {
CometNativeColumnarToRowExec(child)
@@ -147,4 +149,30 @@ case class EliminateRedundantTransitions(session:
SparkSession) extends Rule[Spa
CometColumnarToRowExec(child)
}
}
+
+ /**
+ * Checks if the plan contains a scan that uses mutable buffers. Native C2R
is not compatible
+ * with such scans because the buffers may be modified after C2R reads them.
+ *
+ * This includes:
+ * - CometScanExec with native_comet scan implementation (V1 path) - uses
BatchReader
+ * - CometScanExec with native_iceberg_compat and partition columns - uses
+ * ConstantColumnReader
+ * - CometBatchScanExec with CometParquetScan (V2 Parquet path) - uses
BatchReader
+ */
+ private def hasScanUsingMutableBuffers(op: SparkPlan): Boolean = {
+ op match {
+ case c: QueryStageExec => hasScanUsingMutableBuffers(c.plan)
+ case c: ReusedExchangeExec => hasScanUsingMutableBuffers(c.child)
+ case _ =>
+ op.exists {
+ case scan: CometScanExec =>
+ scan.scanImpl == CometConf.SCAN_NATIVE_COMET ||
+ (scan.scanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT &&
+ scan.relation.partitionSchema.nonEmpty)
+ case scan: CometBatchScanExec =>
scan.scan.isInstanceOf[CometParquetScan]
+ case _ => false
+ }
+ }
+ }
}
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala
index 93526573c..a520098ed 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeColumnarToRowExec.scala
@@ -19,15 +19,25 @@
package org.apache.spark.sql.comet
-import org.apache.spark.TaskContext
+import java.util.UUID
+import java.util.concurrent.{Future, TimeoutException, TimeUnit}
+
+import scala.concurrent.Promise
+import scala.util.control.NonFatal
+
+import org.apache.spark.{broadcast, SparkException, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
+import org.apache.spark.sql.comet.util.{Utils => CometUtils}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan,
SQLExecution}
+import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
+import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SparkFatalException, Utils}
import org.apache.comet.{CometConf, NativeColumnarToRowConverter}
@@ -64,6 +74,116 @@ case class CometNativeColumnarToRowExec(child: SparkPlan)
"numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of
input batches"),
"convertTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time in
conversion"))
+ @transient
+ private lazy val promise = Promise[broadcast.Broadcast[Any]]()
+
+ @transient
+ private val timeout: Long = conf.broadcastTimeout
+
+ private val runId: UUID = UUID.randomUUID
+
+ private lazy val cometBroadcastExchange = findCometBroadcastExchange(child)
+
+ @transient
+ lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
+ SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
+ session,
+ CometBroadcastExchangeExec.executionContext) {
+ try {
+ // Setup a job group here so later it may get cancelled by groupId if
necessary.
+ sparkContext.setJobGroup(
+ runId.toString,
+ s"CometNativeColumnarToRow broadcast exchange (runId $runId)",
+ interruptOnCancel = true)
+
+ val numOutputRows = longMetric("numOutputRows")
+ val numInputBatches = longMetric("numInputBatches")
+ val localSchema = this.schema
+ val batchSize = CometConf.COMET_BATCH_SIZE.get()
+ val broadcastColumnar = child.executeBroadcast()
+ val serializedBatches =
+
broadcastColumnar.value.asInstanceOf[Array[org.apache.spark.util.io.ChunkedByteBuffer]]
+
+ // Use native converter to convert columnar data to rows
+ val converter = new NativeColumnarToRowConverter(localSchema,
batchSize)
+ try {
+ val rows = serializedBatches.iterator
+ .flatMap(CometUtils.decodeBatches(_, this.getClass.getSimpleName))
+ .flatMap { batch =>
+ numInputBatches += 1
+ numOutputRows += batch.numRows()
+ val result = converter.convert(batch)
+ // Wrap iterator to close batch after consumption
+ new Iterator[InternalRow] {
+ override def hasNext: Boolean = {
+ val hasMore = result.hasNext
+ if (!hasMore) {
+ batch.close()
+ }
+ hasMore
+ }
+ override def next(): InternalRow = result.next()
+ }
+ }
+
+ val mode = cometBroadcastExchange.get.mode
+ val relation = mode.transform(rows, Some(numOutputRows.value))
+ val broadcasted = sparkContext.broadcastInternal(relation,
serializedOnly = true)
+ val executionId =
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ SQLMetrics.postDriverMetricUpdates(sparkContext, executionId,
metrics.values.toSeq)
+ promise.trySuccess(broadcasted)
+ broadcasted
+ } finally {
+ converter.close()
+ }
+ } catch {
+ // SPARK-24294: To bypass scala bug:
https://github.com/scala/bug/issues/9554, we throw
+ // SparkFatalException, which is a subclass of Exception.
ThreadUtils.awaitResult
+ // will catch this exception and re-throw the wrapped fatal throwable.
+ case oe: OutOfMemoryError =>
+ val ex = new SparkFatalException(oe)
+ promise.tryFailure(ex)
+ throw ex
+ case e if !NonFatal(e) =>
+ val ex = new SparkFatalException(e)
+ promise.tryFailure(ex)
+ throw ex
+ case e: Throwable =>
+ promise.tryFailure(e)
+ throw e
+ }
+ }
+ }
+
+ override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+ if (cometBroadcastExchange.isEmpty) {
+ throw new SparkException(
+ "CometNativeColumnarToRowExec only supports doExecuteBroadcast when
child contains a " +
+ "CometBroadcastExchange, but got " + child)
+ }
+
+ try {
+ relationFuture.get(timeout,
TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
+ } catch {
+ case ex: TimeoutException =>
+ logError(s"Could not execute broadcast in $timeout secs.", ex)
+ if (!relationFuture.isDone) {
+ sparkContext.cancelJobGroup(runId.toString)
+ relationFuture.cancel(true)
+ }
+ throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout,
Some(ex))
+ }
+ }
+
+ private def findCometBroadcastExchange(op: SparkPlan):
Option[CometBroadcastExchangeExec] = {
+ op match {
+ case b: CometBroadcastExchangeExec => Some(b)
+ case b: BroadcastQueryStageExec => findCometBroadcastExchange(b.plan)
+ case b: ReusedExchangeExec => findCometBroadcastExchange(b.child)
+ case _ =>
op.children.collectFirst(Function.unlift(findCometBroadcastExchange))
+ }
+ }
+
override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val numInputBatches = longMetric("numInputBatches")
@@ -91,7 +211,17 @@ case class CometNativeColumnarToRowExec(child: SparkPlan)
val result = converter.convert(batch)
convertTime += System.nanoTime() - startTime
- result
+ // Wrap iterator to close batch after consumption
+ new Iterator[InternalRow] {
+ override def hasNext: Boolean = {
+ val hasMore = result.hasNext
+ if (!hasMore) {
+ batch.close()
+ }
+ hasMore
+ }
+ override def next(): InternalRow = result.next()
+ }
}
}
}
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index e0a5c43ae..fe5ea77a8 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -30,8 +30,8 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime,
Literal, TruncDate, TruncTimestamp}
import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps
-import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec}
-import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, SparkPlan,
WholeStageCodegenExec}
+import org.apache.spark.sql.comet.{CometNativeColumnarToRowExec,
CometProjectExec}
+import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -1020,11 +1020,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
val query = sql(s"select cast(id as string) from $table")
val (_, cometPlan) = checkSparkAnswerAndOperator(query)
val project = cometPlan
- .asInstanceOf[WholeStageCodegenExec]
- .child
- .asInstanceOf[CometColumnarToRowExec]
- .child
- .asInstanceOf[InputAdapter]
+ .asInstanceOf[CometNativeColumnarToRowExec]
.child
.asInstanceOf[CometProjectExec]
val id = project.expressions.head
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 1b2373ad7..696a12d4a 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression,
ExpressionInfo, He
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode,
BloomFilterAggregate}
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometShuffleExchangeExec}
-import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec,
SQLExecution, UnionExec}
+import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec,
SparkPlan, SQLExecution, UnionExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
BroadcastQueryStageExec}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
@@ -864,9 +864,11 @@ class CometExecSuite extends CometTestBase {
checkSparkAnswerAndOperator(df)
// Before AQE: one CometBroadcastExchange, no CometColumnarToRow
- var columnarToRowExec =
stripAQEPlan(df.queryExecution.executedPlan).collect {
- case s: CometColumnarToRowExec => s
- }
+ var columnarToRowExec: Seq[SparkPlan] =
+ stripAQEPlan(df.queryExecution.executedPlan).collect {
+ case s: CometColumnarToRowExec => s
+ case s: CometNativeColumnarToRowExec => s
+ }
assert(columnarToRowExec.isEmpty)
// Disable CometExecRule after the initial plan is generated. The
CometSortMergeJoin and
@@ -880,14 +882,25 @@ class CometExecSuite extends CometTestBase {
// After AQE: CometBroadcastExchange has to be converted to rows
to conform to Spark
// BroadcastHashJoin.
val plan = stripAQEPlan(df.queryExecution.executedPlan)
- columnarToRowExec = plan.collect { case s: CometColumnarToRowExec
=>
- s
+ columnarToRowExec = plan.collect {
+ case s: CometColumnarToRowExec => s
+ case s: CometNativeColumnarToRowExec => s
}
assert(columnarToRowExec.length == 1)
- // This ColumnarToRowExec should be the immediate child of
BroadcastHashJoinExec
- val parent = plan.find(_.children.contains(columnarToRowExec.head))
- assert(parent.get.isInstanceOf[BroadcastHashJoinExec])
+ // This ColumnarToRowExec should be a descendant of
BroadcastHashJoinExec (possibly
+ // wrapped by InputAdapter for codegen).
+ val broadcastJoins = plan.collect { case b: BroadcastHashJoinExec
=> b }
+ assert(broadcastJoins.nonEmpty, s"Expected BroadcastHashJoinExec
in plan:\n$plan")
+ val hasC2RDescendant = broadcastJoins.exists { join =>
+ join.find {
+ case _: CometColumnarToRowExec | _:
CometNativeColumnarToRowExec => true
+ case _ => false
+ }.isDefined
+ }
+ assert(
+ hasC2RDescendant,
+ "BroadcastHashJoinExec should have a columnar-to-row descendant")
// There should be a CometBroadcastExchangeExec under
CometColumnarToRowExec
val broadcastQueryStage =
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 89249240c..8a2f8af5c 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -80,6 +80,7 @@ abstract class CometTestBase
conf.set(CometConf.COMET_ONHEAP_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true")
+ conf.set(CometConf.COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED.key, "true")
conf.set(CometConf.COMET_RESPECT_PARQUET_FILTER_PUSHDOWN.key, "true")
conf.set(CometConf.COMET_SPARK_TO_ARROW_ENABLED.key, "true")
conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true")
diff --git
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala
index 7caac7135..c8c4baff4 100644
--- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanChecker.scala
@@ -46,7 +46,7 @@ trait CometPlanChecker {
case _: CometNativeScanExec | _: CometScanExec | _: CometBatchScanExec |
_: CometIcebergNativeScanExec =>
case _: CometSinkPlaceHolder | _: CometScanWrapper =>
- case _: CometColumnarToRowExec =>
+ case _: CometColumnarToRowExec | _: CometNativeColumnarToRowExec =>
case _: CometSparkToColumnarExec =>
case _: CometExec | _: CometShuffleExchangeExec =>
case _: CometBroadcastExchangeExec =>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]