This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 77df2ee42d [Variant] add strict mode to cast_to_variant (#8233)
77df2ee42d is described below
commit 77df2ee42d8ca1d1557a64681b240b8409deef01
Author: Yan Tingwang <[email protected]>
AuthorDate: Tue Sep 9 22:31:38 2025 +0800
[Variant] add strict mode to cast_to_variant (#8233)
# Which issue does this PR close?
- Closes #8155 .
# Rationale for this change
cast_to_variant will panic for values of Date64 / Timestamp that can not
be converted to NaiveDate
# What changes are included in this PR?
1. add new api :
`pub fn cast_to_variant_with_options(input: &dyn Array, strict: bool) ->
Result<VariantArray, ArrowError>`
- strict = true: Returns errors on conversion failures (default
behavior)
- strict = false: Returns null values for failed conversions
2. add some tests to test non-strict mode.
3. refactor: eliminate duplication in timestamp conversion using macro
# Are these changes tested?
Yes.
# Are there any user-facing changes?
no.
---------
Signed-off-by: codephage2020 <[email protected]>
Co-authored-by: Ryan Johnson <[email protected]>
---
parquet-variant-compute/src/cast_to_variant.rs | 328 ++++++++++++++++---------
parquet-variant-compute/src/lib.rs | 3 +-
parquet-variant-compute/src/type_conversion.rs | 48 ++++
3 files changed, 264 insertions(+), 115 deletions(-)
diff --git a/parquet-variant-compute/src/cast_to_variant.rs
b/parquet-variant-compute/src/cast_to_variant.rs
index 412f207cfe..231d36f96e 100644
--- a/parquet-variant-compute/src/cast_to_variant.rs
+++ b/parquet-variant-compute/src/cast_to_variant.rs
@@ -20,7 +20,7 @@ use std::sync::Arc;
use crate::type_conversion::{
decimal_to_variant_decimal, generic_conversion_array,
non_generic_conversion_array,
- primitive_conversion_array,
+ primitive_conversion_array, timestamp_to_variant_timestamp,
};
use crate::{VariantArray, VariantArrayBuilder};
use arrow::array::{
@@ -46,6 +46,101 @@ use parquet_variant::{
Variant, VariantBuilder, VariantDecimal16, VariantDecimal4,
VariantDecimal8,
};
+/// Options for controlling the behavior of `cast_to_variant_with_options`.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct CastOptions {
+ /// If true, return error on conversion failure. If false, insert null for
failed conversions.
+ pub strict: bool,
+}
+
+impl Default for CastOptions {
+ fn default() -> Self {
+ Self { strict: true }
+ }
+}
+
+fn convert_timestamp_with_options(
+ time_unit: &TimeUnit,
+ time_zone: &Option<Arc<str>>,
+ input: &dyn Array,
+ builder: &mut VariantArrayBuilder,
+ options: &CastOptions,
+) -> Result<(), ArrowError> {
+ let native_datetimes: Vec<Option<NaiveDateTime>> = match time_unit {
+ arrow_schema::TimeUnit::Second => {
+ let ts_array = input
+ .as_any()
+ .downcast_ref::<TimestampSecondArray>()
+ .expect("Array is not TimestampSecondArray");
+ timestamp_to_variant_timestamp!(
+ ts_array,
+ timestamp_s_to_datetime,
+ "seconds",
+ options.strict
+ )
+ }
+ arrow_schema::TimeUnit::Millisecond => {
+ let ts_array = input
+ .as_any()
+ .downcast_ref::<TimestampMillisecondArray>()
+ .expect("Array is not TimestampMillisecondArray");
+ timestamp_to_variant_timestamp!(
+ ts_array,
+ timestamp_ms_to_datetime,
+ "milliseconds",
+ options.strict
+ )
+ }
+ arrow_schema::TimeUnit::Microsecond => {
+ let ts_array = input
+ .as_any()
+ .downcast_ref::<TimestampMicrosecondArray>()
+ .expect("Array is not TimestampMicrosecondArray");
+ timestamp_to_variant_timestamp!(
+ ts_array,
+ timestamp_us_to_datetime,
+ "microseconds",
+ options.strict
+ )
+ }
+ arrow_schema::TimeUnit::Nanosecond => {
+ let ts_array = input
+ .as_any()
+ .downcast_ref::<TimestampNanosecondArray>()
+ .expect("Array is not TimestampNanosecondArray");
+ timestamp_to_variant_timestamp!(
+ ts_array,
+ timestamp_ns_to_datetime,
+ "nanoseconds",
+ options.strict
+ )
+ }
+ };
+
+ for (i, x) in native_datetimes.iter().enumerate() {
+ match x {
+ Some(ndt) => {
+ if time_zone.is_none() {
+ builder.append_variant((*ndt).into());
+ } else {
+ let utc_dt: DateTime<Utc> = Utc.from_utc_datetime(ndt);
+ builder.append_variant(utc_dt.into());
+ }
+ }
+ None if options.strict && input.is_valid(i) => {
+ return Err(ArrowError::ComputeError(format!(
+ "Failed to convert timestamp at index {}: invalid
timestamp value",
+ i
+ )));
+ }
+ None => {
+ builder.append_null();
+ }
+ }
+ }
+ Ok(())
+}
+
/// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when
you
/// need to convert a specific data type
///
@@ -75,7 +170,14 @@ use parquet_variant::{
/// `1970-01-01T00:00:01.234567890Z`
/// will be truncated to
/// `1970-01-01T00:00:01.234567Z`
-pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
+///
+/// # Arguments
+/// * `input` - The array to convert to VariantArray
+/// * `options` - Options controlling conversion behavior
+pub fn cast_to_variant_with_options(
+ input: &dyn Array,
+ options: &CastOptions,
+) -> Result<VariantArray, ArrowError> {
let mut builder = VariantArrayBuilder::new(input.len());
let input_type = input.data_type();
@@ -167,25 +269,7 @@ pub fn cast_to_variant(input: &dyn Array) ->
Result<VariantArray, ArrowError> {
);
}
DataType::Timestamp(time_unit, time_zone) => {
- convert_timestamp(time_unit, time_zone, input, &mut builder);
- }
- DataType::Date32 => {
- generic_conversion_array!(
- Date32Type,
- as_primitive,
- |v: i32| -> NaiveDate { Date32Type::to_naive_date(v) },
- input,
- builder
- );
- }
- DataType::Date64 => {
- generic_conversion_array!(
- Date64Type,
- as_primitive,
- |v: i64| { Date64Type::to_naive_date_opt(v).unwrap() },
- input,
- builder
- );
+ convert_timestamp_with_options(time_unit, time_zone, input, &mut
builder, options)?;
}
DataType::Time32(unit) => {
match *unit {
@@ -194,10 +278,11 @@ pub fn cast_to_variant(input: &dyn Array) ->
Result<VariantArray, ArrowError> {
Time32SecondType,
as_primitive,
// nano second are always 0
- |v| NaiveTime::from_num_seconds_from_midnight_opt(v as
u32, 0u32).unwrap(),
+ |v| NaiveTime::from_num_seconds_from_midnight_opt(v as
u32, 0u32),
input,
- builder
- );
+ builder,
+ options.strict
+ )?;
}
TimeUnit::Millisecond => {
generic_conversion_array!(
@@ -206,11 +291,11 @@ pub fn cast_to_variant(input: &dyn Array) ->
Result<VariantArray, ArrowError> {
|v| NaiveTime::from_num_seconds_from_midnight_opt(
v as u32 / 1000,
(v as u32 % 1000) * 1_000_000
- )
- .unwrap(),
+ ),
input,
- builder
- );
+ builder,
+ options.strict
+ )?;
}
_ => {
return Err(ArrowError::CastError(format!(
@@ -229,11 +314,11 @@ pub fn cast_to_variant(input: &dyn Array) ->
Result<VariantArray, ArrowError> {
|v| NaiveTime::from_num_seconds_from_midnight_opt(
(v / 1_000_000) as u32,
(v % 1_000_000 * 1_000) as u32
- )
- .unwrap(),
+ ),
input,
- builder
- );
+ builder,
+ options.strict
+ )?;
}
TimeUnit::Nanosecond => {
generic_conversion_array!(
@@ -242,11 +327,11 @@ pub fn cast_to_variant(input: &dyn Array) ->
Result<VariantArray, ArrowError> {
|v| NaiveTime::from_num_seconds_from_midnight_opt(
(v / 1_000_000_000) as u32,
(v % 1_000_000_000) as u32
- )
- .unwrap(),
+ ),
input,
- builder
- );
+ builder,
+ options.strict
+ )?;
}
_ => {
return Err(ArrowError::CastError(format!(
@@ -284,6 +369,25 @@ pub fn cast_to_variant(input: &dyn Array) ->
Result<VariantArray, ArrowError> {
DataType::Utf8View => {
non_generic_conversion_array!(input.as_string_view(), |v| v,
builder);
}
+ DataType::Date32 => {
+ generic_conversion_array!(
+ Date32Type,
+ as_primitive,
+ |v: i32| -> NaiveDate { Date32Type::to_naive_date(v) },
+ input,
+ builder
+ );
+ }
+ DataType::Date64 => {
+ generic_conversion_array!(
+ Date64Type,
+ as_primitive,
+ |v: i64| Date64Type::to_naive_date_opt(v),
+ input,
+ builder,
+ options.strict
+ )?;
+ }
DataType::List(_) => convert_list::<i32>(input, &mut builder)?,
DataType::LargeList(_) => convert_list::<i64>(input, &mut builder)?,
DataType::Struct(_) => convert_struct(input, &mut builder)?,
@@ -310,79 +414,6 @@ pub fn cast_to_variant(input: &dyn Array) ->
Result<VariantArray, ArrowError> {
Ok(builder.build())
}
-// TODO do we need a cast_with_options to allow specifying conversion behavior,
-// e.g. how to handle overflows, whether to convert to Variant::Null or return
-// an error, etc. ?
-
-/// Convert timestamp arrays to native datetimes
-fn convert_timestamp(
- time_unit: &TimeUnit,
- time_zone: &Option<Arc<str>>,
- input: &dyn Array,
- builder: &mut VariantArrayBuilder,
-) {
- let native_datetimes: Vec<Option<NaiveDateTime>> = match time_unit {
- arrow_schema::TimeUnit::Second => {
- let ts_array = input
- .as_any()
- .downcast_ref::<TimestampSecondArray>()
- .expect("Array is not TimestampSecondArray");
-
- ts_array
- .iter()
- .map(|x| x.map(|y| timestamp_s_to_datetime(y).unwrap()))
- .collect()
- }
- arrow_schema::TimeUnit::Millisecond => {
- let ts_array = input
- .as_any()
- .downcast_ref::<TimestampMillisecondArray>()
- .expect("Array is not TimestampMillisecondArray");
-
- ts_array
- .iter()
- .map(|x| x.map(|y| timestamp_ms_to_datetime(y).unwrap()))
- .collect()
- }
- arrow_schema::TimeUnit::Microsecond => {
- let ts_array = input
- .as_any()
- .downcast_ref::<TimestampMicrosecondArray>()
- .expect("Array is not TimestampMicrosecondArray");
- ts_array
- .iter()
- .map(|x| x.map(|y| timestamp_us_to_datetime(y).unwrap()))
- .collect()
- }
- arrow_schema::TimeUnit::Nanosecond => {
- let ts_array = input
- .as_any()
- .downcast_ref::<TimestampNanosecondArray>()
- .expect("Array is not TimestampNanosecondArray");
- ts_array
- .iter()
- .map(|x| x.map(|y| timestamp_ns_to_datetime(y).unwrap()))
- .collect()
- }
- };
-
- for x in native_datetimes {
- match x {
- Some(ndt) => {
- if time_zone.is_none() {
- builder.append_variant(ndt.into());
- } else {
- let utc_dt: DateTime<Utc> = Utc.from_utc_datetime(&ndt);
- builder.append_variant(utc_dt.into());
- }
- }
- None => {
- builder.append_null();
- }
- }
- }
-}
-
/// Generic function to convert list arrays (both List and LargeList) to
variant arrays
fn convert_list<O: OffsetSizeTrait>(
input: &dyn Array,
@@ -525,6 +556,15 @@ fn convert_map(
Ok(())
}
+/// Convert an array to a `VariantArray` with strict mode enabled (returns
errors on conversion failures).
+///
+/// This function provides backward compatibility. For non-strict behavior,
+/// use `cast_to_variant_with_options` with `CastOptions { strict: false }`.
+pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
+ cast_to_variant_with_options(input, &CastOptions::default())
+}
+
+/// Convert union arrays
fn convert_union(
fields: &UnionFields,
input: &dyn Array,
@@ -634,8 +674,8 @@ mod tests {
IntervalDayTimeArray, IntervalMonthDayNanoArray,
IntervalYearMonthArray, LargeListArray,
LargeStringArray, ListArray, MapArray, NullArray, StringArray,
StringRunBuilder,
StringViewArray, StructArray, Time32MillisecondArray,
Time32SecondArray,
- Time64MicrosecondArray, Time64NanosecondArray, UInt16Array,
UInt32Array, UInt64Array,
- UInt8Array, UnionArray,
+ Time64MicrosecondArray, Time64NanosecondArray, TimestampSecondArray,
UInt16Array,
+ UInt32Array, UInt64Array, UInt8Array, UnionArray,
};
use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
@@ -2349,9 +2389,9 @@ mod tests {
/// Converts the given `Array` to a `VariantArray` and tests the conversion
/// against the expected values. It also tests the handling of nulls by
/// setting one element to null and verifying the output.
- fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) {
- // test without nulls
- let variant_array = cast_to_variant(&values).unwrap();
+ fn run_test_with_options(values: ArrayRef, expected: Vec<Option<Variant>>,
strict: bool) {
+ let options = CastOptions { strict };
+ let variant_array = cast_to_variant_with_options(&values,
&options).unwrap();
assert_eq!(variant_array.len(), expected.len());
for (i, expected_value) in expected.iter().enumerate() {
match expected_value {
@@ -2365,4 +2405,64 @@ mod tests {
}
}
}
+
+ fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) {
+ run_test_with_options(values, expected, true);
+ }
+
+ fn run_test_non_strict(values: ArrayRef, expected: Vec<Option<Variant>>) {
+ run_test_with_options(values, expected, false);
+ }
+
+ #[test]
+ fn test_cast_to_variant_non_strict_mode_date64() {
+ let date64_values = Date64Array::from(vec![Some(i64::MAX), Some(0),
Some(i64::MIN)]);
+
+ let values = Arc::new(date64_values);
+ run_test_non_strict(
+ values,
+ vec![
+ None,
+ Some(Variant::Date(Date64Type::to_naive_date_opt(0).unwrap())),
+ None,
+ ],
+ );
+ }
+
+ #[test]
+ fn test_cast_to_variant_non_strict_mode_time32() {
+ let time32_array = Time32SecondArray::from(vec![Some(90000),
Some(3600), Some(-1)]);
+
+ let values = Arc::new(time32_array);
+ run_test_non_strict(
+ values,
+ vec![
+ None,
+ Some(Variant::Time(
+ NaiveTime::from_num_seconds_from_midnight_opt(3600,
0).unwrap(),
+ )),
+ None,
+ ],
+ );
+ }
+
+ #[test]
+ fn test_cast_to_variant_non_strict_mode_timestamp() {
+ let ts_array = TimestampSecondArray::from(vec![Some(i64::MAX),
Some(0), Some(1609459200)])
+ .with_timezone_opt(None::<&str>);
+
+ let values = Arc::new(ts_array);
+ run_test_non_strict(
+ values,
+ vec![
+ None, // Invalid timestamp becomes null
+ Some(Variant::TimestampNtzMicros(
+ timestamp_s_to_datetime(0).unwrap(),
+ )),
+ Some(Variant::TimestampNtzMicros(
+ timestamp_s_to_datetime(1609459200).unwrap(),
+ )),
+ ],
+ );
+ }
}
diff --git a/parquet-variant-compute/src/lib.rs
b/parquet-variant-compute/src/lib.rs
index ef674d9614..3c928636ac 100644
--- a/parquet-variant-compute/src/lib.rs
+++ b/parquet-variant-compute/src/lib.rs
@@ -22,7 +22,7 @@
//! - [`VariantArrayBuilder`]: For building [`VariantArray`]
//! - [`json_to_variant`]: Function to convert a batch of JSON strings to a
`VariantArray`.
//! - [`variant_to_json`]: Function to convert a `VariantArray` to a batch of
JSON strings.
-//! - [`cast_to_variant`]: Module to cast other Arrow arrays to `VariantArray`.
+//! - [`mod@cast_to_variant`]: Module to cast other Arrow arrays to
`VariantArray`.
//! - [`variant_get`]: Module to get values from a `VariantArray` using a
specified [`VariantPath`]
//!
//! ## 🚧 Work In Progress
@@ -46,5 +46,6 @@ pub mod variant_get;
pub use variant_array::{ShreddingState, VariantArray};
pub use variant_array_builder::{VariantArrayBuilder,
VariantArrayVariantBuilder};
+pub use cast_to_variant::{cast_to_variant, cast_to_variant_with_options,
CastOptions};
pub use from_json::json_to_variant;
pub use to_json::variant_to_json;
diff --git a/parquet-variant-compute/src/type_conversion.rs
b/parquet-variant-compute/src/type_conversion.rs
index 647d2c705f..aa60b425a1 100644
--- a/parquet-variant-compute/src/type_conversion.rs
+++ b/parquet-variant-compute/src/type_conversion.rs
@@ -20,6 +20,7 @@
/// Convert the input array to a `VariantArray` row by row, using `method`
/// not requiring a generic type to downcast the generic array to a specific
/// array type and `cast_fn` to transform each element to a type compatible
with Variant
+/// If `strict` is true(default), return error on conversion failure. If
false, insert null.
macro_rules! non_generic_conversion_array {
($array:expr, $cast_fn:expr, $builder:expr) => {{
let array = $array;
@@ -32,6 +33,28 @@ macro_rules! non_generic_conversion_array {
$builder.append_variant(Variant::from(cast_value));
}
}};
+ ($array:expr, $cast_fn:expr, $builder:expr, $strict:expr) => {{
+ let array = $array;
+ for i in 0..array.len() {
+ if array.is_null(i) {
+ $builder.append_null();
+ continue;
+ }
+ match $cast_fn(array.value(i)) {
+ Some(cast_value) => {
+ $builder.append_variant(Variant::from(cast_value));
+ }
+ None if $strict => {
+ return Err(ArrowError::ComputeError(format!(
+ "Failed to convert value at index {}: conversion
failed",
+ i
+ )));
+ }
+ None => $builder.append_null(),
+ }
+ }
+ Ok::<(), ArrowError>(())
+ }};
}
pub(crate) use non_generic_conversion_array;
@@ -52,6 +75,7 @@ pub(crate) use non_generic_conversion_single_value;
/// Convert the input array to a `VariantArray` row by row, using `method`
/// requiring a generic type to downcast the generic array to a specific
/// array type and `cast_fn` to transform each element to a type compatible
with Variant
+/// If `strict` is true(default), return error on conversion failure. If
false, insert null.
macro_rules! generic_conversion_array {
($t:ty, $method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{
$crate::type_conversion::non_generic_conversion_array!(
@@ -60,6 +84,14 @@ macro_rules! generic_conversion_array {
$builder
)
}};
+ ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $builder:expr,
$strict:expr) => {{
+ $crate::type_conversion::non_generic_conversion_array!(
+ $input.$method::<$t>(),
+ $cast_fn,
+ $builder,
+ $strict
+ )
+ }};
}
pub(crate) use generic_conversion_array;
@@ -123,3 +155,19 @@ macro_rules! decimal_to_variant_decimal {
}};
}
pub(crate) use decimal_to_variant_decimal;
+
+/// Convert a timestamp value to a `VariantTimestamp`
+macro_rules! timestamp_to_variant_timestamp {
+ ($ts_array:expr, $converter:expr, $unit_name:expr, $strict:expr) => {
+ if $strict {
+ let error =
+ || ArrowError::ComputeError(format!("Invalid timestamp {}
value", $unit_name));
+ let converter = |x| $converter(x).ok_or_else(error);
+ let iter = $ts_array.iter().map(|x| x.map(converter).transpose());
+ iter.collect::<Result<Vec<_>, ArrowError>>()?
+ } else {
+ $ts_array.iter().map(|x| x.and_then($converter)).collect()
+ }
+ };
+}
+pub(crate) use timestamp_to_variant_timestamp;