This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 064cb472 feat: Implement Spark-compatible CAST from string to 
timestamp types (#335)
064cb472 is described below

commit 064cb472a7f55b6b67f44c22d6a8110320b330e7
Author: Vipul Vaibhaw <[email protected]>
AuthorDate: Thu May 2 21:44:00 2024 +0530

    feat: Implement Spark-compatible CAST from string to timestamp types (#335)
    
    * casting str to timestamp
    
    * fix format
    
    * fixing failed tests, using char as pattern
    
    * bug fixes
    
    * hangling microsecond
    
    * make format
    
    * bug fixes and core refactor
    
    * format code
    
    * removing print statements
    
    * clippy error
    
    * enabling cast timestamp test case
    
    * code refactor
    
    * comet spark test case
    
    * adding all the supported format in test
    
    * fallback spark when timestamp not utc
    
    * bug fix
    
    * bug fix
    
    * adding an explainer commit
    
    * fix test case
    
    * bug fix
    
    * bug fix
    
    * better error handling for unwrap in fn parse_str_to_time_only_timestamp
    
    * remove unwrap from macro
    
    * improving error handling
    
    * adding tests for invalid inputs
    
    * removed all unwraps from timestamp cast functions
    
    * code format
---
 core/src/execution/datafusion/expressions/cast.rs  | 316 ++++++++++++++++++++-
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  12 +-
 .../scala/org/apache/comet/CometCastSuite.scala    |  71 ++++-
 3 files changed, 391 insertions(+), 8 deletions(-)

diff --git a/core/src/execution/datafusion/expressions/cast.rs 
b/core/src/execution/datafusion/expressions/cast.rs
index f5839fd4..7560e0c2 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -25,6 +25,7 @@ use std::{
 use crate::errors::{CometError, CometResult};
 use arrow::{
     compute::{cast_with_options, CastOptions},
+    datatypes::TimestampMicrosecondType,
     record_batch::RecordBatch,
     util::display::FormatOptions,
 };
@@ -33,10 +34,12 @@ use arrow_array::{
     Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, 
PrimitiveArray,
 };
 use arrow_schema::{DataType, Schema};
+use chrono::{TimeZone, Timelike};
 use datafusion::logical_expr::ColumnarValue;
 use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
 use datafusion_physical_expr::PhysicalExpr;
 use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
+use regex::Regex;
 
 use crate::execution::datafusion::expressions::utils::{
     array_with_timezone, down_cast_any_ref, spark_cast,
@@ -86,6 +89,24 @@ macro_rules! cast_utf8_to_int {
     }};
 }
 
+macro_rules! cast_utf8_to_timestamp {
+    ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
+        let len = $array.len();
+        let mut cast_array = 
PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
+        for i in 0..len {
+            if $array.is_null(i) {
+                cast_array.append_null()
+            } else if let Ok(Some(cast_value)) = 
$cast_method($array.value(i).trim(), $eval_mode) {
+                cast_array.append_value(cast_value);
+            } else {
+                cast_array.append_null()
+            }
+        }
+        let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
+        result
+    }};
+}
+
 impl Cast {
     pub fn new(
         child: Arc<dyn PhysicalExpr>,
@@ -125,6 +146,9 @@ impl Cast {
             (DataType::LargeUtf8, DataType::Boolean) => {
                 Self::spark_cast_utf8_to_boolean::<i64>(&array, 
self.eval_mode)?
             }
+            (DataType::Utf8, DataType::Timestamp(_, _)) => {
+                Self::cast_string_to_timestamp(&array, to_type, 
self.eval_mode)?
+            }
             (
                 DataType::Utf8,
                 DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64,
@@ -200,6 +224,30 @@ impl Cast {
         Ok(cast_array)
     }
 
+    fn cast_string_to_timestamp(
+        array: &ArrayRef,
+        to_type: &DataType,
+        eval_mode: EvalMode,
+    ) -> CometResult<ArrayRef> {
+        let string_array = array
+            .as_any()
+            .downcast_ref::<GenericStringArray<i32>>()
+            .expect("Expected a string array");
+
+        let cast_array: ArrayRef = match to_type {
+            DataType::Timestamp(_, _) => {
+                cast_utf8_to_timestamp!(
+                    string_array,
+                    eval_mode,
+                    TimestampMicrosecondType,
+                    timestamp_parser
+                )
+            }
+            _ => unreachable!("Invalid data type {:?} in cast from string", 
to_type),
+        };
+        Ok(cast_array)
+    }
+
     fn spark_cast_utf8_to_boolean<OffsetSize>(
         from: &dyn Array,
         eval_mode: EvalMode,
@@ -510,9 +558,273 @@ impl PhysicalExpr for Cast {
     }
 }
 
+fn timestamp_parser(value: &str, eval_mode: EvalMode) -> 
CometResult<Option<i64>> {
+    let value = value.trim();
+    if value.is_empty() {
+        return Ok(None);
+    }
+    // Define regex patterns and corresponding parsing functions
+    let patterns = &[
+        (
+            Regex::new(r"^\d{4}$").unwrap(),
+            parse_str_to_year_timestamp as fn(&str) -> 
CometResult<Option<i64>>,
+        ),
+        (
+            Regex::new(r"^\d{4}-\d{2}$").unwrap(),
+            parse_str_to_month_timestamp,
+        ),
+        (
+            Regex::new(r"^\d{4}-\d{2}-\d{2}$").unwrap(),
+            parse_str_to_day_timestamp,
+        ),
+        (
+            Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
+            parse_str_to_hour_timestamp,
+        ),
+        (
+            Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
+            parse_str_to_minute_timestamp,
+        ),
+        (
+            Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
+            parse_str_to_second_timestamp,
+        ),
+        (
+            
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
+            parse_str_to_microsecond_timestamp,
+        ),
+        (
+            Regex::new(r"^T\d{1,2}$").unwrap(),
+            parse_str_to_time_only_timestamp,
+        ),
+    ];
+
+    let mut timestamp = None;
+
+    // Iterate through patterns and try matching
+    for (pattern, parse_func) in patterns {
+        if pattern.is_match(value) {
+            timestamp = parse_func(value)?;
+            break;
+        }
+    }
+
+    if timestamp.is_none() {
+        if eval_mode == EvalMode::Ansi {
+            return Err(CometError::CastInvalidValue {
+                value: value.to_string(),
+                from_type: "STRING".to_string(),
+                to_type: "TIMESTAMP".to_string(),
+            });
+        } else {
+            return Ok(None);
+        }
+    }
+
+    match timestamp {
+        Some(ts) => Ok(Some(ts)),
+        None => Err(CometError::Internal(
+            "Failed to parse timestamp".to_string(),
+        )),
+    }
+}
+
+fn parse_ymd_timestamp(year: i32, month: u32, day: u32) -> 
CometResult<Option<i64>> {
+    let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, 0, 0, 0);
+
+    // Check if datetime is not None
+    let utc_datetime = match datetime.single() {
+        Some(dt) => dt.with_timezone(&chrono::Utc),
+        None => {
+            return Err(CometError::Internal(
+                "Failed to parse timestamp".to_string(),
+            ));
+        }
+    };
+
+    Ok(Some(utc_datetime.timestamp_micros()))
+}
+
+fn parse_hms_timestamp(
+    year: i32,
+    month: u32,
+    day: u32,
+    hour: u32,
+    minute: u32,
+    second: u32,
+    microsecond: u32,
+) -> CometResult<Option<i64>> {
+    let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, hour, 
minute, second);
+
+    // Check if datetime is not None
+    let utc_datetime = match datetime.single() {
+        Some(dt) => dt
+            .with_timezone(&chrono::Utc)
+            .with_nanosecond(microsecond * 1000),
+        None => {
+            return Err(CometError::Internal(
+                "Failed to parse timestamp".to_string(),
+            ));
+        }
+    };
+
+    let result = match utc_datetime {
+        Some(dt) => dt.timestamp_micros(),
+        None => {
+            return Err(CometError::Internal(
+                "Failed to parse timestamp".to_string(),
+            ));
+        }
+    };
+
+    Ok(Some(result))
+}
+
+fn get_timestamp_values(value: &str, timestamp_type: &str) -> 
CometResult<Option<i64>> {
+    let values: Vec<_> = value
+        .split(|c| c == 'T' || c == '-' || c == ':' || c == '.')
+        .collect();
+    let year = values[0].parse::<i32>().unwrap_or_default();
+    let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
+    let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
+    let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
+    let minute = values.get(4).map_or(0, |m| m.parse::<u32>().unwrap_or(0));
+    let second = values.get(5).map_or(0, |s| s.parse::<u32>().unwrap_or(0));
+    let microsecond = values.get(6).map_or(0, |ms| 
ms.parse::<u32>().unwrap_or(0));
+
+    match timestamp_type {
+        "year" => parse_ymd_timestamp(year, 1, 1),
+        "month" => parse_ymd_timestamp(year, month, 1),
+        "day" => parse_ymd_timestamp(year, month, day),
+        "hour" => parse_hms_timestamp(year, month, day, hour, 0, 0, 0),
+        "minute" => parse_hms_timestamp(year, month, day, hour, minute, 0, 0),
+        "second" => parse_hms_timestamp(year, month, day, hour, minute, 
second, 0),
+        "microsecond" => parse_hms_timestamp(year, month, day, hour, minute, 
second, microsecond),
+        _ => Err(CometError::CastInvalidValue {
+            value: value.to_string(),
+            from_type: "STRING".to_string(),
+            to_type: "TIMESTAMP".to_string(),
+        }),
+    }
+}
+
+fn parse_str_to_year_timestamp(value: &str) -> CometResult<Option<i64>> {
+    get_timestamp_values(value, "year")
+}
+
+fn parse_str_to_month_timestamp(value: &str) -> CometResult<Option<i64>> {
+    get_timestamp_values(value, "month")
+}
+
+fn parse_str_to_day_timestamp(value: &str) -> CometResult<Option<i64>> {
+    get_timestamp_values(value, "day")
+}
+
+fn parse_str_to_hour_timestamp(value: &str) -> CometResult<Option<i64>> {
+    get_timestamp_values(value, "hour")
+}
+
+fn parse_str_to_minute_timestamp(value: &str) -> CometResult<Option<i64>> {
+    get_timestamp_values(value, "minute")
+}
+
+fn parse_str_to_second_timestamp(value: &str) -> CometResult<Option<i64>> {
+    get_timestamp_values(value, "second")
+}
+
+fn parse_str_to_microsecond_timestamp(value: &str) -> CometResult<Option<i64>> 
{
+    get_timestamp_values(value, "microsecond")
+}
+
+fn parse_str_to_time_only_timestamp(value: &str) -> CometResult<Option<i64>> {
+    let values: Vec<&str> = value.split('T').collect();
+    let time_values: Vec<u32> = values[1]
+        .split(':')
+        .map(|v| v.parse::<u32>().unwrap_or(0))
+        .collect();
+
+    let datetime = chrono::Utc::now();
+    let timestamp = datetime
+        .with_hour(time_values.first().copied().unwrap_or_default())
+        .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0)))
+        .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0)))
+        .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 
1_000))
+        .map(|dt| dt.to_utc().timestamp_micros())
+        .unwrap_or_default();
+
+    Ok(Some(timestamp))
+}
+
 #[cfg(test)]
-mod test {
-    use super::{cast_string_to_i8, EvalMode};
+mod tests {
+    use super::*;
+    use arrow::datatypes::TimestampMicrosecondType;
+    use arrow_array::StringArray;
+    use arrow_schema::TimeUnit;
+
+    #[test]
+    fn timestamp_parser_test() {
+        // write for all formats
+        assert_eq!(
+            timestamp_parser("2020", EvalMode::Legacy).unwrap(),
+            Some(1577836800000000) // this is in milliseconds
+        );
+        assert_eq!(
+            timestamp_parser("2020-01", EvalMode::Legacy).unwrap(),
+            Some(1577836800000000)
+        );
+        assert_eq!(
+            timestamp_parser("2020-01-01", EvalMode::Legacy).unwrap(),
+            Some(1577836800000000)
+        );
+        assert_eq!(
+            timestamp_parser("2020-01-01T12", EvalMode::Legacy).unwrap(),
+            Some(1577880000000000)
+        );
+        assert_eq!(
+            timestamp_parser("2020-01-01T12:34", EvalMode::Legacy).unwrap(),
+            Some(1577882040000000)
+        );
+        assert_eq!(
+            timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy).unwrap(),
+            Some(1577882096000000)
+        );
+        assert_eq!(
+            timestamp_parser("2020-01-01T12:34:56.123456", 
EvalMode::Legacy).unwrap(),
+            Some(1577882096123456)
+        );
+        // assert_eq!(
+        //     timestamp_parser("T2",  EvalMode::Legacy).unwrap(),
+        //     Some(1714356000000000) // this value needs to change everyday.
+        // );
+    }
+
+    #[test]
+    fn test_cast_string_to_timestamp() {
+        let array: ArrayRef = Arc::new(StringArray::from(vec![
+            Some("2020-01-01T12:34:56.123456"),
+            Some("T2"),
+        ]));
+
+        let string_array = array
+            .as_any()
+            .downcast_ref::<GenericStringArray<i32>>()
+            .expect("Expected a string array");
+
+        let eval_mode = EvalMode::Legacy;
+        let result = cast_utf8_to_timestamp!(
+            &string_array,
+            eval_mode,
+            TimestampMicrosecondType,
+            timestamp_parser
+        );
+
+        assert_eq!(
+            result.data_type(),
+            &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
+        );
+        assert_eq!(result.len(), 2);
+    }
 
     #[test]
     fn test_cast_string_as_i8() {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 6eda0547..c07b2b3c 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -585,6 +585,15 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
               // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
               evalMode.toString
             }
+
+            val supportedTimezone = (child.dataType, dt) match {
+              case (DataTypes.StringType, DataTypes.TimestampType)
+                  if !timeZoneId.contains("UTC") =>
+                withInfo(expr, s"Unsupported timezone ${timeZoneId} for 
timestamp cast")
+                false
+              case _ => true
+            }
+
             val supportedCast = (child.dataType, dt) match {
               case (DataTypes.StringType, DataTypes.TimestampType)
                   if !CometConf.COMET_CAST_STRING_TO_TIMESTAMP.get() =>
@@ -593,7 +602,8 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
                 false
               case _ => true
             }
-            if (supportedCast) {
+
+            if (supportedCast && supportedTimezone) {
               castToProto(timeZoneId, dt, childExpr, evalModeStr)
             } else {
               // no need to call withInfo here since it was called when 
determining
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 1bddedde..a31f4e68 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -528,14 +528,37 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       "spark.comet.cast.stringToTimestamp is disabled")
   }
 
-  ignore("cast StringType to TimestampType") {
-    // https://github.com/apache/datafusion-comet/issues/328
-    withSQLConf((CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key, "true")) {
-      val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ 
generateStrings(timestampPattern, 8)
-      castTest(values.toDF("a"), DataTypes.TimestampType)
+  test("cast StringType to TimestampType") {
+    withSQLConf(
+      SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC",
+      CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key -> "true") {
+      val values = Seq(
+        "2020",
+        "2020-01",
+        "2020-01-01",
+        "2020-01-01T12",
+        "2020-01-01T12:34",
+        "2020-01-01T12:34:56",
+        "2020-01-01T12:34:56.123456",
+        "T2",
+        "-9?")
+      castTimestampTest(values.toDF("a"), DataTypes.TimestampType)
+    }
+
+    // test for invalid inputs
+    withSQLConf(
+      SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC",
+      CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key -> "true") {
+      val values = Seq("-9?", "1-", "0.5")
+      castTimestampTest(values.toDF("a"), DataTypes.TimestampType)
     }
   }
 
+  test("cast StringType to TimestampType with invalid timezone") {
+    val values = Seq("2020-01-01T12:34:56.123456", "T2")
+    castFallbackTestTimezone(values.toDF("a"), DataTypes.TimestampType, 
"Unsupported timezone")
+  }
+
   // CAST from DateType
 
   ignore("cast DateType to BooleanType") {
@@ -763,6 +786,44 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  private def castFallbackTestTimezone(
+      input: DataFrame,
+      toType: DataType,
+      expectedMessage: String): Unit = {
+    withTempPath { dir =>
+      val data = roundtripParquet(input, dir).coalesce(1)
+      data.createOrReplaceTempView("t")
+
+      withSQLConf(
+        (SQLConf.ANSI_ENABLED.key, "false"),
+        (CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key, "true"),
+        (SQLConf.SESSION_LOCAL_TIMEZONE.key, "America/Los_Angeles")) {
+        val df = data.withColumn("converted", col("a").cast(toType))
+        df.collect()
+        val str =
+          new 
ExtendedExplainInfo().generateExtendedInfo(df.queryExecution.executedPlan)
+        assert(str.contains(expectedMessage))
+      }
+    }
+  }
+
+  private def castTimestampTest(input: DataFrame, toType: DataType) = {
+    withTempPath { dir =>
+      val data = roundtripParquet(input, dir).coalesce(1)
+      data.createOrReplaceTempView("t")
+
+      withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
+        // cast() should return null for invalid inputs when ansi mode is 
disabled
+        val df = data.withColumn("converted", col("a").cast(toType))
+        checkSparkAnswer(df)
+
+        // try_cast() should always return null for invalid inputs
+        val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t")
+        checkSparkAnswer(df2)
+      }
+    }
+  }
+
   private def castTest(input: DataFrame, toType: DataType): Unit = {
     withTempPath { dir =>
       val data = roundtripParquet(input, dir).coalesce(1)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to