This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new a6cfadb2a feat: Improve compatibility of string to decimal cast (#2925)
a6cfadb2a is described below

commit a6cfadb2a32ab2f02e47462e03219f89459edb5c
Author: B Vadlamani <[email protected]>
AuthorDate: Mon Dec 22 12:32:35 2025 -0800

    feat: Improve compatibility of string to decimal cast (#2925)
---
 docs/source/user-guide/latest/compatibility.md     |   3 +-
 native/spark-expr/src/conversion_funcs/cast.rs     | 319 ++++++++++++++++++++-
 .../org/apache/comet/expressions/CometCast.scala   |   5 +-
 .../scala/org/apache/comet/CometCastSuite.scala    |  92 +++++-
 4 files changed, 395 insertions(+), 24 deletions(-)

diff --git a/docs/source/user-guide/latest/compatibility.md 
b/docs/source/user-guide/latest/compatibility.md
index 60e2234f5..58dd8d6ab 100644
--- a/docs/source/user-guide/latest/compatibility.md
+++ b/docs/source/user-guide/latest/compatibility.md
@@ -183,7 +183,8 @@ The following cast operations are not compatible with Spark 
for all inputs and a
 | double | decimal  | There can be rounding differences |
 | string | float  | Does not support inputs ending with 'd' or 'f'. Does not 
support 'inf'. Does not support ANSI mode. |
 | string | double  | Does not support inputs ending with 'd' or 'f'. Does not 
support 'inf'. Does not support ANSI mode. |
-| string | decimal  | Does not support inputs ending with 'd' or 'f'. Does not 
support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input 
contains no digits |
+| string | decimal  | Does not support fullwidth unicode digits (e.g \\uFF10)
+or strings containing null bytes (e.g \\u0000) |
 | string | timestamp  | Not all valid formats are supported |
 <!-- prettier-ignore-end -->
 <!--END:INCOMPAT_CAST_TABLE-->
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs 
b/native/spark-expr/src/conversion_funcs/cast.rs
index 12a147c6e..6b69c7288 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -20,12 +20,13 @@ use crate::{timezone, BinaryOutputStyle};
 use crate::{EvalMode, SparkError, SparkResult};
 use arrow::array::builder::StringBuilder;
 use arrow::array::{
-    BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, 
ListArray, StringArray,
-    StructArray,
+    BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, 
ListArray,
+    PrimitiveBuilder, StringArray, StructArray,
 };
 use arrow::compute::can_cast_types;
 use arrow::datatypes::{
-    ArrowDictionaryKeyType, ArrowNativeType, DataType, GenericBinaryType, 
Schema,
+    i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, 
GenericBinaryType,
+    Schema,
 };
 use arrow::{
     array::{
@@ -224,9 +225,7 @@ fn can_cast_from_string(to_type: &DataType, options: 
&SparkCastOptions) -> bool
         }
         Decimal128(_, _) => {
             // https://github.com/apache/datafusion-comet/issues/325
-            // Does not support inputs ending with 'd' or 'f'. Does not 
support 'inf'.
-            // Does not support ANSI mode. Returns 0.0 instead of null if 
input contains no digits
-
+            // Does not support fullwidth digits and null byte handling.
             options.allow_incompat
         }
         Date32 | Date64 => {
@@ -976,6 +975,12 @@ fn cast_array(
             cast_string_to_timestamp(&array, to_type, eval_mode, 
&cast_options.timezone)
         }
         (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
+        (Utf8 | LargeUtf8, Decimal128(precision, scale)) => {
+            cast_string_to_decimal(&array, to_type, precision, scale, 
eval_mode)
+        }
+        (Utf8 | LargeUtf8, Decimal256(precision, scale)) => {
+            cast_string_to_decimal(&array, to_type, precision, scale, 
eval_mode)
+        }
         (Int64, Int32)
         | (Int64, Int16)
         | (Int64, Int8)
@@ -1187,7 +1192,7 @@ fn is_datafusion_spark_compatible(
         ),
         DataType::Utf8 if allow_incompat => matches!(
             to_type,
-            DataType::Binary | DataType::Float32 | DataType::Float64 | 
DataType::Decimal128(_, _)
+            DataType::Binary | DataType::Float32 | DataType::Float64
         ),
         DataType::Utf8 => matches!(to_type, DataType::Binary),
         DataType::Date32 => matches!(to_type, DataType::Utf8),
@@ -1976,6 +1981,306 @@ fn do_cast_string_to_int<
     Ok(Some(result))
 }
 
+fn cast_string_to_decimal(
+    array: &ArrayRef,
+    to_type: &DataType,
+    precision: &u8,
+    scale: &i8,
+    eval_mode: EvalMode,
+) -> SparkResult<ArrayRef> {
+    match to_type {
+        DataType::Decimal128(_, _) => {
+            cast_string_to_decimal128_impl(array, eval_mode, *precision, 
*scale)
+        }
+        DataType::Decimal256(_, _) => {
+            cast_string_to_decimal256_impl(array, eval_mode, *precision, 
*scale)
+        }
+        _ => Err(SparkError::Internal(format!(
+            "Unexpected type in cast_string_to_decimal: {:?}",
+            to_type
+        ))),
+    }
+}
+
+fn cast_string_to_decimal128_impl(
+    array: &ArrayRef,
+    eval_mode: EvalMode,
+    precision: u8,
+    scale: i8,
+) -> SparkResult<ArrayRef> {
+    let string_array = array
+        .as_any()
+        .downcast_ref::<StringArray>()
+        .ok_or_else(|| SparkError::Internal("Expected string 
array".to_string()))?;
+
+    let mut decimal_builder = 
Decimal128Builder::with_capacity(string_array.len());
+
+    for i in 0..string_array.len() {
+        if string_array.is_null(i) {
+            decimal_builder.append_null();
+        } else {
+            let str_value = string_array.value(i);
+            match parse_string_to_decimal(str_value, precision, scale) {
+                Ok(Some(decimal_value)) => {
+                    decimal_builder.append_value(decimal_value);
+                }
+                Ok(None) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(invalid_value(
+                            string_array.value(i),
+                            "STRING",
+                            &format!("DECIMAL({},{})", precision, scale),
+                        ));
+                    }
+                    decimal_builder.append_null();
+                }
+                Err(e) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(e);
+                    }
+                    decimal_builder.append_null();
+                }
+            }
+        }
+    }
+
+    Ok(Arc::new(
+        decimal_builder
+            .with_precision_and_scale(precision, scale)?
+            .finish(),
+    ))
+}
+
+fn cast_string_to_decimal256_impl(
+    array: &ArrayRef,
+    eval_mode: EvalMode,
+    precision: u8,
+    scale: i8,
+) -> SparkResult<ArrayRef> {
+    let string_array = array
+        .as_any()
+        .downcast_ref::<StringArray>()
+        .ok_or_else(|| SparkError::Internal("Expected string 
array".to_string()))?;
+
+    let mut decimal_builder = 
PrimitiveBuilder::<Decimal256Type>::with_capacity(string_array.len());
+
+    for i in 0..string_array.len() {
+        if string_array.is_null(i) {
+            decimal_builder.append_null();
+        } else {
+            let str_value = string_array.value(i);
+            match parse_string_to_decimal(str_value, precision, scale) {
+                Ok(Some(decimal_value)) => {
+                    // Convert i128 to i256
+                    let i256_value = i256::from_i128(decimal_value);
+                    decimal_builder.append_value(i256_value);
+                }
+                Ok(None) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(invalid_value(
+                            str_value,
+                            "STRING",
+                            &format!("DECIMAL({},{})", precision, scale),
+                        ));
+                    }
+                    decimal_builder.append_null();
+                }
+                Err(e) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(e);
+                    }
+                    decimal_builder.append_null();
+                }
+            }
+        }
+    }
+
+    Ok(Arc::new(
+        decimal_builder
+            .with_precision_and_scale(precision, scale)?
+            .finish(),
+    ))
+}
+
+/// Parse a string to decimal following Spark's behavior
+fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> 
SparkResult<Option<i128>> {
+    let string_bytes = s.as_bytes();
+    let mut start = 0;
+    let mut end = string_bytes.len();
+
+    // trim whitespaces
+    while start < end && string_bytes[start].is_ascii_whitespace() {
+        start += 1;
+    }
+    while end > start && string_bytes[end - 1].is_ascii_whitespace() {
+        end -= 1;
+    }
+
+    let trimmed = &s[start..end];
+
+    if trimmed.is_empty() {
+        return Ok(None);
+    }
+    // Handle special values (inf, nan, etc.)
+    if trimmed.eq_ignore_ascii_case("inf")
+        || trimmed.eq_ignore_ascii_case("+inf")
+        || trimmed.eq_ignore_ascii_case("infinity")
+        || trimmed.eq_ignore_ascii_case("+infinity")
+        || trimmed.eq_ignore_ascii_case("-inf")
+        || trimmed.eq_ignore_ascii_case("-infinity")
+        || trimmed.eq_ignore_ascii_case("nan")
+    {
+        return Ok(None);
+    }
+
+    // validate and parse mantissa and exponent
+    match parse_decimal_str(trimmed) {
+        Ok((mantissa, exponent)) => {
+            // Convert to target scale
+            let target_scale = scale as i32;
+            let scale_adjustment = target_scale - exponent;
+
+            let scaled_value = if scale_adjustment >= 0 {
+                // Need to multiply (increase scale) but return None if scale 
is too high to fit i128
+                if scale_adjustment > 38 {
+                    return Ok(None);
+                }
+                mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
+            } else {
+                // Need to multiply (increase scale) but return None if scale 
is too high to fit i128
+                let abs_scale_adjustment = (-scale_adjustment) as u32;
+                if abs_scale_adjustment > 38 {
+                    return Ok(Some(0));
+                }
+
+                let divisor = 10_i128.pow(abs_scale_adjustment);
+                let quotient_opt = mantissa.checked_div(divisor);
+                // Check if divisor is 0
+                if quotient_opt.is_none() {
+                    return Ok(None);
+                }
+                let quotient = quotient_opt.unwrap();
+                let remainder = mantissa % divisor;
+
+                // Round half up: if abs(remainder) >= divisor/2, round away 
from zero
+                let half_divisor = divisor / 2;
+                let rounded = if remainder.abs() >= half_divisor {
+                    if mantissa >= 0 {
+                        quotient + 1
+                    } else {
+                        quotient - 1
+                    }
+                } else {
+                    quotient
+                };
+                Some(rounded)
+            };
+
+            match scaled_value {
+                Some(value) => {
+                    // Check if it fits target precision
+                    if is_validate_decimal_precision(value, precision) {
+                        Ok(Some(value))
+                    } else {
+                        Ok(None)
+                    }
+                }
+                None => {
+                    // Overflow while scaling
+                    Ok(None)
+                }
+            }
+        }
+        Err(_) => Ok(None),
+    }
+}
+
+/// Parse a decimal string into mantissa and scale
+/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
+fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
+    if s.is_empty() {
+        return Err("Empty string".to_string());
+    }
+
+    let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 
'E'].contains(&c)) {
+        let mantissa_part = &s[..e_pos];
+        let exponent_part = &s[e_pos + 1..];
+        // Parse exponent
+        let exp: i32 = exponent_part
+            .parse()
+            .map_err(|e| format!("Invalid exponent: {}", e))?;
+
+        (mantissa_part, exp)
+    } else {
+        (s, 0)
+    };
+
+    let negative = mantissa_str.starts_with('-');
+    let mantissa_str = if negative || mantissa_str.starts_with('+') {
+        &mantissa_str[1..]
+    } else {
+        mantissa_str
+    };
+
+    if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') {
+        return Err("Invalid sign format".to_string());
+    }
+
+    let (integral_part, fractional_part) = match mantissa_str.find('.') {
+        Some(dot_pos) => {
+            if mantissa_str[dot_pos + 1..].contains('.') {
+                return Err("Multiple decimal points".to_string());
+            }
+            (&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..])
+        }
+        None => (mantissa_str, ""),
+    };
+
+    if integral_part.is_empty() && fractional_part.is_empty() {
+        return Err("No digits found".to_string());
+    }
+
+    if !integral_part.is_empty() && !integral_part.bytes().all(|b| 
b.is_ascii_digit()) {
+        return Err("Invalid integral part".to_string());
+    }
+
+    if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| 
b.is_ascii_digit()) {
+        return Err("Invalid fractional part".to_string());
+    }
+
+    // Parse integral part
+    let integral_value: i128 = if integral_part.is_empty() {
+        // Empty integral part is valid (e.g., ".5" or "-.7e9")
+        0
+    } else {
+        integral_part
+            .parse()
+            .map_err(|_| "Invalid integral part".to_string())?
+    };
+
+    // Parse fractional part
+    let fractional_scale = fractional_part.len() as i32;
+    let fractional_value: i128 = if fractional_part.is_empty() {
+        0
+    } else {
+        fractional_part
+            .parse()
+            .map_err(|_| "Invalid fractional part".to_string())?
+    };
+
+    // Combine: value = integral * 10^fractional_scale + fractional
+    let mantissa = integral_value
+        .checked_mul(10_i128.pow(fractional_scale as u32))
+        .and_then(|v| v.checked_add(fractional_value))
+        .ok_or("Overflow in mantissa calculation")?;
+
+    let final_mantissa = if negative { -mantissa } else { mantissa };
+    // final scale = fractional_scale - exponent
+    // For example : "1.23E-5" has fractional_scale=2, exponent=-5, so scale = 
2 - (-5) = 7
+    let final_scale = fractional_scale - exponent;
+    Ok((final_mantissa, final_scale))
+}
+
 /// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on 
the evaluation mode
 #[inline]
 fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> 
SparkResult<Option<T>> {
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala 
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 98ce8ac44..14db7c278 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -192,9 +192,8 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
               "Does not support ANSI mode."))
       case _: DecimalType =>
         // https://github.com/apache/datafusion-comet/issues/325
-        Incompatible(
-          Some("Does not support inputs ending with 'd' or 'f'. Does not 
support 'inf'. " +
-            "Does not support ANSI mode. Returns 0.0 instead of null if input 
contains no digits"))
+        Incompatible(Some("""Does not support fullwidth unicode digits (e.g 
\\uFF10)
+            |or strings containing null bytes (e.g \\u0000)""".stripMargin))
       case DataTypes.DateType =>
         // https://github.com/apache/datafusion-comet/issues/327
         Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 90386a979..a7bd6febf 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -661,7 +661,6 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     // https://github.com/apache/datafusion-comet/issues/326
     castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), 
DataTypes.DoubleType)
   }
-
   test("cast StringType to DoubleType (partial support)") {
     withSQLConf(
       CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
@@ -673,21 +672,88 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+//  This is to pass the first `all cast combinations are covered`
   ignore("cast StringType to DecimalType(10,2)") {
-    // https://github.com/apache/datafusion-comet/issues/325
-    val values = gen.generateStrings(dataSize, numericPattern, 8).toDF("a")
-    castTest(values, DataTypes.createDecimalType(10, 2))
+    val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
+    castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false)
   }
 
-  test("cast StringType to DecimalType(10,2) (partial support)") {
-    withSQLConf(
-      CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
-      SQLConf.ANSI_ENABLED.key -> "false") {
-      val values = gen
-        .generateStrings(dataSize, "0123456789.", 8)
-        .filter(_.exists(_.isDigit))
-        .toDF("a")
-      castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false)
+  test("cast StringType to DecimalType(10,2) (does not support fullwidth 
unicode digits)") {
+    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
+      // TODO fix for Spark 4.0.0
+      assume(!isSpark40Plus)
+      val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
+      Seq(true, false).foreach(ansiEnabled =>
+        castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = 
ansiEnabled))
+    }
+  }
+
+  test("cast StringType to DecimalType(2,2)") {
+    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
+      // TODO fix for Spark 4.0.0
+      assume(!isSpark40Plus)
+      val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
+      Seq(true, false).foreach(ansiEnabled =>
+        castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = 
ansiEnabled))
+    }
+  }
+
+  test("cast StringType to DecimalType(38,10) high precision") {
+    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
+      // TODO fix for Spark 4.0.0
+      assume(!isSpark40Plus)
+      val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a")
+      Seq(true, false).foreach(ansiEnabled =>
+        castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = 
ansiEnabled))
+    }
+  }
+
+  test("cast StringType to DecimalType(10,2) basic values") {
+    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
+      // TODO fix for Spark 4.0.0
+      assume(!isSpark40Plus)
+      val values = Seq(
+        "123.45",
+        "-67.89",
+        "-67.89",
+        "-67.895",
+        "67.895",
+        "0.001",
+        "999.99",
+        "123.456",
+        "123.45D",
+        ".5",
+        "5.",
+        "+123.45",
+        "  123.45  ",
+        "inf",
+        "",
+        "abc",
+        null).toDF("a")
+      Seq(true, false).foreach(ansiEnabled =>
+        castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = 
ansiEnabled))
+    }
+  }
+
+  test("cast StringType to Decimal type scientific notation") {
+    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
+      // TODO fix for Spark 4.0.0
+      assume(!isSpark40Plus)
+      val values = Seq(
+        "1.23E-5",
+        "1.23e10",
+        "1.23E+10",
+        "-1.23e-5",
+        "1e5",
+        "1E-2",
+        "-1.5e3",
+        "1.23E0",
+        "0e0",
+        "1.23e",
+        "e5",
+        null).toDF("a")
+      Seq(true, false).foreach(ansiEnabled =>
+        castTest(values, DataTypes.createDecimalType(23, 8), testAnsi = 
ansiEnabled))
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to