This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new acadfbf25f Minor: dont panic with bad arguments to round (#10899)
acadfbf25f is described below

commit acadfbf25feec2f2d18322bc03c53b61c1292d9b
Author: tmi <[email protected]>
AuthorDate: Fri Jun 28 00:51:23 2024 +0200

    Minor: dont panic with bad arguments to round (#10899)
    
    * Minor: dont panic with bad arguments to round
    
    * Minor: add more safety casts to round func
    
    * Remove panic + add sqllogictest
    
    * clippy
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/functions/src/math/round.rs        | 109 ++++++++++++++++++--------
 datafusion/sqllogictest/test_files/scalar.slt |  10 +++
 2 files changed, 86 insertions(+), 33 deletions(-)

diff --git a/datafusion/functions/src/math/round.rs 
b/datafusion/functions/src/math/round.rs
index 1bab2953e4..71ab7c1b43 100644
--- a/datafusion/functions/src/math/round.rs
+++ b/datafusion/functions/src/math/round.rs
@@ -20,10 +20,13 @@ use std::sync::Arc;
 
 use crate::utils::make_scalar_function;
 
-use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
+use arrow::array::{ArrayRef, Float32Array, Float64Array, Int32Array};
+use arrow::compute::{cast_with_options, CastOptions};
 use arrow::datatypes::DataType;
-use arrow::datatypes::DataType::{Float32, Float64};
-use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
+use arrow::datatypes::DataType::{Float32, Float64, Int32};
+use datafusion_common::{
+    exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue,
+};
 use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
 use datafusion_expr::TypeSignature::Exact;
 use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
@@ -114,7 +117,11 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
     match args[0].data_type() {
         DataType::Float64 => match decimal_places {
             ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => 
{
-                let decimal_places = decimal_places.try_into().unwrap();
+                let decimal_places: i32 = 
decimal_places.try_into().map_err(|e| {
+                    exec_datafusion_err!(
+                        "Invalid value for decimal places: {decimal_places}: 
{e}"
+                    )
+                })?;
 
                 Ok(Arc::new(make_function_scalar_inputs!(
                     &args[0],
@@ -128,21 +135,30 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
                     }
                 )) as ArrayRef)
             }
-            ColumnarValue::Array(decimal_places) => 
Ok(Arc::new(make_function_inputs2!(
-                &args[0],
-                decimal_places,
-                "value",
-                "decimal_places",
-                Float64Array,
-                Int64Array,
-                {
-                    |value: f64, decimal_places: i64| {
-                        (value * 
10.0_f64.powi(decimal_places.try_into().unwrap()))
-                            .round()
-                            / 10.0_f64.powi(decimal_places.try_into().unwrap())
+            ColumnarValue::Array(decimal_places) => {
+                let options = CastOptions {
+                    safe: false, // raise error if the cast is not possible
+                    ..Default::default()
+                };
+                let decimal_places = cast_with_options(&decimal_places, 
&Int32, &options)
+                    .map_err(|e| {
+                        exec_datafusion_err!("Invalid values for decimal 
places: {e}")
+                    })?;
+                Ok(Arc::new(make_function_inputs2!(
+                    &args[0],
+                    decimal_places,
+                    "value",
+                    "decimal_places",
+                    Float64Array,
+                    Int32Array,
+                    {
+                        |value: f64, decimal_places: i32| {
+                            (value * 10.0_f64.powi(decimal_places)).round()
+                                / 10.0_f64.powi(decimal_places)
+                        }
                     }
-                }
-            )) as ArrayRef),
+                )) as ArrayRef)
+            }
             _ => {
                 exec_err!("round function requires a scalar or array for 
decimal_places")
             }
@@ -150,7 +166,11 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
 
         DataType::Float32 => match decimal_places {
             ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => 
{
-                let decimal_places = decimal_places.try_into().unwrap();
+                let decimal_places: i32 = 
decimal_places.try_into().map_err(|e| {
+                    exec_datafusion_err!(
+                        "Invalid value for decimal places: {decimal_places}: 
{e}"
+                    )
+                })?;
 
                 Ok(Arc::new(make_function_scalar_inputs!(
                     &args[0],
@@ -164,21 +184,30 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
                     }
                 )) as ArrayRef)
             }
-            ColumnarValue::Array(decimal_places) => 
Ok(Arc::new(make_function_inputs2!(
-                &args[0],
-                decimal_places,
-                "value",
-                "decimal_places",
-                Float32Array,
-                Int64Array,
-                {
-                    |value: f32, decimal_places: i64| {
-                        (value * 
10.0_f32.powi(decimal_places.try_into().unwrap()))
-                            .round()
-                            / 10.0_f32.powi(decimal_places.try_into().unwrap())
+            ColumnarValue::Array(_) => {
+                let ColumnarValue::Array(decimal_places) =
+                    decimal_places.cast_to(&Int32, None).map_err(|e| {
+                        exec_datafusion_err!("Invalid values for decimal 
places: {e}")
+                    })?
+                else {
+                    panic!("Unexpected result of ColumnarValue::Array.cast")
+                };
+
+                Ok(Arc::new(make_function_inputs2!(
+                    &args[0],
+                    decimal_places,
+                    "value",
+                    "decimal_places",
+                    Float32Array,
+                    Int32Array,
+                    {
+                        |value: f32, decimal_places: i32| {
+                            (value * 10.0_f32.powi(decimal_places)).round()
+                                / 10.0_f32.powi(decimal_places)
+                        }
                     }
-                }
-            )) as ArrayRef),
+                )) as ArrayRef)
+            }
             _ => {
                 exec_err!("round function requires a scalar or array for 
decimal_places")
             }
@@ -196,6 +225,7 @@ mod test {
 
     use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
     use datafusion_common::cast::{as_float32_array, as_float64_array};
+    use datafusion_common::DataFusionError;
 
     #[test]
     fn test_round_f32() {
@@ -262,4 +292,17 @@ mod test {
 
         assert_eq!(floats, &expected);
     }
+
+    #[test]
+    fn test_round_f32_cast_fail() {
+        let args: Vec<ArrayRef> = vec![
+            Arc::new(Float64Array::from(vec![125.2345])), // input
+            Arc::new(Int64Array::from(vec![2147483648])), // decimal_places
+        ];
+
+        let result = round(&args);
+
+        assert!(result.is_err());
+        assert!(matches!(result, Err(DataFusionError::Execution { .. })));
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/scalar.slt 
b/datafusion/sqllogictest/test_files/scalar.slt
index a68d1cc7a7..e152269812 100644
--- a/datafusion/sqllogictest/test_files/scalar.slt
+++ b/datafusion/sqllogictest/test_files/scalar.slt
@@ -823,6 +823,16 @@ select round(a), round(b), round(c) from small_floats;
 0 0 1
 1 0 0
 
+# round with too large
+#  max Int32 is 2147483647
+query error DataFusion error: Execution error: Invalid values for decimal 
places: Cast error: Can't cast value 2147483648 to type Int32
+select round(3.14, 2147483648);
+
+# with array
+query error DataFusion error: Execution error: Invalid values for decimal 
places: Cast error: Can't cast value 2147483649 to type Int32
+select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14, 
2147483649);
+
+
 ## signum
 
 # signum scalar function


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to