parthchandra commented on code in PR #307:
URL: https://github.com/apache/datafusion-comet/pull/307#discussion_r1586606093


##########
core/src/execution/datafusion/expressions/cast.rs:
##########
@@ -142,6 +233,202 @@ impl Cast {
     }
 }
 
+/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte
+fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i8>> {
+    Ok(cast_string_to_int_with_range_check(
+        str,
+        eval_mode,
+        "TINYINT",
+        i8::MIN as i32,
+        i8::MAX as i32,
+    )?
+    .map(|v| v as i8))
+}
+
+/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort
+fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i16>> {
+    Ok(cast_string_to_int_with_range_check(
+        str,
+        eval_mode,
+        "SMALLINT",
+        i16::MIN as i32,
+        i16::MAX as i32,
+    )?
+    .map(|v| v as i16))
+}
+
+/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper 
intWrapper)
+fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i32>> {
+    do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
+}
+
+/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper 
intWrapper)
+fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i64>> {
+    do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
+}
+
+fn cast_string_to_int_with_range_check(
+    str: &str,
+    eval_mode: EvalMode,
+    type_name: &str,
+    min: i32,
+    max: i32,
+) -> CometResult<Option<i32>> {
+    match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
+        None => Ok(None),
+        Some(v) if v >= min && v <= max => Ok(Some(v)),
+        _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", 
type_name)),
+        _ => Ok(None),
+    }
+}
+
+#[derive(PartialEq)]
+enum State {
+    SkipLeadingWhiteSpace,
+    SkipTrailingWhiteSpace,
+    ParseSignAndDigits,
+    ParseFractionalDigits,
+}
+
+/// Equivalent to
+/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, 
boolean allowDecimal)
+/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, 
boolean allowDecimal)
+fn do_cast_string_to_int<
+    T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
+>(
+    str: &str,
+    eval_mode: EvalMode,
+    type_name: &str,
+    min_value: T,
+) -> CometResult<Option<T>> {
+    let len = str.len();
+    if str.is_empty() {
+        return none_or_err(eval_mode, type_name, str);
+    }
+
+    let mut result: T = T::zero();
+    let mut negative = false;
+    let radix = T::from(10);
+    let stop_value = min_value / radix;
+    let mut state = State::SkipLeadingWhiteSpace;
+    let mut parsed_sign = false;
+
+    for (i, ch) in str.char_indices() {
+        // skip leading whitespace
+        if state == State::SkipLeadingWhiteSpace {
+            if ch.is_whitespace() {
+                // consume this char
+                continue;
+            }
+            // change state and fall through to next section
+            state = State::ParseSignAndDigits;
+        }
+
+        if state == State::ParseSignAndDigits {
+            if !parsed_sign {
+                negative = ch == '-';
+                let positive = ch == '+';
+                parsed_sign = true;
+                if negative || positive {
+                    if i + 1 == len {
+                        // input string is just "+" or "-"
+                        return none_or_err(eval_mode, type_name, str);
+                    }
+                    // consume this char
+                    continue;
+                }
+            }
+
+            if ch == '.' {
+                if eval_mode == EvalMode::Legacy {
+                    // truncate decimal in legacy mode
+                    state = State::ParseFractionalDigits;
+                    continue;
+                } else {
+                    return none_or_err(eval_mode, type_name, str);
+                }
+            }
+
+            let digit = if ch.is_ascii_digit() {
+                (ch as u32) - ('0' as u32)
+            } else {
+                return none_or_err(eval_mode, type_name, str);
+            };
+
+            // We are going to process the new digit and accumulate the 
result. However, before

Review Comment:
   A comment to explain why we're using subtraction instead of addition would 
make it easier to understand this part of the code. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to