This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 2d28010ad Add try_unary, binary, try_binary kernels (#2666)
2d28010ad is described below

commit 2d28010ad2691bfdd7429f98848f4be32538bd6f
Author: Raphael Taylor-Davies <1781103+tustv...@users.noreply.github.com>
AuthorDate: Sun Sep 11 07:34:12 2022 +0100

    Add try_unary, binary, try_binary kernels (#2666)
---
 arrow/benches/arithmetic_kernels.rs     | 143 +++++++++----------------
 arrow/src/array/iterator.rs             |  47 ++++++++-
 arrow/src/compute/kernels/arithmetic.rs | 182 ++++++++------------------------
 arrow/src/compute/kernels/arity.rs      | 178 +++++++++++++++++++++++++++----
 arrow/src/util/bit_iterator.rs          |  42 ++++++++
 5 files changed, 336 insertions(+), 256 deletions(-)

diff --git a/arrow/benches/arithmetic_kernels.rs 
b/arrow/benches/arithmetic_kernels.rs
index 10af0b543..2aa2e7191 100644
--- a/arrow/benches/arithmetic_kernels.rs
+++ b/arrow/benches/arithmetic_kernels.rs
@@ -20,107 +20,62 @@ extern crate criterion;
 use criterion::Criterion;
 use rand::Rng;
 
-use std::sync::Arc;
-
 extern crate arrow;
 
+use arrow::datatypes::Float32Type;
 use arrow::util::bench_util::*;
-use arrow::{array::*, datatypes::Float32Type};
 use arrow::{compute::kernels::arithmetic::*, util::test_util::seedable_rng};
 
-fn create_array(size: usize, with_nulls: bool) -> ArrayRef {
-    let null_density = if with_nulls { 0.5 } else { 0.0 };
-    let array = create_primitive_array::<Float32Type>(size, null_density);
-    Arc::new(array)
-}
-
-fn bench_add(arr_a: &ArrayRef, arr_b: &ArrayRef) {
-    let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
-    let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
-    criterion::black_box(add(arr_a, arr_b).unwrap());
-}
-
-fn bench_subtract(arr_a: &ArrayRef, arr_b: &ArrayRef) {
-    let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
-    let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
-    criterion::black_box(subtract(arr_a, arr_b).unwrap());
-}
-
-fn bench_multiply(arr_a: &ArrayRef, arr_b: &ArrayRef) {
-    let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
-    let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
-    criterion::black_box(multiply(arr_a, arr_b).unwrap());
-}
-
-fn bench_divide(arr_a: &ArrayRef, arr_b: &ArrayRef) {
-    let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
-    let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
-    criterion::black_box(divide_checked(arr_a, arr_b).unwrap());
-}
-
-fn bench_divide_unchecked(arr_a: &ArrayRef, arr_b: &ArrayRef) {
-    let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
-    let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
-    criterion::black_box(divide(arr_a, arr_b).unwrap());
-}
-
-fn bench_divide_scalar(array: &ArrayRef, divisor: f32) {
-    let array = array.as_any().downcast_ref::<Float32Array>().unwrap();
-    criterion::black_box(divide_scalar(array, divisor).unwrap());
-}
-
-fn bench_modulo(arr_a: &ArrayRef, arr_b: &ArrayRef) {
-    let arr_a = arr_a.as_any().downcast_ref::<Float32Array>().unwrap();
-    let arr_b = arr_b.as_any().downcast_ref::<Float32Array>().unwrap();
-    criterion::black_box(modulus(arr_a, arr_b).unwrap());
-}
-
-fn bench_modulo_scalar(array: &ArrayRef, divisor: f32) {
-    let array = array.as_any().downcast_ref::<Float32Array>().unwrap();
-    criterion::black_box(modulus_scalar(array, divisor).unwrap());
-}
-
 fn add_benchmark(c: &mut Criterion) {
     const BATCH_SIZE: usize = 64 * 1024;
-    let arr_a = create_array(BATCH_SIZE, false);
-    let arr_b = create_array(BATCH_SIZE, false);
-    let scalar = seedable_rng().gen();
-
-    c.bench_function("add", |b| b.iter(|| bench_add(&arr_a, &arr_b)));
-    c.bench_function("subtract", |b| b.iter(|| bench_subtract(&arr_a, 
&arr_b)));
-    c.bench_function("multiply", |b| b.iter(|| bench_multiply(&arr_a, 
&arr_b)));
-    c.bench_function("divide", |b| b.iter(|| bench_divide(&arr_a, &arr_b)));
-    c.bench_function("divide_unchecked", |b| {
-        b.iter(|| bench_divide_unchecked(&arr_a, &arr_b))
-    });
-    c.bench_function("divide_scalar", |b| {
-        b.iter(|| bench_divide_scalar(&arr_a, scalar))
-    });
-    c.bench_function("modulo", |b| b.iter(|| bench_modulo(&arr_a, &arr_b)));
-    c.bench_function("modulo_scalar", |b| {
-        b.iter(|| bench_modulo_scalar(&arr_a, scalar))
-    });
-
-    let arr_a_nulls = create_array(BATCH_SIZE, true);
-    let arr_b_nulls = create_array(BATCH_SIZE, true);
-    c.bench_function("add_nulls", |b| {
-        b.iter(|| bench_add(&arr_a_nulls, &arr_b_nulls))
-    });
-    c.bench_function("divide_nulls", |b| {
-        b.iter(|| bench_divide(&arr_a_nulls, &arr_b_nulls))
-    });
-    c.bench_function("divide_nulls_unchecked", |b| {
-        b.iter(|| bench_divide_unchecked(&arr_a_nulls, &arr_b_nulls))
-    });
-    c.bench_function("divide_scalar_nulls", |b| {
-        b.iter(|| bench_divide_scalar(&arr_a_nulls, scalar))
-    });
-    c.bench_function("modulo_nulls", |b| {
-        b.iter(|| bench_modulo(&arr_a_nulls, &arr_b_nulls))
-    });
-    c.bench_function("modulo_scalar_nulls", |b| {
-        b.iter(|| bench_modulo_scalar(&arr_a_nulls, scalar))
-    });
+    for null_density in [0., 0.1, 0.5, 0.9, 1.0] {
+        let arr_a = create_primitive_array::<Float32Type>(BATCH_SIZE, 
null_density);
+        let arr_b = create_primitive_array::<Float32Type>(BATCH_SIZE, 
null_density);
+        let scalar = seedable_rng().gen();
+
+        c.bench_function(&format!("add({})", null_density), |b| {
+            b.iter(|| criterion::black_box(add(&arr_a, &arr_b).unwrap()))
+        });
+        c.bench_function(&format!("add_checked({})", null_density), |b| {
+            b.iter(|| criterion::black_box(add_checked(&arr_a, 
&arr_b).unwrap()))
+        });
+        c.bench_function(&format!("add_scalar({})", null_density), |b| {
+            b.iter(|| criterion::black_box(add_scalar(&arr_a, 
scalar).unwrap()))
+        });
+        c.bench_function(&format!("subtract({})", null_density), |b| {
+            b.iter(|| criterion::black_box(subtract(&arr_a, &arr_b).unwrap()))
+        });
+        c.bench_function(&format!("subtract_checked({})", null_density), |b| {
+            b.iter(|| criterion::black_box(subtract_checked(&arr_a, 
&arr_b).unwrap()))
+        });
+        c.bench_function(&format!("subtract_scalar({})", null_density), |b| {
+            b.iter(|| criterion::black_box(subtract_scalar(&arr_a, 
scalar).unwrap()))
+        });
+        c.bench_function(&format!("multiply({})", null_density), |b| {
+            b.iter(|| criterion::black_box(multiply(&arr_a, &arr_b).unwrap()))
+        });
+        c.bench_function(&format!("multiply_checked({})", null_density), |b| {
+            b.iter(|| criterion::black_box(multiply_checked(&arr_a, 
&arr_b).unwrap()))
+        });
+        c.bench_function(&format!("multiply_scalar({})", null_density), |b| {
+            b.iter(|| criterion::black_box(multiply_scalar(&arr_a, 
scalar).unwrap()))
+        });
+        c.bench_function(&format!("divide({})", null_density), |b| {
+            b.iter(|| criterion::black_box(divide(&arr_a, &arr_b).unwrap()))
+        });
+        c.bench_function(&format!("divide_checked({})", null_density), |b| {
+            b.iter(|| criterion::black_box(divide_checked(&arr_a, 
&arr_b).unwrap()))
+        });
+        c.bench_function(&format!("divide_scalar({})", null_density), |b| {
+            b.iter(|| criterion::black_box(divide_scalar(&arr_a, 
scalar).unwrap()))
+        });
+        c.bench_function(&format!("modulo({})", null_density), |b| {
+            b.iter(|| criterion::black_box(modulus(&arr_a, &arr_b).unwrap()))
+        });
+        c.bench_function(&format!("modulo_scalar({})", null_density), |b| {
+            b.iter(|| criterion::black_box(modulus_scalar(&arr_a, 
scalar).unwrap()))
+        });
+    }
 }
 
 criterion_group!(benches, add_benchmark);
diff --git a/arrow/src/array/iterator.rs b/arrow/src/array/iterator.rs
index 4269e9962..e64712fa8 100644
--- a/arrow/src/array/iterator.rs
+++ b/arrow/src/array/iterator.rs
@@ -24,8 +24,51 @@ use super::{
     PrimitiveArray,
 };
 
-/// an iterator that returns Some(T) or None, that can be used on any 
[`ArrayAccessor`]
-// Note: This implementation is based on std's [Vec]s' [IntoIter].
+/// An iterator that returns Some(T) or None, that can be used on any 
[`ArrayAccessor`]
+///
+/// # Performance
+///
+/// [`ArrayIter`] provides an idiomatic way to iterate over an array, however, 
this
+/// comes at the cost of performance. In particular the interleaved handling of
+/// the null mask is often sub-optimal.
+///
+/// If performing an infallible operation, it is typically faster to perform 
the operation
+/// on every index of the array, and handle the null mask separately. For 
[`PrimitiveArray`]
+/// this functionality is provided by [`compute::unary`]
+///
+/// ```
+/// # use arrow::array::PrimitiveArray;
+/// # use arrow::compute::unary;
+/// # use arrow::datatypes::Int32Type;
+///
+/// fn add(a: &PrimitiveArray<Int32Type>, b: i32) -> PrimitiveArray<Int32Type> 
{
+///     unary(a, |a| a + b)
+/// }
+/// ```
+///
+/// If performing a fallible operation, it isn't possible to perform the 
operation independently
+/// of the null mask, as this might result in a spurious failure on a null 
index. However,
+/// there are more efficient ways to iterate over just the non-null indices, 
this functionality
+/// is provided by [`compute::try_unary`]
+///
+/// ```
+/// # use arrow::array::PrimitiveArray;
+/// # use arrow::compute::try_unary;
+/// # use arrow::datatypes::Int32Type;
+/// # use arrow::error::{ArrowError, Result};
+///
+/// fn checked_add(a: &PrimitiveArray<Int32Type>, b: i32) -> 
Result<PrimitiveArray<Int32Type>> {
+///     try_unary(a, |a| {
+///         a.checked_add(b).ok_or_else(|| {
+///             ArrowError::CastError(format!("overflow adding {} to {}", a, 
b))
+///         })
+///     })
+/// }
+/// ```
+///
+/// [`PrimitiveArray`]: [crate::array::PrimitiveArray]
+/// [`compute::unary`]: [crate::compute::unary]
+/// [`compute::try_unary`]: [crate::compute::try_unary]
 #[derive(Debug)]
 pub struct ArrayIter<T: ArrayAccessor> {
     array: T,
diff --git a/arrow/src/compute/kernels/arithmetic.rs 
b/arrow/src/compute/kernels/arithmetic.rs
index 9bf4b00c3..17850f2a8 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -31,12 +31,11 @@ use crate::buffer::Buffer;
 #[cfg(feature = "simd")]
 use crate::buffer::MutableBuffer;
 use crate::compute::kernels::arity::unary;
-use crate::compute::unary_dyn;
 use crate::compute::util::combine_option_bitmap;
+use crate::compute::{binary, try_binary, unary_dyn};
 use crate::datatypes::{
-    native_op::ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, 
DataType,
-    Date32Type, Date64Type, IntervalDayTimeType, IntervalMonthDayNanoType, 
IntervalUnit,
-    IntervalYearMonthType,
+    native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, 
Date64Type,
+    IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, 
IntervalYearMonthType,
 };
 use crate::datatypes::{
     Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, 
UInt16Type,
@@ -74,33 +73,7 @@ where
         ));
     }
 
-    let null_bit_buffer =
-        combine_option_bitmap(&[left.data_ref(), right.data_ref()], 
left.len())?;
-
-    let values = left
-        .values()
-        .iter()
-        .zip(right.values().iter())
-        .map(|(l, r)| op(*l, *r));
-    // JUSTIFICATION
-    //  Benefit
-    //      ~60% speedup
-    //  Soundness
-    //      `values` is an iterator with a known size from a PrimitiveArray
-    let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
-
-    let data = unsafe {
-        ArrayData::new_unchecked(
-            LT::DATA_TYPE,
-            left.len(),
-            None,
-            null_bit_buffer,
-            0,
-            vec![buffer],
-            vec![],
-        )
-    };
-    Ok(PrimitiveArray::<LT>::from(data))
+    Ok(binary(left, right, op))
 }
 
 /// This is similar to `math_op` as it performs given operation between two 
input primitive arrays.
@@ -122,85 +95,11 @@ where
         ));
     }
 
-    let left_iter = ArrayIter::new(left);
-    let right_iter = ArrayIter::new(right);
-
-    let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> = 
left_iter
-        .into_iter()
-        .zip(right_iter.into_iter())
-        .map(|(l, r)| {
-            if let (Some(l), Some(r)) = (l, r) {
-                let result = op(l, r);
-                if let Some(r) = result {
-                    Ok(Some(r))
-                } else {
-                    // Overflow
-                    Err(ArrowError::ComputeError(format!(
-                        "Overflow happened on: {:?}, {:?}",
-                        l, r
-                    )))
-                }
-            } else {
-                Ok(None)
-            }
-        })
-        .collect();
-
-    let values = values?;
-
-    Ok(PrimitiveArray::<LT>::from_iter(values))
-}
-
-/// This is similar to `math_checked_op` but just for divide op.
-fn math_checked_divide<LT, RT, F>(
-    left: &PrimitiveArray<LT>,
-    right: &PrimitiveArray<RT>,
-    op: F,
-) -> Result<PrimitiveArray<LT>>
-where
-    LT: ArrowNumericType,
-    RT: ArrowNumericType,
-    RT::Native: One + Zero,
-    F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
-{
-    if left.len() != right.len() {
-        return Err(ArrowError::ComputeError(
-            "Cannot perform math operation on arrays of different 
length".to_string(),
-        ));
-    }
-
-    let left_iter = ArrayIter::new(left);
-    let right_iter = ArrayIter::new(right);
-
-    let values: Result<Vec<Option<<LT as ArrowPrimitiveType>::Native>>> = 
left_iter
-        .into_iter()
-        .zip(right_iter.into_iter())
-        .map(|(l, r)| {
-            if let (Some(l), Some(r)) = (l, r) {
-                let result = op(l, r);
-                if let Some(r) = result {
-                    Ok(Some(r))
-                } else if r.is_zero() {
-                    Err(ArrowError::ComputeError(format!(
-                        "DivideByZero on: {:?}, {:?}",
-                        l, r
-                    )))
-                } else {
-                    // Overflow
-                    Err(ArrowError::ComputeError(format!(
-                        "Overflow happened on: {:?}, {:?}",
-                        l, r
-                    )))
-                }
-            } else {
-                Ok(None)
-            }
+    try_binary(left, right, |a, b| {
+        op(a, b).ok_or_else(|| {
+            ArrowError::ComputeError(format!("Overflow happened on: {:?}, 
{:?}", a, b))
         })
-        .collect();
-
-    let values = values?;
-
-    Ok(PrimitiveArray::<LT>::from_iter(values))
+    })
 }
 
 /// Helper function for operations where a valid `0` on the right array should
@@ -211,15 +110,16 @@ where
 /// This function errors if:
 /// * the arrays have different lengths
 /// * there is an element where both left and right values are valid and the 
right value is `0`
-fn math_checked_divide_op<T, F>(
-    left: &PrimitiveArray<T>,
-    right: &PrimitiveArray<T>,
+fn math_checked_divide_op<LT, RT, F>(
+    left: &PrimitiveArray<LT>,
+    right: &PrimitiveArray<RT>,
     op: F,
-) -> Result<PrimitiveArray<T>>
+) -> Result<PrimitiveArray<LT>>
 where
-    T: ArrowNumericType,
-    T::Native: One + Zero,
-    F: Fn(T::Native, T::Native) -> T::Native,
+    LT: ArrowNumericType,
+    RT: ArrowNumericType,
+    RT::Native: One + Zero,
+    F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
 {
     if left.len() != right.len() {
         return Err(ArrowError::ComputeError(
@@ -227,16 +127,18 @@ where
         ));
     }
 
-    let null_bit_buffer =
-        combine_option_bitmap(&[left.data_ref(), right.data_ref()], 
left.len())?;
-
-    math_checked_divide_op_on_iters(
-        left.into_iter(),
-        right.into_iter(),
-        op,
-        left.len(),
-        null_bit_buffer,
-    )
+    try_binary(left, right, |l, r| {
+        if r.is_zero() {
+            Err(ArrowError::DivideByZero)
+        } else {
+            op(l, r).ok_or_else(|| {
+                ArrowError::ComputeError(format!(
+                    "Overflow happened on: {:?}, {:?}",
+                    l, r
+                ))
+            })
+        }
+    })
 }
 
 /// Helper function for operations where a valid `0` on the right array should
@@ -900,7 +802,7 @@ pub fn add_scalar<T>(
     scalar: T::Native,
 ) -> Result<PrimitiveArray<T>>
 where
-    T: datatypes::ArrowNumericType,
+    T: ArrowNumericType,
     T::Native: Add<Output = T::Native>,
 {
     Ok(unary(array, |value| value + scalar))
@@ -911,7 +813,7 @@ where
 /// the scalar, or a `DictionaryArray` of the value type same as the scalar.
 pub fn add_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> 
Result<ArrayRef>
 where
-    T: datatypes::ArrowNumericType,
+    T: ArrowNumericType,
     T::Native: Add<Output = T::Native>,
 {
     unary_dyn::<_, T>(array, |value| value + scalar)
@@ -927,7 +829,7 @@ pub fn subtract<T>(
     right: &PrimitiveArray<T>,
 ) -> Result<PrimitiveArray<T>>
 where
-    T: datatypes::ArrowNumericType,
+    T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp,
 {
     math_op(left, right, |a, b| a.sub_wrapping(b))
@@ -943,7 +845,7 @@ pub fn subtract_checked<T>(
     right: &PrimitiveArray<T>,
 ) -> Result<PrimitiveArray<T>>
 where
-    T: datatypes::ArrowNumericType,
+    T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp,
 {
     math_checked_op(left, right, |a, b| a.sub_checked(b))
@@ -1033,7 +935,7 @@ pub fn multiply<T>(
     right: &PrimitiveArray<T>,
 ) -> Result<PrimitiveArray<T>>
 where
-    T: datatypes::ArrowNumericType,
+    T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp,
 {
     math_op(left, right, |a, b| a.mul_wrapping(b))
@@ -1049,7 +951,7 @@ pub fn multiply_checked<T>(
     right: &PrimitiveArray<T>,
 ) -> Result<PrimitiveArray<T>>
 where
-    T: datatypes::ArrowNumericType,
+    T: ArrowNumericType,
     T::Native: ArrowNativeTypeOp,
 {
     math_checked_op(left, right, |a, b| a.mul_checked(b))
@@ -1100,7 +1002,7 @@ where
 /// the scalar, or a `DictionaryArray` of the value type same as the scalar.
 pub fn multiply_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> 
Result<ArrayRef>
 where
-    T: datatypes::ArrowNumericType,
+    T: ArrowNumericType,
     T::Native: Add<Output = T::Native>
         + Sub<Output = T::Native>
         + Mul<Output = T::Native>
@@ -1120,7 +1022,7 @@ pub fn modulus<T>(
     right: &PrimitiveArray<T>,
 ) -> Result<PrimitiveArray<T>>
 where
-    T: datatypes::ArrowNumericType,
+    T: ArrowNumericType,
     T::Native: Rem<Output = T::Native> + Zero + One,
 {
     #[cfg(feature = "simd")]
@@ -1128,7 +1030,7 @@ where
         a % b
     });
     #[cfg(not(feature = "simd"))]
-    return math_checked_divide_op(left, right, |a, b| a % b);
+    return math_checked_divide_op(left, right, |a, b| Some(a % b));
 }
 
 /// Perform `left / right` operation on two arrays. If either left or right 
value is null
@@ -1148,7 +1050,7 @@ where
     #[cfg(feature = "simd")]
     return simd_checked_divide_op(&left, &right, simd_checked_divide::<T>, |a, 
b| a / b);
     #[cfg(not(feature = "simd"))]
-    return math_checked_divide(left, right, |a, b| a.div_checked(b));
+    return math_checked_divide_op(left, right, |a, b| a.div_checked(b));
 }
 
 /// Perform `left / right` operation on two arrays. If either left or right 
value is null
@@ -1162,7 +1064,7 @@ pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> 
Result<ArrayRef> {
         _ => {
             downcast_primitive_array!(
                 (left, right) => {
-                    math_checked_divide_op(left, right, |a, b| a / b).map(|a| 
Arc::new(a) as ArrayRef)
+                    math_checked_divide_op(left, right, |a, b| Some(a / 
b)).map(|a| Arc::new(a) as ArrayRef)
                 }
                 _ => Err(ArrowError::CastError(format!(
                     "Unsupported data type {}, {}",
@@ -1199,7 +1101,7 @@ pub fn modulus_scalar<T>(
     modulo: T::Native,
 ) -> Result<PrimitiveArray<T>>
 where
-    T: datatypes::ArrowNumericType,
+    T: ArrowNumericType,
     T::Native: Rem<Output = T::Native> + Zero,
 {
     if modulo.is_zero() {
@@ -1217,7 +1119,7 @@ pub fn divide_scalar<T>(
     divisor: T::Native,
 ) -> Result<PrimitiveArray<T>>
 where
-    T: datatypes::ArrowNumericType,
+    T: ArrowNumericType,
     T::Native: Div<Output = T::Native> + Zero,
 {
     if divisor.is_zero() {
@@ -1232,7 +1134,7 @@ where
 /// same as the scalar, or a `DictionaryArray` of the value type same as the 
scalar.
 pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) -> 
Result<ArrayRef>
 where
-    T: datatypes::ArrowNumericType,
+    T: ArrowNumericType,
     T::Native: Div<Output = T::Native> + Zero,
 {
     if divisor.is_zero() {
diff --git a/arrow/src/compute/kernels/arity.rs 
b/arrow/src/compute/kernels/arity.rs
index 1251baf52..ee3ff5e23 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -17,37 +17,41 @@
 
 //! Defines kernels suitable to perform operations to primitive arrays.
 
-use crate::array::{Array, ArrayData, ArrayRef, DictionaryArray, 
PrimitiveArray};
+use crate::array::{
+    Array, ArrayData, ArrayRef, BufferBuilder, DictionaryArray, PrimitiveArray,
+};
 use crate::buffer::Buffer;
+use crate::compute::util::combine_option_bitmap;
 use crate::datatypes::{ArrowNumericType, ArrowPrimitiveType};
 use crate::downcast_dictionary_array;
 use crate::error::{ArrowError, Result};
+use crate::util::bit_iterator::try_for_each_valid_idx;
 use std::sync::Arc;
 
 #[inline]
-fn into_primitive_array_data<I: ArrowPrimitiveType, O: ArrowPrimitiveType>(
-    array: &PrimitiveArray<I>,
+unsafe fn build_primitive_array<O: ArrowPrimitiveType>(
+    len: usize,
     buffer: Buffer,
-) -> ArrayData {
-    let data = array.data();
-    unsafe {
-        ArrayData::new_unchecked(
-            O::DATA_TYPE,
-            array.len(),
-            Some(data.null_count()),
-            data.null_buffer()
-                .map(|b| b.bit_slice(array.offset(), array.len())),
-            0,
-            vec![buffer],
-            vec![],
-        )
-    }
+    null_count: usize,
+    null_buffer: Option<Buffer>,
+) -> PrimitiveArray<O> {
+    PrimitiveArray::from(ArrayData::new_unchecked(
+        O::DATA_TYPE,
+        len,
+        Some(null_count),
+        null_buffer,
+        0,
+        vec![buffer],
+        vec![],
+    ))
 }
 
 /// Applies an unary and infallible function to a primitive array.
 /// This is the fastest way to perform an operation on a primitive array when
-/// the benefits of a vectorized operation outweights the cost of branching 
nulls and non-nulls.
+/// the benefits of a vectorized operation outweigh the cost of branching 
nulls and non-nulls.
+///
 /// # Implementation
+///
 /// This will apply the function for all values, including those on null slots.
 /// This implies that the operation must be infallible for any value of the 
corresponding type
 /// or this function may panic.
@@ -68,6 +72,14 @@ where
     O: ArrowPrimitiveType,
     F: Fn(I::Native) -> O::Native,
 {
+    let data = array.data();
+    let len = data.len();
+    let null_count = data.null_count();
+
+    let null_buffer = data
+        .null_buffer()
+        .map(|b| b.bit_slice(data.offset(), data.len()));
+
     let values = array.values().iter().map(|v| op(*v));
     // JUSTIFICATION
     //  Benefit
@@ -75,9 +87,40 @@ where
     //  Soundness
     //      `values` is an iterator with a known size because arrays are sized.
     let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
+    unsafe { build_primitive_array(len, buffer, null_count, null_buffer) }
+}
+
+/// Applies a unary and fallible function to all valid values in a primitive 
array
+///
+/// This is unlike [`unary`] which will apply an infallible function to all 
rows regardless
+/// of validity, in many cases this will be significantly faster and should be 
preferred
+/// if `op` is infallible.
+///
+/// Note: LLVM is currently unable to effectively vectorize fallible operations
+pub fn try_unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> 
Result<PrimitiveArray<O>>
+where
+    I: ArrowPrimitiveType,
+    O: ArrowPrimitiveType,
+    F: Fn(I::Native) -> Result<O::Native>,
+{
+    let len = array.len();
+    let null_count = array.null_count();
+
+    let mut buffer = BufferBuilder::<O::Native>::new(len);
+    buffer.append_n_zeroed(array.len());
+    let slice = buffer.as_slice_mut();
+
+    let null_buffer = array
+        .data_ref()
+        .null_buffer()
+        .map(|b| b.bit_slice(array.offset(), array.len()));
 
-    let data = into_primitive_array_data::<_, O>(array, buffer);
-    PrimitiveArray::<O>::from(data)
+    try_for_each_valid_idx(array.len(), 0, null_count, null_buffer.as_deref(), 
|idx| {
+        unsafe { *slice.get_unchecked_mut(idx) = 
op(array.value_unchecked(idx))? };
+        Ok::<_, ArrowError>(())
+    })?;
+
+    Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, 
null_buffer) })
 }
 
 /// A helper function that applies an unary function to a dictionary array 
with primitive value type.
@@ -119,6 +162,101 @@ where
     }
 }
 
+/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in 
`0..len`, collecting
+/// the results in a [`PrimitiveArray`]. If any index is null in either `a` or 
`b`, the
+/// corresponding index in the result will also be null
+///
+/// Like [`unary`] the provided function is evaluated for every index, 
ignoring validity. This
+/// is beneficial when the cost of the operation is low compared to the cost 
of branching, and
+/// especially when the operation can be vectorised, however, requires `op` to 
be infallible
+/// for all possible values of its inputs
+///
+/// # Panic
+///
+/// Panics if the arrays have different lengths
+pub fn binary<A, B, F, O>(
+    a: &PrimitiveArray<A>,
+    b: &PrimitiveArray<B>,
+    op: F,
+) -> PrimitiveArray<O>
+where
+    A: ArrowPrimitiveType,
+    B: ArrowPrimitiveType,
+    O: ArrowPrimitiveType,
+    F: Fn(A::Native, B::Native) -> O::Native,
+{
+    assert_eq!(a.len(), b.len());
+    let len = a.len();
+
+    if a.is_empty() {
+        return PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE));
+    }
+
+    let null_buffer = combine_option_bitmap(&[a.data(), b.data()], 
len).unwrap();
+    let null_count = null_buffer
+        .as_ref()
+        .map(|x| len - x.count_set_bits())
+        .unwrap_or_default();
+
+    let values = a.values().iter().zip(b.values()).map(|(l, r)| op(*l, *r));
+    // JUSTIFICATION
+    //  Benefit
+    //      ~60% speedup
+    //  Soundness
+    //      `values` is an iterator with a known size from a PrimitiveArray
+    let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
+
+    unsafe { build_primitive_array(len, buffer, null_count, null_buffer) }
+}
+
+/// Applies the provided fallible binary operation across `a` and `b`, 
returning any error,
+/// and collecting the results into a [`PrimitiveArray`]. If any index is null 
in either `a`
+/// or `b`, the corresponding index in the result will also be null
+///
+/// Like [`try_unary`] the function is only evaluated for non-null indices
+///
+/// # Panic
+///
+/// Panics if the arrays have different lengths
+pub fn try_binary<A, B, F, O>(
+    a: &PrimitiveArray<A>,
+    b: &PrimitiveArray<B>,
+    op: F,
+) -> Result<PrimitiveArray<O>>
+where
+    A: ArrowPrimitiveType,
+    B: ArrowPrimitiveType,
+    O: ArrowPrimitiveType,
+    F: Fn(A::Native, B::Native) -> Result<O::Native>,
+{
+    assert_eq!(a.len(), b.len());
+    let len = a.len();
+
+    if a.is_empty() {
+        return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE)));
+    }
+
+    let null_buffer = combine_option_bitmap(&[a.data(), b.data()], 
len).unwrap();
+    let null_count = null_buffer
+        .as_ref()
+        .map(|x| len - x.count_set_bits())
+        .unwrap_or_default();
+
+    let mut buffer = BufferBuilder::<O::Native>::new(len);
+    buffer.append_n_zeroed(len);
+    let slice = buffer.as_slice_mut();
+
+    try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| {
+        unsafe {
+            *slice.get_unchecked_mut(idx) =
+                op(a.value_unchecked(idx), b.value_unchecked(idx))?
+        };
+        Ok::<_, ArrowError>(())
+    })?;
+
+    Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, 
null_buffer) })
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/arrow/src/util/bit_iterator.rs b/arrow/src/util/bit_iterator.rs
index bba9dac60..ceefaa860 100644
--- a/arrow/src/util/bit_iterator.rs
+++ b/arrow/src/util/bit_iterator.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 use crate::util::bit_chunk_iterator::{UnalignedBitChunk, 
UnalignedBitChunkIterator};
+use std::result::Result;
 
 /// Iterator of contiguous ranges of set bits within a provided packed bitmask
 ///
@@ -157,4 +158,45 @@ impl<'a> Iterator for BitIndexIterator<'a> {
     }
 }
 
+/// Calls the provided closure for each index in the provided null mask that 
is set,
+/// using an adaptive strategy based on the null count
+///
+/// Ideally this would be encapsulated in an [`Iterator`] that would determine 
the optimal
+/// strategy up front, and then yield indexes based on this.
+///
+/// Unfortunately, external iteration based on the resulting [`Iterator`] 
would match the strategy
+/// variant on each call to [`Iterator::next`], and LLVM generally cannot 
eliminate this.
+///
+/// One solution to this might be internal iteration, e.g. 
[`Iterator::try_fold`], however,
+/// it is currently [not possible] to override this for custom iterators in 
stable Rust.
+///
+/// As such this is the next best option
+///
+/// [not possible]: https://github.com/rust-lang/rust/issues/69595
+#[inline]
+pub fn try_for_each_valid_idx<E, F: FnMut(usize) -> Result<(), E>>(
+    len: usize,
+    offset: usize,
+    null_count: usize,
+    nulls: Option<&[u8]>,
+    f: F,
+) -> Result<(), E> {
+    let valid_count = len - null_count;
+
+    if valid_count == len {
+        (0..len).try_for_each(f)
+    } else if null_count != len {
+        let selectivity = valid_count as f64 / len as f64;
+        if selectivity > 0.8 {
+            BitSliceIterator::new(nulls.unwrap(), offset, len)
+                .flat_map(|(start, end)| start..end)
+                .try_for_each(f)
+        } else {
+            BitIndexIterator::new(nulls.unwrap(), offset, len).try_for_each(f)
+        }
+    } else {
+        Ok(())
+    }
+}
+
 // Note: tests located in filter module

Reply via email to