This is an automated email from the ASF dual-hosted git repository.
wangzhen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 726c4b77a fix: format decimal to string when casting decimal with
overflow (#2916)
726c4b77a is described below
commit 726c4b77af7fa964f677388114df74f12ddcfdb1
Author: Manu Zhang <[email protected]>
AuthorDate: Mon Jan 5 13:35:44 2026 +0800
fix: format decimal to string when casting decimal with overflow (#2916)
---
native/spark-expr/src/conversion_funcs/cast.rs | 59 +++++++++++++++++++---
.../scala/org/apache/comet/CometCastSuite.scala | 20 ++++++++
2 files changed, 72 insertions(+), 7 deletions(-)
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs
b/native/spark-expr/src/conversion_funcs/cast.rs
index 501191708..a2e12168d 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -685,11 +685,18 @@ macro_rules! cast_decimal_to_int16_down {
.map(|value| match value {
Some(value) => {
let divisor = 10_i128.pow($scale as u32);
- let (truncated, decimal) = (value / divisor, (value %
divisor).abs());
+ let truncated = value / divisor;
let is_overflow = truncated.abs() > i32::MAX.into();
if is_overflow {
return Err(cast_overflow(
- &format!("{}.{}BD", truncated, decimal),
+ &format!(
+ "{}BD",
+ format_decimal_str(
+ &value.to_string(),
+ $precision as usize,
+ $scale
+ )
+ ),
&format!("DECIMAL({},{})", $precision, $scale),
$dest_type_str,
));
@@ -698,7 +705,14 @@ macro_rules! cast_decimal_to_int16_down {
<$rust_dest_type>::try_from(i32_value)
.map_err(|_| {
cast_overflow(
- &format!("{}.{}BD", truncated, decimal),
+ &format!(
+ "{}BD",
+ format_decimal_str(
+ &value.to_string(),
+ $precision as usize,
+ $scale
+ )
+ ),
&format!("DECIMAL({},{})", $precision,
$scale),
$dest_type_str,
)
@@ -748,11 +762,18 @@ macro_rules! cast_decimal_to_int32_up {
.map(|value| match value {
Some(value) => {
let divisor = 10_i128.pow($scale as u32);
- let (truncated, decimal) = (value / divisor, (value %
divisor).abs());
+ let truncated = value / divisor;
let is_overflow = truncated.abs() >
$max_dest_val.into();
if is_overflow {
return Err(cast_overflow(
- &format!("{}.{}BD", truncated, decimal),
+ &format!(
+ "{}BD",
+ format_decimal_str(
+ &value.to_string(),
+ $precision as usize,
+ $scale
+ )
+ ),
&format!("DECIMAL({},{})", $precision, $scale),
$dest_type_str,
));
@@ -780,6 +801,30 @@ macro_rules! cast_decimal_to_int32_up {
}};
}
+// copied from arrow::dataTypes::Decimal128Type since
Decimal128Type::format_decimal can't be called directly
+fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
+ let (sign, rest) = match value_str.strip_prefix('-') {
+ Some(stripped) => ("-", stripped),
+ None => ("", value_str),
+ };
+ let bound = precision.min(rest.len()) + sign.len();
+ let value_str = &value_str[0..bound];
+
+ if scale == 0 {
+ value_str.to_string()
+ } else if scale < 0 {
+ let padding = value_str.len() + scale.unsigned_abs() as usize;
+ format!("{value_str:0<padding$}")
+ } else if rest.len() > scale as usize {
+ // Decimal separator is in the middle of the string
+ let (whole, decimal) = value_str.split_at(value_str.len() - scale as
usize);
+ format!("{whole}.{decimal}")
+ } else {
+ // String has to be padded
+ format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
+ }
+}
+
impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
@@ -1866,12 +1911,12 @@ fn spark_cast_nonintegral_numeric_to_integral(
),
(DataType::Decimal128(precision, scale), DataType::Int8) => {
cast_decimal_to_int16_down!(
- array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
+ array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale
)
}
(DataType::Decimal128(precision, scale), DataType::Int16) => {
cast_decimal_to_int16_down!(
- array, eval_mode, Int16Array, i16, "SMALLINT", precision,
*scale
+ array, eval_mode, Int16Array, i16, "SMALLINT", *precision,
*scale
)
}
(DataType::Decimal128(precision, scale), DataType::Int32) => {
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 1892749be..8a68df382 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -529,6 +529,9 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
test("cast DecimalType(10,2) to ShortType") {
castTest(generateDecimalsPrecision10Scale2(), DataTypes.ShortType)
+ castTest(
+ generateDecimalsPrecision10Scale2(Seq(BigDecimal("-96833550.07"))),
+ DataTypes.ShortType)
}
test("cast DecimalType(10,2) to IntegerType") {
@@ -553,14 +556,23 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
test("cast DecimalType(38,18) to ShortType") {
castTest(generateDecimalsPrecision38Scale18(), DataTypes.ShortType)
+ castTest(
+
generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))),
+ DataTypes.ShortType)
}
test("cast DecimalType(38,18) to IntegerType") {
castTest(generateDecimalsPrecision38Scale18(), DataTypes.IntegerType)
+ castTest(
+
generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))),
+ DataTypes.IntegerType)
}
test("cast DecimalType(38,18) to LongType") {
castTest(generateDecimalsPrecision38Scale18(), DataTypes.LongType)
+ castTest(
+
generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))),
+ DataTypes.LongType)
}
test("cast DecimalType(10,2) to StringType") {
@@ -1205,6 +1217,10 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
BigDecimal("32768.678"),
BigDecimal("123456.789"),
BigDecimal("99999999.999"))
+ generateDecimalsPrecision10Scale2(values)
+ }
+
+ private def generateDecimalsPrecision10Scale2(values: Seq[BigDecimal]):
DataFrame = {
withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10,
2))).drop("b")
}
@@ -1227,6 +1243,10 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
// Long Max
BigDecimal("9223372036854775808.234567"),
BigDecimal("99999999999999999999.999999999999"))
+ generateDecimalsPrecision38Scale18(values)
+ }
+
+ private def generateDecimalsPrecision38Scale18(values: Seq[BigDecimal]):
DataFrame = {
withNulls(values).toDF("a")
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]