This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 26ea71cef Verify ArrayData::data_type compatible in 
PrimitiveArray::from (#3440)
26ea71cef is described below

commit 26ea71cefab55990ffb8197707f2a8518e41412d
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Tue Jan 3 19:04:57 2023 +0000

    Verify ArrayData::data_type compatible in PrimitiveArray::from (#3440)
---
 arrow-arith/src/arity.rs                 |  8 +++-----
 arrow-array/src/array/primitive_array.rs | 30 ++++++++++++++++++++++++++----
 arrow-row/src/dictionary.rs              |  5 +----
 arrow-row/src/fixed.rs                   |  5 +----
 4 files changed, 31 insertions(+), 17 deletions(-)

diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs
index e89fe7b91..3e7a81862 100644
--- a/arrow-arith/src/arity.rs
+++ b/arrow-arith/src/arity.rs
@@ -114,9 +114,7 @@ where
     T: ArrowPrimitiveType,
     F: Fn(T::Native) -> Result<T::Native, ArrowError>,
 {
-    if std::mem::discriminant(&array.value_type())
-        != std::mem::discriminant(&T::DATA_TYPE)
-    {
+    if !PrimitiveArray::<T>::is_compatible(&array.value_type()) {
         return Err(ArrowError::CastError(format!(
             "Cannot perform the unary operation of type {} on dictionary array 
of value type {}",
             T::DATA_TYPE,
@@ -138,7 +136,7 @@ where
     downcast_dictionary_array! {
         array => unary_dict::<_, F, T>(array, op),
         t => {
-            if std::mem::discriminant(t) == 
std::mem::discriminant(&T::DATA_TYPE) {
+            if PrimitiveArray::<T>::is_compatible(t) {
                 Ok(Arc::new(unary::<T, F, T>(
                     
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
                     op,
@@ -170,7 +168,7 @@ where
             )))
         },
         t => {
-            if std::mem::discriminant(t) == 
std::mem::discriminant(&T::DATA_TYPE) {
+            if PrimitiveArray::<T>::is_compatible(t) {
                 Ok(Arc::new(try_unary::<T, F, T>(
                     
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
                     op,
diff --git a/arrow-array/src/array/primitive_array.rs 
b/arrow-array/src/array/primitive_array.rs
index 4ff0ed4d9..01eda724b 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -297,6 +297,21 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
         PrimitiveBuilder::<T>::with_capacity(capacity)
     }
 
+    /// Returns if this [`PrimitiveArray`] is compatible with the provided 
[`DataType`]
+    ///
+    /// This is equivalent to `data_type == T::DATA_TYPE`, however ignores 
timestamp
+    /// timezones and decimal precision and scale
+    pub fn is_compatible(data_type: &DataType) -> bool {
+        match T::DATA_TYPE {
+            DataType::Timestamp(t1, _) => {
+                matches!(data_type, DataType::Timestamp(t2, _) if &t1 == t2)
+            }
+            DataType::Decimal128(_, _) => matches!(data_type, 
DataType::Decimal128(_, _)),
+            DataType::Decimal256(_, _) => matches!(data_type, 
DataType::Decimal256(_, _)),
+            _ => T::DATA_TYPE.eq(data_type),
+        }
+    }
+
     /// Returns the primitive value at index `i`.
     ///
     /// # Safety
@@ -1042,10 +1057,8 @@ impl<T: ArrowTimestampType> PrimitiveArray<T> {
 /// Constructs a `PrimitiveArray` from an array data reference.
 impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
     fn from(data: ArrayData) -> Self {
-        // Use discriminant to allow for decimals
-        assert_eq!(
-            std::mem::discriminant(&T::DATA_TYPE),
-            std::mem::discriminant(data.data_type()),
+        assert!(
+            Self::is_compatible(data.data_type()),
             "PrimitiveArray expected ArrayData with type {} got {}",
             T::DATA_TYPE,
             data.data_type()
@@ -2205,4 +2218,13 @@ mod tests {
         let c = array.unary_mut(|x| x * 2 + 1).unwrap();
         assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None]));
     }
+
+    #[test]
+    #[should_panic(
+        expected = "PrimitiveArray expected ArrayData with type 
Interval(MonthDayNano) got Interval(DayTime)"
+    )]
+    fn test_invalid_interval_type() {
+        let array = IntervalDayTimeArray::from(vec![1, 2, 3]);
+        let _ = IntervalMonthDayNanoArray::from(array.into_data());
+    }
 }
diff --git a/arrow-row/src/dictionary.rs b/arrow-row/src/dictionary.rs
index 0da6c68d1..e332e1131 100644
--- a/arrow-row/src/dictionary.rs
+++ b/arrow-row/src/dictionary.rs
@@ -270,10 +270,7 @@ fn decode_primitive<T: ArrowPrimitiveType>(
 where
     T::Native: FixedLengthEncoding,
 {
-    assert_eq!(
-        std::mem::discriminant(&T::DATA_TYPE),
-        std::mem::discriminant(&data_type),
-    );
+    assert!(PrimitiveArray::<T>::is_compatible(&data_type));
 
     // SAFETY:
     // Validated data type above
diff --git a/arrow-row/src/fixed.rs b/arrow-row/src/fixed.rs
index 159eba9ad..d4b82c2a3 100644
--- a/arrow-row/src/fixed.rs
+++ b/arrow-row/src/fixed.rs
@@ -343,10 +343,7 @@ pub fn decode_primitive<T: ArrowPrimitiveType>(
 where
     T::Native: FixedLengthEncoding,
 {
-    assert_eq!(
-        std::mem::discriminant(&T::DATA_TYPE),
-        std::mem::discriminant(&data_type),
-    );
+    assert!(PrimitiveArray::<T>::is_compatible(&data_type));
     // SAFETY:
     // Validated data type above
     unsafe { decode_fixed::<T::Native>(rows, data_type, options).into() }

Reply via email to