This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 77df2ee42d [Variant] add strict mode to cast_to_variant (#8233)
77df2ee42d is described below

commit 77df2ee42d8ca1d1557a64681b240b8409deef01
Author: Yan Tingwang <[email protected]>
AuthorDate: Tue Sep 9 22:31:38 2025 +0800

    [Variant] add strict mode to cast_to_variant (#8233)
    
    # Which issue does this PR close?
    
    - Closes #8155 .
    
    # Rationale for this change
    
    cast_to_variant will panic for values of Date64 / Timestamp that can not
    be converted to NaiveDate
    
    # What changes are included in this PR?
    
    1. add new api :
    `pub fn cast_to_variant_with_options(input: &dyn Array, strict: bool) ->
    Result<VariantArray, ArrowError>`
    - strict = true: Returns errors on conversion failures (default
    behavior)
      - strict = false: Returns null values for failed conversions
    2. add some tests to test non-strict mode.
    3. refactor: eliminate duplication in timestamp conversion using macro
    # Are these changes tested?
    
    Yes.
    
    # Are there any user-facing changes?
    
    no.
    
    ---------
    
    Signed-off-by: codephage2020 <[email protected]>
    Co-authored-by: Ryan Johnson <[email protected]>
---
 parquet-variant-compute/src/cast_to_variant.rs | 328 ++++++++++++++++---------
 parquet-variant-compute/src/lib.rs             |   3 +-
 parquet-variant-compute/src/type_conversion.rs |  48 ++++
 3 files changed, 264 insertions(+), 115 deletions(-)

diff --git a/parquet-variant-compute/src/cast_to_variant.rs 
b/parquet-variant-compute/src/cast_to_variant.rs
index 412f207cfe..231d36f96e 100644
--- a/parquet-variant-compute/src/cast_to_variant.rs
+++ b/parquet-variant-compute/src/cast_to_variant.rs
@@ -20,7 +20,7 @@ use std::sync::Arc;
 
 use crate::type_conversion::{
     decimal_to_variant_decimal, generic_conversion_array, 
non_generic_conversion_array,
-    primitive_conversion_array,
+    primitive_conversion_array, timestamp_to_variant_timestamp,
 };
 use crate::{VariantArray, VariantArrayBuilder};
 use arrow::array::{
@@ -46,6 +46,101 @@ use parquet_variant::{
     Variant, VariantBuilder, VariantDecimal16, VariantDecimal4, 
VariantDecimal8,
 };
 
+/// Options for controlling the behavior of `cast_to_variant_with_options`.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct CastOptions {
+    /// If true, return error on conversion failure. If false, insert null for 
failed conversions.
+    pub strict: bool,
+}
+
+impl Default for CastOptions {
+    fn default() -> Self {
+        Self { strict: true }
+    }
+}
+
+fn convert_timestamp_with_options(
+    time_unit: &TimeUnit,
+    time_zone: &Option<Arc<str>>,
+    input: &dyn Array,
+    builder: &mut VariantArrayBuilder,
+    options: &CastOptions,
+) -> Result<(), ArrowError> {
+    let native_datetimes: Vec<Option<NaiveDateTime>> = match time_unit {
+        arrow_schema::TimeUnit::Second => {
+            let ts_array = input
+                .as_any()
+                .downcast_ref::<TimestampSecondArray>()
+                .expect("Array is not TimestampSecondArray");
+            timestamp_to_variant_timestamp!(
+                ts_array,
+                timestamp_s_to_datetime,
+                "seconds",
+                options.strict
+            )
+        }
+        arrow_schema::TimeUnit::Millisecond => {
+            let ts_array = input
+                .as_any()
+                .downcast_ref::<TimestampMillisecondArray>()
+                .expect("Array is not TimestampMillisecondArray");
+            timestamp_to_variant_timestamp!(
+                ts_array,
+                timestamp_ms_to_datetime,
+                "milliseconds",
+                options.strict
+            )
+        }
+        arrow_schema::TimeUnit::Microsecond => {
+            let ts_array = input
+                .as_any()
+                .downcast_ref::<TimestampMicrosecondArray>()
+                .expect("Array is not TimestampMicrosecondArray");
+            timestamp_to_variant_timestamp!(
+                ts_array,
+                timestamp_us_to_datetime,
+                "microseconds",
+                options.strict
+            )
+        }
+        arrow_schema::TimeUnit::Nanosecond => {
+            let ts_array = input
+                .as_any()
+                .downcast_ref::<TimestampNanosecondArray>()
+                .expect("Array is not TimestampNanosecondArray");
+            timestamp_to_variant_timestamp!(
+                ts_array,
+                timestamp_ns_to_datetime,
+                "nanoseconds",
+                options.strict
+            )
+        }
+    };
+
+    for (i, x) in native_datetimes.iter().enumerate() {
+        match x {
+            Some(ndt) => {
+                if time_zone.is_none() {
+                    builder.append_variant((*ndt).into());
+                } else {
+                    let utc_dt: DateTime<Utc> = Utc.from_utc_datetime(ndt);
+                    builder.append_variant(utc_dt.into());
+                }
+            }
+            None if options.strict && input.is_valid(i) => {
+                return Err(ArrowError::ComputeError(format!(
+                    "Failed to convert timestamp at index {}: invalid 
timestamp value",
+                    i
+                )));
+            }
+            None => {
+                builder.append_null();
+            }
+        }
+    }
+    Ok(())
+}
+
 /// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when 
you
 /// need to convert a specific data type
 ///
@@ -75,7 +170,14 @@ use parquet_variant::{
 /// `1970-01-01T00:00:01.234567890Z`
 /// will be truncated to
 /// `1970-01-01T00:00:01.234567Z`
-pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
+///
+/// # Arguments
+/// * `input` - The array to convert to VariantArray
+/// * `options` - Options controlling conversion behavior
+pub fn cast_to_variant_with_options(
+    input: &dyn Array,
+    options: &CastOptions,
+) -> Result<VariantArray, ArrowError> {
     let mut builder = VariantArrayBuilder::new(input.len());
 
     let input_type = input.data_type();
@@ -167,25 +269,7 @@ pub fn cast_to_variant(input: &dyn Array) -> 
Result<VariantArray, ArrowError> {
             );
         }
         DataType::Timestamp(time_unit, time_zone) => {
-            convert_timestamp(time_unit, time_zone, input, &mut builder);
-        }
-        DataType::Date32 => {
-            generic_conversion_array!(
-                Date32Type,
-                as_primitive,
-                |v: i32| -> NaiveDate { Date32Type::to_naive_date(v) },
-                input,
-                builder
-            );
-        }
-        DataType::Date64 => {
-            generic_conversion_array!(
-                Date64Type,
-                as_primitive,
-                |v: i64| { Date64Type::to_naive_date_opt(v).unwrap() },
-                input,
-                builder
-            );
+            convert_timestamp_with_options(time_unit, time_zone, input, &mut 
builder, options)?;
         }
         DataType::Time32(unit) => {
             match *unit {
@@ -194,10 +278,11 @@ pub fn cast_to_variant(input: &dyn Array) -> 
Result<VariantArray, ArrowError> {
                         Time32SecondType,
                         as_primitive,
                         // nano second are always 0
-                        |v| NaiveTime::from_num_seconds_from_midnight_opt(v as 
u32, 0u32).unwrap(),
+                        |v| NaiveTime::from_num_seconds_from_midnight_opt(v as 
u32, 0u32),
                         input,
-                        builder
-                    );
+                        builder,
+                        options.strict
+                    )?;
                 }
                 TimeUnit::Millisecond => {
                     generic_conversion_array!(
@@ -206,11 +291,11 @@ pub fn cast_to_variant(input: &dyn Array) -> 
Result<VariantArray, ArrowError> {
                         |v| NaiveTime::from_num_seconds_from_midnight_opt(
                             v as u32 / 1000,
                             (v as u32 % 1000) * 1_000_000
-                        )
-                        .unwrap(),
+                        ),
                         input,
-                        builder
-                    );
+                        builder,
+                        options.strict
+                    )?;
                 }
                 _ => {
                     return Err(ArrowError::CastError(format!(
@@ -229,11 +314,11 @@ pub fn cast_to_variant(input: &dyn Array) -> 
Result<VariantArray, ArrowError> {
                         |v| NaiveTime::from_num_seconds_from_midnight_opt(
                             (v / 1_000_000) as u32,
                             (v % 1_000_000 * 1_000) as u32
-                        )
-                        .unwrap(),
+                        ),
                         input,
-                        builder
-                    );
+                        builder,
+                        options.strict
+                    )?;
                 }
                 TimeUnit::Nanosecond => {
                     generic_conversion_array!(
@@ -242,11 +327,11 @@ pub fn cast_to_variant(input: &dyn Array) -> 
Result<VariantArray, ArrowError> {
                         |v| NaiveTime::from_num_seconds_from_midnight_opt(
                             (v / 1_000_000_000) as u32,
                             (v % 1_000_000_000) as u32
-                        )
-                        .unwrap(),
+                        ),
                         input,
-                        builder
-                    );
+                        builder,
+                        options.strict
+                    )?;
                 }
                 _ => {
                     return Err(ArrowError::CastError(format!(
@@ -284,6 +369,25 @@ pub fn cast_to_variant(input: &dyn Array) -> 
Result<VariantArray, ArrowError> {
         DataType::Utf8View => {
             non_generic_conversion_array!(input.as_string_view(), |v| v, 
builder);
         }
+        DataType::Date32 => {
+            generic_conversion_array!(
+                Date32Type,
+                as_primitive,
+                |v: i32| -> NaiveDate { Date32Type::to_naive_date(v) },
+                input,
+                builder
+            );
+        }
+        DataType::Date64 => {
+            generic_conversion_array!(
+                Date64Type,
+                as_primitive,
+                |v: i64| Date64Type::to_naive_date_opt(v),
+                input,
+                builder,
+                options.strict
+            )?;
+        }
         DataType::List(_) => convert_list::<i32>(input, &mut builder)?,
         DataType::LargeList(_) => convert_list::<i64>(input, &mut builder)?,
         DataType::Struct(_) => convert_struct(input, &mut builder)?,
@@ -310,79 +414,6 @@ pub fn cast_to_variant(input: &dyn Array) -> 
Result<VariantArray, ArrowError> {
     Ok(builder.build())
 }
 
-// TODO do we need a cast_with_options to allow specifying conversion behavior,
-// e.g. how to handle overflows, whether to convert to Variant::Null or return
-// an error, etc. ?
-
-/// Convert timestamp arrays to native datetimes
-fn convert_timestamp(
-    time_unit: &TimeUnit,
-    time_zone: &Option<Arc<str>>,
-    input: &dyn Array,
-    builder: &mut VariantArrayBuilder,
-) {
-    let native_datetimes: Vec<Option<NaiveDateTime>> = match time_unit {
-        arrow_schema::TimeUnit::Second => {
-            let ts_array = input
-                .as_any()
-                .downcast_ref::<TimestampSecondArray>()
-                .expect("Array is not TimestampSecondArray");
-
-            ts_array
-                .iter()
-                .map(|x| x.map(|y| timestamp_s_to_datetime(y).unwrap()))
-                .collect()
-        }
-        arrow_schema::TimeUnit::Millisecond => {
-            let ts_array = input
-                .as_any()
-                .downcast_ref::<TimestampMillisecondArray>()
-                .expect("Array is not TimestampMillisecondArray");
-
-            ts_array
-                .iter()
-                .map(|x| x.map(|y| timestamp_ms_to_datetime(y).unwrap()))
-                .collect()
-        }
-        arrow_schema::TimeUnit::Microsecond => {
-            let ts_array = input
-                .as_any()
-                .downcast_ref::<TimestampMicrosecondArray>()
-                .expect("Array is not TimestampMicrosecondArray");
-            ts_array
-                .iter()
-                .map(|x| x.map(|y| timestamp_us_to_datetime(y).unwrap()))
-                .collect()
-        }
-        arrow_schema::TimeUnit::Nanosecond => {
-            let ts_array = input
-                .as_any()
-                .downcast_ref::<TimestampNanosecondArray>()
-                .expect("Array is not TimestampNanosecondArray");
-            ts_array
-                .iter()
-                .map(|x| x.map(|y| timestamp_ns_to_datetime(y).unwrap()))
-                .collect()
-        }
-    };
-
-    for x in native_datetimes {
-        match x {
-            Some(ndt) => {
-                if time_zone.is_none() {
-                    builder.append_variant(ndt.into());
-                } else {
-                    let utc_dt: DateTime<Utc> = Utc.from_utc_datetime(&ndt);
-                    builder.append_variant(utc_dt.into());
-                }
-            }
-            None => {
-                builder.append_null();
-            }
-        }
-    }
-}
-
 /// Generic function to convert list arrays (both List and LargeList) to 
variant arrays
 fn convert_list<O: OffsetSizeTrait>(
     input: &dyn Array,
@@ -525,6 +556,15 @@ fn convert_map(
     Ok(())
 }
 
+/// Convert an array to a `VariantArray` with strict mode enabled (returns 
errors on conversion failures).
+///
+/// This function provides backward compatibility. For non-strict behavior,
+/// use `cast_to_variant_with_options` with `CastOptions { strict: false }`.
+pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
+    cast_to_variant_with_options(input, &CastOptions::default())
+}
+
+/// Convert union arrays
 fn convert_union(
     fields: &UnionFields,
     input: &dyn Array,
@@ -634,8 +674,8 @@ mod tests {
         IntervalDayTimeArray, IntervalMonthDayNanoArray, 
IntervalYearMonthArray, LargeListArray,
         LargeStringArray, ListArray, MapArray, NullArray, StringArray, 
StringRunBuilder,
         StringViewArray, StructArray, Time32MillisecondArray, 
Time32SecondArray,
-        Time64MicrosecondArray, Time64NanosecondArray, UInt16Array, 
UInt32Array, UInt64Array,
-        UInt8Array, UnionArray,
+        Time64MicrosecondArray, Time64NanosecondArray, TimestampSecondArray, 
UInt16Array,
+        UInt32Array, UInt64Array, UInt8Array, UnionArray,
     };
     use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
     use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
@@ -2349,9 +2389,9 @@ mod tests {
     /// Converts the given `Array` to a `VariantArray` and tests the conversion
     /// against the expected values. It also tests the handling of nulls by
     /// setting one element to null and verifying the output.
-    fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) {
-        // test without nulls
-        let variant_array = cast_to_variant(&values).unwrap();
+    fn run_test_with_options(values: ArrayRef, expected: Vec<Option<Variant>>, 
strict: bool) {
+        let options = CastOptions { strict };
+        let variant_array = cast_to_variant_with_options(&values, 
&options).unwrap();
         assert_eq!(variant_array.len(), expected.len());
         for (i, expected_value) in expected.iter().enumerate() {
             match expected_value {
@@ -2365,4 +2405,64 @@ mod tests {
             }
         }
     }
+
+    fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) {
+        run_test_with_options(values, expected, true);
+    }
+
+    fn run_test_non_strict(values: ArrayRef, expected: Vec<Option<Variant>>) {
+        run_test_with_options(values, expected, false);
+    }
+
+    #[test]
+    fn test_cast_to_variant_non_strict_mode_date64() {
+        let date64_values = Date64Array::from(vec![Some(i64::MAX), Some(0), 
Some(i64::MIN)]);
+
+        let values = Arc::new(date64_values);
+        run_test_non_strict(
+            values,
+            vec![
+                None,
+                Some(Variant::Date(Date64Type::to_naive_date_opt(0).unwrap())),
+                None,
+            ],
+        );
+    }
+
+    #[test]
+    fn test_cast_to_variant_non_strict_mode_time32() {
+        let time32_array = Time32SecondArray::from(vec![Some(90000), 
Some(3600), Some(-1)]);
+
+        let values = Arc::new(time32_array);
+        run_test_non_strict(
+            values,
+            vec![
+                None,
+                Some(Variant::Time(
+                    NaiveTime::from_num_seconds_from_midnight_opt(3600, 
0).unwrap(),
+                )),
+                None,
+            ],
+        );
+    }
+
+    #[test]
+    fn test_cast_to_variant_non_strict_mode_timestamp() {
+        let ts_array = TimestampSecondArray::from(vec![Some(i64::MAX), 
Some(0), Some(1609459200)])
+            .with_timezone_opt(None::<&str>);
+
+        let values = Arc::new(ts_array);
+        run_test_non_strict(
+            values,
+            vec![
+                None, // Invalid timestamp becomes null
+                Some(Variant::TimestampNtzMicros(
+                    timestamp_s_to_datetime(0).unwrap(),
+                )),
+                Some(Variant::TimestampNtzMicros(
+                    timestamp_s_to_datetime(1609459200).unwrap(),
+                )),
+            ],
+        );
+    }
 }
diff --git a/parquet-variant-compute/src/lib.rs 
b/parquet-variant-compute/src/lib.rs
index ef674d9614..3c928636ac 100644
--- a/parquet-variant-compute/src/lib.rs
+++ b/parquet-variant-compute/src/lib.rs
@@ -22,7 +22,7 @@
 //! - [`VariantArrayBuilder`]: For building [`VariantArray`]
 //! - [`json_to_variant`]: Function to convert a batch of JSON strings to a 
`VariantArray`.
 //! - [`variant_to_json`]: Function to convert a `VariantArray` to a batch of 
JSON strings.
-//! - [`cast_to_variant`]: Module to cast other Arrow arrays to `VariantArray`.
+//! - [`mod@cast_to_variant`]: Module to cast other Arrow arrays to 
`VariantArray`.
 //! - [`variant_get`]: Module to get values from a `VariantArray` using a 
specified [`VariantPath`]
 //!
 //! ## 🚧 Work In Progress
@@ -46,5 +46,6 @@ pub mod variant_get;
 pub use variant_array::{ShreddingState, VariantArray};
 pub use variant_array_builder::{VariantArrayBuilder, 
VariantArrayVariantBuilder};
 
+pub use cast_to_variant::{cast_to_variant, cast_to_variant_with_options, 
CastOptions};
 pub use from_json::json_to_variant;
 pub use to_json::variant_to_json;
diff --git a/parquet-variant-compute/src/type_conversion.rs 
b/parquet-variant-compute/src/type_conversion.rs
index 647d2c705f..aa60b425a1 100644
--- a/parquet-variant-compute/src/type_conversion.rs
+++ b/parquet-variant-compute/src/type_conversion.rs
@@ -20,6 +20,7 @@
 /// Convert the input array to a `VariantArray` row by row, using `method`
 /// not requiring a generic type to downcast the generic array to a specific
 /// array type and `cast_fn` to transform each element to a type compatible 
with Variant
+/// If `strict` is true(default), return error on conversion failure. If 
false, insert null.
 macro_rules! non_generic_conversion_array {
     ($array:expr, $cast_fn:expr, $builder:expr) => {{
         let array = $array;
@@ -32,6 +33,28 @@ macro_rules! non_generic_conversion_array {
             $builder.append_variant(Variant::from(cast_value));
         }
     }};
+    ($array:expr, $cast_fn:expr, $builder:expr, $strict:expr) => {{
+        let array = $array;
+        for i in 0..array.len() {
+            if array.is_null(i) {
+                $builder.append_null();
+                continue;
+            }
+            match $cast_fn(array.value(i)) {
+                Some(cast_value) => {
+                    $builder.append_variant(Variant::from(cast_value));
+                }
+                None if $strict => {
+                    return Err(ArrowError::ComputeError(format!(
+                        "Failed to convert value at index {}: conversion 
failed",
+                        i
+                    )));
+                }
+                None => $builder.append_null(),
+            }
+        }
+        Ok::<(), ArrowError>(())
+    }};
 }
 pub(crate) use non_generic_conversion_array;
 
@@ -52,6 +75,7 @@ pub(crate) use non_generic_conversion_single_value;
 /// Convert the input array to a `VariantArray` row by row, using `method`
 /// requiring a generic type to downcast the generic array to a specific
 /// array type and `cast_fn` to transform each element to a type compatible 
with Variant
+/// If `strict` is true(default), return error on conversion failure. If 
false, insert null.
 macro_rules! generic_conversion_array {
     ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{
         $crate::type_conversion::non_generic_conversion_array!(
@@ -60,6 +84,14 @@ macro_rules! generic_conversion_array {
             $builder
         )
     }};
+    ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $builder:expr, 
$strict:expr) => {{
+        $crate::type_conversion::non_generic_conversion_array!(
+            $input.$method::<$t>(),
+            $cast_fn,
+            $builder,
+            $strict
+        )
+    }};
 }
 pub(crate) use generic_conversion_array;
 
@@ -123,3 +155,19 @@ macro_rules! decimal_to_variant_decimal {
     }};
 }
 pub(crate) use decimal_to_variant_decimal;
+
+/// Convert a timestamp value to a `VariantTimestamp`
+macro_rules! timestamp_to_variant_timestamp {
+    ($ts_array:expr, $converter:expr, $unit_name:expr, $strict:expr) => {
+        if $strict {
+            let error =
+                || ArrowError::ComputeError(format!("Invalid timestamp {} 
value", $unit_name));
+            let converter = |x| $converter(x).ok_or_else(error);
+            let iter = $ts_array.iter().map(|x| x.map(converter).transpose());
+            iter.collect::<Result<Vec<_>, ArrowError>>()?
+        } else {
+            $ts_array.iter().map(|x| x.and_then($converter)).collect()
+        }
+    };
+}
+pub(crate) use timestamp_to_variant_timestamp;

Reply via email to