scovich commented on code in PR #8299: URL: https://github.com/apache/arrow-rs/pull/8299#discussion_r2334832862
########## parquet-variant-compute/src/arrow_to_variant.rs: ########## @@ -0,0 +1,2424 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +use crate::type_conversion::decimal_to_variant_decimal; +use arrow::array::{ + Array, AsArray, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::compute::kernels::cast; +use arrow::datatypes::{ + ArrowNativeType, ArrowPrimitiveType, ArrowTemporalType, ArrowTimestampType, Date32Type, + Date64Type, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + RunEndIndexType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow::temporal_conversions::{as_date, as_datetime, as_time}; +use arrow_schema::{ArrowError, DataType, TimeUnit}; +use chrono::{DateTime, TimeZone, Utc}; +use parquet_variant::{ + ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal16, VariantDecimal4, + VariantDecimal8, +}; + +// ============================================================================ +// Row-oriented builders for efficient Arrow-to-Variant conversion +// ============================================================================ + +/// Row builder for converting Arrow arrays to VariantArray row by row +pub(crate) enum ArrowToVariantRowBuilder<'a> { + Null(NullArrowToVariantBuilder), + Boolean(BooleanArrowToVariantBuilder<'a>), + PrimitiveInt8(PrimitiveArrowToVariantBuilder<'a, Int8Type>), + PrimitiveInt16(PrimitiveArrowToVariantBuilder<'a, Int16Type>), + PrimitiveInt32(PrimitiveArrowToVariantBuilder<'a, Int32Type>), + PrimitiveInt64(PrimitiveArrowToVariantBuilder<'a, Int64Type>), + PrimitiveUInt8(PrimitiveArrowToVariantBuilder<'a, UInt8Type>), + PrimitiveUInt16(PrimitiveArrowToVariantBuilder<'a, UInt16Type>), + PrimitiveUInt32(PrimitiveArrowToVariantBuilder<'a, UInt32Type>), + PrimitiveUInt64(PrimitiveArrowToVariantBuilder<'a, UInt64Type>), + PrimitiveFloat16(PrimitiveArrowToVariantBuilder<'a, Float16Type>), + PrimitiveFloat32(PrimitiveArrowToVariantBuilder<'a, Float32Type>), + PrimitiveFloat64(PrimitiveArrowToVariantBuilder<'a, Float64Type>), + Decimal32(Decimal32ArrowToVariantBuilder<'a>), + Decimal64(Decimal64ArrowToVariantBuilder<'a>), + Decimal128(Decimal128ArrowToVariantBuilder<'a>), + Decimal256(Decimal256ArrowToVariantBuilder<'a>), + TimestampSecond(TimestampArrowToVariantBuilder<'a, TimestampSecondType>), + TimestampMillisecond(TimestampArrowToVariantBuilder<'a, TimestampMillisecondType>), + TimestampMicrosecond(TimestampArrowToVariantBuilder<'a, TimestampMicrosecondType>), + TimestampNanosecond(TimestampArrowToVariantBuilder<'a, TimestampNanosecondType>), + Date32(DateArrowToVariantBuilder<'a, Date32Type>), + Date64(DateArrowToVariantBuilder<'a, Date64Type>), + Time32Second(TimeArrowToVariantBuilder<'a, Time32SecondType>), + Time32Millisecond(TimeArrowToVariantBuilder<'a, Time32MillisecondType>), + Time64Microsecond(TimeArrowToVariantBuilder<'a, Time64MicrosecondType>), + Time64Nanosecond(TimeArrowToVariantBuilder<'a, Time64NanosecondType>), + Binary(BinaryArrowToVariantBuilder<'a, i32>), + LargeBinary(BinaryArrowToVariantBuilder<'a, i64>), + BinaryView(BinaryViewArrowToVariantBuilder<'a>), + FixedSizeBinary(FixedSizeBinaryArrowToVariantBuilder<'a>), + Utf8(StringArrowToVariantBuilder<'a, i32>), + LargeUtf8(StringArrowToVariantBuilder<'a, i64>), + Utf8View(StringViewArrowToVariantBuilder<'a>), + List(ListArrowToVariantBuilder<'a, i32>), + LargeList(ListArrowToVariantBuilder<'a, i64>), + Struct(StructArrowToVariantBuilder<'a>), + Map(MapArrowToVariantBuilder<'a>), + Union(UnionArrowToVariantBuilder<'a>), + Dictionary(DictionaryArrowToVariantBuilder<'a>), + RunEndEncodedInt16(RunEndEncodedArrowToVariantBuilder<'a, Int16Type>), + RunEndEncodedInt32(RunEndEncodedArrowToVariantBuilder<'a, Int32Type>), + RunEndEncodedInt64(RunEndEncodedArrowToVariantBuilder<'a, Int64Type>), +} + +impl<'a> ArrowToVariantRowBuilder<'a> { + pub fn append_row( + &mut self, + index: usize, + builder: &mut impl VariantBuilderExt, + ) -> Result<(), ArrowError> { + match self { + ArrowToVariantRowBuilder::Null(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Boolean(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::PrimitiveInt8(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::PrimitiveInt16(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::PrimitiveInt32(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::PrimitiveInt64(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::PrimitiveUInt8(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::PrimitiveUInt16(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::PrimitiveUInt32(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::PrimitiveUInt64(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::PrimitiveFloat16(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::PrimitiveFloat32(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::PrimitiveFloat64(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Decimal32(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Decimal64(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Decimal128(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Decimal256(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::TimestampSecond(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::TimestampMillisecond(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::TimestampMicrosecond(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::TimestampNanosecond(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Date32(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Date64(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Time32Second(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Time32Millisecond(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Time64Microsecond(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Time64Nanosecond(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Binary(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::LargeBinary(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::BinaryView(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::FixedSizeBinary(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Utf8(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::LargeUtf8(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Utf8View(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::List(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::LargeList(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Struct(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Map(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Union(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::Dictionary(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::RunEndEncodedInt16(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::RunEndEncodedInt32(b) => b.append_row(index, builder), + ArrowToVariantRowBuilder::RunEndEncodedInt64(b) => b.append_row(index, builder), + } + } +} + +/// Factory function to create the appropriate row builder for a given DataType +pub(crate) fn make_arrow_to_variant_row_builder<'a>( + data_type: &'a DataType, + array: &'a dyn Array, +) -> Result<ArrowToVariantRowBuilder<'a>, ArrowError> { + let builder = match data_type { + DataType::Null => ArrowToVariantRowBuilder::Null(NullArrowToVariantBuilder), + DataType::Boolean => { + ArrowToVariantRowBuilder::Boolean(BooleanArrowToVariantBuilder::new(array)) + } + DataType::Int8 => { + ArrowToVariantRowBuilder::PrimitiveInt8(PrimitiveArrowToVariantBuilder::new(array)) + } + DataType::Int16 => { + ArrowToVariantRowBuilder::PrimitiveInt16(PrimitiveArrowToVariantBuilder::new(array)) + } + DataType::Int32 => { + ArrowToVariantRowBuilder::PrimitiveInt32(PrimitiveArrowToVariantBuilder::new(array)) + } + DataType::Int64 => { + ArrowToVariantRowBuilder::PrimitiveInt64(PrimitiveArrowToVariantBuilder::new(array)) + } + DataType::UInt8 => { + ArrowToVariantRowBuilder::PrimitiveUInt8(PrimitiveArrowToVariantBuilder::new(array)) + } + DataType::UInt16 => { + ArrowToVariantRowBuilder::PrimitiveUInt16(PrimitiveArrowToVariantBuilder::new(array)) + } + DataType::UInt32 => { + ArrowToVariantRowBuilder::PrimitiveUInt32(PrimitiveArrowToVariantBuilder::new(array)) + } + DataType::UInt64 => { + ArrowToVariantRowBuilder::PrimitiveUInt64(PrimitiveArrowToVariantBuilder::new(array)) + } + DataType::Float16 => { + ArrowToVariantRowBuilder::PrimitiveFloat16(PrimitiveArrowToVariantBuilder::new(array)) + } + DataType::Float32 => { + ArrowToVariantRowBuilder::PrimitiveFloat32(PrimitiveArrowToVariantBuilder::new(array)) + } + DataType::Float64 => { + ArrowToVariantRowBuilder::PrimitiveFloat64(PrimitiveArrowToVariantBuilder::new(array)) + } + DataType::Decimal32(_, scale) => { + ArrowToVariantRowBuilder::Decimal32(Decimal32ArrowToVariantBuilder::new(array, *scale)) + } + DataType::Decimal64(_, scale) => { + ArrowToVariantRowBuilder::Decimal64(Decimal64ArrowToVariantBuilder::new(array, *scale)) + } + DataType::Decimal128(_, scale) => ArrowToVariantRowBuilder::Decimal128( + Decimal128ArrowToVariantBuilder::new(array, *scale), + ), + DataType::Decimal256(_, scale) => ArrowToVariantRowBuilder::Decimal256( + Decimal256ArrowToVariantBuilder::new(array, *scale), + ), + DataType::Timestamp(time_unit, time_zone) => match time_unit { + TimeUnit::Second => ArrowToVariantRowBuilder::TimestampSecond( + TimestampArrowToVariantBuilder::new(array, time_zone.is_some()), + ), + TimeUnit::Millisecond => ArrowToVariantRowBuilder::TimestampMillisecond( + TimestampArrowToVariantBuilder::new(array, time_zone.is_some()), + ), + TimeUnit::Microsecond => ArrowToVariantRowBuilder::TimestampMicrosecond( + TimestampArrowToVariantBuilder::new(array, time_zone.is_some()), + ), + TimeUnit::Nanosecond => ArrowToVariantRowBuilder::TimestampNanosecond( + TimestampArrowToVariantBuilder::new(array, time_zone.is_some()), + ), + }, + DataType::Date32 => ArrowToVariantRowBuilder::Date32(DateArrowToVariantBuilder::new(array)), + DataType::Date64 => ArrowToVariantRowBuilder::Date64(DateArrowToVariantBuilder::new(array)), + DataType::Time32(time_unit) => match time_unit { + TimeUnit::Second => { + ArrowToVariantRowBuilder::Time32Second(TimeArrowToVariantBuilder::new(array)) + } + TimeUnit::Millisecond => { + ArrowToVariantRowBuilder::Time32Millisecond(TimeArrowToVariantBuilder::new(array)) + } + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported Time32 unit: {time_unit:?}" + ))) + } + }, + DataType::Time64(time_unit) => match time_unit { + TimeUnit::Microsecond => { + ArrowToVariantRowBuilder::Time64Microsecond(TimeArrowToVariantBuilder::new(array)) + } + TimeUnit::Nanosecond => { + ArrowToVariantRowBuilder::Time64Nanosecond(TimeArrowToVariantBuilder::new(array)) + } + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported Time64 unit: {time_unit:?}" + ))) + } + }, + DataType::Duration(_) | DataType::Interval(_) => { + return Err(ArrowError::InvalidArgumentError( + "Casting duration/interval types to Variant is not supported. \ + The Variant format does not define duration/interval types." + .to_string(), + )) + } + DataType::Binary => { + ArrowToVariantRowBuilder::Binary(BinaryArrowToVariantBuilder::new(array)) + } + DataType::LargeBinary => { + ArrowToVariantRowBuilder::LargeBinary(BinaryArrowToVariantBuilder::new(array)) + } + DataType::BinaryView => { + ArrowToVariantRowBuilder::BinaryView(BinaryViewArrowToVariantBuilder::new(array)) + } + DataType::FixedSizeBinary(_) => ArrowToVariantRowBuilder::FixedSizeBinary( + FixedSizeBinaryArrowToVariantBuilder::new(array), + ), + DataType::Utf8 => ArrowToVariantRowBuilder::Utf8(StringArrowToVariantBuilder::new(array)), + DataType::LargeUtf8 => { + ArrowToVariantRowBuilder::LargeUtf8(StringArrowToVariantBuilder::new(array)) + } + DataType::Utf8View => { + ArrowToVariantRowBuilder::Utf8View(StringViewArrowToVariantBuilder::new(array)) + } + DataType::List(_) => ArrowToVariantRowBuilder::List(ListArrowToVariantBuilder::new(array)?), + DataType::LargeList(_) => { + ArrowToVariantRowBuilder::LargeList(ListArrowToVariantBuilder::new(array)?) + } + DataType::Struct(_) => { + ArrowToVariantRowBuilder::Struct(StructArrowToVariantBuilder::new(array.as_struct())?) + } + DataType::Map(_, _) => ArrowToVariantRowBuilder::Map(MapArrowToVariantBuilder::new(array)?), + DataType::Union(_, _) => { + ArrowToVariantRowBuilder::Union(UnionArrowToVariantBuilder::new(array)?) + } + DataType::Dictionary(_, _) => { + ArrowToVariantRowBuilder::Dictionary(DictionaryArrowToVariantBuilder::new(array)?) + } + DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() { + DataType::Int16 => ArrowToVariantRowBuilder::RunEndEncodedInt16( + RunEndEncodedArrowToVariantBuilder::new(array)?, + ), + DataType::Int32 => ArrowToVariantRowBuilder::RunEndEncodedInt32( + RunEndEncodedArrowToVariantBuilder::new(array)?, + ), + DataType::Int64 => ArrowToVariantRowBuilder::RunEndEncodedInt64( + RunEndEncodedArrowToVariantBuilder::new(array)?, + ), + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported run ends type: {:?}", + run_ends.data_type() + ))); + } + }, + dt => { + return Err(ArrowError::CastError(format!( + "Unsupported data type for casting to Variant: {dt:?}", + ))); + } + }; + Ok(builder) +} + +/// Macro to define (possibly generic) row builders with consistent structure and behavior. +/// Supports optional extra fields that are passed to the constructor. +macro_rules! define_row_builder { + ( + struct $name:ident<$lifetime:lifetime $(, $generic:ident: $($bound:path)+)?> + $(where $where_path:path: $where_bound:path)? + $({ $($field:ident: $field_type:ty),* $(,)? })?, + |$array_param:ident| -> $array_type:ty { $init_expr:expr }, + |$value:ident| $value_transform:expr + ) => { + pub(crate) struct $name<$lifetime $(, $generic: $($bound)+)?> + $(where $where_path: $where_bound)? + { + array: &$lifetime $array_type, + $($($field: $field_type,)*)? + } + + impl<$lifetime $(, $generic: $($bound)+)?> $name<$lifetime $(, $generic)?> + $(where $where_path: $where_bound)? + { + pub(crate) fn new($array_param: &$lifetime dyn Array $(, $($field: $field_type),*)?) -> Self { + Self { + array: $init_expr, + $($($field,)*)? + } + } + + fn append_row(&self, index: usize, builder: &mut impl VariantBuilderExt) -> Result<(), ArrowError> { + if self.array.is_null(index) { + builder.append_null(); + } else { + let $value = self.array.value(index); + // Capture fields as variables the transform can access (hygiene) + $($(let $field = &self.$field;)*)? + builder.append_value($value_transform); + } + Ok(()) + } + } + }; +} + +define_row_builder!( + struct BooleanArrowToVariantBuilder<'a>, + |array| -> arrow::array::BooleanArray { array.as_boolean() }, + |value| value +); + +define_row_builder!( + struct PrimitiveArrowToVariantBuilder<'a, T: ArrowPrimitiveType> + where T::Native: Into<Variant<'a, 'a>>, + |array| -> PrimitiveArray<T> { array.as_primitive() }, + |value| value +); + +define_row_builder!( + struct Decimal32ArrowToVariantBuilder<'a> { + scale: i8, + }, + |array| -> arrow::array::Decimal32Array { array.as_primitive() }, + |value| decimal_to_variant_decimal!(value, scale, i32, VariantDecimal4) +); + +define_row_builder!( + struct Decimal64ArrowToVariantBuilder<'a> { + scale: i8, + }, + |array| -> arrow::array::Decimal64Array { array.as_primitive() }, + |value| decimal_to_variant_decimal!(value, scale, i64, VariantDecimal8) +); + +define_row_builder!( + struct Decimal128ArrowToVariantBuilder<'a> { + scale: i8, + }, + |array| -> arrow::array::Decimal128Array { array.as_primitive() }, + |value| decimal_to_variant_decimal!(value, scale, i128, VariantDecimal16) +); + +define_row_builder!( + struct Decimal256ArrowToVariantBuilder<'a> { + scale: i8, + }, + |array| -> arrow::array::Decimal256Array { array.as_primitive() }, + |value| { + // Decimal256 needs special handling - convert to i128 if possible + match value.to_i128() { + Some(i128_val) => decimal_to_variant_decimal!(i128_val, scale, i128, VariantDecimal16), + None => Variant::Null, // Value too large for i128 + } + } +); + +define_row_builder!( + struct TimestampArrowToVariantBuilder<'a, T: ArrowTimestampType> { + has_time_zone: bool, + }, + |array| -> arrow::array::PrimitiveArray<T> { array.as_primitive() }, + |value| { + // Convert using Arrow's temporal conversion functions + let Some(naive_datetime) = as_datetime::<T>(value) else { + return Err(ArrowError::CastError( Review Comment: NOTE: Now that strict casting has been merged in, the error case is now handled in the builder itself -- not the value transform. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
