Jefffrey commented on code in PR #19926:
URL: https://github.com/apache/datafusion/pull/19926#discussion_r2741494306
##########
datafusion/functions/src/math/round.rs:
##########
@@ -159,48 +279,122 @@ impl ScalarUDFImpl for RoundFunc {
);
};
- match value_scalar {
- ScalarValue::Float32(Some(v)) => {
+ match (value_scalar, args.return_type()) {
+ (ScalarValue::Float32(Some(v)), _) => {
let rounded = round_float(*v, dp)?;
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
}
- ScalarValue::Float64(Some(v)) => {
+ (ScalarValue::Float64(Some(v)), _) => {
let rounded = round_float(*v, dp)?;
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
}
- ScalarValue::Decimal128(Some(v), precision, scale) => {
- let rounded = round_decimal(*v, *scale, dp)?;
+ (
+ ScalarValue::Decimal32(Some(v), _precision, scale),
+ Decimal32(out_precision, out_scale),
+ ) => {
+ let rounded = round_decimal(*v, *scale, *out_scale, dp)?;
+ let rounded = if *out_precision ==
Decimal32Type::MAX_PRECISION
+ && *scale == 0
+ && dp < 0
Review Comment:
Same here regarding explanation of these checks
##########
datafusion/functions/src/math/round.rs:
##########
@@ -25,19 +25,96 @@ use arrow::datatypes::DataType::{
};
use arrow::datatypes::{
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
- Decimal256Type, Float32Type, Float64Type, Int32Type,
+ Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type,
};
+use arrow::datatypes::{Field, FieldRef};
use arrow::error::ArrowError;
use datafusion_common::types::{
NativeType, logical_float32, logical_float64, logical_int32,
};
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
- Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl,
Signature,
- TypeSignature, TypeSignatureClass, Volatility,
+ Coercion, ColumnarValue, Documentation, ReturnFieldArgs,
ScalarFunctionArgs,
+ ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
+use std::sync::Arc;
+
+fn output_scale_for_decimal(input_scale: i8, decimal_places: i32) ->
Result<i8> {
+ let new_scale = i32::from(input_scale).min(decimal_places.max(0));
+ i8::try_from(new_scale).map_err(|_| {
+ datafusion_common::DataFusionError::Internal(format!(
+ "Computed decimal scale {new_scale} is out of range for i8"
+ ))
+ })
Review Comment:
I feel we could clamp `decimal_places` to the max precision, which is well
within an `i8`; that way we can avoid all these error paths, especially because
such a large `decimal_places` would return 0 (or some value) instead of erroring
##########
datafusion/functions/src/math/round.rs:
##########
@@ -25,19 +25,96 @@ use arrow::datatypes::DataType::{
};
use arrow::datatypes::{
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
- Decimal256Type, Float32Type, Float64Type, Int32Type,
+ Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type,
};
+use arrow::datatypes::{Field, FieldRef};
use arrow::error::ArrowError;
use datafusion_common::types::{
NativeType, logical_float32, logical_float64, logical_int32,
};
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
- Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl,
Signature,
- TypeSignature, TypeSignatureClass, Volatility,
+ Coercion, ColumnarValue, Documentation, ReturnFieldArgs,
ScalarFunctionArgs,
+ ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
+use std::sync::Arc;
+
+fn output_scale_for_decimal(input_scale: i8, decimal_places: i32) ->
Result<i8> {
+ let new_scale = i32::from(input_scale).min(decimal_places.max(0));
+ i8::try_from(new_scale).map_err(|_| {
+ datafusion_common::DataFusionError::Internal(format!(
+ "Computed decimal scale {new_scale} is out of range for i8"
+ ))
+ })
+}
+
+fn validate_decimal_precision<T: DecimalType>(
+ value: T::Native,
+ precision: u8,
+ scale: i8,
+) -> Result<T::Native, ArrowError> {
+ T::validate_decimal_precision(value, precision, scale).map_err(|_| {
+ ArrowError::ComputeError(format!(
+ "Decimal overflow: rounded value exceeds precision {precision}"
+ ))
+ })?;
+ Ok(value)
+}
+
+fn calculate_new_precision_scale<T: DecimalType>(
+ precision: u8,
+ scale: i8,
+ decimal_places: Option<i32>,
+) -> Result<DataType> {
+ if let Some(decimal_places) = decimal_places {
+ let new_scale = output_scale_for_decimal(scale, decimal_places)?;
+ let new_precision = if scale == 0
+ && decimal_places < 0
+ && decimal_places
+ .checked_neg()
+ .map(|abs| abs <= i32::from(precision))
+ .unwrap_or(false)
Review Comment:
This checked negative of `decimal_places` is also something we can avoid if
we clamp it earlier on
Also could we add some explanation here for these checks? To me it's not
clearly obvious why we check for `scale == 0` and `abs(decimal_places) <=
precision`
##########
datafusion/functions/src/math/round.rs:
##########
@@ -228,29 +422,32 @@ fn round_columnar(
value: &ColumnarValue,
decimal_places: &ColumnarValue,
number_rows: usize,
+ return_type: &DataType,
) -> Result<ColumnarValue> {
let value_array = value.to_array(number_rows)?;
let both_scalars = matches!(value, ColumnarValue::Scalar(_))
Review Comment:
I realized this `both_scalars` check is redundant since we have a fast path
above; can be refactored away in a followup PR
##########
datafusion/sqllogictest/test_files/scalar.slt:
##########
@@ -951,15 +952,51 @@ query TR
select arrow_typeof(round('12345.55'::decimal(10,2), -1)),
round('12345.55'::decimal(10,2), -1);
----
-Decimal128(10, 2) 12350
+Decimal128(10, 0) 12350
+
+# round decimal scale 0 negative places (carry can require extra precision)
+query TR
+select arrow_typeof(round('99'::decimal(2,0), -1)),
+ round('99'::decimal(2,0), -1);
+----
+Decimal128(3, 0) 100
# round decimal256 keeps decimals
query TR
select arrow_typeof(round('1234.5678'::decimal(50,4), 2)),
round('1234.5678'::decimal(50,4), 2);
----
-Decimal256(50, 4) 1234.57
+Decimal256(50, 2) 1234.57
+
+# round decimal with carry-over (reduce scale)
Review Comment:
Do we have tests for decimals with negative scales?
##########
datafusion/functions/src/math/round.rs:
##########
@@ -377,9 +619,45 @@ fn round_decimal<V: ArrowNativeTypeOp>(
})?;
}
- quotient
- .mul_checked(factor)
- .map_err(|_| ArrowError::ComputeError("Overflow while rounding
decimal".into()))
+ // Determine how to scale the result based on output_scale vs computed
scale
+ // computed_scale = max(0, min(input_scale, decimal_places))
+ let computed_scale = if decimal_places >= 0 {
+ let new_scale = i32::from(input_scale).min(decimal_places).max(0);
+ i8::try_from(new_scale).map_err(|_| {
+ ArrowError::ComputeError(format!(
+ "Computed decimal scale {new_scale} is out of range for i8"
+ ))
+ })?
+ } else {
+ 0
+ };
+
+ if output_scale == computed_scale {
+ // scale reduction, return quotient directly (or shifted for negative
dp)
+ if decimal_places >= 0 {
+ Ok(quotient)
+ } else {
+ // For negative decimal_places, multiply by 10^(-decimal_places)
to shift left
+ let neg_dp: u32 = (-decimal_places).try_into().map_err(|_| {
+ ArrowError::ComputeError(format!(
+ "Invalid negative decimal places: {decimal_places}"
Review Comment:
I don't think we can have a invalid `decimal_places`? At least, we shouldn't
throw an error; what are some examples of invalid values?
##########
datafusion/functions/src/math/round.rs:
##########
@@ -259,13 +456,24 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
- |v, dp| round_decimal(v, *scale, dp),
+ |v, dp| {
+ let rounded = round_decimal(v, *scale, *new_scale, dp)?;
+ if *precision == Decimal32Type::MAX_PRECISION
+ && (decimal_places_is_array || (*scale == 0 && dp < 0))
+ {
+ validate_decimal_precision::<Decimal32Type>(
+ rounded, *precision, *new_scale,
+ )
+ } else {
+ Ok(rounded)
+ }
Review Comment:
I feel these arms can be deduplicated using `DecimalType` generic,
potentially with the scalar fast path above too?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]