This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 069681a16 perf: Improve performance of CAST from string to int (#3017)
069681a16 is described below
commit 069681a163f9bad8dea26795163bae1c81303005
Author: B Vadlamani <[email protected]>
AuthorDate: Tue Jan 6 11:40:05 2026 -0800
perf: Improve performance of CAST from string to int (#3017)
---
native/spark-expr/src/conversion_funcs/cast.rs | 286 +++++++++++++++++++------
1 file changed, 217 insertions(+), 69 deletions(-)
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs
b/native/spark-expr/src/conversion_funcs/cast.rs
index a2e12168d..314beb18c 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -54,8 +54,8 @@ use datafusion::common::{
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::ColumnarValue;
use num::{
- cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub,
Integer, Num,
- ToPrimitive, Zero,
+ cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub,
Integer, ToPrimitive,
+ Zero,
};
use regex::Regex;
use std::str::FromStr;
@@ -389,13 +389,23 @@ macro_rules! cast_utf8_to_int {
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
let len = $array.len();
let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
- for i in 0..len {
- if $array.is_null(i) {
- cast_array.append_null()
- } else if let Some(cast_value) = $cast_method($array.value(i),
$eval_mode)? {
- cast_array.append_value(cast_value);
- } else {
- cast_array.append_null()
+ if $array.null_count() == 0 {
+ for i in 0..len {
+ if let Some(cast_value) = $cast_method($array.value(i),
$eval_mode)? {
+ cast_array.append_value(cast_value);
+ } else {
+ cast_array.append_null()
+ }
+ }
+ } else {
+ for i in 0..len {
+ if $array.is_null(i) {
+ cast_array.append_null()
+ } else if let Some(cast_value) = $cast_method($array.value(i),
$eval_mode)? {
+ cast_array.append_value(cast_value);
+ } else {
+ cast_array.append_null()
+ }
}
}
let result: SparkResult<ArrayRef> = Ok(Arc::new(cast_array.finish())
as ArrayRef);
@@ -1999,82 +2009,216 @@ fn cast_string_to_int_with_range_check(
}
}
+// Returns (start, end) indices after trimming whitespace
+fn trim_whitespace(bytes: &[u8]) -> (usize, usize) {
+ let mut start = 0;
+ let mut end = bytes.len();
+
+ while start < end && bytes[start].is_ascii_whitespace() {
+ start += 1;
+ }
+ while end > start && bytes[end - 1].is_ascii_whitespace() {
+ end -= 1;
+ }
+
+ (start, end)
+}
+
+// Parses sign and returns (is_negative, start_idx after sign)
+// Returns None if invalid (e.g., just "+" or "-")
+fn parse_sign(trimmed_bytes: &[u8]) -> Option<(bool, usize)> {
+ let len = trimmed_bytes.len();
+ if len == 0 {
+ return None;
+ }
+
+ let first_char = trimmed_bytes[0];
+ let negative = first_char == b'-';
+
+ if negative || first_char == b'+' {
+ if len == 1 {
+ return None;
+ }
+ Some((negative, 1))
+ } else {
+ Some((false, 0))
+ }
+}
+
/// Equivalent to
/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper,
boolean allowDecimal)
/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper,
boolean allowDecimal)
-fn do_cast_string_to_int<
- T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
->(
+fn do_parse_string_to_int_legacy<T: Integer + CheckedSub + CheckedNeg +
From<u8> + Copy>(
str: &str,
- eval_mode: EvalMode,
- type_name: &str,
min_value: T,
) -> SparkResult<Option<T>> {
- let trimmed_str = str.trim();
- if trimmed_str.is_empty() {
- return none_or_err(eval_mode, type_name, str);
+ let bytes = str.as_bytes();
+ let (start, end) = trim_whitespace(bytes);
+
+ if start == end {
+ return Ok(None);
}
- let len = trimmed_str.len();
+ let trimmed_bytes = &bytes[start..end];
+
+ let (negative, idx) = match parse_sign(trimmed_bytes) {
+ Some(result) => result,
+ None => return Ok(None),
+ };
+
let mut result: T = T::zero();
- let mut negative = false;
- let radix = T::from(10);
+
+ let radix = T::from(10_u8);
let stop_value = min_value / radix;
let mut parse_sign_and_digits = true;
- for (i, ch) in trimmed_str.char_indices() {
+ for &ch in &trimmed_bytes[idx..] {
if parse_sign_and_digits {
- if i == 0 {
- negative = ch == '-';
- let positive = ch == '+';
- if negative || positive {
- if i + 1 == len {
- // input string is just "+" or "-"
- return none_or_err(eval_mode, type_name, str);
- }
- // consume this char
- continue;
- }
+ if ch == b'.' {
+ // truncate decimal in legacy mode
+ parse_sign_and_digits = false;
+ continue;
}
- if ch == '.' {
- if eval_mode == EvalMode::Legacy {
- // truncate decimal in legacy mode
- parse_sign_and_digits = false;
- continue;
- } else {
- return none_or_err(eval_mode, type_name, str);
- }
+ if !ch.is_ascii_digit() {
+ return Ok(None);
}
- let digit = if ch.is_ascii_digit() {
- (ch as u32) - ('0' as u32)
- } else {
- return none_or_err(eval_mode, type_name, str);
- };
+ let digit: T = T::from(ch - b'0');
- // We are going to process the new digit and accumulate the
result. However, before
- // doing this, if the result is already smaller than the
- // stopValue(Integer.MIN_VALUE / radix), then result * 10 will
definitely be
- // smaller than minValue, and we can stop
if result < stop_value {
- return none_or_err(eval_mode, type_name, str);
+ return Ok(None);
}
-
- // Since the previous result is greater than or equal to
stopValue(Integer.MIN_VALUE /
- // radix), we can just use `result > 0` to check overflow. If
result
- // overflows, we should stop
let v = result * radix;
- let digit = (digit as i32).into();
match v.checked_sub(&digit) {
Some(x) if x <= T::zero() => result = x,
_ => {
- return none_or_err(eval_mode, type_name, str);
+ return Ok(None);
}
}
} else {
- // make sure fractional digits are valid digits but ignore them
+ // in legacy mode we still process chars after the dot and make
sure the chars are digits
if !ch.is_ascii_digit() {
- return none_or_err(eval_mode, type_name, str);
+ return Ok(None);
+ }
+ }
+ }
+
+ if !negative {
+ if let Some(neg) = result.checked_neg() {
+ if neg < T::zero() {
+ return Ok(None);
+ }
+ result = neg;
+ } else {
+ return Ok(None);
+ }
+ }
+
+ Ok(Some(result))
+}
+
+fn do_parse_string_to_int_ansi<T: Integer + CheckedSub + CheckedNeg + From<u8>
+ Copy>(
+ str: &str,
+ type_name: &str,
+ min_value: T,
+) -> SparkResult<Option<T>> {
+ let bytes = str.as_bytes();
+ let (start, end) = trim_whitespace(bytes);
+
+ if start == end {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+ let trimmed_bytes = &bytes[start..end];
+
+ let (negative, idx) = match parse_sign(trimmed_bytes) {
+ Some(result) => result,
+ None => return Err(invalid_value(str, "STRING", type_name)),
+ };
+
+ let mut result: T = T::zero();
+
+ let radix = T::from(10_u8);
+ let stop_value = min_value / radix;
+
+ for &ch in &trimmed_bytes[idx..] {
+ if ch == b'.' {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+
+ if !ch.is_ascii_digit() {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+
+ let digit: T = T::from(ch - b'0');
+
+ if result < stop_value {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+ let v = result * radix;
+ match v.checked_sub(&digit) {
+ Some(x) if x <= T::zero() => result = x,
+ _ => {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+ }
+ }
+
+ if !negative {
+ if let Some(neg) = result.checked_neg() {
+ if neg < T::zero() {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+ result = neg;
+ } else {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+ }
+
+ Ok(Some(result))
+}
+
+fn do_parse_string_to_int_try<T: Integer + CheckedSub + CheckedNeg + From<u8>
+ Copy>(
+ str: &str,
+ min_value: T,
+) -> SparkResult<Option<T>> {
+ let bytes = str.as_bytes();
+ let (start, end) = trim_whitespace(bytes);
+
+ if start == end {
+ return Ok(None);
+ }
+ let trimmed_bytes = &bytes[start..end];
+
+ let (negative, idx) = match parse_sign(trimmed_bytes) {
+ Some(result) => result,
+ None => return Ok(None),
+ };
+
+ let mut result: T = T::zero();
+
+ let radix = T::from(10_u8);
+ let stop_value = min_value / radix;
+
+ // we don't have to go beyond decimal point in try eval mode - early
return NULL
+ for &ch in &trimmed_bytes[idx..] {
+ if ch == b'.' {
+ return Ok(None);
+ }
+
+ if !ch.is_ascii_digit() {
+ return Ok(None);
+ }
+
+ let digit: T = T::from(ch - b'0');
+
+ if result < stop_value {
+ return Ok(None);
+ }
+ let v = result * radix;
+ match v.checked_sub(&digit) {
+ Some(x) if x <= T::zero() => result = x,
+ _ => {
+ return Ok(None);
}
}
}
@@ -2082,17 +2226,30 @@ fn do_cast_string_to_int<
if !negative {
if let Some(neg) = result.checked_neg() {
if neg < T::zero() {
- return none_or_err(eval_mode, type_name, str);
+ return Ok(None);
}
result = neg;
} else {
- return none_or_err(eval_mode, type_name, str);
+ return Ok(None);
}
}
Ok(Some(result))
}
+fn do_cast_string_to_int<T: Integer + CheckedSub + CheckedNeg + From<u8> +
Copy>(
+ str: &str,
+ eval_mode: EvalMode,
+ type_name: &str,
+ min_value: T,
+) -> SparkResult<Option<T>> {
+ match eval_mode {
+ EvalMode::Legacy => do_parse_string_to_int_legacy(str, min_value),
+ EvalMode::Ansi => do_parse_string_to_int_ansi(str, type_name,
min_value),
+ EvalMode::Try => do_parse_string_to_int_try(str, min_value),
+ }
+}
+
fn cast_string_to_decimal(
array: &ArrayRef,
to_type: &DataType,
@@ -2393,15 +2550,6 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32),
String> {
Ok((final_mantissa, final_scale))
}
-/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on
the evaluation mode
-#[inline]
-fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) ->
SparkResult<Option<T>> {
- match eval_mode {
- EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
- _ => Ok(None),
- }
-}
-
#[inline]
fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError {
SparkError::CastInvalidValue {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]