This is an automated email from the ASF dual-hosted git repository.

wangzhen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 726c4b77a fix: format decimal to string when casting decimal with 
overflow (#2916)
726c4b77a is described below

commit 726c4b77af7fa964f677388114df74f12ddcfdb1
Author: Manu Zhang <[email protected]>
AuthorDate: Mon Jan 5 13:35:44 2026 +0800

    fix: format decimal to string when casting decimal with overflow (#2916)
---
 native/spark-expr/src/conversion_funcs/cast.rs     | 59 +++++++++++++++++++---
 .../scala/org/apache/comet/CometCastSuite.scala    | 20 ++++++++
 2 files changed, 72 insertions(+), 7 deletions(-)

diff --git a/native/spark-expr/src/conversion_funcs/cast.rs 
b/native/spark-expr/src/conversion_funcs/cast.rs
index 501191708..a2e12168d 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -685,11 +685,18 @@ macro_rules! cast_decimal_to_int16_down {
                 .map(|value| match value {
                     Some(value) => {
                         let divisor = 10_i128.pow($scale as u32);
-                        let (truncated, decimal) = (value / divisor, (value % 
divisor).abs());
+                        let truncated = value / divisor;
                         let is_overflow = truncated.abs() > i32::MAX.into();
                         if is_overflow {
                             return Err(cast_overflow(
-                                &format!("{}.{}BD", truncated, decimal),
+                                &format!(
+                                    "{}BD",
+                                    format_decimal_str(
+                                        &value.to_string(),
+                                        $precision as usize,
+                                        $scale
+                                    )
+                                ),
                                 &format!("DECIMAL({},{})", $precision, $scale),
                                 $dest_type_str,
                             ));
@@ -698,7 +705,14 @@ macro_rules! cast_decimal_to_int16_down {
                         <$rust_dest_type>::try_from(i32_value)
                             .map_err(|_| {
                                 cast_overflow(
-                                    &format!("{}.{}BD", truncated, decimal),
+                                    &format!(
+                                        "{}BD",
+                                        format_decimal_str(
+                                            &value.to_string(),
+                                            $precision as usize,
+                                            $scale
+                                        )
+                                    ),
                                     &format!("DECIMAL({},{})", $precision, 
$scale),
                                     $dest_type_str,
                                 )
@@ -748,11 +762,18 @@ macro_rules! cast_decimal_to_int32_up {
                 .map(|value| match value {
                     Some(value) => {
                         let divisor = 10_i128.pow($scale as u32);
-                        let (truncated, decimal) = (value / divisor, (value % 
divisor).abs());
+                        let truncated = value / divisor;
                         let is_overflow = truncated.abs() > 
$max_dest_val.into();
                         if is_overflow {
                             return Err(cast_overflow(
-                                &format!("{}.{}BD", truncated, decimal),
+                                &format!(
+                                    "{}BD",
+                                    format_decimal_str(
+                                        &value.to_string(),
+                                        $precision as usize,
+                                        $scale
+                                    )
+                                ),
                                 &format!("DECIMAL({},{})", $precision, $scale),
                                 $dest_type_str,
                             ));
@@ -780,6 +801,30 @@ macro_rules! cast_decimal_to_int32_up {
     }};
 }
 
+// copied from arrow::dataTypes::Decimal128Type since 
Decimal128Type::format_decimal can't be called directly
+fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
+    let (sign, rest) = match value_str.strip_prefix('-') {
+        Some(stripped) => ("-", stripped),
+        None => ("", value_str),
+    };
+    let bound = precision.min(rest.len()) + sign.len();
+    let value_str = &value_str[0..bound];
+
+    if scale == 0 {
+        value_str.to_string()
+    } else if scale < 0 {
+        let padding = value_str.len() + scale.unsigned_abs() as usize;
+        format!("{value_str:0<padding$}")
+    } else if rest.len() > scale as usize {
+        // Decimal separator is in the middle of the string
+        let (whole, decimal) = value_str.split_at(value_str.len() - scale as 
usize);
+        format!("{whole}.{decimal}")
+    } else {
+        // String has to be padded
+        format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
+    }
+}
+
 impl Cast {
     pub fn new(
         child: Arc<dyn PhysicalExpr>,
@@ -1866,12 +1911,12 @@ fn spark_cast_nonintegral_numeric_to_integral(
         ),
         (DataType::Decimal128(precision, scale), DataType::Int8) => {
             cast_decimal_to_int16_down!(
-                array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
+                array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale
             )
         }
         (DataType::Decimal128(precision, scale), DataType::Int16) => {
             cast_decimal_to_int16_down!(
-                array, eval_mode, Int16Array, i16, "SMALLINT", precision, 
*scale
+                array, eval_mode, Int16Array, i16, "SMALLINT", *precision, 
*scale
             )
         }
         (DataType::Decimal128(precision, scale), DataType::Int32) => {
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 1892749be..8a68df382 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -529,6 +529,9 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
   test("cast DecimalType(10,2) to ShortType") {
     castTest(generateDecimalsPrecision10Scale2(), DataTypes.ShortType)
+    castTest(
+      generateDecimalsPrecision10Scale2(Seq(BigDecimal("-96833550.07"))),
+      DataTypes.ShortType)
   }
 
   test("cast DecimalType(10,2) to IntegerType") {
@@ -553,14 +556,23 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
   test("cast DecimalType(38,18) to ShortType") {
     castTest(generateDecimalsPrecision38Scale18(), DataTypes.ShortType)
+    castTest(
+      
generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))),
+      DataTypes.ShortType)
   }
 
   test("cast DecimalType(38,18) to IntegerType") {
     castTest(generateDecimalsPrecision38Scale18(), DataTypes.IntegerType)
+    castTest(
+      
generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))),
+      DataTypes.IntegerType)
   }
 
   test("cast DecimalType(38,18) to LongType") {
     castTest(generateDecimalsPrecision38Scale18(), DataTypes.LongType)
+    castTest(
+      
generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))),
+      DataTypes.LongType)
   }
 
   test("cast DecimalType(10,2) to StringType") {
@@ -1205,6 +1217,10 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       BigDecimal("32768.678"),
       BigDecimal("123456.789"),
       BigDecimal("99999999.999"))
+    generateDecimalsPrecision10Scale2(values)
+  }
+
+  private def generateDecimalsPrecision10Scale2(values: Seq[BigDecimal]): 
DataFrame = {
     withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10, 
2))).drop("b")
   }
 
@@ -1227,6 +1243,10 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       // Long Max
       BigDecimal("9223372036854775808.234567"),
       BigDecimal("99999999999999999999.999999999999"))
+    generateDecimalsPrecision38Scale18(values)
+  }
+
+  private def generateDecimalsPrecision38Scale18(values: Seq[BigDecimal]): 
DataFrame = {
     withNulls(values).toDF("a")
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to