This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new bcd8af8e7 Add support for `DataType::Timestamp` casts in 
`unwrap_cast_in_comparison` optimizer pass (#4148)
bcd8af8e7 is described below

commit bcd8af8e7cfdcafa948340a7de9a2101837eeaaf
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Nov 12 06:16:43 2022 -0500

    Add support for `DataType::Timestamp` casts in `unwrap_cast_in_comparison` 
optimizer pass (#4148)
    
    * Add support for timestamp casts in unwrap_cast_in_comparison optimzier 
pass
    
    * correct comment in test
    
    * Update datafusion/optimizer/src/unwrap_cast_in_comparison.rs
---
 .../optimizer/src/unwrap_cast_in_comparison.rs     | 156 ++++++++++++++++++++-
 datafusion/optimizer/tests/integration-test.rs     |   3 +-
 2 files changed, 156 insertions(+), 3 deletions(-)

diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs 
b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 7ac91ae3c..28b085684 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -21,7 +21,7 @@
 use crate::utils::rewrite_preserving_name;
 use crate::{OptimizerConfig, OptimizerRule};
 use arrow::datatypes::{
-    DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
+    DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, 
MIN_DECIMAL_FOR_EACH_PRECISION,
 };
 use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, 
ScalarValue};
 use datafusion_expr::expr::{BinaryExpr, Cast};
@@ -288,6 +288,7 @@ fn is_support_data_type(data_type: &DataType) -> bool {
             | DataType::Int32
             | DataType::Int64
             | DataType::Decimal128(_, _)
+            | DataType::Timestamp(_, _)
     )
 }
 
@@ -306,6 +307,7 @@ fn try_cast_literal_to_type(
     }
     let mul = match target_type {
         DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 
=> 1_i128,
+        DataType::Timestamp(_, _) => 1_i128,
         DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
         other_type => {
             return Err(DataFusionError::Internal(format!(
@@ -319,6 +321,7 @@ fn try_cast_literal_to_type(
         DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
         DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
         DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
+        DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
         DataType::Decimal128(precision, _) => (
             // Different precision for decimal128 can store different range of 
value.
             // For example, the precision is 3, the max of value is `999` and 
the min
@@ -338,6 +341,10 @@ fn try_cast_literal_to_type(
         ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
         ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
         ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
+        ScalarValue::TimestampSecond(Some(v), _) => (*v as 
i128).checked_mul(mul),
+        ScalarValue::TimestampMillisecond(Some(v), _) => (*v as 
i128).checked_mul(mul),
+        ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as 
i128).checked_mul(mul),
+        ScalarValue::TimestampNanosecond(Some(v), _) => (*v as 
i128).checked_mul(mul),
         ScalarValue::Decimal128(Some(v), _, scale) => {
             let lit_scale_mul = 10_i128.pow(*scale as u32);
             if mul >= lit_scale_mul {
@@ -376,6 +383,18 @@ fn try_cast_literal_to_type(
                     DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
                     DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
                     DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
+                    DataType::Timestamp(TimeUnit::Second, tz) => {
+                        ScalarValue::TimestampSecond(Some(value as i64), 
tz.clone())
+                    }
+                    DataType::Timestamp(TimeUnit::Millisecond, tz) => {
+                        ScalarValue::TimestampMillisecond(Some(value as i64), 
tz.clone())
+                    }
+                    DataType::Timestamp(TimeUnit::Microsecond, tz) => {
+                        ScalarValue::TimestampMicrosecond(Some(value as i64), 
tz.clone())
+                    }
+                    DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
+                        ScalarValue::TimestampNanosecond(Some(value as i64), 
tz.clone())
+                    }
                     DataType::Decimal128(p, s) => {
                         ScalarValue::Decimal128(Some(value), *p, *s)
                     }
@@ -629,6 +648,18 @@ mod tests {
         assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
     }
 
+    #[test]
+    /// Basic integration test for unwrapping casts with different timezones
+    fn test_unwrap_cast_with_timestamp_nanos() {
+        let schema = expr_test_schema();
+        // cast(ts_nano as Timestamp(Nanosecond, UTC)) < 
1666612093000000000::Timestamp(Nanosecond, Utc))
+        let expr_lt = try_cast(col("ts_nano_none"), timestamp_nano_utc_type())
+            .lt(lit_timestamp_nano_utc(1666612093000000000));
+        let expected =
+            
col("ts_nano_none").lt(lit_timestamp_nano_none(1666612093000000000));
+        assert_eq!(optimize_test(expr_lt, &schema), expected);
+    }
+
     fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
         let mut expr_rewriter = UnwrapCastExprRewriter {
             schema: schema.clone(),
@@ -646,6 +677,8 @@ mod tests {
                     DFField::new(None, "c4", DataType::Decimal128(38, 37), 
false),
                     DFField::new(None, "c5", DataType::Float32, false),
                     DFField::new(None, "c6", DataType::UInt32, false),
+                    DFField::new(None, "ts_nano_none", 
timestamp_nano_none_type(), false),
+                    DFField::new(None, "ts_nano_utf", 
timestamp_nano_utc_type(), false),
                 ],
                 HashMap::new(),
             )
@@ -669,13 +702,32 @@ mod tests {
         lit(ScalarValue::Decimal128(Some(value), precision, scale))
     }
 
+    fn lit_timestamp_nano_none(ts: i64) -> Expr {
+        lit(ScalarValue::TimestampNanosecond(Some(ts), None))
+    }
+
+    fn lit_timestamp_nano_utc(ts: i64) -> Expr {
+        let utc = Some("+0:00".to_string());
+        lit(ScalarValue::TimestampNanosecond(Some(ts), utc))
+    }
+
     fn null_decimal(precision: u8, scale: u8) -> Expr {
         lit(ScalarValue::Decimal128(None, precision, scale))
     }
 
+    fn timestamp_nano_none_type() -> DataType {
+        DataType::Timestamp(TimeUnit::Nanosecond, None)
+    }
+
+    // this is the type that now() returns
+    fn timestamp_nano_utc_type() -> DataType {
+        let utc = Some("+0:00".to_string());
+        DataType::Timestamp(TimeUnit::Nanosecond, utc)
+    }
+
     #[test]
     fn test_try_cast_to_type_nulls() {
-        // test values that can be cast to/from all integer types
+        // test that nulls can be cast to/from all integer types
         let scalars = vec![
             ScalarValue::Int8(None),
             ScalarValue::Int16(None),
@@ -783,6 +835,106 @@ mod tests {
         );
     }
 
+    #[test]
+    fn test_try_cast_to_type_timestamps() {
+        for time_unit in [
+            TimeUnit::Second,
+            TimeUnit::Millisecond,
+            TimeUnit::Microsecond,
+            TimeUnit::Nanosecond,
+        ] {
+            let utc = Some("+0:00".to_string());
+            // No timezone, utc timezone
+            let (lit_tz_none, lit_tz_utc) = match time_unit {
+                TimeUnit::Second => (
+                    ScalarValue::TimestampSecond(Some(12345), None),
+                    ScalarValue::TimestampSecond(Some(12345), utc),
+                ),
+
+                TimeUnit::Millisecond => (
+                    ScalarValue::TimestampMillisecond(Some(12345), None),
+                    ScalarValue::TimestampMillisecond(Some(12345), utc),
+                ),
+
+                TimeUnit::Microsecond => (
+                    ScalarValue::TimestampMicrosecond(Some(12345), None),
+                    ScalarValue::TimestampMicrosecond(Some(12345), utc),
+                ),
+
+                TimeUnit::Nanosecond => (
+                    ScalarValue::TimestampNanosecond(Some(12345), None),
+                    ScalarValue::TimestampNanosecond(Some(12345), utc),
+                ),
+            };
+
+            // Datafusion ignores timezones for comparisons of ScalarValue
+            // so double check it here
+            assert_eq!(lit_tz_none, lit_tz_utc);
+
+            // e.g. DataType::Timestamp(_, None)
+            let dt_tz_none = lit_tz_none.get_datatype();
+
+            // e.g. DataType::Timestamp(_, Some(utc))
+            let dt_tz_utc = lit_tz_utc.get_datatype();
+
+            // None <--> None
+            expect_cast(
+                lit_tz_none.clone(),
+                dt_tz_none.clone(),
+                ExpectedCast::Value(lit_tz_none.clone()),
+            );
+
+            // None <--> Utc
+            expect_cast(
+                lit_tz_none.clone(),
+                dt_tz_utc.clone(),
+                ExpectedCast::Value(lit_tz_utc.clone()),
+            );
+
+            // Utc <--> None
+            expect_cast(
+                lit_tz_utc.clone(),
+                dt_tz_none.clone(),
+                ExpectedCast::Value(lit_tz_none.clone()),
+            );
+
+            // Utc <--> Utc
+            expect_cast(
+                lit_tz_utc.clone(),
+                dt_tz_utc.clone(),
+                ExpectedCast::Value(lit_tz_utc.clone()),
+            );
+
+            // timestamp to int64
+            expect_cast(
+                lit_tz_utc.clone(),
+                DataType::Int64,
+                ExpectedCast::Value(ScalarValue::Int64(Some(12345))),
+            );
+
+            // int64 to timestamp
+            expect_cast(
+                ScalarValue::Int64(Some(12345)),
+                dt_tz_none.clone(),
+                ExpectedCast::Value(lit_tz_none.clone()),
+            );
+
+            // int64 to timestamp
+            expect_cast(
+                ScalarValue::Int64(Some(12345)),
+                dt_tz_utc.clone(),
+                ExpectedCast::Value(lit_tz_utc.clone()),
+            );
+
+            // timestamp to string (not supported yet)
+            expect_cast(
+                lit_tz_utc.clone(),
+                DataType::LargeUtf8,
+                ExpectedCast::NoValue,
+            );
+        }
+    }
+
     #[test]
     fn test_try_cast_to_type_unsupported() {
         // int64 to list
diff --git a/datafusion/optimizer/tests/integration-test.rs 
b/datafusion/optimizer/tests/integration-test.rs
index be62ba2a5..48cd831bd 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -236,7 +236,8 @@ fn timestamp_nano_ts_none_predicates() -> Result<()> {
     // constant and compared to the column without a cast so it can be
     // pushed down / pruned
     let expected =
-        "Projection: test.col_int32\n  Filter: CAST(test.col_ts_nano_none AS 
Timestamp(Nanosecond, Some(\"+00:00\"))) < 
TimestampNanosecond(1666612093000000000, Some(\"+00:00\"))\
+        "Projection: test.col_int32\
+         \n  Filter: test.col_ts_nano_none < 
TimestampNanosecond(1666612093000000000, None)\
          \n    TableScan: test projection=[col_int32, col_ts_nano_none]";
     assert_eq!(expected, format!("{:?}", plan));
     Ok(())

Reply via email to