andygrove commented on code in PR #2925:
URL: https://github.com/apache/datafusion-comet/pull/2925#discussion_r2625347126


##########
native/spark-expr/src/conversion_funcs/cast.rs:
##########
@@ -1976,6 +1975,363 @@ fn do_cast_string_to_int<
     Ok(Some(result))
 }
 
+fn cast_string_to_decimal(
+    array: &ArrayRef,
+    to_type: &DataType,
+    precision: &u8,
+    scale: &i8,
+    eval_mode: EvalMode,
+) -> SparkResult<ArrayRef> {
+    match to_type {
+        DataType::Decimal128(_, _) => {
+            cast_string_to_decimal128_impl(array, eval_mode, *precision, 
*scale)
+        }
+        DataType::Decimal256(_, _) => {
+            cast_string_to_decimal256_impl(array, eval_mode, *precision, 
*scale)
+        }
+        _ => Err(SparkError::Internal(format!(
+            "Unexpected type in cast_string_to_decimal: {:?}",
+            to_type
+        ))),
+    }
+}
+
+fn cast_string_to_decimal128_impl(
+    array: &ArrayRef,
+    eval_mode: EvalMode,
+    precision: u8,
+    scale: i8,
+) -> SparkResult<ArrayRef> {
+    let string_array = array
+        .as_any()
+        .downcast_ref::<StringArray>()
+        .ok_or_else(|| SparkError::Internal("Expected string 
array".to_string()))?;
+
+    let mut decimal_builder = 
Decimal128Builder::with_capacity(string_array.len());
+
+    for i in 0..string_array.len() {
+        if string_array.is_null(i) {
+            decimal_builder.append_null();
+        } else {
+            let str_value = string_array.value(i).trim();
+            match parse_string_to_decimal(str_value, precision, scale) {
+                Ok(Some(decimal_value)) => {
+                    decimal_builder.append_value(decimal_value);
+                }
+                Ok(None) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(invalid_value(
+                            string_array.value(i),
+                            "STRING",
+                            &format!("DECIMAL({},{})", precision, scale),
+                        ));
+                    }
+                    decimal_builder.append_null();
+                }
+                Err(e) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(e);
+                    }
+                    decimal_builder.append_null();
+                }
+            }
+        }
+    }
+
+    Ok(Arc::new(
+        decimal_builder
+            .with_precision_and_scale(precision, scale)?
+            .finish(),
+    ))
+}
+
+fn cast_string_to_decimal256_impl(
+    array: &ArrayRef,
+    eval_mode: EvalMode,
+    precision: u8,
+    scale: i8,
+) -> SparkResult<ArrayRef> {
+    let string_array = array
+        .as_any()
+        .downcast_ref::<StringArray>()
+        .ok_or_else(|| SparkError::Internal("Expected string 
array".to_string()))?;
+
+    let mut decimal_builder = 
PrimitiveBuilder::<Decimal256Type>::with_capacity(string_array.len());
+
+    for i in 0..string_array.len() {
+        if string_array.is_null(i) {
+            decimal_builder.append_null();
+        } else {
+            let str_value = string_array.value(i).trim();
+            match parse_string_to_decimal(str_value, precision, scale) {
+                Ok(Some(decimal_value)) => {
+                    // Convert i128 to i256
+                    let i256_value = i256::from_i128(decimal_value);
+                    decimal_builder.append_value(i256_value);
+                }
+                Ok(None) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(invalid_value(
+                            str_value,
+                            "STRING",
+                            &format!("DECIMAL({},{})", precision, scale),
+                        ));
+                    }
+                    decimal_builder.append_null();
+                }
+                Err(e) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(e);
+                    }
+                    decimal_builder.append_null();
+                }
+            }
+        }
+    }
+
+    Ok(Arc::new(
+        decimal_builder
+            .with_precision_and_scale(precision, scale)?
+            .finish(),
+    ))
+}
+
+/// Validates if a string is a valid decimal similar to BigDecimal
+fn is_valid_decimal_format(s: &str) -> bool {
+    if s.is_empty() {
+        return false;
+    }
+
+    let bytes = s.as_bytes();
+    let mut idx = 0;
+    let len = bytes.len();
+
+    // Skip leading +/- signs
+    if bytes[idx] == b'+' || bytes[idx] == b'-' {
+        idx += 1;
+        if idx >= len {
+            // Sign only. Fail early
+            return false;
+        }
+    }
+
+    // Check invalid cases like "++", "+-"
+    if bytes[idx] == b'+' || bytes[idx] == b'-' {
+        return false;
+    }
+
+    // Now we need at least one digit either before or after a decimal point
+    let mut has_digit = false;
+    let mut is_decimal_point_seen = false;
+
+    while idx < len {
+        let ch = bytes[idx];
+
+        if ch.is_ascii_digit() {
+            has_digit = true;
+            idx += 1;
+        } else if ch == b'.' {
+            if is_decimal_point_seen {
+                // Multiple decimal points or decimal after exponent
+                return false;
+            }
+            is_decimal_point_seen = true;
+            idx += 1;
+        } else if ch.eq_ignore_ascii_case(&b'e') {
+            if !has_digit {
+                // Exponent without any digits before it
+                return false;
+            }
+            idx += 1;
+            // Exponent part must have optional sign followed by atleast a 
digit
+            if idx >= len {
+                return false;
+            }
+
+            if bytes[idx] == b'+' || bytes[idx] == b'-' {
+                idx += 1;
+                if idx >= len {
+                    return false;
+                }
+            }
+
+            // Must have at least one digit in exponent
+            if !bytes[idx].is_ascii_digit() {
+                return false;
+            }
+
+            // Rest all should only be digits
+            while idx < len {
+                if !bytes[idx].is_ascii_digit() {
+                    return false;
+                }
+                idx += 1;
+            }
+            break;
+        } else {
+            // Invalid character found. Fail fast
+            return false;
+        }
+    }
+    has_digit
+}
+
+/// Parse a string to decimal following Spark's behavior
+fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> 
SparkResult<Option<i128>> {
+    if s.is_empty() {
+        return Ok(None);
+    }
+    // Handle special values (inf, nan, etc.)
+    if s.eq_ignore_ascii_case("inf")
+        || s.eq_ignore_ascii_case("+inf")
+        || s.eq_ignore_ascii_case("infinity")
+        || s.eq_ignore_ascii_case("+infinity")
+        || s.eq_ignore_ascii_case("-inf")
+        || s.eq_ignore_ascii_case("-infinity")
+        || s.eq_ignore_ascii_case("nan")
+    {
+        return Ok(None);
+    }
+
+    if !is_valid_decimal_format(s) {
+        return Ok(None);
+    }
+
+    match parse_decimal_str(s) {
+        Ok((mantissa, exponent)) => {
+            // Convert to target scale
+            let target_scale = scale as i32;
+            let scale_adjustment = target_scale - exponent;
+
+            let scaled_value = if scale_adjustment >= 0 {
+                // Need to multiply (increase scale) but return None if scale 
is too high to fit i128
+                if scale_adjustment > 38 {
+                    return Ok(None);
+                }
+                mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
+            } else {
+                // Need to multiply (increase scale) but return None if scale 
is too high to fit i128
+                let abs_scale_adjustment = (-scale_adjustment) as u32;
+                if abs_scale_adjustment > 38 {
+                    return Ok(Some(0));
+                }
+
+                let divisor = 10_i128.pow(abs_scale_adjustment);
+                let quotient_opt = mantissa.checked_div(divisor);
+                // Check if divisor is 0
+                if quotient_opt.is_none() {
+                    return Ok(None);
+                }
+                let quotient = quotient_opt.unwrap();
+                let remainder = mantissa % divisor;
+
+                // Round half up: if abs(remainder) >= divisor/2, round away 
from zero
+                let half_divisor = divisor / 2;
+                let rounded = if remainder.abs() >= half_divisor {
+                    if mantissa >= 0 {
+                        quotient + 1
+                    } else {
+                        quotient - 1
+                    }
+                } else {
+                    quotient
+                };
+                Some(rounded)
+            };
+
+            match scaled_value {
+                Some(value) => {
+                    // Check if it fits target precision
+                    if is_validate_decimal_precision(value, precision) {
+                        Ok(Some(value))
+                    } else {
+                        Ok(None)
+                    }
+                }
+                None => {
+                    // Overflow while scaling
+                    Ok(None)
+                }
+            }
+        }
+        Err(_) => Ok(None),
+    }
+}
+
+/// Parse a decimal string into mantissa and scale
+/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
+fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
+    let s = s.trim();
+    if s.is_empty() {
+        return Err("Empty string".to_string());
+    }
+
+    let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 
'E'].contains(&c)) {
+        let mantissa_part = &s[..e_pos];
+        let exponent_part = &s[e_pos + 1..];
+        // Parse exponent
+        let exp: i32 = exponent_part
+            .parse()
+            .map_err(|e| format!("Invalid exponent: {}", e))?;
+
+        (mantissa_part, exp)
+    } else {
+        (s, 0)
+    };
+
+    let negative = mantissa_str.starts_with('-');
+    let mantissa_str = if negative || mantissa_str.starts_with('+') {
+        &mantissa_str[1..]
+    } else {
+        mantissa_str
+    };
+
+    let split_by_dot: Vec<&str> = mantissa_str.split('.').collect();

Review Comment:
   Calling `split` seems expensive as well.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to