viirya commented on code in PR #2740: URL: https://github.com/apache/arrow-rs/pull/2740#discussion_r973524711
########## arrow/src/compute/kernels/arithmetic.rs: ########## @@ -522,67 +548,86 @@ macro_rules! typed_dict_math_op { }}; } -/// Helper function to perform math lambda function on values from two dictionary arrays, this -/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) -macro_rules! math_dict_op { - ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ - if $left.len() != $right.len() { - return Err(ArrowError::ComputeError(format!( - "Cannot perform operation on arrays of different length ({}, {})", - $left.len(), - $right.len() - ))); - } +/// Perform given operation on two `DictionaryArray`s. +/// Returns an error if the two arrays have different value type +fn math_op_dict<K, T, F>( + left: &DictionaryArray<K>, + right: &DictionaryArray<K>, + op: F, +) -> Result<PrimitiveArray<T>> +where + K: ArrowNumericType, + T: ArrowNumericType, + F: Fn(T::Native, T::Native) -> T::Native, + T::Native: ArrowNativeTypeOp, +{ + if left.len() != right.len() { + return Err(ArrowError::ComputeError(format!( + "Cannot perform operation on arrays of different length ({}, {})", + left.len(), + right.len() + ))); + } - // Safety justification: Since the inputs are valid Arrow arrays, all values are - // valid indexes into the dictionary (which is verified during construction) - - let left_iter = unsafe { - $left - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter_unchecked($left.keys_iter()) - }; - - let right_iter = unsafe { - $right - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter_unchecked($right.keys_iter()) - }; - - let result = left_iter - .zip(right_iter) - .map(|(left_value, right_value)| { - if let (Some(left), Some(right)) = (left_value, right_value) { - Some($op(left, right)) - } else { - None - } - }) - .collect(); + // Safety justification: Since the inputs are valid Arrow arrays, all values are + // valid indexes into the dictionary (which is verified during construction) - Ok(result) - }}; + let left_iter = unsafe { + left.values() + .as_any() + .downcast_ref::<PrimitiveArray<T>>() + .unwrap() + .take_iter_unchecked(left.keys_iter()) + }; + + let right_iter = unsafe { + right + .values() + .as_any() + .downcast_ref::<PrimitiveArray<T>>() + .unwrap() + .take_iter_unchecked(right.keys_iter()) + }; + + let result = left_iter + .zip(right_iter) + .map(|(left_value, right_value)| { + if let (Some(left), Some(right)) = (left_value, right_value) { + Some(op(left, right)) + } else { + None + } + }) + .collect(); + + Ok(result) } /// Perform given operation on two `DictionaryArray`s. /// Returns an error if the two arrays have different value type -fn math_op_dict<K, T, F>( +fn math_checked_op_dict<K, T, F>( left: &DictionaryArray<K>, right: &DictionaryArray<K>, op: F, ) -> Result<PrimitiveArray<T>> where K: ArrowNumericType, T: ArrowNumericType, - F: Fn(T::Native, T::Native) -> T::Native, + F: Fn(T::Native, T::Native) -> Result<T::Native>, + T::Native: ArrowNativeTypeOp, { - math_dict_op!(left, right, op, PrimitiveArray<T>) + if left.len() != right.len() { Review Comment: Removed length check. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org