This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new c63ca330d7 fix: increase ROUND decimal precision to prevent overflow
truncation (#19926)
c63ca330d7 is described below
commit c63ca330d7ae5a2e9e053c5fdf017b500b3e1c21
Author: Kumar Ujjawal <[email protected]>
AuthorDate: Fri Feb 27 19:40:20 2026 +0530
fix: increase ROUND decimal precision to prevent overflow truncation
(#19926)
## Which issue does this PR close?
<!--
We generally require a GitHub issue to be filed for all bug fixes and
enhancements and this helps us generate change logs for our releases.
You can link an issue to this PR using the GitHub syntax. For example
`Closes #123` indicates that this PR will close issue #123.
-->
- Closes #19921.
## Rationale for this change
`SELECT ROUND('999.9'::DECIMAL(4,1))` incorrectly returned `100.0`
instead of `1000`.
When rounding a decimal value causes a carry-over that increases the
number of digits (e.g., 999.9 → 1000.0), the result overflows the
original precision constraint. The overflow was silently truncated
during display, producing incorrect results.
<!--
Why are you proposing this change? If this is already explained clearly
in the issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.
-->
## What changes are included in this PR?
- Fixes ROUND on DECIMAL so carry-over doesn’t silently truncate/produce
wrong results.
- When decimal_places is a constant (including omitted/NULL), ROUND now
reduces the output scale to min(input_scale, max(decimal_places, 0))
(Spark/DuckDB-style), reclaiming
precision for the integer part.
- When decimal_places is not a constant (e.g. a column/array), ROUND
keeps the original scale and may increase precision by 1 (capped); if
precision is already max and the rounded
value can’t fit, it returns an overflow error instead of a wrong value.
- Adds/updates sqllogictests for these behaviors and edge cases.
<!--
There is no need to duplicate the description in the issue here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->
## Are these changes tested?
- Added new sqllogictest case for the specific bug scenario
- Updated 2 existing tests with new expected precision values
- All existing tests pass
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code
If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?
-->
## Are there any user-facing changes?
| Aspect | Before | After |
| --- | --- | --- |
| ROUND('999.9'::DECIMAL(4,1)) | 100.0 (wrong) | 1000 (correct) |
| Return type (default/literal dp) | Decimal128(4, 1) | Decimal128(4, 0)
|
- Return type for DECIMAL inputs can change: with literal dp it
generally reduces scale; with non-literal dp it keeps scale and may
increase precision by 1.
- New error case: when precision is already max and dp is non-literal,
ROUND may now error on overflow rather than return a truncated/wrong
decimal.
<!--
If there are user-facing changes then we may require documentation to be
updated before approving the PR.
-->
<!--
If there are any breaking changes to public APIs, please add the `api
change` label.
-->
---------
Co-authored-by: Martin Grigorov <[email protected]>
---
datafusion/functions/src/math/round.rs | 472 +++++++++++++++++++++----
datafusion/sqllogictest/test_files/decimal.slt | 2 +-
datafusion/sqllogictest/test_files/scalar.slt | 68 +++-
3 files changed, 477 insertions(+), 65 deletions(-)
diff --git a/datafusion/functions/src/math/round.rs
b/datafusion/functions/src/math/round.rs
index 8c25c57740..07cddf9341 100644
--- a/datafusion/functions/src/math/round.rs
+++ b/datafusion/functions/src/math/round.rs
@@ -25,8 +25,9 @@ use arrow::datatypes::DataType::{
};
use arrow::datatypes::{
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
- Decimal256Type, Float32Type, Float64Type, Int32Type,
+ Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type,
};
+use arrow::datatypes::{Field, FieldRef};
use arrow::error::ArrowError;
use datafusion_common::types::{
NativeType, logical_float32, logical_float64, logical_int32,
@@ -34,10 +35,120 @@ use datafusion_common::types::{
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
- Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl,
Signature,
- TypeSignature, TypeSignatureClass, Volatility,
+ Coercion, ColumnarValue, Documentation, ReturnFieldArgs,
ScalarFunctionArgs,
+ ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
+use std::sync::Arc;
+
+fn output_scale_for_decimal(precision: u8, input_scale: i8, decimal_places:
i32) -> i8 {
+ // `decimal_places` controls the maximum output scale, but scale cannot
exceed the input scale.
+ //
+ // For negative-scale decimals, allow further scale reduction to match
negative `decimal_places`
+ // (e.g. scale -2 rounded to -3 becomes scale -3). This preserves fixed
precision by
+ // representing the rounded result at a coarser scale.
+ if input_scale < 0 {
+ // Decimal scales must be within [-precision, precision] and fit in
i8. For negative-scale
+ // decimals, allow rounding to move the output scale further negative,
but cap it at
+ // `-precision` (beyond that, the rounded result is always 0).
+ let min_scale = -i32::from(precision);
+ let new_scale =
i32::from(input_scale).min(decimal_places).max(min_scale);
+ return new_scale as i8;
+ }
+
+ // The `min` ensures the result is always within i8 range because
`input_scale` is i8.
+ let decimal_places = decimal_places.max(0);
+ i32::from(input_scale).min(decimal_places) as i8
+}
+
+fn normalize_decimal_places_for_decimal(
+ decimal_places: i32,
+ precision: u8,
+ scale: i8,
+) -> Option<i32> {
+ if decimal_places >= 0 {
+ return Some(decimal_places);
+ }
+
+ // For fixed precision decimals, the absolute value is strictly less than
10^(precision - scale).
+ // If the rounding position is beyond that (abs(decimal_places) >
precision - scale), the
+ // rounded result is always 0, and we can avoid overflow in intermediate
10^n computations.
+ let max_rounding_pow10 = i64::from(precision) - i64::from(scale);
+ if max_rounding_pow10 <= 0 {
+ return None;
+ }
+
+ let abs_decimal_places = i64::from(decimal_places.unsigned_abs());
+ (abs_decimal_places <= max_rounding_pow10).then_some(decimal_places)
+}
+
+fn validate_decimal_precision<T: DecimalType>(
+ value: T::Native,
+ precision: u8,
+ scale: i8,
+) -> Result<T::Native, ArrowError> {
+ T::validate_decimal_precision(value, precision, scale).map_err(|e| {
+ ArrowError::ComputeError(format!(
+ "Decimal overflow: rounded value exceeds precision {precision}:
{e}"
+ ))
+ })?;
+ Ok(value)
+}
+
+fn calculate_new_precision_scale<T: DecimalType>(
+ precision: u8,
+ scale: i8,
+ decimal_places: Option<i32>,
+) -> Result<DataType> {
+ if let Some(decimal_places) = decimal_places {
+ let new_scale = output_scale_for_decimal(precision, scale,
decimal_places);
+
+ // When rounding an integer decimal (scale == 0) to a negative
`decimal_places`, a carry can
+ // add an extra digit to the integer part (e.g. 99 -> 100 when
rounding to -1). This can
+ // only happen when the rounding position is within the existing
precision.
+ let abs_decimal_places = decimal_places.unsigned_abs();
+ let new_precision = if scale == 0
+ && decimal_places < 0
+ && abs_decimal_places <= u32::from(precision)
+ {
+ precision.saturating_add(1).min(T::MAX_PRECISION)
+ } else {
+ precision
+ };
+ Ok(T::TYPE_CONSTRUCTOR(new_precision, new_scale))
+ } else {
+ let new_precision = precision.saturating_add(1).min(T::MAX_PRECISION);
+ Ok(T::TYPE_CONSTRUCTOR(new_precision, scale))
+ }
+}
+
+fn decimal_places_from_scalar(scalar: &ScalarValue) -> Result<i32> {
+ let out_of_range = |value: String| {
+ datafusion_common::DataFusionError::Execution(format!(
+ "round decimal_places {value} is out of supported i32 range"
+ ))
+ };
+ match scalar {
+ ScalarValue::Int8(Some(v)) => Ok(i32::from(*v)),
+ ScalarValue::Int16(Some(v)) => Ok(i32::from(*v)),
+ ScalarValue::Int32(Some(v)) => Ok(*v),
+ ScalarValue::Int64(Some(v)) => {
+ i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
+ }
+ ScalarValue::UInt8(Some(v)) => Ok(i32::from(*v)),
+ ScalarValue::UInt16(Some(v)) => Ok(i32::from(*v)),
+ ScalarValue::UInt32(Some(v)) => {
+ i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
+ }
+ ScalarValue::UInt64(Some(v)) => {
+ i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
+ }
+ other => exec_err!(
+ "Unexpected datatype for decimal_places: {}",
+ other.data_type()
+ ),
+ }
+}
#[user_doc(
doc_section(label = "Math Functions"),
@@ -117,15 +228,59 @@ impl ScalarUDFImpl for RoundFunc {
&self.signature
}
- fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- Ok(match arg_types[0].clone() {
- Float32 => Float32,
- dt @ Decimal128(_, _)
- | dt @ Decimal256(_, _)
- | dt @ Decimal32(_, _)
- | dt @ Decimal64(_, _) => dt,
- _ => Float64,
- })
+ fn return_field_from_args(&self, args: ReturnFieldArgs) ->
Result<FieldRef> {
+ let input_field = &args.arg_fields[0];
+ let input_type = input_field.data_type();
+
+ // If decimal_places is a scalar literal, we can incorporate it into
the output type
+ // (scale reduction). Otherwise, keep the input scale as we can't pick
a per-row scale.
+ //
+ // Note: `scalar_arguments` contains the original literal values
(pre-coercion), so
+ // integer literals may appear as Int64 even though the signature
coerces them to Int32.
+ let decimal_places: Option<i32> = match args.scalar_arguments.get(1) {
+ None => Some(0), // No dp argument means default to 0
+ Some(None) => None, // dp is not a literal (e.g. column)
+ Some(Some(scalar)) if scalar.is_null() => Some(0), // null dp =>
default to 0
+ Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?),
+ };
+
+ // Calculate return type based on input type
+ // For decimals: reduce scale to decimal_places (reclaims precision
for integer part)
+ // This matches Spark/DuckDB behavior where ROUND adjusts the scale
+ // BUT only if dp is a scalar literal - otherwise keep original scale
and add
+ // extra precision to accommodate potential carry-over.
+ let return_type =
+ match input_type {
+ Float32 => Float32,
+ Decimal32(precision, scale) => calculate_new_precision_scale::<
+ Decimal32Type,
+ >(
+ *precision, *scale, decimal_places
+ )?,
+ Decimal64(precision, scale) => calculate_new_precision_scale::<
+ Decimal64Type,
+ >(
+ *precision, *scale, decimal_places
+ )?,
+ Decimal128(precision, scale) =>
calculate_new_precision_scale::<
+ Decimal128Type,
+ >(
+ *precision, *scale, decimal_places
+ )?,
+ Decimal256(precision, scale) =>
calculate_new_precision_scale::<
+ Decimal256Type,
+ >(
+ *precision, *scale, decimal_places
+ )?,
+ _ => Float64,
+ };
+
+ let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
+ Ok(Arc::new(Field::new(self.name(), return_type, nullable)))
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ internal_err!("use return_field_from_args instead")
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
@@ -141,7 +296,6 @@ impl ScalarUDFImpl for RoundFunc {
&default_decimal_places
};
- // Scalar fast path for float and decimal types - avoid array
conversion overhead
if let (ColumnarValue::Scalar(value_scalar),
ColumnarValue::Scalar(dp_scalar)) =
(&args.args[0], decimal_places)
{
@@ -159,48 +313,132 @@ impl ScalarUDFImpl for RoundFunc {
);
};
- match value_scalar {
- ScalarValue::Float32(Some(v)) => {
+ match (value_scalar, args.return_type()) {
+ (ScalarValue::Float32(Some(v)), _) => {
let rounded = round_float(*v, dp)?;
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
}
- ScalarValue::Float64(Some(v)) => {
+ (ScalarValue::Float64(Some(v)), _) => {
let rounded = round_float(*v, dp)?;
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
}
- ScalarValue::Decimal128(Some(v), precision, scale) => {
- let rounded = round_decimal(*v, *scale, dp)?;
+ (
+ ScalarValue::Decimal32(Some(v), in_precision, scale),
+ Decimal32(out_precision, out_scale),
+ ) => {
+ let rounded =
+ round_decimal_or_zero(*v, *in_precision, *scale,
*out_scale, dp)?;
+ let rounded = if *out_precision ==
Decimal32Type::MAX_PRECISION
+ && *scale == 0
+ && dp < 0
+ {
+ // With scale == 0 and negative dp, rounding can carry
into an additional
+ // digit (e.g. 99 -> 100). If we're already at max
precision we can't widen
+ // the type, so validate and error rather than
producing an invalid decimal.
+ validate_decimal_precision::<Decimal32Type>(
+ rounded,
+ *out_precision,
+ *out_scale,
+ )
+ } else {
+ Ok(rounded)
+ }?;
let scalar =
- ScalarValue::Decimal128(Some(rounded), *precision,
*scale);
+ ScalarValue::Decimal32(Some(rounded), *out_precision,
*out_scale);
Ok(ColumnarValue::Scalar(scalar))
}
- ScalarValue::Decimal256(Some(v), precision, scale) => {
- let rounded = round_decimal(*v, *scale, dp)?;
+ (
+ ScalarValue::Decimal64(Some(v), in_precision, scale),
+ Decimal64(out_precision, out_scale),
+ ) => {
+ let rounded =
+ round_decimal_or_zero(*v, *in_precision, *scale,
*out_scale, dp)?;
+ let rounded = if *out_precision ==
Decimal64Type::MAX_PRECISION
+ && *scale == 0
+ && dp < 0
+ {
+ // See Decimal32 branch for details.
+ validate_decimal_precision::<Decimal64Type>(
+ rounded,
+ *out_precision,
+ *out_scale,
+ )
+ } else {
+ Ok(rounded)
+ }?;
let scalar =
- ScalarValue::Decimal256(Some(rounded), *precision,
*scale);
+ ScalarValue::Decimal64(Some(rounded), *out_precision,
*out_scale);
Ok(ColumnarValue::Scalar(scalar))
}
- ScalarValue::Decimal64(Some(v), precision, scale) => {
- let rounded = round_decimal(*v, *scale, dp)?;
- let scalar =
- ScalarValue::Decimal64(Some(rounded), *precision,
*scale);
+ (
+ ScalarValue::Decimal128(Some(v), in_precision, scale),
+ Decimal128(out_precision, out_scale),
+ ) => {
+ let rounded =
+ round_decimal_or_zero(*v, *in_precision, *scale,
*out_scale, dp)?;
+ let rounded = if *out_precision ==
Decimal128Type::MAX_PRECISION
+ && *scale == 0
+ && dp < 0
+ {
+ // See Decimal32 branch for details.
+ validate_decimal_precision::<Decimal128Type>(
+ rounded,
+ *out_precision,
+ *out_scale,
+ )
+ } else {
+ Ok(rounded)
+ }?;
+ let scalar = ScalarValue::Decimal128(
+ Some(rounded),
+ *out_precision,
+ *out_scale,
+ );
Ok(ColumnarValue::Scalar(scalar))
}
- ScalarValue::Decimal32(Some(v), precision, scale) => {
- let rounded = round_decimal(*v, *scale, dp)?;
- let scalar =
- ScalarValue::Decimal32(Some(rounded), *precision,
*scale);
+ (
+ ScalarValue::Decimal256(Some(v), in_precision, scale),
+ Decimal256(out_precision, out_scale),
+ ) => {
+ let rounded =
+ round_decimal_or_zero(*v, *in_precision, *scale,
*out_scale, dp)?;
+ let rounded = if *out_precision ==
Decimal256Type::MAX_PRECISION
+ && *scale == 0
+ && dp < 0
+ {
+ // See Decimal32 branch for details.
+ validate_decimal_precision::<Decimal256Type>(
+ rounded,
+ *out_precision,
+ *out_scale,
+ )
+ } else {
+ Ok(rounded)
+ }?;
+ let scalar = ScalarValue::Decimal256(
+ Some(rounded),
+ *out_precision,
+ *out_scale,
+ );
Ok(ColumnarValue::Scalar(scalar))
}
- _ => {
+ (ScalarValue::Null, _) =>
ColumnarValue::Scalar(ScalarValue::Null)
+ .cast_to(args.return_type(), None),
+ (value_scalar, return_type) => {
internal_err!(
- "Unexpected datatype for value: {}",
- value_scalar.data_type()
+ "Unexpected datatype for round(value, decimal_places):
value {}, return type {}",
+ value_scalar.data_type(),
+ return_type
)
}
}
} else {
- round_columnar(&args.args[0], decimal_places, args.number_rows)
+ round_columnar(
+ &args.args[0],
+ decimal_places,
+ args.number_rows,
+ args.return_type(),
+ )
}
}
@@ -228,13 +466,15 @@ fn round_columnar(
value: &ColumnarValue,
decimal_places: &ColumnarValue,
number_rows: usize,
+ return_type: &DataType,
) -> Result<ColumnarValue> {
let value_array = value.to_array(number_rows)?;
let both_scalars = matches!(value, ColumnarValue::Scalar(_))
&& matches!(decimal_places, ColumnarValue::Scalar(_));
+ let decimal_places_is_array = matches!(decimal_places,
ColumnarValue::Array(_));
- let arr: ArrayRef = match value_array.data_type() {
- Float64 => {
+ let arr: ArrayRef = match (value_array.data_type(), return_type) {
+ (Float64, _) => {
let result = calculate_binary_math::<Float64Type, Int32Type,
Float64Type, _>(
value_array.as_ref(),
decimal_places,
@@ -242,7 +482,7 @@ fn round_columnar(
)?;
result as _
}
- Float32 => {
+ (Float32, _) => {
let result = calculate_binary_math::<Float32Type, Int32Type,
Float32Type, _>(
value_array.as_ref(),
decimal_places,
@@ -250,7 +490,8 @@ fn round_columnar(
)?;
result as _
}
- Decimal32(precision, scale) => {
+ (Decimal32(input_precision, scale), Decimal32(precision, new_scale))
=> {
+ // reduce scale to reclaim integer precision
let result = calculate_binary_decimal_math::<
Decimal32Type,
Int32Type,
@@ -259,13 +500,34 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
- |v, dp| round_decimal(v, *scale, dp),
+ |v, dp| {
+ let rounded = round_decimal_or_zero(
+ v,
+ *input_precision,
+ *scale,
+ *new_scale,
+ dp,
+ )?;
+ if *precision == Decimal32Type::MAX_PRECISION
+ && (decimal_places_is_array || (*scale == 0 && dp < 0))
+ {
+ // If we're already at max precision, we can't widen
the result type. For
+ // dp arrays, or for scale == 0 with negative dp,
rounding can overflow the
+ // fixed-precision type. Validate per-row and return
an error instead of
+ // producing an invalid decimal that Arrow may display
incorrectly.
+ validate_decimal_precision::<Decimal32Type>(
+ rounded, *precision, *new_scale,
+ )
+ } else {
+ Ok(rounded)
+ }
+ },
*precision,
- *scale,
+ *new_scale,
)?;
result as _
}
- Decimal64(precision, scale) => {
+ (Decimal64(input_precision, scale), Decimal64(precision, new_scale))
=> {
let result = calculate_binary_decimal_math::<
Decimal64Type,
Int32Type,
@@ -274,13 +536,31 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
- |v, dp| round_decimal(v, *scale, dp),
+ |v, dp| {
+ let rounded = round_decimal_or_zero(
+ v,
+ *input_precision,
+ *scale,
+ *new_scale,
+ dp,
+ )?;
+ if *precision == Decimal64Type::MAX_PRECISION
+ && (decimal_places_is_array || (*scale == 0 && dp < 0))
+ {
+ // See Decimal32 branch for details.
+ validate_decimal_precision::<Decimal64Type>(
+ rounded, *precision, *new_scale,
+ )
+ } else {
+ Ok(rounded)
+ }
+ },
*precision,
- *scale,
+ *new_scale,
)?;
result as _
}
- Decimal128(precision, scale) => {
+ (Decimal128(input_precision, scale), Decimal128(precision, new_scale))
=> {
let result = calculate_binary_decimal_math::<
Decimal128Type,
Int32Type,
@@ -289,13 +569,31 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
- |v, dp| round_decimal(v, *scale, dp),
+ |v, dp| {
+ let rounded = round_decimal_or_zero(
+ v,
+ *input_precision,
+ *scale,
+ *new_scale,
+ dp,
+ )?;
+ if *precision == Decimal128Type::MAX_PRECISION
+ && (decimal_places_is_array || (*scale == 0 && dp < 0))
+ {
+ // See Decimal32 branch for details.
+ validate_decimal_precision::<Decimal128Type>(
+ rounded, *precision, *new_scale,
+ )
+ } else {
+ Ok(rounded)
+ }
+ },
*precision,
- *scale,
+ *new_scale,
)?;
result as _
}
- Decimal256(precision, scale) => {
+ (Decimal256(input_precision, scale), Decimal256(precision, new_scale))
=> {
let result = calculate_binary_decimal_math::<
Decimal256Type,
Int32Type,
@@ -304,13 +602,31 @@ fn round_columnar(
>(
value_array.as_ref(),
decimal_places,
- |v, dp| round_decimal(v, *scale, dp),
+ |v, dp| {
+ let rounded = round_decimal_or_zero(
+ v,
+ *input_precision,
+ *scale,
+ *new_scale,
+ dp,
+ )?;
+ if *precision == Decimal256Type::MAX_PRECISION
+ && (decimal_places_is_array || (*scale == 0 && dp < 0))
+ {
+ // See Decimal32 branch for details.
+ validate_decimal_precision::<Decimal256Type>(
+ rounded, *precision, *new_scale,
+ )
+ } else {
+ Ok(rounded)
+ }
+ },
*precision,
- *scale,
+ *new_scale,
)?;
result as _
}
- other => exec_err!("Unsupported data type {other:?} for function
round")?,
+ (other, _) => exec_err!("Unsupported data type {other:?} for function
round")?,
};
if both_scalars {
@@ -334,19 +650,17 @@ where
fn round_decimal<V: ArrowNativeTypeOp>(
value: V,
- scale: i8,
+ input_scale: i8,
+ output_scale: i8,
decimal_places: i32,
) -> Result<V, ArrowError> {
- let diff = i64::from(scale) - i64::from(decimal_places);
+ let diff = i64::from(input_scale) - i64::from(decimal_places);
if diff <= 0 {
return Ok(value);
}
- let diff: u32 = diff.try_into().map_err(|e| {
- ArrowError::ComputeError(format!(
- "Invalid value for decimal places: {decimal_places}: {e}"
- ))
- })?;
+ debug_assert!(diff <= i64::from(u32::MAX));
+ let diff = diff as u32;
let one = V::ONE;
let two = V::from_usize(2).ok_or_else(|| {
@@ -358,7 +672,7 @@ fn round_decimal<V: ArrowNativeTypeOp>(
let factor = ten.pow_checked(diff).map_err(|_| {
ArrowError::ComputeError(format!(
- "Overflow while rounding decimal with scale {scale} and decimal
places {decimal_places}"
+ "Overflow while rounding decimal with scale {input_scale} and
decimal places {decimal_places}"
))
})?;
@@ -377,11 +691,44 @@ fn round_decimal<V: ArrowNativeTypeOp>(
})?;
}
+ // `quotient` is the rounded value at scale `decimal_places`. Rescale to
the desired
+ // `output_scale` (which is always >= `decimal_places` in cases where diff
> 0).
+ let scale_shift = i64::from(output_scale) - i64::from(decimal_places);
+ if scale_shift == 0 {
+ return Ok(quotient);
+ }
+
+ debug_assert!(scale_shift > 0);
+ debug_assert!(scale_shift <= i64::from(u32::MAX));
+ let scale_shift = scale_shift as u32;
+ let shift_factor = ten.pow_checked(scale_shift).map_err(|_| {
+ ArrowError::ComputeError(format!(
+ "Overflow while rounding decimal with scale {input_scale} and
decimal places {decimal_places}"
+ ))
+ })?;
quotient
- .mul_checked(factor)
+ .mul_checked(shift_factor)
.map_err(|_| ArrowError::ComputeError("Overflow while rounding
decimal".into()))
}
+fn round_decimal_or_zero<V: ArrowNativeTypeOp>(
+ value: V,
+ precision: u8,
+ input_scale: i8,
+ output_scale: i8,
+ decimal_places: i32,
+) -> Result<V, ArrowError> {
+ if let Some(dp) =
+ normalize_decimal_places_for_decimal(decimal_places, precision,
input_scale)
+ {
+ round_decimal(value, input_scale, output_scale, dp)
+ } else {
+ V::from_usize(0).ok_or_else(|| {
+ ArrowError::ComputeError("Internal error: could not create
constant 0".into())
+ })
+ }
+}
+
#[cfg(test)]
mod test {
use std::sync::Arc;
@@ -397,12 +744,17 @@ mod test {
decimal_places: Option<ArrayRef>,
) -> Result<ArrayRef, DataFusionError> {
let number_rows = value.len();
+ // NOTE: For decimal inputs, the actual ROUND return type can differ
from the
+ // input type (scale reduction for literal `decimal_places`). These
unit tests
+ // only exercise Float32/Float64 behavior.
+ let return_type = value.data_type().clone();
let value = ColumnarValue::Array(value);
let decimal_places = decimal_places
.map(ColumnarValue::Array)
.unwrap_or_else(||
ColumnarValue::Scalar(ScalarValue::Int32(Some(0))));
- let result = super::round_columnar(&value, &decimal_places,
number_rows)?;
+ let result =
+ super::round_columnar(&value, &decimal_places, number_rows,
&return_type)?;
match result {
ColumnarValue::Array(array) => Ok(array),
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1),
diff --git a/datafusion/sqllogictest/test_files/decimal.slt
b/datafusion/sqllogictest/test_files/decimal.slt
index f53f493929..eca2c88bb5 100644
--- a/datafusion/sqllogictest/test_files/decimal.slt
+++ b/datafusion/sqllogictest/test_files/decimal.slt
@@ -782,7 +782,7 @@ query TR
select arrow_typeof(round(173975140545.855, 2)),
round(173975140545.855, 2);
----
-Decimal128(15, 3) 173975140545.86
+Decimal128(15, 2) 173975140545.86
# smoke test for decimal parsing
query RT
diff --git a/datafusion/sqllogictest/test_files/scalar.slt
b/datafusion/sqllogictest/test_files/scalar.slt
index 7a4a81b5fa..681540a29d 100644
--- a/datafusion/sqllogictest/test_files/scalar.slt
+++ b/datafusion/sqllogictest/test_files/scalar.slt
@@ -923,7 +923,7 @@ select round(a), round(b), round(c) from small_floats;
# round with too large
# max Int32 is 2147483647
-query error Arrow error: Cast error: Can't cast value 2147483648 to type Int32
+query error round decimal_places 2147483648 is out of supported i32 range
select round(3.14, 2147483648);
# with array
@@ -931,11 +931,12 @@ query error Arrow error: Cast error: Can't cast value
2147483649 to type Int32
select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14,
2147483649);
# round decimal should not cast to float
+# scale reduces to match decimal_places
query TR
select arrow_typeof(round('173975140545.855'::decimal(38,10), 2)),
round('173975140545.855'::decimal(38,10), 2);
----
-Decimal128(38, 10) 173975140545.86
+Decimal128(38, 2) 173975140545.86
# round decimal ties away from zero
query RRRR
@@ -951,15 +952,74 @@ query TR
select arrow_typeof(round('12345.55'::decimal(10,2), -1)),
round('12345.55'::decimal(10,2), -1);
----
-Decimal128(10, 2) 12350
+Decimal128(10, 0) 12350
+
+# round decimal scale 0 negative places (carry can require extra precision)
+query TR
+select arrow_typeof(round('99'::decimal(2,0), -1)),
+ round('99'::decimal(2,0), -1);
+----
+Decimal128(3, 0) 100
# round decimal256 keeps decimals
query TR
select arrow_typeof(round('1234.5678'::decimal(50,4), 2)),
round('1234.5678'::decimal(50,4), 2);
----
-Decimal256(50, 4) 1234.57
+Decimal256(50, 2) 1234.57
+
+# round decimal with carry-over (reduce scale)
+# Scale reduces from 1 to 0, allowing extra digit for carry-over
+query TRRR
+select arrow_typeof(round('999.9'::decimal(4,1))),
+ round('999.9'::decimal(4,1)),
+ round('-999.9'::decimal(4,1)),
+ round('99.99'::decimal(4,2));
+----
+Decimal128(4, 0) 1000 -1000 100
+# round decimal with carry-over and non-literal decimal_places (increase
precision)
+# Scale can't be reduced when decimal_places isn't a constant, so precision
increases.
+query TR
+select arrow_typeof(round(val, dp)), round(val, dp)
+from (values (cast('999.9' as decimal(4,1)), 0)) as t(val, dp);
+----
+Decimal128(5, 1) 1000
+
+# round decimal at max precision now works (scale reduction handles overflow)
+query TR
+select
arrow_typeof(round('9999999999999999999999999999999999999.9'::decimal(38,1))),
+ round('9999999999999999999999999999999999999.9'::decimal(38,1));
+----
+Decimal128(38, 0) 10000000000000000000000000000000000000
+
+# round decimal at max precision with non-literal decimal_places can overflow
+query error Decimal overflow: rounded value exceeds precision 38
+select round(val, dp)
+from (values (cast('9999999999999999999999999999999999999.9' as
decimal(38,1)), 0)) as t(val, dp);
+
+# round decimal with negative scale
+query TRRR
+select arrow_typeof(round(cast(500 as decimal(10,-2)), -3)),
+ round(cast(500 as decimal(10,-2)), -3),
+ round(cast(400 as decimal(10,-2)), -3),
+ round(cast(-500 as decimal(10,-2)), -3);
+----
+Decimal128(10, -3) 1000 0 -1000
+
+# round decimal with negative scale and carry-over
+query TR
+select arrow_typeof(round(cast(999999999900 as decimal(10,-2)), -3)),
+ round(cast(999999999900 as decimal(10,-2)), -3);
+----
+Decimal128(10, -3) 1000000000000
+
+# round decimal with very small decimal_places (i32::MIN) should not error
+query TR
+select arrow_typeof(round('123.45'::decimal(5,2), -2147483648)),
+ round('123.45'::decimal(5,2), -2147483648);
+----
+Decimal128(5, 0) 0
## signum
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]