This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new c63ca330d7 fix: increase ROUND decimal precision to prevent overflow 
truncation (#19926)
c63ca330d7 is described below

commit c63ca330d7ae5a2e9e053c5fdf017b500b3e1c21
Author: Kumar Ujjawal <[email protected]>
AuthorDate: Fri Feb 27 19:40:20 2026 +0530

    fix: increase ROUND decimal precision to prevent overflow truncation 
(#19926)
    
    ## Which issue does this PR close?
    
    <!--
    We generally require a GitHub issue to be filed for all bug fixes and
    enhancements and this helps us generate change logs for our releases.
    You can link an issue to this PR using the GitHub syntax. For example
    `Closes #123` indicates that this PR will close issue #123.
    -->
    
    - Closes #19921.
    
    ## Rationale for this change
    
    `SELECT ROUND('999.9'::DECIMAL(4,1))` incorrectly returned `100.0`
    instead of `1000`.
    
    When rounding a decimal value causes a carry-over that increases the
    number of digits (e.g., 999.9 → 1000.0), the result overflows the
    original precision constraint. The overflow was silently truncated
    during display, producing incorrect results.
    
    <!--
    Why are you proposing this change? If this is already explained clearly
    in the issue then this section is not needed.
    Explaining clearly why changes are proposed helps reviewers understand
    your changes and offer better suggestions for fixes.
    -->
    
    ## What changes are included in this PR?
    
    - Fixes ROUND on DECIMAL so carry-over doesn’t silently truncate/produce
    wrong results.
    - When decimal_places is a constant (including omitted/NULL), ROUND now
    reduces the output scale to min(input_scale, max(decimal_places, 0))
    (Spark/DuckDB-style), reclaiming
        precision for the integer part.
    - When decimal_places is not a constant (e.g. a column/array), ROUND
    keeps the original scale and may increase precision by 1 (capped); if
    precision is already max and the rounded
    value can’t fit, it returns an overflow error instead of a wrong value.
      - Adds/updates sqllogictests for these behaviors and edge cases.
    
    <!--
    There is no need to duplicate the description in the issue here but it
    is sometimes worth providing a summary of the individual changes in this
    PR.
    -->
    
    ## Are these changes tested?
    
    - Added new sqllogictest case for the specific bug scenario
    - Updated 2 existing tests with new expected precision values
    - All existing tests pass
    
    <!--
    We typically require tests for all PRs in order to:
    1. Prevent the code from being accidentally broken by subsequent changes
    2. Serve as another way to document the expected behavior of the code
    
    If tests are not included in your PR, please explain why (for example,
    are they covered by existing tests)?
    -->
    
    ## Are there any user-facing changes?
    
    | Aspect | Before | After |
      | --- | --- | --- |
      | ROUND('999.9'::DECIMAL(4,1)) | 100.0 (wrong) | 1000 (correct) |
    | Return type (default/literal dp) | Decimal128(4, 1) | Decimal128(4, 0)
    |
    
    - Return type for DECIMAL inputs can change: with literal dp it
    generally reduces scale; with non-literal dp it keeps scale and may
    increase precision by 1.
    - New error case: when precision is already max and dp is non-literal,
    ROUND may now error on overflow rather than return a truncated/wrong
    decimal.
    <!--
    If there are user-facing changes then we may require documentation to be
    updated before approving the PR.
    -->
    
    <!--
    If there are any breaking changes to public APIs, please add the `api
    change` label.
    -->
    
    ---------
    
    Co-authored-by: Martin Grigorov <[email protected]>
---
 datafusion/functions/src/math/round.rs         | 472 +++++++++++++++++++++----
 datafusion/sqllogictest/test_files/decimal.slt |   2 +-
 datafusion/sqllogictest/test_files/scalar.slt  |  68 +++-
 3 files changed, 477 insertions(+), 65 deletions(-)

diff --git a/datafusion/functions/src/math/round.rs 
b/datafusion/functions/src/math/round.rs
index 8c25c57740..07cddf9341 100644
--- a/datafusion/functions/src/math/round.rs
+++ b/datafusion/functions/src/math/round.rs
@@ -25,8 +25,9 @@ use arrow::datatypes::DataType::{
 };
 use arrow::datatypes::{
     ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
-    Decimal256Type, Float32Type, Float64Type, Int32Type,
+    Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type,
 };
+use arrow::datatypes::{Field, FieldRef};
 use arrow::error::ArrowError;
 use datafusion_common::types::{
     NativeType, logical_float32, logical_float64, logical_int32,
@@ -34,10 +35,120 @@ use datafusion_common::types::{
 use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
 use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
 use datafusion_expr::{
-    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, 
Signature,
-    TypeSignature, TypeSignatureClass, Volatility,
+    Coercion, ColumnarValue, Documentation, ReturnFieldArgs, 
ScalarFunctionArgs,
+    ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
 };
 use datafusion_macros::user_doc;
+use std::sync::Arc;
+
+fn output_scale_for_decimal(precision: u8, input_scale: i8, decimal_places: 
i32) -> i8 {
+    // `decimal_places` controls the maximum output scale, but scale cannot 
exceed the input scale.
+    //
+    // For negative-scale decimals, allow further scale reduction to match 
negative `decimal_places`
+    // (e.g. scale -2 rounded to -3 becomes scale -3). This preserves fixed 
precision by
+    // representing the rounded result at a coarser scale.
+    if input_scale < 0 {
+        // Decimal scales must be within [-precision, precision] and fit in 
i8. For negative-scale
+        // decimals, allow rounding to move the output scale further negative, 
but cap it at
+        // `-precision` (beyond that, the rounded result is always 0).
+        let min_scale = -i32::from(precision);
+        let new_scale = 
i32::from(input_scale).min(decimal_places).max(min_scale);
+        return new_scale as i8;
+    }
+
+    // The `min` ensures the result is always within i8 range because 
`input_scale` is i8.
+    let decimal_places = decimal_places.max(0);
+    i32::from(input_scale).min(decimal_places) as i8
+}
+
+fn normalize_decimal_places_for_decimal(
+    decimal_places: i32,
+    precision: u8,
+    scale: i8,
+) -> Option<i32> {
+    if decimal_places >= 0 {
+        return Some(decimal_places);
+    }
+
+    // For fixed precision decimals, the absolute value is strictly less than 
10^(precision - scale).
+    // If the rounding position is beyond that (abs(decimal_places) > 
precision - scale), the
+    // rounded result is always 0, and we can avoid overflow in intermediate 
10^n computations.
+    let max_rounding_pow10 = i64::from(precision) - i64::from(scale);
+    if max_rounding_pow10 <= 0 {
+        return None;
+    }
+
+    let abs_decimal_places = i64::from(decimal_places.unsigned_abs());
+    (abs_decimal_places <= max_rounding_pow10).then_some(decimal_places)
+}
+
+fn validate_decimal_precision<T: DecimalType>(
+    value: T::Native,
+    precision: u8,
+    scale: i8,
+) -> Result<T::Native, ArrowError> {
+    T::validate_decimal_precision(value, precision, scale).map_err(|e| {
+        ArrowError::ComputeError(format!(
+            "Decimal overflow: rounded value exceeds precision {precision}: 
{e}"
+        ))
+    })?;
+    Ok(value)
+}
+
+fn calculate_new_precision_scale<T: DecimalType>(
+    precision: u8,
+    scale: i8,
+    decimal_places: Option<i32>,
+) -> Result<DataType> {
+    if let Some(decimal_places) = decimal_places {
+        let new_scale = output_scale_for_decimal(precision, scale, 
decimal_places);
+
+        // When rounding an integer decimal (scale == 0) to a negative 
`decimal_places`, a carry can
+        // add an extra digit to the integer part (e.g. 99 -> 100 when 
rounding to -1). This can
+        // only happen when the rounding position is within the existing 
precision.
+        let abs_decimal_places = decimal_places.unsigned_abs();
+        let new_precision = if scale == 0
+            && decimal_places < 0
+            && abs_decimal_places <= u32::from(precision)
+        {
+            precision.saturating_add(1).min(T::MAX_PRECISION)
+        } else {
+            precision
+        };
+        Ok(T::TYPE_CONSTRUCTOR(new_precision, new_scale))
+    } else {
+        let new_precision = precision.saturating_add(1).min(T::MAX_PRECISION);
+        Ok(T::TYPE_CONSTRUCTOR(new_precision, scale))
+    }
+}
+
+fn decimal_places_from_scalar(scalar: &ScalarValue) -> Result<i32> {
+    let out_of_range = |value: String| {
+        datafusion_common::DataFusionError::Execution(format!(
+            "round decimal_places {value} is out of supported i32 range"
+        ))
+    };
+    match scalar {
+        ScalarValue::Int8(Some(v)) => Ok(i32::from(*v)),
+        ScalarValue::Int16(Some(v)) => Ok(i32::from(*v)),
+        ScalarValue::Int32(Some(v)) => Ok(*v),
+        ScalarValue::Int64(Some(v)) => {
+            i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
+        }
+        ScalarValue::UInt8(Some(v)) => Ok(i32::from(*v)),
+        ScalarValue::UInt16(Some(v)) => Ok(i32::from(*v)),
+        ScalarValue::UInt32(Some(v)) => {
+            i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
+        }
+        ScalarValue::UInt64(Some(v)) => {
+            i32::try_from(*v).map_err(|_| out_of_range(v.to_string()))
+        }
+        other => exec_err!(
+            "Unexpected datatype for decimal_places: {}",
+            other.data_type()
+        ),
+    }
+}
 
 #[user_doc(
     doc_section(label = "Math Functions"),
@@ -117,15 +228,59 @@ impl ScalarUDFImpl for RoundFunc {
         &self.signature
     }
 
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        Ok(match arg_types[0].clone() {
-            Float32 => Float32,
-            dt @ Decimal128(_, _)
-            | dt @ Decimal256(_, _)
-            | dt @ Decimal32(_, _)
-            | dt @ Decimal64(_, _) => dt,
-            _ => Float64,
-        })
+    fn return_field_from_args(&self, args: ReturnFieldArgs) -> 
Result<FieldRef> {
+        let input_field = &args.arg_fields[0];
+        let input_type = input_field.data_type();
+
+        // If decimal_places is a scalar literal, we can incorporate it into 
the output type
+        // (scale reduction). Otherwise, keep the input scale as we can't pick 
a per-row scale.
+        //
+        // Note: `scalar_arguments` contains the original literal values 
(pre-coercion), so
+        // integer literals may appear as Int64 even though the signature 
coerces them to Int32.
+        let decimal_places: Option<i32> = match args.scalar_arguments.get(1) {
+            None => Some(0),    // No dp argument means default to 0
+            Some(None) => None, // dp is not a literal (e.g. column)
+            Some(Some(scalar)) if scalar.is_null() => Some(0), // null dp => 
default to 0
+            Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?),
+        };
+
+        // Calculate return type based on input type
+        // For decimals: reduce scale to decimal_places (reclaims precision 
for integer part)
+        // This matches Spark/DuckDB behavior where ROUND adjusts the scale
+        // BUT only if dp is a scalar literal - otherwise keep original scale 
and add
+        // extra precision to accommodate potential carry-over.
+        let return_type =
+            match input_type {
+                Float32 => Float32,
+                Decimal32(precision, scale) => calculate_new_precision_scale::<
+                    Decimal32Type,
+                >(
+                    *precision, *scale, decimal_places
+                )?,
+                Decimal64(precision, scale) => calculate_new_precision_scale::<
+                    Decimal64Type,
+                >(
+                    *precision, *scale, decimal_places
+                )?,
+                Decimal128(precision, scale) => 
calculate_new_precision_scale::<
+                    Decimal128Type,
+                >(
+                    *precision, *scale, decimal_places
+                )?,
+                Decimal256(precision, scale) => 
calculate_new_precision_scale::<
+                    Decimal256Type,
+                >(
+                    *precision, *scale, decimal_places
+                )?,
+                _ => Float64,
+            };
+
+        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
+        Ok(Arc::new(Field::new(self.name(), return_type, nullable)))
+    }
+
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        internal_err!("use return_field_from_args instead")
     }
 
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
@@ -141,7 +296,6 @@ impl ScalarUDFImpl for RoundFunc {
             &default_decimal_places
         };
 
-        // Scalar fast path for float and decimal types - avoid array 
conversion overhead
         if let (ColumnarValue::Scalar(value_scalar), 
ColumnarValue::Scalar(dp_scalar)) =
             (&args.args[0], decimal_places)
         {
@@ -159,48 +313,132 @@ impl ScalarUDFImpl for RoundFunc {
                 );
             };
 
-            match value_scalar {
-                ScalarValue::Float32(Some(v)) => {
+            match (value_scalar, args.return_type()) {
+                (ScalarValue::Float32(Some(v)), _) => {
                     let rounded = round_float(*v, dp)?;
                     Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
                 }
-                ScalarValue::Float64(Some(v)) => {
+                (ScalarValue::Float64(Some(v)), _) => {
                     let rounded = round_float(*v, dp)?;
                     Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
                 }
-                ScalarValue::Decimal128(Some(v), precision, scale) => {
-                    let rounded = round_decimal(*v, *scale, dp)?;
+                (
+                    ScalarValue::Decimal32(Some(v), in_precision, scale),
+                    Decimal32(out_precision, out_scale),
+                ) => {
+                    let rounded =
+                        round_decimal_or_zero(*v, *in_precision, *scale, 
*out_scale, dp)?;
+                    let rounded = if *out_precision == 
Decimal32Type::MAX_PRECISION
+                        && *scale == 0
+                        && dp < 0
+                    {
+                        // With scale == 0 and negative dp, rounding can carry 
into an additional
+                        // digit (e.g. 99 -> 100). If we're already at max 
precision we can't widen
+                        // the type, so validate and error rather than 
producing an invalid decimal.
+                        validate_decimal_precision::<Decimal32Type>(
+                            rounded,
+                            *out_precision,
+                            *out_scale,
+                        )
+                    } else {
+                        Ok(rounded)
+                    }?;
                     let scalar =
-                        ScalarValue::Decimal128(Some(rounded), *precision, 
*scale);
+                        ScalarValue::Decimal32(Some(rounded), *out_precision, 
*out_scale);
                     Ok(ColumnarValue::Scalar(scalar))
                 }
-                ScalarValue::Decimal256(Some(v), precision, scale) => {
-                    let rounded = round_decimal(*v, *scale, dp)?;
+                (
+                    ScalarValue::Decimal64(Some(v), in_precision, scale),
+                    Decimal64(out_precision, out_scale),
+                ) => {
+                    let rounded =
+                        round_decimal_or_zero(*v, *in_precision, *scale, 
*out_scale, dp)?;
+                    let rounded = if *out_precision == 
Decimal64Type::MAX_PRECISION
+                        && *scale == 0
+                        && dp < 0
+                    {
+                        // See Decimal32 branch for details.
+                        validate_decimal_precision::<Decimal64Type>(
+                            rounded,
+                            *out_precision,
+                            *out_scale,
+                        )
+                    } else {
+                        Ok(rounded)
+                    }?;
                     let scalar =
-                        ScalarValue::Decimal256(Some(rounded), *precision, 
*scale);
+                        ScalarValue::Decimal64(Some(rounded), *out_precision, 
*out_scale);
                     Ok(ColumnarValue::Scalar(scalar))
                 }
-                ScalarValue::Decimal64(Some(v), precision, scale) => {
-                    let rounded = round_decimal(*v, *scale, dp)?;
-                    let scalar =
-                        ScalarValue::Decimal64(Some(rounded), *precision, 
*scale);
+                (
+                    ScalarValue::Decimal128(Some(v), in_precision, scale),
+                    Decimal128(out_precision, out_scale),
+                ) => {
+                    let rounded =
+                        round_decimal_or_zero(*v, *in_precision, *scale, 
*out_scale, dp)?;
+                    let rounded = if *out_precision == 
Decimal128Type::MAX_PRECISION
+                        && *scale == 0
+                        && dp < 0
+                    {
+                        // See Decimal32 branch for details.
+                        validate_decimal_precision::<Decimal128Type>(
+                            rounded,
+                            *out_precision,
+                            *out_scale,
+                        )
+                    } else {
+                        Ok(rounded)
+                    }?;
+                    let scalar = ScalarValue::Decimal128(
+                        Some(rounded),
+                        *out_precision,
+                        *out_scale,
+                    );
                     Ok(ColumnarValue::Scalar(scalar))
                 }
-                ScalarValue::Decimal32(Some(v), precision, scale) => {
-                    let rounded = round_decimal(*v, *scale, dp)?;
-                    let scalar =
-                        ScalarValue::Decimal32(Some(rounded), *precision, 
*scale);
+                (
+                    ScalarValue::Decimal256(Some(v), in_precision, scale),
+                    Decimal256(out_precision, out_scale),
+                ) => {
+                    let rounded =
+                        round_decimal_or_zero(*v, *in_precision, *scale, 
*out_scale, dp)?;
+                    let rounded = if *out_precision == 
Decimal256Type::MAX_PRECISION
+                        && *scale == 0
+                        && dp < 0
+                    {
+                        // See Decimal32 branch for details.
+                        validate_decimal_precision::<Decimal256Type>(
+                            rounded,
+                            *out_precision,
+                            *out_scale,
+                        )
+                    } else {
+                        Ok(rounded)
+                    }?;
+                    let scalar = ScalarValue::Decimal256(
+                        Some(rounded),
+                        *out_precision,
+                        *out_scale,
+                    );
                     Ok(ColumnarValue::Scalar(scalar))
                 }
-                _ => {
+                (ScalarValue::Null, _) => 
ColumnarValue::Scalar(ScalarValue::Null)
+                    .cast_to(args.return_type(), None),
+                (value_scalar, return_type) => {
                     internal_err!(
-                        "Unexpected datatype for value: {}",
-                        value_scalar.data_type()
+                        "Unexpected datatype for round(value, decimal_places): 
value {}, return type {}",
+                        value_scalar.data_type(),
+                        return_type
                     )
                 }
             }
         } else {
-            round_columnar(&args.args[0], decimal_places, args.number_rows)
+            round_columnar(
+                &args.args[0],
+                decimal_places,
+                args.number_rows,
+                args.return_type(),
+            )
         }
     }
 
@@ -228,13 +466,15 @@ fn round_columnar(
     value: &ColumnarValue,
     decimal_places: &ColumnarValue,
     number_rows: usize,
+    return_type: &DataType,
 ) -> Result<ColumnarValue> {
     let value_array = value.to_array(number_rows)?;
     let both_scalars = matches!(value, ColumnarValue::Scalar(_))
         && matches!(decimal_places, ColumnarValue::Scalar(_));
+    let decimal_places_is_array = matches!(decimal_places, 
ColumnarValue::Array(_));
 
-    let arr: ArrayRef = match value_array.data_type() {
-        Float64 => {
+    let arr: ArrayRef = match (value_array.data_type(), return_type) {
+        (Float64, _) => {
             let result = calculate_binary_math::<Float64Type, Int32Type, 
Float64Type, _>(
                 value_array.as_ref(),
                 decimal_places,
@@ -242,7 +482,7 @@ fn round_columnar(
             )?;
             result as _
         }
-        Float32 => {
+        (Float32, _) => {
             let result = calculate_binary_math::<Float32Type, Int32Type, 
Float32Type, _>(
                 value_array.as_ref(),
                 decimal_places,
@@ -250,7 +490,8 @@ fn round_columnar(
             )?;
             result as _
         }
-        Decimal32(precision, scale) => {
+        (Decimal32(input_precision, scale), Decimal32(precision, new_scale)) 
=> {
+            // reduce scale to reclaim integer precision
             let result = calculate_binary_decimal_math::<
                 Decimal32Type,
                 Int32Type,
@@ -259,13 +500,34 @@ fn round_columnar(
             >(
                 value_array.as_ref(),
                 decimal_places,
-                |v, dp| round_decimal(v, *scale, dp),
+                |v, dp| {
+                    let rounded = round_decimal_or_zero(
+                        v,
+                        *input_precision,
+                        *scale,
+                        *new_scale,
+                        dp,
+                    )?;
+                    if *precision == Decimal32Type::MAX_PRECISION
+                        && (decimal_places_is_array || (*scale == 0 && dp < 0))
+                    {
+                        // If we're already at max precision, we can't widen 
the result type. For
+                        // dp arrays, or for scale == 0 with negative dp, 
rounding can overflow the
+                        // fixed-precision type. Validate per-row and return 
an error instead of
+                        // producing an invalid decimal that Arrow may display 
incorrectly.
+                        validate_decimal_precision::<Decimal32Type>(
+                            rounded, *precision, *new_scale,
+                        )
+                    } else {
+                        Ok(rounded)
+                    }
+                },
                 *precision,
-                *scale,
+                *new_scale,
             )?;
             result as _
         }
-        Decimal64(precision, scale) => {
+        (Decimal64(input_precision, scale), Decimal64(precision, new_scale)) 
=> {
             let result = calculate_binary_decimal_math::<
                 Decimal64Type,
                 Int32Type,
@@ -274,13 +536,31 @@ fn round_columnar(
             >(
                 value_array.as_ref(),
                 decimal_places,
-                |v, dp| round_decimal(v, *scale, dp),
+                |v, dp| {
+                    let rounded = round_decimal_or_zero(
+                        v,
+                        *input_precision,
+                        *scale,
+                        *new_scale,
+                        dp,
+                    )?;
+                    if *precision == Decimal64Type::MAX_PRECISION
+                        && (decimal_places_is_array || (*scale == 0 && dp < 0))
+                    {
+                        // See Decimal32 branch for details.
+                        validate_decimal_precision::<Decimal64Type>(
+                            rounded, *precision, *new_scale,
+                        )
+                    } else {
+                        Ok(rounded)
+                    }
+                },
                 *precision,
-                *scale,
+                *new_scale,
             )?;
             result as _
         }
-        Decimal128(precision, scale) => {
+        (Decimal128(input_precision, scale), Decimal128(precision, new_scale)) 
=> {
             let result = calculate_binary_decimal_math::<
                 Decimal128Type,
                 Int32Type,
@@ -289,13 +569,31 @@ fn round_columnar(
             >(
                 value_array.as_ref(),
                 decimal_places,
-                |v, dp| round_decimal(v, *scale, dp),
+                |v, dp| {
+                    let rounded = round_decimal_or_zero(
+                        v,
+                        *input_precision,
+                        *scale,
+                        *new_scale,
+                        dp,
+                    )?;
+                    if *precision == Decimal128Type::MAX_PRECISION
+                        && (decimal_places_is_array || (*scale == 0 && dp < 0))
+                    {
+                        // See Decimal32 branch for details.
+                        validate_decimal_precision::<Decimal128Type>(
+                            rounded, *precision, *new_scale,
+                        )
+                    } else {
+                        Ok(rounded)
+                    }
+                },
                 *precision,
-                *scale,
+                *new_scale,
             )?;
             result as _
         }
-        Decimal256(precision, scale) => {
+        (Decimal256(input_precision, scale), Decimal256(precision, new_scale)) 
=> {
             let result = calculate_binary_decimal_math::<
                 Decimal256Type,
                 Int32Type,
@@ -304,13 +602,31 @@ fn round_columnar(
             >(
                 value_array.as_ref(),
                 decimal_places,
-                |v, dp| round_decimal(v, *scale, dp),
+                |v, dp| {
+                    let rounded = round_decimal_or_zero(
+                        v,
+                        *input_precision,
+                        *scale,
+                        *new_scale,
+                        dp,
+                    )?;
+                    if *precision == Decimal256Type::MAX_PRECISION
+                        && (decimal_places_is_array || (*scale == 0 && dp < 0))
+                    {
+                        // See Decimal32 branch for details.
+                        validate_decimal_precision::<Decimal256Type>(
+                            rounded, *precision, *new_scale,
+                        )
+                    } else {
+                        Ok(rounded)
+                    }
+                },
                 *precision,
-                *scale,
+                *new_scale,
             )?;
             result as _
         }
-        other => exec_err!("Unsupported data type {other:?} for function 
round")?,
+        (other, _) => exec_err!("Unsupported data type {other:?} for function 
round")?,
     };
 
     if both_scalars {
@@ -334,19 +650,17 @@ where
 
 fn round_decimal<V: ArrowNativeTypeOp>(
     value: V,
-    scale: i8,
+    input_scale: i8,
+    output_scale: i8,
     decimal_places: i32,
 ) -> Result<V, ArrowError> {
-    let diff = i64::from(scale) - i64::from(decimal_places);
+    let diff = i64::from(input_scale) - i64::from(decimal_places);
     if diff <= 0 {
         return Ok(value);
     }
 
-    let diff: u32 = diff.try_into().map_err(|e| {
-        ArrowError::ComputeError(format!(
-            "Invalid value for decimal places: {decimal_places}: {e}"
-        ))
-    })?;
+    debug_assert!(diff <= i64::from(u32::MAX));
+    let diff = diff as u32;
 
     let one = V::ONE;
     let two = V::from_usize(2).ok_or_else(|| {
@@ -358,7 +672,7 @@ fn round_decimal<V: ArrowNativeTypeOp>(
 
     let factor = ten.pow_checked(diff).map_err(|_| {
         ArrowError::ComputeError(format!(
-            "Overflow while rounding decimal with scale {scale} and decimal 
places {decimal_places}"
+            "Overflow while rounding decimal with scale {input_scale} and 
decimal places {decimal_places}"
         ))
     })?;
 
@@ -377,11 +691,44 @@ fn round_decimal<V: ArrowNativeTypeOp>(
         })?;
     }
 
+    // `quotient` is the rounded value at scale `decimal_places`. Rescale to 
the desired
+    // `output_scale` (which is always >= `decimal_places` in cases where diff 
> 0).
+    let scale_shift = i64::from(output_scale) - i64::from(decimal_places);
+    if scale_shift == 0 {
+        return Ok(quotient);
+    }
+
+    debug_assert!(scale_shift > 0);
+    debug_assert!(scale_shift <= i64::from(u32::MAX));
+    let scale_shift = scale_shift as u32;
+    let shift_factor = ten.pow_checked(scale_shift).map_err(|_| {
+        ArrowError::ComputeError(format!(
+            "Overflow while rounding decimal with scale {input_scale} and 
decimal places {decimal_places}"
+        ))
+    })?;
     quotient
-        .mul_checked(factor)
+        .mul_checked(shift_factor)
         .map_err(|_| ArrowError::ComputeError("Overflow while rounding 
decimal".into()))
 }
 
+fn round_decimal_or_zero<V: ArrowNativeTypeOp>(
+    value: V,
+    precision: u8,
+    input_scale: i8,
+    output_scale: i8,
+    decimal_places: i32,
+) -> Result<V, ArrowError> {
+    if let Some(dp) =
+        normalize_decimal_places_for_decimal(decimal_places, precision, 
input_scale)
+    {
+        round_decimal(value, input_scale, output_scale, dp)
+    } else {
+        V::from_usize(0).ok_or_else(|| {
+            ArrowError::ComputeError("Internal error: could not create 
constant 0".into())
+        })
+    }
+}
+
 #[cfg(test)]
 mod test {
     use std::sync::Arc;
@@ -397,12 +744,17 @@ mod test {
         decimal_places: Option<ArrayRef>,
     ) -> Result<ArrayRef, DataFusionError> {
         let number_rows = value.len();
+        // NOTE: For decimal inputs, the actual ROUND return type can differ 
from the
+        // input type (scale reduction for literal `decimal_places`). These 
unit tests
+        // only exercise Float32/Float64 behavior.
+        let return_type = value.data_type().clone();
         let value = ColumnarValue::Array(value);
         let decimal_places = decimal_places
             .map(ColumnarValue::Array)
             .unwrap_or_else(|| 
ColumnarValue::Scalar(ScalarValue::Int32(Some(0))));
 
-        let result = super::round_columnar(&value, &decimal_places, 
number_rows)?;
+        let result =
+            super::round_columnar(&value, &decimal_places, number_rows, 
&return_type)?;
         match result {
             ColumnarValue::Array(array) => Ok(array),
             ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1),
diff --git a/datafusion/sqllogictest/test_files/decimal.slt 
b/datafusion/sqllogictest/test_files/decimal.slt
index f53f493929..eca2c88bb5 100644
--- a/datafusion/sqllogictest/test_files/decimal.slt
+++ b/datafusion/sqllogictest/test_files/decimal.slt
@@ -782,7 +782,7 @@ query TR
 select arrow_typeof(round(173975140545.855, 2)),
        round(173975140545.855, 2);
 ----
-Decimal128(15, 3) 173975140545.86
+Decimal128(15, 2) 173975140545.86
 
 # smoke test for decimal parsing
 query RT
diff --git a/datafusion/sqllogictest/test_files/scalar.slt 
b/datafusion/sqllogictest/test_files/scalar.slt
index 7a4a81b5fa..681540a29d 100644
--- a/datafusion/sqllogictest/test_files/scalar.slt
+++ b/datafusion/sqllogictest/test_files/scalar.slt
@@ -923,7 +923,7 @@ select round(a), round(b), round(c) from small_floats;
 
 # round with too large
 #  max Int32 is 2147483647
-query error Arrow error: Cast error: Can't cast value 2147483648 to type Int32
+query error round decimal_places 2147483648 is out of supported i32 range
 select round(3.14, 2147483648);
 
 # with array
@@ -931,11 +931,12 @@ query error Arrow error: Cast error: Can't cast value 
2147483649 to type Int32
 select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14, 
2147483649);
 
 # round decimal should not cast to float
+# scale reduces to match decimal_places
 query TR
 select arrow_typeof(round('173975140545.855'::decimal(38,10), 2)),
        round('173975140545.855'::decimal(38,10), 2);
 ----
-Decimal128(38, 10) 173975140545.86
+Decimal128(38, 2) 173975140545.86
 
 # round decimal ties away from zero
 query RRRR
@@ -951,15 +952,74 @@ query TR
 select arrow_typeof(round('12345.55'::decimal(10,2), -1)),
        round('12345.55'::decimal(10,2), -1);
 ----
-Decimal128(10, 2) 12350
+Decimal128(10, 0) 12350
+
+# round decimal scale 0 negative places (carry can require extra precision)
+query TR
+select arrow_typeof(round('99'::decimal(2,0), -1)),
+       round('99'::decimal(2,0), -1);
+----
+Decimal128(3, 0) 100
 
 # round decimal256 keeps decimals
 query TR
 select arrow_typeof(round('1234.5678'::decimal(50,4), 2)),
        round('1234.5678'::decimal(50,4), 2);
 ----
-Decimal256(50, 4) 1234.57
+Decimal256(50, 2) 1234.57
+
+# round decimal with carry-over (reduce scale)
+# Scale reduces from 1 to 0, allowing extra digit for carry-over
+query TRRR
+select arrow_typeof(round('999.9'::decimal(4,1))),
+       round('999.9'::decimal(4,1)),
+       round('-999.9'::decimal(4,1)),
+       round('99.99'::decimal(4,2));
+----
+Decimal128(4, 0) 1000 -1000 100
 
+# round decimal with carry-over and non-literal decimal_places (increase 
precision)
+# Scale can't be reduced when decimal_places isn't a constant, so precision 
increases.
+query TR
+select arrow_typeof(round(val, dp)), round(val, dp)
+from (values (cast('999.9' as decimal(4,1)), 0)) as t(val, dp);
+----
+Decimal128(5, 1) 1000
+
+# round decimal at max precision now works (scale reduction handles overflow)
+query TR
+select 
arrow_typeof(round('9999999999999999999999999999999999999.9'::decimal(38,1))),
+       round('9999999999999999999999999999999999999.9'::decimal(38,1));
+----
+Decimal128(38, 0) 10000000000000000000000000000000000000
+
+# round decimal at max precision with non-literal decimal_places can overflow
+query error Decimal overflow: rounded value exceeds precision 38
+select round(val, dp)
+from (values (cast('9999999999999999999999999999999999999.9' as 
decimal(38,1)), 0)) as t(val, dp);
+
+# round decimal with negative scale
+query TRRR
+select arrow_typeof(round(cast(500 as decimal(10,-2)), -3)),
+       round(cast(500 as decimal(10,-2)), -3),
+       round(cast(400 as decimal(10,-2)), -3),
+       round(cast(-500 as decimal(10,-2)), -3);
+----
+Decimal128(10, -3) 1000 0 -1000
+
+# round decimal with negative scale and carry-over
+query TR
+select arrow_typeof(round(cast(999999999900 as decimal(10,-2)), -3)),
+       round(cast(999999999900 as decimal(10,-2)), -3);
+----
+Decimal128(10, -3) 1000000000000
+
+# round decimal with very small decimal_places (i32::MIN) should not error
+query TR
+select arrow_typeof(round('123.45'::decimal(5,2), -2147483648)),
+       round('123.45'::decimal(5,2), -2147483648);
+----
+Decimal128(5, 0) 0
 
 ## signum
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to