This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new acadfbf25f Minor: dont panic with bad arguments to round (#10899)
acadfbf25f is described below
commit acadfbf25feec2f2d18322bc03c53b61c1292d9b
Author: tmi <[email protected]>
AuthorDate: Fri Jun 28 00:51:23 2024 +0200
Minor: dont panic with bad arguments to round (#10899)
* Minor: dont panic with bad arguments to round
* Minor: add more safety casts to round func
* Remove panic + add sqllogictest
* clippy
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/functions/src/math/round.rs | 109 ++++++++++++++++++--------
datafusion/sqllogictest/test_files/scalar.slt | 10 +++
2 files changed, 86 insertions(+), 33 deletions(-)
diff --git a/datafusion/functions/src/math/round.rs
b/datafusion/functions/src/math/round.rs
index 1bab2953e4..71ab7c1b43 100644
--- a/datafusion/functions/src/math/round.rs
+++ b/datafusion/functions/src/math/round.rs
@@ -20,10 +20,13 @@ use std::sync::Arc;
use crate::utils::make_scalar_function;
-use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
+use arrow::array::{ArrayRef, Float32Array, Float64Array, Int32Array};
+use arrow::compute::{cast_with_options, CastOptions};
use arrow::datatypes::DataType;
-use arrow::datatypes::DataType::{Float32, Float64};
-use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
+use arrow::datatypes::DataType::{Float32, Float64, Int32};
+use datafusion_common::{
+ exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue,
+};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
@@ -114,7 +117,11 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Float64 => match decimal_places {
ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) =>
{
- let decimal_places = decimal_places.try_into().unwrap();
+ let decimal_places: i32 =
decimal_places.try_into().map_err(|e| {
+ exec_datafusion_err!(
+ "Invalid value for decimal places: {decimal_places}:
{e}"
+ )
+ })?;
Ok(Arc::new(make_function_scalar_inputs!(
&args[0],
@@ -128,21 +135,30 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
}
)) as ArrayRef)
}
- ColumnarValue::Array(decimal_places) =>
Ok(Arc::new(make_function_inputs2!(
- &args[0],
- decimal_places,
- "value",
- "decimal_places",
- Float64Array,
- Int64Array,
- {
- |value: f64, decimal_places: i64| {
- (value *
10.0_f64.powi(decimal_places.try_into().unwrap()))
- .round()
- / 10.0_f64.powi(decimal_places.try_into().unwrap())
+ ColumnarValue::Array(decimal_places) => {
+ let options = CastOptions {
+ safe: false, // raise error if the cast is not possible
+ ..Default::default()
+ };
+ let decimal_places = cast_with_options(&decimal_places,
&Int32, &options)
+ .map_err(|e| {
+ exec_datafusion_err!("Invalid values for decimal
places: {e}")
+ })?;
+ Ok(Arc::new(make_function_inputs2!(
+ &args[0],
+ decimal_places,
+ "value",
+ "decimal_places",
+ Float64Array,
+ Int32Array,
+ {
+ |value: f64, decimal_places: i32| {
+ (value * 10.0_f64.powi(decimal_places)).round()
+ / 10.0_f64.powi(decimal_places)
+ }
}
- }
- )) as ArrayRef),
+ )) as ArrayRef)
+ }
_ => {
exec_err!("round function requires a scalar or array for
decimal_places")
}
@@ -150,7 +166,11 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
DataType::Float32 => match decimal_places {
ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) =>
{
- let decimal_places = decimal_places.try_into().unwrap();
+ let decimal_places: i32 =
decimal_places.try_into().map_err(|e| {
+ exec_datafusion_err!(
+ "Invalid value for decimal places: {decimal_places}:
{e}"
+ )
+ })?;
Ok(Arc::new(make_function_scalar_inputs!(
&args[0],
@@ -164,21 +184,30 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
}
)) as ArrayRef)
}
- ColumnarValue::Array(decimal_places) =>
Ok(Arc::new(make_function_inputs2!(
- &args[0],
- decimal_places,
- "value",
- "decimal_places",
- Float32Array,
- Int64Array,
- {
- |value: f32, decimal_places: i64| {
- (value *
10.0_f32.powi(decimal_places.try_into().unwrap()))
- .round()
- / 10.0_f32.powi(decimal_places.try_into().unwrap())
+ ColumnarValue::Array(_) => {
+ let ColumnarValue::Array(decimal_places) =
+ decimal_places.cast_to(&Int32, None).map_err(|e| {
+ exec_datafusion_err!("Invalid values for decimal
places: {e}")
+ })?
+ else {
+ panic!("Unexpected result of ColumnarValue::Array.cast")
+ };
+
+ Ok(Arc::new(make_function_inputs2!(
+ &args[0],
+ decimal_places,
+ "value",
+ "decimal_places",
+ Float32Array,
+ Int32Array,
+ {
+ |value: f32, decimal_places: i32| {
+ (value * 10.0_f32.powi(decimal_places)).round()
+ / 10.0_f32.powi(decimal_places)
+ }
}
- }
- )) as ArrayRef),
+ )) as ArrayRef)
+ }
_ => {
exec_err!("round function requires a scalar or array for
decimal_places")
}
@@ -196,6 +225,7 @@ mod test {
use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
use datafusion_common::cast::{as_float32_array, as_float64_array};
+ use datafusion_common::DataFusionError;
#[test]
fn test_round_f32() {
@@ -262,4 +292,17 @@ mod test {
assert_eq!(floats, &expected);
}
+
+ #[test]
+ fn test_round_f32_cast_fail() {
+ let args: Vec<ArrayRef> = vec![
+ Arc::new(Float64Array::from(vec![125.2345])), // input
+ Arc::new(Int64Array::from(vec![2147483648])), // decimal_places
+ ];
+
+ let result = round(&args);
+
+ assert!(result.is_err());
+ assert!(matches!(result, Err(DataFusionError::Execution { .. })));
+ }
}
diff --git a/datafusion/sqllogictest/test_files/scalar.slt
b/datafusion/sqllogictest/test_files/scalar.slt
index a68d1cc7a7..e152269812 100644
--- a/datafusion/sqllogictest/test_files/scalar.slt
+++ b/datafusion/sqllogictest/test_files/scalar.slt
@@ -823,6 +823,16 @@ select round(a), round(b), round(c) from small_floats;
0 0 1
1 0 0
+# round with too large
+# max Int32 is 2147483647
+query error DataFusion error: Execution error: Invalid values for decimal
places: Cast error: Can't cast value 2147483648 to type Int32
+select round(3.14, 2147483648);
+
+# with array
+query error DataFusion error: Execution error: Invalid values for decimal
places: Cast error: Can't cast value 2147483649 to type Int32
+select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14,
2147483649);
+
+
## signum
# signum scalar function
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]