This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 26ea71cef Verify ArrayData::data_type compatible in
PrimitiveArray::from (#3440)
26ea71cef is described below
commit 26ea71cefab55990ffb8197707f2a8518e41412d
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Tue Jan 3 19:04:57 2023 +0000
Verify ArrayData::data_type compatible in PrimitiveArray::from (#3440)
---
arrow-arith/src/arity.rs | 8 +++-----
arrow-array/src/array/primitive_array.rs | 30 ++++++++++++++++++++++++++----
arrow-row/src/dictionary.rs | 5 +----
arrow-row/src/fixed.rs | 5 +----
4 files changed, 31 insertions(+), 17 deletions(-)
diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs
index e89fe7b91..3e7a81862 100644
--- a/arrow-arith/src/arity.rs
+++ b/arrow-arith/src/arity.rs
@@ -114,9 +114,7 @@ where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native, ArrowError>,
{
- if std::mem::discriminant(&array.value_type())
- != std::mem::discriminant(&T::DATA_TYPE)
- {
+ if !PrimitiveArray::<T>::is_compatible(&array.value_type()) {
return Err(ArrowError::CastError(format!(
"Cannot perform the unary operation of type {} on dictionary array
of value type {}",
T::DATA_TYPE,
@@ -138,7 +136,7 @@ where
downcast_dictionary_array! {
array => unary_dict::<_, F, T>(array, op),
t => {
- if std::mem::discriminant(t) ==
std::mem::discriminant(&T::DATA_TYPE) {
+ if PrimitiveArray::<T>::is_compatible(t) {
Ok(Arc::new(unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
@@ -170,7 +168,7 @@ where
)))
},
t => {
- if std::mem::discriminant(t) ==
std::mem::discriminant(&T::DATA_TYPE) {
+ if PrimitiveArray::<T>::is_compatible(t) {
Ok(Arc::new(try_unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
diff --git a/arrow-array/src/array/primitive_array.rs
b/arrow-array/src/array/primitive_array.rs
index 4ff0ed4d9..01eda724b 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -297,6 +297,21 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
PrimitiveBuilder::<T>::with_capacity(capacity)
}
+ /// Returns if this [`PrimitiveArray`] is compatible with the provided
[`DataType`]
+ ///
+ /// This is equivalent to `data_type == T::DATA_TYPE`, however ignores
timestamp
+ /// timezones and decimal precision and scale
+ pub fn is_compatible(data_type: &DataType) -> bool {
+ match T::DATA_TYPE {
+ DataType::Timestamp(t1, _) => {
+ matches!(data_type, DataType::Timestamp(t2, _) if &t1 == t2)
+ }
+ DataType::Decimal128(_, _) => matches!(data_type,
DataType::Decimal128(_, _)),
+ DataType::Decimal256(_, _) => matches!(data_type,
DataType::Decimal256(_, _)),
+ _ => T::DATA_TYPE.eq(data_type),
+ }
+ }
+
/// Returns the primitive value at index `i`.
///
/// # Safety
@@ -1042,10 +1057,8 @@ impl<T: ArrowTimestampType> PrimitiveArray<T> {
/// Constructs a `PrimitiveArray` from an array data reference.
impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
fn from(data: ArrayData) -> Self {
- // Use discriminant to allow for decimals
- assert_eq!(
- std::mem::discriminant(&T::DATA_TYPE),
- std::mem::discriminant(data.data_type()),
+ assert!(
+ Self::is_compatible(data.data_type()),
"PrimitiveArray expected ArrayData with type {} got {}",
T::DATA_TYPE,
data.data_type()
@@ -2205,4 +2218,13 @@ mod tests {
let c = array.unary_mut(|x| x * 2 + 1).unwrap();
assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None]));
}
+
+ #[test]
+ #[should_panic(
+ expected = "PrimitiveArray expected ArrayData with type
Interval(MonthDayNano) got Interval(DayTime)"
+ )]
+ fn test_invalid_interval_type() {
+ let array = IntervalDayTimeArray::from(vec![1, 2, 3]);
+ let _ = IntervalMonthDayNanoArray::from(array.into_data());
+ }
}
diff --git a/arrow-row/src/dictionary.rs b/arrow-row/src/dictionary.rs
index 0da6c68d1..e332e1131 100644
--- a/arrow-row/src/dictionary.rs
+++ b/arrow-row/src/dictionary.rs
@@ -270,10 +270,7 @@ fn decode_primitive<T: ArrowPrimitiveType>(
where
T::Native: FixedLengthEncoding,
{
- assert_eq!(
- std::mem::discriminant(&T::DATA_TYPE),
- std::mem::discriminant(&data_type),
- );
+ assert!(PrimitiveArray::<T>::is_compatible(&data_type));
// SAFETY:
// Validated data type above
diff --git a/arrow-row/src/fixed.rs b/arrow-row/src/fixed.rs
index 159eba9ad..d4b82c2a3 100644
--- a/arrow-row/src/fixed.rs
+++ b/arrow-row/src/fixed.rs
@@ -343,10 +343,7 @@ pub fn decode_primitive<T: ArrowPrimitiveType>(
where
T::Native: FixedLengthEncoding,
{
- assert_eq!(
- std::mem::discriminant(&T::DATA_TYPE),
- std::mem::discriminant(&data_type),
- );
+ assert!(PrimitiveArray::<T>::is_compatible(&data_type));
// SAFETY:
// Validated data type above
unsafe { decode_fixed::<T::Native>(rows, data_type, options).into() }