This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new aed319c10c Scalar arithmetic should return error when overflows. 
(#5811)
aed319c10c is described below

commit aed319c10c87ab6f0490cc2e8657d98e454d7f80
Author: Zhiyuan Zheng <[email protected]>
AuthorDate: Wed Apr 12 02:07:12 2023 +0800

    Scalar arithmetic should return error when overflows. (#5811)
    
    * Scalar arithmetic should return error when overflows.
    
    * Add a few tests.
    
    * add new checked_* ops.
    
    * fix test failure.
---
 datafusion/common/src/scalar.rs | 185 +++++++++++++++++++++++++++++++++++++++-
 1 file changed, 183 insertions(+), 2 deletions(-)

diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 8c26ea9b04..63b6caa623 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -579,6 +579,41 @@ macro_rules! primitive_op {
         }
     };
 }
+macro_rules! primitive_checked_op {
+    ($LEFT:expr, $RIGHT:expr, $SCALAR:ident, $FUNCTION:ident, $OPERATION:tt) 
=> {
+        match ($LEFT, $RIGHT) {
+            (lhs, None) => Ok(ScalarValue::$SCALAR(*lhs)),
+            #[allow(unused_variables)]
+            (None, Some(b)) => {
+                primitive_checked_right!(*b, $OPERATION, $SCALAR)
+            }
+            (Some(a), Some(b)) => {
+                if let Some(value) = (*a).$FUNCTION(*b) {
+                    Ok(ScalarValue::$SCALAR(Some(value)))
+                } else {
+                    Err(DataFusionError::Execution(
+                        "Overflow while calculating ScalarValue.".to_string(),
+                    ))
+                }
+            }
+        }
+    };
+}
+
+macro_rules! primitive_checked_right {
+    ($TERM:expr, -, $SCALAR:ident) => {
+        if let Some(value) = $TERM.checked_neg() {
+            Ok(ScalarValue::$SCALAR(Some(value)))
+        } else {
+            Err(DataFusionError::Execution(
+                "Overflow while calculating ScalarValue.".to_string(),
+            ))
+        }
+    };
+    ($TERM:expr, $OPERATION:tt, $SCALAR:ident) => {
+        primitive_right!($TERM, $OPERATION, $SCALAR)
+    };
+}
 
 macro_rules! primitive_right {
     ($TERM:expr, +, $SCALAR:ident) => {
@@ -625,6 +660,41 @@ macro_rules! unsigned_subtraction_error {
     }};
 }
 
+macro_rules! impl_checked_op {
+    ($LHS:expr, $RHS:expr, $FUNCTION:ident, $OPERATION:tt) => {
+        // Only covering primitive types that support checked_* operands, and 
fall back to raw operation for other types.
+        match ($LHS, $RHS) {
+            (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
+                primitive_checked_op!(lhs, rhs, UInt64, $FUNCTION, $OPERATION)
+            },
+            (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => {
+                primitive_checked_op!(lhs, rhs, Int64, $FUNCTION, $OPERATION)
+            },
+            (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => {
+                primitive_checked_op!(lhs, rhs, UInt32, $FUNCTION, $OPERATION)
+            },
+            (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => {
+                primitive_checked_op!(lhs, rhs, Int32, $FUNCTION, $OPERATION)
+            },
+            (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => {
+                primitive_checked_op!(lhs, rhs, UInt16, $FUNCTION, $OPERATION)
+            },
+            (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => {
+                primitive_checked_op!(lhs, rhs, Int16, $FUNCTION, $OPERATION)
+            },
+            (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => {
+                primitive_checked_op!(lhs, rhs, UInt8, $FUNCTION, $OPERATION)
+            },
+            (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => {
+                primitive_checked_op!(lhs, rhs, Int8, $FUNCTION, $OPERATION)
+            },
+            _ => {
+                impl_op!($LHS, $RHS, $OPERATION)
+            }
+        }
+    };
+}
+
 macro_rules! impl_op {
     ($LHS:expr, $RHS:expr, +) => {
         impl_op_arithmetic!($LHS, $RHS, +)
@@ -1855,11 +1925,21 @@ impl ScalarValue {
         impl_op!(self, rhs, +)
     }
 
+    pub fn add_checked<T: Borrow<ScalarValue>>(&self, other: T) -> 
Result<ScalarValue> {
+        let rhs = other.borrow();
+        impl_checked_op!(self, rhs, checked_add, +)
+    }
+
     pub fn sub<T: Borrow<ScalarValue>>(&self, other: T) -> Result<ScalarValue> 
{
         let rhs = other.borrow();
         impl_op!(self, rhs, -)
     }
 
+    pub fn sub_checked<T: Borrow<ScalarValue>>(&self, other: T) -> 
Result<ScalarValue> {
+        let rhs = other.borrow();
+        impl_checked_op!(self, rhs, checked_sub, -)
+    }
+
     pub fn is_unsigned(&self) -> bool {
         matches!(
             self,
@@ -1926,9 +2006,9 @@ impl ScalarValue {
         }
 
         let distance = if self > other {
-            self.sub(other).ok()?
+            self.sub_checked(other).ok()?
         } else {
-            other.sub(self).ok()?
+            other.sub_checked(self).ok()?
         };
 
         match distance {
@@ -3625,8 +3705,10 @@ mod tests {
     use std::cmp::Ordering;
     use std::sync::Arc;
 
+    use arrow::compute;
     use arrow::compute::kernels;
     use arrow::datatypes::ArrowPrimitiveType;
+    use arrow_array::ArrowNumericType;
     use rand::Rng;
 
     use crate::cast::{as_string_array, as_uint32_array, as_uint64_array};
@@ -3664,6 +3746,100 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn scalar_sub_trait_int32_test() -> Result<()> {
+        let int_value = ScalarValue::Int32(Some(42));
+        let int_value_2 = ScalarValue::Int32(Some(100));
+        assert_eq!(int_value.sub(&int_value_2)?, 
ScalarValue::Int32(Some(-58)));
+        assert_eq!(int_value_2.sub(int_value)?, ScalarValue::Int32(Some(58)));
+        Ok(())
+    }
+
+    #[test]
+    fn scalar_sub_trait_int32_overflow_test() -> Result<()> {
+        let int_value = ScalarValue::Int32(Some(i32::MAX));
+        let int_value_2 = ScalarValue::Int32(Some(i32::MIN));
+        assert!(matches!(
+            int_value.sub_checked(&int_value_2),
+            Err(DataFusionError::Execution(msg)) if msg == "Overflow while 
calculating ScalarValue."
+        ));
+        Ok(())
+    }
+
+    #[test]
+    fn scalar_sub_trait_int64_test() -> Result<()> {
+        let int_value = ScalarValue::Int64(Some(42));
+        let int_value_2 = ScalarValue::Int64(Some(100));
+        assert_eq!(int_value.sub(&int_value_2)?, 
ScalarValue::Int64(Some(-58)));
+        assert_eq!(int_value_2.sub(int_value)?, ScalarValue::Int64(Some(58)));
+        Ok(())
+    }
+
+    #[test]
+    fn scalar_sub_trait_int64_overflow_test() -> Result<()> {
+        let int_value = ScalarValue::Int64(Some(i64::MAX));
+        let int_value_2 = ScalarValue::Int64(Some(i64::MIN));
+        assert!(matches!(
+            int_value.sub_checked(&int_value_2),
+            Err(DataFusionError::Execution(msg)) if msg == "Overflow while 
calculating ScalarValue."
+        ));
+        Ok(())
+    }
+
+    #[test]
+    fn scalar_add_overflow_test() -> Result<()> {
+        check_scalar_add_overflow::<Int8Type>(
+            ScalarValue::Int8(Some(i8::MAX)),
+            ScalarValue::Int8(Some(i8::MAX)),
+        );
+        check_scalar_add_overflow::<UInt8Type>(
+            ScalarValue::UInt8(Some(u8::MAX)),
+            ScalarValue::UInt8(Some(u8::MAX)),
+        );
+        check_scalar_add_overflow::<Int16Type>(
+            ScalarValue::Int16(Some(i16::MAX)),
+            ScalarValue::Int16(Some(i16::MAX)),
+        );
+        check_scalar_add_overflow::<UInt16Type>(
+            ScalarValue::UInt16(Some(u16::MAX)),
+            ScalarValue::UInt16(Some(u16::MAX)),
+        );
+        check_scalar_add_overflow::<Int32Type>(
+            ScalarValue::Int32(Some(i32::MAX)),
+            ScalarValue::Int32(Some(i32::MAX)),
+        );
+        check_scalar_add_overflow::<UInt32Type>(
+            ScalarValue::UInt32(Some(u32::MAX)),
+            ScalarValue::UInt32(Some(u32::MAX)),
+        );
+        check_scalar_add_overflow::<Int64Type>(
+            ScalarValue::Int64(Some(i64::MAX)),
+            ScalarValue::Int64(Some(i64::MAX)),
+        );
+        check_scalar_add_overflow::<UInt64Type>(
+            ScalarValue::UInt64(Some(u64::MAX)),
+            ScalarValue::UInt64(Some(u64::MAX)),
+        );
+
+        Ok(())
+    }
+
+    // Verifies that ScalarValue has the same behavior with compute kernal 
when it overflows.
+    fn check_scalar_add_overflow<T>(left: ScalarValue, right: ScalarValue)
+    where
+        T: ArrowNumericType,
+    {
+        let scalar_result = left.add_checked(&right);
+
+        let left_array = left.to_array();
+        let right_array = right.to_array();
+        let arrow_left_array = left_array.as_primitive::<T>();
+        let arrow_right_array = right_array.as_primitive::<T>();
+        let arrow_result = compute::add_checked(arrow_left_array, 
arrow_right_array);
+
+        assert_eq!(scalar_result.is_ok(), arrow_result.is_ok());
+    }
+
     #[test]
     fn test_interval_add_timestamp() -> Result<()> {
         let interval = ScalarValue::IntervalMonthDayNano(Some(123));
@@ -5231,6 +5407,11 @@ mod tests {
                 ScalarValue::Decimal128(Some(123), 5, 5),
                 ScalarValue::Decimal128(Some(120), 5, 5),
             ),
+            // Overflows
+            (
+                ScalarValue::Int8(Some(i8::MAX)),
+                ScalarValue::Int8(Some(i8::MIN)),
+            ),
         ];
         for (lhs, rhs) in cases {
             let distance = lhs.distance(&rhs);

Reply via email to