This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 064cb472 feat: Implement Spark-compatible CAST from string to
timestamp types (#335)
064cb472 is described below
commit 064cb472a7f55b6b67f44c22d6a8110320b330e7
Author: Vipul Vaibhaw <[email protected]>
AuthorDate: Thu May 2 21:44:00 2024 +0530
feat: Implement Spark-compatible CAST from string to timestamp types (#335)
* casting str to timestamp
* fix format
* fixing failed tests, using char as pattern
* bug fixes
* hangling microsecond
* make format
* bug fixes and core refactor
* format code
* removing print statements
* clippy error
* enabling cast timestamp test case
* code refactor
* comet spark test case
* adding all the supported format in test
* fallback spark when timestamp not utc
* bug fix
* bug fix
* adding an explainer commit
* fix test case
* bug fix
* bug fix
* better error handling for unwrap in fn parse_str_to_time_only_timestamp
* remove unwrap from macro
* improving error handling
* adding tests for invalid inputs
* removed all unwraps from timestamp cast functions
* code format
---
core/src/execution/datafusion/expressions/cast.rs | 316 ++++++++++++++++++++-
.../org/apache/comet/serde/QueryPlanSerde.scala | 12 +-
.../scala/org/apache/comet/CometCastSuite.scala | 71 ++++-
3 files changed, 391 insertions(+), 8 deletions(-)
diff --git a/core/src/execution/datafusion/expressions/cast.rs
b/core/src/execution/datafusion/expressions/cast.rs
index f5839fd4..7560e0c2 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -25,6 +25,7 @@ use std::{
use crate::errors::{CometError, CometResult};
use arrow::{
compute::{cast_with_options, CastOptions},
+ datatypes::TimestampMicrosecondType,
record_batch::RecordBatch,
util::display::FormatOptions,
};
@@ -33,10 +34,12 @@ use arrow_array::{
Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait,
PrimitiveArray,
};
use arrow_schema::{DataType, Schema};
+use chrono::{TimeZone, Timelike};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
+use regex::Regex;
use crate::execution::datafusion::expressions::utils::{
array_with_timezone, down_cast_any_ref, spark_cast,
@@ -86,6 +89,24 @@ macro_rules! cast_utf8_to_int {
}};
}
+macro_rules! cast_utf8_to_timestamp {
+ ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
+ let len = $array.len();
+ let mut cast_array =
PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
+ for i in 0..len {
+ if $array.is_null(i) {
+ cast_array.append_null()
+ } else if let Ok(Some(cast_value)) =
$cast_method($array.value(i).trim(), $eval_mode) {
+ cast_array.append_value(cast_value);
+ } else {
+ cast_array.append_null()
+ }
+ }
+ let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
+ result
+ }};
+}
+
impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
@@ -125,6 +146,9 @@ impl Cast {
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array,
self.eval_mode)?
}
+ (DataType::Utf8, DataType::Timestamp(_, _)) => {
+ Self::cast_string_to_timestamp(&array, to_type,
self.eval_mode)?
+ }
(
DataType::Utf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64,
@@ -200,6 +224,30 @@ impl Cast {
Ok(cast_array)
}
+ fn cast_string_to_timestamp(
+ array: &ArrayRef,
+ to_type: &DataType,
+ eval_mode: EvalMode,
+ ) -> CometResult<ArrayRef> {
+ let string_array = array
+ .as_any()
+ .downcast_ref::<GenericStringArray<i32>>()
+ .expect("Expected a string array");
+
+ let cast_array: ArrayRef = match to_type {
+ DataType::Timestamp(_, _) => {
+ cast_utf8_to_timestamp!(
+ string_array,
+ eval_mode,
+ TimestampMicrosecondType,
+ timestamp_parser
+ )
+ }
+ _ => unreachable!("Invalid data type {:?} in cast from string",
to_type),
+ };
+ Ok(cast_array)
+ }
+
fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
@@ -510,9 +558,273 @@ impl PhysicalExpr for Cast {
}
}
+fn timestamp_parser(value: &str, eval_mode: EvalMode) ->
CometResult<Option<i64>> {
+ let value = value.trim();
+ if value.is_empty() {
+ return Ok(None);
+ }
+ // Define regex patterns and corresponding parsing functions
+ let patterns = &[
+ (
+ Regex::new(r"^\d{4}$").unwrap(),
+ parse_str_to_year_timestamp as fn(&str) ->
CometResult<Option<i64>>,
+ ),
+ (
+ Regex::new(r"^\d{4}-\d{2}$").unwrap(),
+ parse_str_to_month_timestamp,
+ ),
+ (
+ Regex::new(r"^\d{4}-\d{2}-\d{2}$").unwrap(),
+ parse_str_to_day_timestamp,
+ ),
+ (
+ Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
+ parse_str_to_hour_timestamp,
+ ),
+ (
+ Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
+ parse_str_to_minute_timestamp,
+ ),
+ (
+ Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
+ parse_str_to_second_timestamp,
+ ),
+ (
+
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
+ parse_str_to_microsecond_timestamp,
+ ),
+ (
+ Regex::new(r"^T\d{1,2}$").unwrap(),
+ parse_str_to_time_only_timestamp,
+ ),
+ ];
+
+ let mut timestamp = None;
+
+ // Iterate through patterns and try matching
+ for (pattern, parse_func) in patterns {
+ if pattern.is_match(value) {
+ timestamp = parse_func(value)?;
+ break;
+ }
+ }
+
+ if timestamp.is_none() {
+ if eval_mode == EvalMode::Ansi {
+ return Err(CometError::CastInvalidValue {
+ value: value.to_string(),
+ from_type: "STRING".to_string(),
+ to_type: "TIMESTAMP".to_string(),
+ });
+ } else {
+ return Ok(None);
+ }
+ }
+
+ match timestamp {
+ Some(ts) => Ok(Some(ts)),
+ None => Err(CometError::Internal(
+ "Failed to parse timestamp".to_string(),
+ )),
+ }
+}
+
+fn parse_ymd_timestamp(year: i32, month: u32, day: u32) ->
CometResult<Option<i64>> {
+ let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, 0, 0, 0);
+
+ // Check if datetime is not None
+ let utc_datetime = match datetime.single() {
+ Some(dt) => dt.with_timezone(&chrono::Utc),
+ None => {
+ return Err(CometError::Internal(
+ "Failed to parse timestamp".to_string(),
+ ));
+ }
+ };
+
+ Ok(Some(utc_datetime.timestamp_micros()))
+}
+
+fn parse_hms_timestamp(
+ year: i32,
+ month: u32,
+ day: u32,
+ hour: u32,
+ minute: u32,
+ second: u32,
+ microsecond: u32,
+) -> CometResult<Option<i64>> {
+ let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, hour,
minute, second);
+
+ // Check if datetime is not None
+ let utc_datetime = match datetime.single() {
+ Some(dt) => dt
+ .with_timezone(&chrono::Utc)
+ .with_nanosecond(microsecond * 1000),
+ None => {
+ return Err(CometError::Internal(
+ "Failed to parse timestamp".to_string(),
+ ));
+ }
+ };
+
+ let result = match utc_datetime {
+ Some(dt) => dt.timestamp_micros(),
+ None => {
+ return Err(CometError::Internal(
+ "Failed to parse timestamp".to_string(),
+ ));
+ }
+ };
+
+ Ok(Some(result))
+}
+
+fn get_timestamp_values(value: &str, timestamp_type: &str) ->
CometResult<Option<i64>> {
+ let values: Vec<_> = value
+ .split(|c| c == 'T' || c == '-' || c == ':' || c == '.')
+ .collect();
+ let year = values[0].parse::<i32>().unwrap_or_default();
+ let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
+ let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
+ let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
+ let minute = values.get(4).map_or(0, |m| m.parse::<u32>().unwrap_or(0));
+ let second = values.get(5).map_or(0, |s| s.parse::<u32>().unwrap_or(0));
+ let microsecond = values.get(6).map_or(0, |ms|
ms.parse::<u32>().unwrap_or(0));
+
+ match timestamp_type {
+ "year" => parse_ymd_timestamp(year, 1, 1),
+ "month" => parse_ymd_timestamp(year, month, 1),
+ "day" => parse_ymd_timestamp(year, month, day),
+ "hour" => parse_hms_timestamp(year, month, day, hour, 0, 0, 0),
+ "minute" => parse_hms_timestamp(year, month, day, hour, minute, 0, 0),
+ "second" => parse_hms_timestamp(year, month, day, hour, minute,
second, 0),
+ "microsecond" => parse_hms_timestamp(year, month, day, hour, minute,
second, microsecond),
+ _ => Err(CometError::CastInvalidValue {
+ value: value.to_string(),
+ from_type: "STRING".to_string(),
+ to_type: "TIMESTAMP".to_string(),
+ }),
+ }
+}
+
+fn parse_str_to_year_timestamp(value: &str) -> CometResult<Option<i64>> {
+ get_timestamp_values(value, "year")
+}
+
+fn parse_str_to_month_timestamp(value: &str) -> CometResult<Option<i64>> {
+ get_timestamp_values(value, "month")
+}
+
+fn parse_str_to_day_timestamp(value: &str) -> CometResult<Option<i64>> {
+ get_timestamp_values(value, "day")
+}
+
+fn parse_str_to_hour_timestamp(value: &str) -> CometResult<Option<i64>> {
+ get_timestamp_values(value, "hour")
+}
+
+fn parse_str_to_minute_timestamp(value: &str) -> CometResult<Option<i64>> {
+ get_timestamp_values(value, "minute")
+}
+
+fn parse_str_to_second_timestamp(value: &str) -> CometResult<Option<i64>> {
+ get_timestamp_values(value, "second")
+}
+
+fn parse_str_to_microsecond_timestamp(value: &str) -> CometResult<Option<i64>>
{
+ get_timestamp_values(value, "microsecond")
+}
+
+fn parse_str_to_time_only_timestamp(value: &str) -> CometResult<Option<i64>> {
+ let values: Vec<&str> = value.split('T').collect();
+ let time_values: Vec<u32> = values[1]
+ .split(':')
+ .map(|v| v.parse::<u32>().unwrap_or(0))
+ .collect();
+
+ let datetime = chrono::Utc::now();
+ let timestamp = datetime
+ .with_hour(time_values.first().copied().unwrap_or_default())
+ .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0)))
+ .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0)))
+ .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) *
1_000))
+ .map(|dt| dt.to_utc().timestamp_micros())
+ .unwrap_or_default();
+
+ Ok(Some(timestamp))
+}
+
#[cfg(test)]
-mod test {
- use super::{cast_string_to_i8, EvalMode};
+mod tests {
+ use super::*;
+ use arrow::datatypes::TimestampMicrosecondType;
+ use arrow_array::StringArray;
+ use arrow_schema::TimeUnit;
+
+ #[test]
+ fn timestamp_parser_test() {
+ // write for all formats
+ assert_eq!(
+ timestamp_parser("2020", EvalMode::Legacy).unwrap(),
+ Some(1577836800000000) // this is in milliseconds
+ );
+ assert_eq!(
+ timestamp_parser("2020-01", EvalMode::Legacy).unwrap(),
+ Some(1577836800000000)
+ );
+ assert_eq!(
+ timestamp_parser("2020-01-01", EvalMode::Legacy).unwrap(),
+ Some(1577836800000000)
+ );
+ assert_eq!(
+ timestamp_parser("2020-01-01T12", EvalMode::Legacy).unwrap(),
+ Some(1577880000000000)
+ );
+ assert_eq!(
+ timestamp_parser("2020-01-01T12:34", EvalMode::Legacy).unwrap(),
+ Some(1577882040000000)
+ );
+ assert_eq!(
+ timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy).unwrap(),
+ Some(1577882096000000)
+ );
+ assert_eq!(
+ timestamp_parser("2020-01-01T12:34:56.123456",
EvalMode::Legacy).unwrap(),
+ Some(1577882096123456)
+ );
+ // assert_eq!(
+ // timestamp_parser("T2", EvalMode::Legacy).unwrap(),
+ // Some(1714356000000000) // this value needs to change everyday.
+ // );
+ }
+
+ #[test]
+ fn test_cast_string_to_timestamp() {
+ let array: ArrayRef = Arc::new(StringArray::from(vec![
+ Some("2020-01-01T12:34:56.123456"),
+ Some("T2"),
+ ]));
+
+ let string_array = array
+ .as_any()
+ .downcast_ref::<GenericStringArray<i32>>()
+ .expect("Expected a string array");
+
+ let eval_mode = EvalMode::Legacy;
+ let result = cast_utf8_to_timestamp!(
+ &string_array,
+ eval_mode,
+ TimestampMicrosecondType,
+ timestamp_parser
+ );
+
+ assert_eq!(
+ result.data_type(),
+ &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
+ );
+ assert_eq!(result.len(), 2);
+ }
#[test]
fn test_cast_string_as_i8() {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 6eda0547..c07b2b3c 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -585,6 +585,15 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
// Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
evalMode.toString
}
+
+ val supportedTimezone = (child.dataType, dt) match {
+ case (DataTypes.StringType, DataTypes.TimestampType)
+ if !timeZoneId.contains("UTC") =>
+ withInfo(expr, s"Unsupported timezone ${timeZoneId} for
timestamp cast")
+ false
+ case _ => true
+ }
+
val supportedCast = (child.dataType, dt) match {
case (DataTypes.StringType, DataTypes.TimestampType)
if !CometConf.COMET_CAST_STRING_TO_TIMESTAMP.get() =>
@@ -593,7 +602,8 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
false
case _ => true
}
- if (supportedCast) {
+
+ if (supportedCast && supportedTimezone) {
castToProto(timeZoneId, dt, childExpr, evalModeStr)
} else {
// no need to call withInfo here since it was called when
determining
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 1bddedde..a31f4e68 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -528,14 +528,37 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
"spark.comet.cast.stringToTimestamp is disabled")
}
- ignore("cast StringType to TimestampType") {
- // https://github.com/apache/datafusion-comet/issues/328
- withSQLConf((CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key, "true")) {
- val values = Seq("2020-01-01T12:34:56.123456", "T2") ++
generateStrings(timestampPattern, 8)
- castTest(values.toDF("a"), DataTypes.TimestampType)
+ test("cast StringType to TimestampType") {
+ withSQLConf(
+ SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC",
+ CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key -> "true") {
+ val values = Seq(
+ "2020",
+ "2020-01",
+ "2020-01-01",
+ "2020-01-01T12",
+ "2020-01-01T12:34",
+ "2020-01-01T12:34:56",
+ "2020-01-01T12:34:56.123456",
+ "T2",
+ "-9?")
+ castTimestampTest(values.toDF("a"), DataTypes.TimestampType)
+ }
+
+ // test for invalid inputs
+ withSQLConf(
+ SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC",
+ CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key -> "true") {
+ val values = Seq("-9?", "1-", "0.5")
+ castTimestampTest(values.toDF("a"), DataTypes.TimestampType)
}
}
+ test("cast StringType to TimestampType with invalid timezone") {
+ val values = Seq("2020-01-01T12:34:56.123456", "T2")
+ castFallbackTestTimezone(values.toDF("a"), DataTypes.TimestampType,
"Unsupported timezone")
+ }
+
// CAST from DateType
ignore("cast DateType to BooleanType") {
@@ -763,6 +786,44 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ private def castFallbackTestTimezone(
+ input: DataFrame,
+ toType: DataType,
+ expectedMessage: String): Unit = {
+ withTempPath { dir =>
+ val data = roundtripParquet(input, dir).coalesce(1)
+ data.createOrReplaceTempView("t")
+
+ withSQLConf(
+ (SQLConf.ANSI_ENABLED.key, "false"),
+ (CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key, "true"),
+ (SQLConf.SESSION_LOCAL_TIMEZONE.key, "America/Los_Angeles")) {
+ val df = data.withColumn("converted", col("a").cast(toType))
+ df.collect()
+ val str =
+ new
ExtendedExplainInfo().generateExtendedInfo(df.queryExecution.executedPlan)
+ assert(str.contains(expectedMessage))
+ }
+ }
+ }
+
+ private def castTimestampTest(input: DataFrame, toType: DataType) = {
+ withTempPath { dir =>
+ val data = roundtripParquet(input, dir).coalesce(1)
+ data.createOrReplaceTempView("t")
+
+ withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
+ // cast() should return null for invalid inputs when ansi mode is
disabled
+ val df = data.withColumn("converted", col("a").cast(toType))
+ checkSparkAnswer(df)
+
+ // try_cast() should always return null for invalid inputs
+ val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t")
+ checkSparkAnswer(df2)
+ }
+ }
+ }
+
private def castTest(input: DataFrame, toType: DataType): Unit = {
withTempPath { dir =>
val data = roundtripParquet(input, dir).coalesce(1)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]