jorgecarleitao commented on a change in pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#discussion_r513164636



##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = 
Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed 
fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> 
DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> 
DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different 
positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the 
array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> 
Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, 
right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, 
right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string 
dictionary arrays"

Review comment:
       This code is as it was here.
   
   FYI I did tried to make this for arbitrary types, and I was very (very) 
close from having it one done, but it requires some unsafe usage, and so I left 
it for a separate PR. The relevant thread is here: 
https://users.rust-lang.org/t/how-to-move-values-to-closure-that-indirectly-depends-on-them/50586

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = 
Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed 
fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> 
DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> 
DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different 
positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the 
array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> 
Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, 
right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, 
right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string 
dictionary arrays"

Review comment:
       This code is as it was here.
   
   FYI I did try to make this for arbitrary types, and I was very (very) close 
from having it one done, but it requires some unsafe usage, and so I left it 
for a separate PR. The relevant thread is here: 
https://users.rust-lang.org/t/how-to-move-values-to-closure-that-indirectly-depends-on-them/50586




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to