This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 069681a16 perf: Improve performance of CAST from string to int (#3017)
069681a16 is described below

commit 069681a163f9bad8dea26795163bae1c81303005
Author: B Vadlamani <[email protected]>
AuthorDate: Tue Jan 6 11:40:05 2026 -0800

    perf: Improve performance of CAST from string to int (#3017)
---
 native/spark-expr/src/conversion_funcs/cast.rs | 286 +++++++++++++++++++------
 1 file changed, 217 insertions(+), 69 deletions(-)

diff --git a/native/spark-expr/src/conversion_funcs/cast.rs 
b/native/spark-expr/src/conversion_funcs/cast.rs
index a2e12168d..314beb18c 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -54,8 +54,8 @@ use datafusion::common::{
 use datafusion::physical_expr::PhysicalExpr;
 use datafusion::physical_plan::ColumnarValue;
 use num::{
-    cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, 
Integer, Num,
-    ToPrimitive, Zero,
+    cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, 
Integer, ToPrimitive,
+    Zero,
 };
 use regex::Regex;
 use std::str::FromStr;
@@ -389,13 +389,23 @@ macro_rules! cast_utf8_to_int {
     ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
         let len = $array.len();
         let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
-        for i in 0..len {
-            if $array.is_null(i) {
-                cast_array.append_null()
-            } else if let Some(cast_value) = $cast_method($array.value(i), 
$eval_mode)? {
-                cast_array.append_value(cast_value);
-            } else {
-                cast_array.append_null()
+        if $array.null_count() == 0 {
+            for i in 0..len {
+                if let Some(cast_value) = $cast_method($array.value(i), 
$eval_mode)? {
+                    cast_array.append_value(cast_value);
+                } else {
+                    cast_array.append_null()
+                }
+            }
+        } else {
+            for i in 0..len {
+                if $array.is_null(i) {
+                    cast_array.append_null()
+                } else if let Some(cast_value) = $cast_method($array.value(i), 
$eval_mode)? {
+                    cast_array.append_value(cast_value);
+                } else {
+                    cast_array.append_null()
+                }
             }
         }
         let result: SparkResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) 
as ArrayRef);
@@ -1999,82 +2009,216 @@ fn cast_string_to_int_with_range_check(
     }
 }
 
+// Returns (start, end) indices after trimming whitespace
+fn trim_whitespace(bytes: &[u8]) -> (usize, usize) {
+    let mut start = 0;
+    let mut end = bytes.len();
+
+    while start < end && bytes[start].is_ascii_whitespace() {
+        start += 1;
+    }
+    while end > start && bytes[end - 1].is_ascii_whitespace() {
+        end -= 1;
+    }
+
+    (start, end)
+}
+
+// Parses sign and returns (is_negative, start_idx after sign)
+// Returns None if invalid (e.g., just "+" or "-")
+fn parse_sign(trimmed_bytes: &[u8]) -> Option<(bool, usize)> {
+    let len = trimmed_bytes.len();
+    if len == 0 {
+        return None;
+    }
+
+    let first_char = trimmed_bytes[0];
+    let negative = first_char == b'-';
+
+    if negative || first_char == b'+' {
+        if len == 1 {
+            return None;
+        }
+        Some((negative, 1))
+    } else {
+        Some((false, 0))
+    }
+}
+
 /// Equivalent to
 /// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, 
boolean allowDecimal)
 /// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, 
boolean allowDecimal)
-fn do_cast_string_to_int<
-    T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
->(
+fn do_parse_string_to_int_legacy<T: Integer + CheckedSub + CheckedNeg + 
From<u8> + Copy>(
     str: &str,
-    eval_mode: EvalMode,
-    type_name: &str,
     min_value: T,
 ) -> SparkResult<Option<T>> {
-    let trimmed_str = str.trim();
-    if trimmed_str.is_empty() {
-        return none_or_err(eval_mode, type_name, str);
+    let bytes = str.as_bytes();
+    let (start, end) = trim_whitespace(bytes);
+
+    if start == end {
+        return Ok(None);
     }
-    let len = trimmed_str.len();
+    let trimmed_bytes = &bytes[start..end];
+
+    let (negative, idx) = match parse_sign(trimmed_bytes) {
+        Some(result) => result,
+        None => return Ok(None),
+    };
+
     let mut result: T = T::zero();
-    let mut negative = false;
-    let radix = T::from(10);
+
+    let radix = T::from(10_u8);
     let stop_value = min_value / radix;
     let mut parse_sign_and_digits = true;
 
-    for (i, ch) in trimmed_str.char_indices() {
+    for &ch in &trimmed_bytes[idx..] {
         if parse_sign_and_digits {
-            if i == 0 {
-                negative = ch == '-';
-                let positive = ch == '+';
-                if negative || positive {
-                    if i + 1 == len {
-                        // input string is just "+" or "-"
-                        return none_or_err(eval_mode, type_name, str);
-                    }
-                    // consume this char
-                    continue;
-                }
+            if ch == b'.' {
+                // truncate decimal in legacy mode
+                parse_sign_and_digits = false;
+                continue;
             }
 
-            if ch == '.' {
-                if eval_mode == EvalMode::Legacy {
-                    // truncate decimal in legacy mode
-                    parse_sign_and_digits = false;
-                    continue;
-                } else {
-                    return none_or_err(eval_mode, type_name, str);
-                }
+            if !ch.is_ascii_digit() {
+                return Ok(None);
             }
 
-            let digit = if ch.is_ascii_digit() {
-                (ch as u32) - ('0' as u32)
-            } else {
-                return none_or_err(eval_mode, type_name, str);
-            };
+            let digit: T = T::from(ch - b'0');
 
-            // We are going to process the new digit and accumulate the 
result. However, before
-            // doing this, if the result is already smaller than the
-            // stopValue(Integer.MIN_VALUE / radix), then result * 10 will 
definitely be
-            // smaller than minValue, and we can stop
             if result < stop_value {
-                return none_or_err(eval_mode, type_name, str);
+                return Ok(None);
             }
-
-            // Since the previous result is greater than or equal to 
stopValue(Integer.MIN_VALUE /
-            // radix), we can just use `result > 0` to check overflow. If 
result
-            // overflows, we should stop
             let v = result * radix;
-            let digit = (digit as i32).into();
             match v.checked_sub(&digit) {
                 Some(x) if x <= T::zero() => result = x,
                 _ => {
-                    return none_or_err(eval_mode, type_name, str);
+                    return Ok(None);
                 }
             }
         } else {
-            // make sure fractional digits are valid digits but ignore them
+            // in legacy mode we still process chars after the dot and make 
sure the chars are digits
             if !ch.is_ascii_digit() {
-                return none_or_err(eval_mode, type_name, str);
+                return Ok(None);
+            }
+        }
+    }
+
+    if !negative {
+        if let Some(neg) = result.checked_neg() {
+            if neg < T::zero() {
+                return Ok(None);
+            }
+            result = neg;
+        } else {
+            return Ok(None);
+        }
+    }
+
+    Ok(Some(result))
+}
+
+fn do_parse_string_to_int_ansi<T: Integer + CheckedSub + CheckedNeg + From<u8> 
+ Copy>(
+    str: &str,
+    type_name: &str,
+    min_value: T,
+) -> SparkResult<Option<T>> {
+    let bytes = str.as_bytes();
+    let (start, end) = trim_whitespace(bytes);
+
+    if start == end {
+        return Err(invalid_value(str, "STRING", type_name));
+    }
+    let trimmed_bytes = &bytes[start..end];
+
+    let (negative, idx) = match parse_sign(trimmed_bytes) {
+        Some(result) => result,
+        None => return Err(invalid_value(str, "STRING", type_name)),
+    };
+
+    let mut result: T = T::zero();
+
+    let radix = T::from(10_u8);
+    let stop_value = min_value / radix;
+
+    for &ch in &trimmed_bytes[idx..] {
+        if ch == b'.' {
+            return Err(invalid_value(str, "STRING", type_name));
+        }
+
+        if !ch.is_ascii_digit() {
+            return Err(invalid_value(str, "STRING", type_name));
+        }
+
+        let digit: T = T::from(ch - b'0');
+
+        if result < stop_value {
+            return Err(invalid_value(str, "STRING", type_name));
+        }
+        let v = result * radix;
+        match v.checked_sub(&digit) {
+            Some(x) if x <= T::zero() => result = x,
+            _ => {
+                return Err(invalid_value(str, "STRING", type_name));
+            }
+        }
+    }
+
+    if !negative {
+        if let Some(neg) = result.checked_neg() {
+            if neg < T::zero() {
+                return Err(invalid_value(str, "STRING", type_name));
+            }
+            result = neg;
+        } else {
+            return Err(invalid_value(str, "STRING", type_name));
+        }
+    }
+
+    Ok(Some(result))
+}
+
+fn do_parse_string_to_int_try<T: Integer + CheckedSub + CheckedNeg + From<u8> 
+ Copy>(
+    str: &str,
+    min_value: T,
+) -> SparkResult<Option<T>> {
+    let bytes = str.as_bytes();
+    let (start, end) = trim_whitespace(bytes);
+
+    if start == end {
+        return Ok(None);
+    }
+    let trimmed_bytes = &bytes[start..end];
+
+    let (negative, idx) = match parse_sign(trimmed_bytes) {
+        Some(result) => result,
+        None => return Ok(None),
+    };
+
+    let mut result: T = T::zero();
+
+    let radix = T::from(10_u8);
+    let stop_value = min_value / radix;
+
+    // we don't have to go beyond decimal point in try eval mode - early 
return NULL
+    for &ch in &trimmed_bytes[idx..] {
+        if ch == b'.' {
+            return Ok(None);
+        }
+
+        if !ch.is_ascii_digit() {
+            return Ok(None);
+        }
+
+        let digit: T = T::from(ch - b'0');
+
+        if result < stop_value {
+            return Ok(None);
+        }
+        let v = result * radix;
+        match v.checked_sub(&digit) {
+            Some(x) if x <= T::zero() => result = x,
+            _ => {
+                return Ok(None);
             }
         }
     }
@@ -2082,17 +2226,30 @@ fn do_cast_string_to_int<
     if !negative {
         if let Some(neg) = result.checked_neg() {
             if neg < T::zero() {
-                return none_or_err(eval_mode, type_name, str);
+                return Ok(None);
             }
             result = neg;
         } else {
-            return none_or_err(eval_mode, type_name, str);
+            return Ok(None);
         }
     }
 
     Ok(Some(result))
 }
 
+fn do_cast_string_to_int<T: Integer + CheckedSub + CheckedNeg + From<u8> + 
Copy>(
+    str: &str,
+    eval_mode: EvalMode,
+    type_name: &str,
+    min_value: T,
+) -> SparkResult<Option<T>> {
+    match eval_mode {
+        EvalMode::Legacy => do_parse_string_to_int_legacy(str, min_value),
+        EvalMode::Ansi => do_parse_string_to_int_ansi(str, type_name, 
min_value),
+        EvalMode::Try => do_parse_string_to_int_try(str, min_value),
+    }
+}
+
 fn cast_string_to_decimal(
     array: &ArrayRef,
     to_type: &DataType,
@@ -2393,15 +2550,6 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), 
String> {
     Ok((final_mantissa, final_scale))
 }
 
-/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on 
the evaluation mode
-#[inline]
-fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> 
SparkResult<Option<T>> {
-    match eval_mode {
-        EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
-        _ => Ok(None),
-    }
-}
-
 #[inline]
 fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError {
     SparkError::CastInvalidValue {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to