This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new bc15cbdfc Fix timestamp handling in cast kernel (#1936) (#4033) (#4034)
bc15cbdfc is described below

commit bc15cbdfc1ada7b729eda5cdfb09fc7eda0c90ce
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Apr 7 16:08:54 2023 +0100

    Fix timestamp handling in cast kernel (#1936) (#4033) (#4034)
---
 arrow-array/src/types.rs |  26 ++++++++-
 arrow-cast/src/cast.rs   | 139 ++++++++++++++++++++++++++++++-----------------
 2 files changed, 113 insertions(+), 52 deletions(-)

diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs
index 827729ca6..e2d7a2492 100644
--- a/arrow-array/src/types.rs
+++ b/arrow-array/src/types.rs
@@ -26,7 +26,7 @@ use arrow_schema::{
     DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
     DECIMAL_DEFAULT_SCALE,
 };
-use chrono::{Duration, NaiveDate};
+use chrono::{Duration, NaiveDate, NaiveDateTime};
 use half::f16;
 use std::marker::PhantomData;
 use std::ops::{Add, Sub};
@@ -311,19 +311,43 @@ pub trait ArrowTimestampType: ArrowTemporalType<Native = 
i64> {
     fn get_time_unit() -> TimeUnit {
         Self::UNIT
     }
+
+    /// Creates a ArrowTimestampType::Native from the provided 
[`NaiveDateTime`]
+    ///
+    /// See [`DataType::Timestamp`] for more information on timezone handling
+    fn make_value(naive: NaiveDateTime) -> Option<i64>;
 }
 
 impl ArrowTimestampType for TimestampSecondType {
     const UNIT: TimeUnit = TimeUnit::Second;
+
+    fn make_value(naive: NaiveDateTime) -> Option<i64> {
+        Some(naive.timestamp())
+    }
 }
 impl ArrowTimestampType for TimestampMillisecondType {
     const UNIT: TimeUnit = TimeUnit::Millisecond;
+
+    fn make_value(naive: NaiveDateTime) -> Option<i64> {
+        let millis = naive.timestamp().checked_mul(1_000)?;
+        millis.checked_add(naive.timestamp_subsec_millis() as i64)
+    }
 }
 impl ArrowTimestampType for TimestampMicrosecondType {
     const UNIT: TimeUnit = TimeUnit::Microsecond;
+
+    fn make_value(naive: NaiveDateTime) -> Option<i64> {
+        let micros = naive.timestamp().checked_mul(1_000_000)?;
+        micros.checked_add(naive.timestamp_subsec_micros() as i64)
+    }
 }
 impl ArrowTimestampType for TimestampNanosecondType {
     const UNIT: TimeUnit = TimeUnit::Nanosecond;
+
+    fn make_value(naive: NaiveDateTime) -> Option<i64> {
+        let nanos = naive.timestamp().checked_mul(1_000_000_000)?;
+        nanos.checked_add(naive.timestamp_subsec_nanos() as i64)
+    }
 }
 
 impl IntervalYearMonthType {
diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index 372fcc1a3..05b56a0e8 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -35,14 +35,14 @@
 //! assert_eq!(7.0, c.value(2));
 //! ```
 
-use chrono::{NaiveTime, Timelike};
+use chrono::{NaiveTime, TimeZone, Timelike, Utc};
 use std::cmp::Ordering;
 use std::sync::Arc;
 
 use crate::display::{array_value_to_string, ArrayFormatter, FormatOptions};
 use crate::parse::{
     parse_interval_day_time, parse_interval_month_day_nano, 
parse_interval_year_month,
-    string_to_timestamp_nanos,
+    string_to_datetime,
 };
 use arrow_array::{
     builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *,
@@ -1233,16 +1233,16 @@ pub fn cast_with_options(
                 cast_string_to_time64nanosecond::<i64>(array, cast_options)
             }
             Timestamp(TimeUnit::Second, to_tz) => {
-                cast_string_to_timestamp::<i64, TimestampSecondType>(array, 
to_tz,cast_options)
+                cast_string_to_timestamp::<i64, TimestampSecondType>(array, 
to_tz, cast_options)
             }
             Timestamp(TimeUnit::Millisecond, to_tz) => {
-                cast_string_to_timestamp::<i64, 
TimestampMillisecondType>(array, to_tz,cast_options)
+                cast_string_to_timestamp::<i64, 
TimestampMillisecondType>(array, to_tz, cast_options)
             }
             Timestamp(TimeUnit::Microsecond, to_tz) => {
-                cast_string_to_timestamp::<i64, 
TimestampMicrosecondType>(array, to_tz,cast_options)
+                cast_string_to_timestamp::<i64, 
TimestampMicrosecondType>(array, to_tz, cast_options)
             }
             Timestamp(TimeUnit::Nanosecond, to_tz) => {
-                cast_string_to_timestamp::<i64, 
TimestampNanosecondType>(array, to_tz,cast_options)
+                cast_string_to_timestamp::<i64, 
TimestampNanosecondType>(array, to_tz, cast_options)
             }
             Interval(IntervalUnit::YearMonth) => {
                 cast_string_to_year_month_interval::<i64>(array, cast_options)
@@ -2653,45 +2653,58 @@ fn cast_string_to_time64nanosecond<Offset: 
OffsetSizeTrait>(
 }
 
 /// Casts generic string arrays to an ArrowTimestampType 
(TimeStampNanosecondArray, etc.)
-fn cast_string_to_timestamp<
-    Offset: OffsetSizeTrait,
-    TimestampType: ArrowTimestampType<Native = i64>,
->(
+fn cast_string_to_timestamp<O: OffsetSizeTrait, T: ArrowTimestampType>(
     array: &dyn Array,
     to_tz: &Option<Arc<str>>,
     cast_options: &CastOptions,
 ) -> Result<ArrayRef, ArrowError> {
-    let string_array = array
-        .as_any()
-        .downcast_ref::<GenericStringArray<Offset>>()
-        .unwrap();
-
-    let scale_factor = match TimestampType::UNIT {
-        TimeUnit::Second => 1_000_000_000,
-        TimeUnit::Millisecond => 1_000_000,
-        TimeUnit::Microsecond => 1_000,
-        TimeUnit::Nanosecond => 1,
+    let array = array.as_string::<O>();
+    let out: PrimitiveArray<T> = match to_tz {
+        Some(tz) => {
+            let tz: Tz = tz.as_ref().parse()?;
+            cast_string_to_timestamp_impl(array, &tz, cast_options)?
+        }
+        None => cast_string_to_timestamp_impl(array, &Utc, cast_options)?,
     };
+    Ok(Arc::new(out.with_timezone_opt(to_tz.clone())))
+}
 
-    let array = if cast_options.safe {
-        let iter = string_array.iter().map(|v| {
-            v.and_then(|v| string_to_timestamp_nanos(v).ok().map(|t| t / 
scale_factor))
+fn cast_string_to_timestamp_impl<
+    O: OffsetSizeTrait,
+    T: ArrowTimestampType,
+    Tz: TimeZone,
+>(
+    array: &GenericStringArray<O>,
+    tz: &Tz,
+    cast_options: &CastOptions,
+) -> Result<PrimitiveArray<T>, ArrowError> {
+    if cast_options.safe {
+        let iter = array.iter().map(|v| {
+            v.and_then(|v| {
+                let naive = string_to_datetime(tz, v).ok()?.naive_utc();
+                T::make_value(naive)
+            })
         });
         // Benefit:
         //     20% performance improvement
         // Soundness:
         //     The iterator is trustedLen because it comes from an 
`StringArray`.
 
-        unsafe {
-            PrimitiveArray::<TimestampType>::from_trusted_len_iter(iter)
-                .with_timezone_opt(to_tz.clone())
-        }
+        Ok(unsafe { PrimitiveArray::from_trusted_len_iter(iter) })
     } else {
-        let vec = string_array
+        let vec = array
             .iter()
             .map(|v| {
-                v.map(|v| string_to_timestamp_nanos(v).map(|t| t / 
scale_factor))
-                    .transpose()
+                v.map(|v| {
+                    let naive = string_to_datetime(tz, v)?.naive_utc();
+                    T::make_value(naive).ok_or_else(|| {
+                        ArrowError::CastError(format!(
+                            "Overflow converting {naive} to {:?}",
+                            T::UNIT
+                        ))
+                    })
+                })
+                .transpose()
             })
             .collect::<Result<Vec<Option<i64>>, _>>()?;
 
@@ -2699,13 +2712,8 @@ fn cast_string_to_timestamp<
         //     20% performance improvement
         // Soundness:
         //     The iterator is trustedLen because it comes from an 
`StringArray`.
-        unsafe {
-            PrimitiveArray::<TimestampType>::from_trusted_len_iter(vec.iter())
-                .with_timezone_opt(to_tz.clone())
-        }
-    };
-
-    Ok(Arc::new(array) as ArrayRef)
+        Ok(unsafe { PrimitiveArray::from_trusted_len_iter(vec.iter()) })
+    }
 }
 
 fn cast_string_to_year_month_interval<Offset: OffsetSizeTrait>(
@@ -5018,6 +5026,14 @@ mod tests {
         }
     }
 
+    #[test]
+    fn test_cast_string_to_timestamp_overflow() {
+        let array = StringArray::from(vec!["9800-09-08T12:00:00.123456789"]);
+        let result = cast(&array, &DataType::Timestamp(TimeUnit::Second, 
None)).unwrap();
+        let result = result.as_primitive::<TimestampSecondType>();
+        assert_eq!(result.values(), &[247112596800]);
+    }
+
     #[test]
     fn test_cast_string_to_date32() {
         let a1 = Arc::new(StringArray::from(vec![
@@ -8079,24 +8095,45 @@ mod tests {
             let array = Arc::new(valid) as ArrayRef;
             let b = cast_with_options(
                 &array,
-                &DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)),
+                &DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.clone())),
                 &CastOptions { safe: false },
             )
             .unwrap();
 
-            let c = b
-                .as_any()
-                .downcast_ref::<TimestampNanosecondArray>()
-                .unwrap();
-            assert_eq!(1672574706789000000, c.value(0));
-            assert_eq!(1672571106789000000, c.value(1));
-            assert_eq!(1672574706789000000, c.value(2));
-            assert_eq!(1672574706789000000, c.value(3));
-            assert_eq!(1672518906000000000, c.value(4));
-            assert_eq!(1672518906000000000, c.value(5));
-            assert_eq!(1672545906789000000, c.value(6));
-            assert_eq!(1672545906000000000, c.value(7));
-            assert_eq!(1672531200000000000, c.value(8));
+            let tz = tz.as_ref().parse().unwrap();
+
+            let as_tz = |v: i64| {
+                as_datetime_with_timezone::<TimestampNanosecondType>(v, 
tz).unwrap()
+            };
+
+            let as_utc = |v: &i64| as_tz(*v).naive_utc().to_string();
+            let as_local = |v: &i64| as_tz(*v).naive_local().to_string();
+
+            let values = b.as_primitive::<TimestampNanosecondType>().values();
+            let utc_results: Vec<_> = values.iter().map(as_utc).collect();
+            let local_results: Vec<_> = values.iter().map(as_local).collect();
+
+            // Absolute timestamps should be parsed preserving the same UTC 
instant
+            assert_eq!(
+                &utc_results[..6],
+                &[
+                    "2023-01-01 12:05:06.789".to_string(),
+                    "2023-01-01 11:05:06.789".to_string(),
+                    "2023-01-01 12:05:06.789".to_string(),
+                    "2023-01-01 12:05:06.789".to_string(),
+                    "2022-12-31 20:35:06".to_string(),
+                    "2022-12-31 20:35:06".to_string(),
+                ]
+            );
+            // Non-absolute timestamps should be parsed preserving the same 
local instant
+            assert_eq!(
+                &local_results[6..],
+                &[
+                    "2023-01-01 04:05:06.789".to_string(),
+                    "2023-01-01 04:05:06".to_string(),
+                    "2023-01-01 00:00:00".to_string()
+                ]
+            )
         }
 
         test_tz("+00:00".into());

Reply via email to