shepmaster commented on a change in pull request #1074:
URL: https://github.com/apache/arrow-rs/pull/1074#discussion_r774003760



##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: 
StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{

Review comment:
       Should `$T` and `$TT` be 
[`ty`](https://doc.rust-lang.org/stable/reference/macros-by-example.html#metavariables)?

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: 
StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP::<$TT>(left, right)

Review comment:
       `$OP::<$TT>` could probably be fused as one `expr` macro argument

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: 
StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP::<$TT>(left, right)
+    }};
+}
+
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(format!(
+                "Can not convert scalar {:?} to i128",
+                $RIGHT
+            ))
+        });
+        match ($LEFT.data_type(), $RIGHT::Arrow) {
+            // (DataType::Boolean, DataType::Boolean) => {
+            //     typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+            // }
+            (DataType::Int8, DataType::Int8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int8Array, $OP, Int8Type)
+            }
+            (DataType::Int16, DataType::Int16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int16Array, $OP, Int16Type)
+            }
+            (DataType::Int32, DataType::Int32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int32Array, $OP, Int32Type)
+            }
+            (DataType::Int64, DataType::Int64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int64Array, $OP, Int64Type)
+            }
+            (DataType::UInt8, DataType::UInt8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt8Array, $OP, UInt8Type)
+            }
+            (DataType::UInt16, DataType::UInt16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt16Array, $OP, UInt16Type)
+            }
+            (DataType::UInt32, DataType::UInt32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt32Array, $OP, UInt32Type)
+            }
+            (DataType::UInt64, DataType::UInt64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt64Array, $OP, UInt64Type)
+            }
+            (DataType::Float32, DataType::Float32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float32Array, $OP, Float32Type)
+            }
+            (DataType::Float64, DataType::Float64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float64Array, $OP, Float64Type)
+            }
+            (DataType::Utf8, DataType::Utf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, StringArray, $OP, i32)
+            }
+            (DataType::LargeUtf8, DataType::LargeUtf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, LargeStringArray, $OP, i64)
+            }
+            (DataType::Dictionary(DataType::UInt8, DataType::UInt8), 
DataType::UInt8) => {
+                let values_comp =
+                    typed_compare_scalar!($LEFT.values(), $RIGHT, eq_scalar);
+                unpack_dict_comparison($LEFT, values_comp)
+            }
+            (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+                "Comparing arrays of type {} is not yet implemented",
+                t1
+            ))),
+            (t1, t2) => Err(ArrowError::CastError(format!(
+                "Cannot compare an array with a scalar of different type ({} 
and {})",
+                t1, t2
+            ))),
+        }
+    }};
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimtiveArrays, and DictionaryArrays that have primitive 
values

Review comment:
       ```suggestion
   /// value. Supports PrimitiveArrays, and DictionaryArrays that have 
primitive values
   ```

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: 
StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP::<$TT>(left, right)
+    }};
+}
+
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(format!(
+                "Can not convert scalar {:?} to i128",
+                $RIGHT
+            ))
+        });
+        match ($LEFT.data_type(), $RIGHT::Arrow) {
+            // (DataType::Boolean, DataType::Boolean) => {
+            //     typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+            // }
+            (DataType::Int8, DataType::Int8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int8Array, $OP, Int8Type)
+            }
+            (DataType::Int16, DataType::Int16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int16Array, $OP, Int16Type)
+            }
+            (DataType::Int32, DataType::Int32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int32Array, $OP, Int32Type)
+            }
+            (DataType::Int64, DataType::Int64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int64Array, $OP, Int64Type)
+            }
+            (DataType::UInt8, DataType::UInt8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt8Array, $OP, UInt8Type)
+            }
+            (DataType::UInt16, DataType::UInt16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt16Array, $OP, UInt16Type)
+            }
+            (DataType::UInt32, DataType::UInt32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt32Array, $OP, UInt32Type)
+            }
+            (DataType::UInt64, DataType::UInt64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt64Array, $OP, UInt64Type)
+            }
+            (DataType::Float32, DataType::Float32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float32Array, $OP, Float32Type)
+            }
+            (DataType::Float64, DataType::Float64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float64Array, $OP, Float64Type)
+            }
+            (DataType::Utf8, DataType::Utf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, StringArray, $OP, i32)
+            }
+            (DataType::LargeUtf8, DataType::LargeUtf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, LargeStringArray, $OP, i64)
+            }
+            (DataType::Dictionary(DataType::UInt8, DataType::UInt8), 
DataType::UInt8) => {
+                let values_comp =
+                    typed_compare_scalar!($LEFT.values(), $RIGHT, eq_scalar);
+                unpack_dict_comparison($LEFT, values_comp)
+            }
+            (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+                "Comparing arrays of type {} is not yet implemented",
+                t1
+            ))),
+            (t1, t2) => Err(ArrowError::CastError(format!(
+                "Cannot compare an array with a scalar of different type ({} 
and {})",
+                t1, t2
+            ))),
+        }
+    }};
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimtiveArrays, and DictionaryArrays that have primitive 
values
+pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
+where
+    T: IntoArrowNumericType + TryInto<i128> + Copy + std::fmt::Debug,
+{
+    dyn_compare_scalar!(left, right, eq_scalar)

Review comment:
       I admit this is a drive-by review, but I'm not seeing the benefit of the 
macros here yet. They don't do any repetition reduction. It looks like 
`dyn_compare_scalar` could be inlined and `dyn_cmp_scalar` could be a regular 
function.

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: 
StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP::<$TT>(left, right)
+    }};
+}
+
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right = $RIGHT.try_into().map_err(|_| {

Review comment:
       I'm missing where this `right` value is used...

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: 
StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP::<$TT>(left, right)
+    }};
+}
+
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(format!(
+                "Can not convert scalar {:?} to i128",
+                $RIGHT
+            ))
+        });
+        match ($LEFT.data_type(), $RIGHT::Arrow) {
+            // (DataType::Boolean, DataType::Boolean) => {
+            //     typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+            // }
+            (DataType::Int8, DataType::Int8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int8Array, $OP, Int8Type)
+            }
+            (DataType::Int16, DataType::Int16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int16Array, $OP, Int16Type)
+            }
+            (DataType::Int32, DataType::Int32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int32Array, $OP, Int32Type)
+            }
+            (DataType::Int64, DataType::Int64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int64Array, $OP, Int64Type)
+            }
+            (DataType::UInt8, DataType::UInt8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt8Array, $OP, UInt8Type)
+            }
+            (DataType::UInt16, DataType::UInt16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt16Array, $OP, UInt16Type)
+            }
+            (DataType::UInt32, DataType::UInt32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt32Array, $OP, UInt32Type)
+            }
+            (DataType::UInt64, DataType::UInt64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt64Array, $OP, UInt64Type)
+            }
+            (DataType::Float32, DataType::Float32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float32Array, $OP, Float32Type)
+            }
+            (DataType::Float64, DataType::Float64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float64Array, $OP, Float64Type)
+            }
+            (DataType::Utf8, DataType::Utf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, StringArray, $OP, i32)
+            }
+            (DataType::LargeUtf8, DataType::LargeUtf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, LargeStringArray, $OP, i64)
+            }
+            (DataType::Dictionary(DataType::UInt8, DataType::UInt8), 
DataType::UInt8) => {
+                let values_comp =
+                    typed_compare_scalar!($LEFT.values(), $RIGHT, eq_scalar);
+                unpack_dict_comparison($LEFT, values_comp)
+            }
+            (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+                "Comparing arrays of type {} is not yet implemented",
+                t1
+            ))),
+            (t1, t2) => Err(ArrowError::CastError(format!(
+                "Cannot compare an array with a scalar of different type ({} 
and {})",
+                t1, t2
+            ))),
+        }
+    }};
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimtiveArrays, and DictionaryArrays that have primitive 
values
+pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
+where
+    T: IntoArrowNumericType + TryInto<i128> + Copy + std::fmt::Debug,
+{
+    dyn_compare_scalar!(left, right, eq_scalar)
+}
+
+/// unpacks the results of comparing left.values (as a boolean)
+///
+/// TODO add example
+///
+fn unpack_dict_comparison<K>(
+    left: &DictionaryArray<K>,
+    dict_comparison: BooleanArray,
+) -> Result<BooleanArray>
+where
+    K: ArrowNumericType,
+{
+    assert_eq!(dict_comparison.len(), left.values().len());

Review comment:
       Why an assertion as opposed to an error or a "no this does not match?"




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to