nevi-me commented on a change in pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#discussion_r511642958



##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = 
Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed 
fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> 
DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> 
DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different 
positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the 
array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> 
Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, 
right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, 
right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string 
dictionary arrays"

Review comment:
       Would we incur a high cost if we cast dictionaries to primitives, then 
compared the primitives?

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = 
Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed 
fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> 
DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> 
DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different 
positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the 
array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> 
Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, 
right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, 
right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string 
dictionary arrays"
+                        .to_string(),
+                ));
             }
-
-            // convert to bits with canonical pattern for NaN
-            let a = if a.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                a.to_bits()
-            };
-            let b = if b.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                b.to_bits()
-            };
-
-            if a == b {
-                // Equal or both NaN
-                Ordering::Equal
-            } else if a < b {
-                // (-0.0, 0.0) or (!NaN, NaN)
-                Ordering::Less
-            } else {
-                // (0.0, -0.0) or (NaN, !NaN)
-                Ordering::Greater
+            match (key_type_lhs.as_ref(), key_type_rhs.as_ref()) {
+                (a, b) if a != b => {
+                    return Err(ArrowError::InvalidArgumentError(
+                        "Can't compare arrays of different types".to_string(),
+                    ));
+                }
+                (UInt8, UInt8) => compare_dict_string::<UInt8Type>(left, 
right),
+                (UInt16, UInt16) => compare_dict_string::<UInt16Type>(left, 
right),
+                (UInt32, UInt32) => compare_dict_string::<UInt32Type>(left, 
right),
+                (UInt64, UInt64) => compare_dict_string::<UInt64Type>(left, 
right),
+                (Int8, Int8) => compare_dict_string::<Int8Type>(left, right),
+                (Int16, Int16) => compare_dict_string::<Int16Type>(left, 
right),
+                (Int32, Int32) => compare_dict_string::<Int32Type>(left, 
right),
+                (Int64, Int64) => compare_dict_string::<Int64Type>(left, 
right),
+                _ => todo!(),
             }
         }
-    };
+        _ => todo!(),

Review comment:
       We can add a helpful err instead of panicking

##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -453,49 +466,46 @@ pub fn lexsort(columns: &[SortColumn]) -> 
Result<Vec<ArrayRef>> {
 /// Sort elements lexicographically from a list of `ArrayRef` into an unsigned 
integer
 /// (`UInt32Array`) of indices.
 pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result<UInt32Array> {
+    if columns.len() == 0 {
+        return Err(ArrowError::InvalidArgumentError(
+            "Sort requires at least one column".to_string(),
+        ));
+    }
     if columns.len() == 1 {
         // fallback to non-lexical sort
         let column = &columns[0];
         return sort_to_indices(&column.values, column.options);
     }
 
-    let mut row_count = None;
+    let row_count = columns[0].values.len();
+    if columns.iter().any(|item| item.values.len() != row_count) {
+        return Err(ArrowError::ComputeError(
+            "lexical sort columns have different row counts".to_string(),
+        ));
+    };
+
     // convert ArrayRefs to OrdArray trait objects and perform row count check
     let flat_columns = columns
         .iter()
-        .map(|column| -> Result<(&Array, Box<OrdArray>, SortOptions)> {
-            // row count check
-            let curr_row_count = column.values.len() - column.values.offset();
-            match row_count {
-                None => {
-                    row_count = Some(curr_row_count);
-                }
-                Some(cnt) => {
-                    if curr_row_count != cnt {
-                        return Err(ArrowError::ComputeError(
-                            "lexical sort columns have different row 
counts".to_string(),
-                        ));
-                    }
-                }
-            }
-            // flatten and convert to OrdArray
+        .map(|column| -> Result<(&Array, DynComparator, SortOptions)> {
+            // flatten and convert build comparators
             Ok((
                 column.values.as_ref(),
-                as_ordarray(&column.values)?,
+                build_compare(column.values.as_ref(), column.values.as_ref())?,

Review comment:
       I'm happy with this approach, very creative




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to