This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 018ffbe143 preserve array type / timezone in `date_bin` and 
`date_trunc` functions (#7729)
018ffbe143 is described below

commit 018ffbe14383256868e7298b8f3fe4d3ab841639
Author: Martin Hilton <[email protected]>
AuthorDate: Wed Oct 4 15:09:53 2023 +0100

    preserve array type / timezone in `date_bin` and `date_trunc` functions 
(#7729)
    
    * preserve array type in date_bin and date_trunc functions
    
    The result type of date_bin and date_trunc never includes any
    timezone information. Change this such that the timezone of the
    resulting array from these functions is copied from the input array.
    
    * Update datafusion/expr/src/built_in_function.rs
    
    Co-authored-by: Alex Huang <[email protected]>
    
    * fix: syntax error
    
    * fix: datafusion-cli cargo update
    
    * review suggestions
    
    Add some additional tests suggested in code reviews.
    
    * fix formatting
    
    ---------
    
    Co-authored-by: Alex Huang <[email protected]>
---
 datafusion/expr/src/built_in_function.rs           |  17 +-
 .../physical-expr/src/datetime_expressions.rs      | 299 ++++++++++++++++++++-
 datafusion/sqllogictest/test_files/timestamps.slt  |  56 ++++
 3 files changed, 355 insertions(+), 17 deletions(-)

diff --git a/datafusion/expr/src/built_in_function.rs 
b/datafusion/expr/src/built_in_function.rs
index 58d84545db..70514f52d5 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -618,13 +618,20 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8),
             BuiltinScalarFunction::DatePart => Ok(Float64),
             BuiltinScalarFunction::DateBin | BuiltinScalarFunction::DateTrunc 
=> {
-                match input_expr_types[1] {
-                    Timestamp(Nanosecond, _) | Utf8 | Null => {
+                match &input_expr_types[1] {
+                    Timestamp(Nanosecond, None) | Utf8 | Null => {
                         Ok(Timestamp(Nanosecond, None))
                     }
-                    Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, 
None)),
-                    Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, 
None)),
-                    Timestamp(Second, _) => Ok(Timestamp(Second, None)),
+                    Timestamp(Nanosecond, tz_opt) => {
+                        Ok(Timestamp(Nanosecond, tz_opt.clone()))
+                    }
+                    Timestamp(Microsecond, tz_opt) => {
+                        Ok(Timestamp(Microsecond, tz_opt.clone()))
+                    }
+                    Timestamp(Millisecond, tz_opt) => {
+                        Ok(Timestamp(Millisecond, tz_opt.clone()))
+                    }
+                    Timestamp(Second, tz_opt) => Ok(Timestamp(Second, 
tz_opt.clone())),
                     _ => plan_err!(
                     "The {self} function can only accept timestamp as the 
second arg."
                 ),
diff --git a/datafusion/physical-expr/src/datetime_expressions.rs 
b/datafusion/physical-expr/src/datetime_expressions.rs
index 5ce71f4584..5cf1c21df5 100644
--- a/datafusion/physical-expr/src/datetime_expressions.rs
+++ b/datafusion/physical-expr/src/datetime_expressions.rs
@@ -433,7 +433,8 @@ pub fn date_trunc(args: &[ColumnarValue]) -> 
Result<ColumnarValue> {
                                 granularity.as_str(),
                             )
                         })
-                        .collect::<Result<TimestampSecondArray>>()?;
+                        .collect::<Result<TimestampSecondArray>>()?
+                        .with_timezone_opt(tz_opt.clone());
                     ColumnarValue::Array(Arc::new(array))
                 }
                 DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => {
@@ -449,7 +450,8 @@ pub fn date_trunc(args: &[ColumnarValue]) -> 
Result<ColumnarValue> {
                                 granularity.as_str(),
                             )
                         })
-                        .collect::<Result<TimestampMillisecondArray>>()?;
+                        .collect::<Result<TimestampMillisecondArray>>()?
+                        .with_timezone_opt(tz_opt.clone());
                     ColumnarValue::Array(Arc::new(array))
                 }
                 DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
@@ -465,7 +467,25 @@ pub fn date_trunc(args: &[ColumnarValue]) -> 
Result<ColumnarValue> {
                                 granularity.as_str(),
                             )
                         })
-                        .collect::<Result<TimestampMicrosecondArray>>()?;
+                        .collect::<Result<TimestampMicrosecondArray>>()?
+                        .with_timezone_opt(tz_opt.clone());
+                    ColumnarValue::Array(Arc::new(array))
+                }
+                DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => {
+                    let parsed_tz = parse_tz(tz_opt)?;
+                    let array = as_timestamp_nanosecond_array(array)?;
+                    let array = array
+                        .iter()
+                        .map(|x| {
+                            _date_trunc(
+                                TimeUnit::Nanosecond,
+                                &x,
+                                parsed_tz,
+                                granularity.as_str(),
+                            )
+                        })
+                        .collect::<Result<TimestampNanosecondArray>>()?
+                        .with_timezone_opt(tz_opt.clone());
                     ColumnarValue::Array(Arc::new(array))
                 }
                 _ => {
@@ -713,35 +733,39 @@ fn date_bin_impl(
             ))
         }
         ColumnarValue::Array(array) => match array.data_type() {
-            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+            DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => {
                 let array = as_timestamp_nanosecond_array(array)?
                     .iter()
                     .map(f_nanos)
-                    .collect::<TimestampNanosecondArray>();
+                    .collect::<TimestampNanosecondArray>()
+                    .with_timezone_opt(tz_opt.clone());
 
                 ColumnarValue::Array(Arc::new(array))
             }
-            DataType::Timestamp(TimeUnit::Microsecond, _) => {
+            DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
                 let array = as_timestamp_microsecond_array(array)?
                     .iter()
                     .map(f_micros)
-                    .collect::<TimestampMicrosecondArray>();
+                    .collect::<TimestampMicrosecondArray>()
+                    .with_timezone_opt(tz_opt.clone());
 
                 ColumnarValue::Array(Arc::new(array))
             }
-            DataType::Timestamp(TimeUnit::Millisecond, _) => {
+            DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => {
                 let array = as_timestamp_millisecond_array(array)?
                     .iter()
                     .map(f_millis)
-                    .collect::<TimestampMillisecondArray>();
+                    .collect::<TimestampMillisecondArray>()
+                    .with_timezone_opt(tz_opt.clone());
 
                 ColumnarValue::Array(Arc::new(array))
             }
-            DataType::Timestamp(TimeUnit::Second, _) => {
+            DataType::Timestamp(TimeUnit::Second, tz_opt) => {
                 let array = as_timestamp_second_array(array)?
                     .iter()
                     .map(f_secs)
-                    .collect::<TimestampSecondArray>();
+                    .collect::<TimestampSecondArray>()
+                    .with_timezone_opt(tz_opt.clone());
 
                 ColumnarValue::Array(Arc::new(array))
             }
@@ -925,7 +949,9 @@ where
 mod tests {
     use std::sync::Arc;
 
-    use arrow::array::{ArrayRef, Int64Array, IntervalDayTimeArray, 
StringBuilder};
+    use arrow::array::{
+        as_primitive_array, ArrayRef, Int64Array, IntervalDayTimeArray, 
StringBuilder,
+    };
 
     use super::*;
 
@@ -1051,6 +1077,125 @@ mod tests {
         });
     }
 
+    #[test]
+    fn test_date_trunc_timezones() {
+        let cases = vec![
+            (
+                vec![
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T01:00:00Z",
+                    "2020-09-08T02:00:00Z",
+                    "2020-09-08T03:00:00Z",
+                    "2020-09-08T04:00:00Z",
+                ],
+                Some("+00".into()),
+                vec![
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                ],
+            ),
+            (
+                vec![
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T01:00:00Z",
+                    "2020-09-08T02:00:00Z",
+                    "2020-09-08T03:00:00Z",
+                    "2020-09-08T04:00:00Z",
+                ],
+                None,
+                vec![
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                ],
+            ),
+            (
+                vec![
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T01:00:00Z",
+                    "2020-09-08T02:00:00Z",
+                    "2020-09-08T03:00:00Z",
+                    "2020-09-08T04:00:00Z",
+                ],
+                Some("-02".into()),
+                vec![
+                    "2020-09-07T02:00:00Z",
+                    "2020-09-07T02:00:00Z",
+                    "2020-09-08T02:00:00Z",
+                    "2020-09-08T02:00:00Z",
+                    "2020-09-08T02:00:00Z",
+                ],
+            ),
+            (
+                vec![
+                    "2020-09-08T00:00:00+05",
+                    "2020-09-08T01:00:00+05",
+                    "2020-09-08T02:00:00+05",
+                    "2020-09-08T03:00:00+05",
+                    "2020-09-08T04:00:00+05",
+                ],
+                Some("+05".into()),
+                vec![
+                    "2020-09-08T00:00:00+05",
+                    "2020-09-08T00:00:00+05",
+                    "2020-09-08T00:00:00+05",
+                    "2020-09-08T00:00:00+05",
+                    "2020-09-08T00:00:00+05",
+                ],
+            ),
+            (
+                vec![
+                    "2020-09-08T00:00:00+08",
+                    "2020-09-08T01:00:00+08",
+                    "2020-09-08T02:00:00+08",
+                    "2020-09-08T03:00:00+08",
+                    "2020-09-08T04:00:00+08",
+                ],
+                Some("+08".into()),
+                vec![
+                    "2020-09-08T00:00:00+08",
+                    "2020-09-08T00:00:00+08",
+                    "2020-09-08T00:00:00+08",
+                    "2020-09-08T00:00:00+08",
+                    "2020-09-08T00:00:00+08",
+                ],
+            ),
+        ];
+
+        cases.iter().for_each(|(original, tz_opt, expected)| {
+            let input = original
+                .iter()
+                .map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
+                .collect::<TimestampNanosecondArray>()
+                .with_timezone_opt(tz_opt.clone());
+            let right = expected
+                .iter()
+                .map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
+                .collect::<TimestampNanosecondArray>()
+                .with_timezone_opt(tz_opt.clone());
+            let result = date_trunc(&[
+                
ColumnarValue::Scalar(ScalarValue::Utf8(Some("day".to_string()))),
+                ColumnarValue::Array(Arc::new(input)),
+            ])
+            .unwrap();
+            if let ColumnarValue::Array(result) = result {
+                assert_eq!(
+                    result.data_type(),
+                    &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone())
+                );
+                let left = 
as_primitive_array::<TimestampNanosecondType>(&result);
+                assert_eq!(left, &right);
+            } else {
+                panic!("unexpected column type");
+            }
+        });
+    }
+
     #[test]
     fn test_date_bin_single() {
         use chrono::Duration;
@@ -1252,6 +1397,136 @@ mod tests {
         );
     }
 
+    #[test]
+    fn test_date_bin_timezones() {
+        let cases = vec![
+            (
+                vec![
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T01:00:00Z",
+                    "2020-09-08T02:00:00Z",
+                    "2020-09-08T03:00:00Z",
+                    "2020-09-08T04:00:00Z",
+                ],
+                Some("+00".into()),
+                "1970-01-01T00:00:00Z",
+                vec![
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                ],
+            ),
+            (
+                vec![
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T01:00:00Z",
+                    "2020-09-08T02:00:00Z",
+                    "2020-09-08T03:00:00Z",
+                    "2020-09-08T04:00:00Z",
+                ],
+                None,
+                "1970-01-01T00:00:00Z",
+                vec![
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                ],
+            ),
+            (
+                vec![
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T01:00:00Z",
+                    "2020-09-08T02:00:00Z",
+                    "2020-09-08T03:00:00Z",
+                    "2020-09-08T04:00:00Z",
+                ],
+                Some("-02".into()),
+                "1970-01-01T00:00:00Z",
+                vec![
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                    "2020-09-08T00:00:00Z",
+                ],
+            ),
+            (
+                vec![
+                    "2020-09-08T00:00:00+05",
+                    "2020-09-08T01:00:00+05",
+                    "2020-09-08T02:00:00+05",
+                    "2020-09-08T03:00:00+05",
+                    "2020-09-08T04:00:00+05",
+                ],
+                Some("+05".into()),
+                "1970-01-01T00:00:00+05",
+                vec![
+                    "2020-09-08T00:00:00+05",
+                    "2020-09-08T00:00:00+05",
+                    "2020-09-08T00:00:00+05",
+                    "2020-09-08T00:00:00+05",
+                    "2020-09-08T00:00:00+05",
+                ],
+            ),
+            (
+                vec![
+                    "2020-09-08T00:00:00+08",
+                    "2020-09-08T01:00:00+08",
+                    "2020-09-08T02:00:00+08",
+                    "2020-09-08T03:00:00+08",
+                    "2020-09-08T04:00:00+08",
+                ],
+                Some("+08".into()),
+                "1970-01-01T00:00:00+08",
+                vec![
+                    "2020-09-08T00:00:00+08",
+                    "2020-09-08T00:00:00+08",
+                    "2020-09-08T00:00:00+08",
+                    "2020-09-08T00:00:00+08",
+                    "2020-09-08T00:00:00+08",
+                ],
+            ),
+        ];
+
+        cases
+            .iter()
+            .for_each(|(original, tz_opt, origin, expected)| {
+                let input = original
+                    .iter()
+                    .map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
+                    .collect::<TimestampNanosecondArray>()
+                    .with_timezone_opt(tz_opt.clone());
+                let right = expected
+                    .iter()
+                    .map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
+                    .collect::<TimestampNanosecondArray>()
+                    .with_timezone_opt(tz_opt.clone());
+                let result = date_bin(&[
+                    ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)),
+                    ColumnarValue::Array(Arc::new(input)),
+                    ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(
+                        Some(string_to_timestamp_nanos(origin).unwrap()),
+                        tz_opt.clone(),
+                    )),
+                ])
+                .unwrap();
+                if let ColumnarValue::Array(result) = result {
+                    assert_eq!(
+                        result.data_type(),
+                        &DataType::Timestamp(TimeUnit::Nanosecond, 
tz_opt.clone())
+                    );
+                    let left = 
as_primitive_array::<TimestampNanosecondType>(&result);
+                    assert_eq!(left, &right);
+                } else {
+                    panic!("unexpected column type");
+                }
+            });
+    }
+
     #[test]
     fn to_timestamp_invalid_input_type() -> Result<()> {
         // pass the wrong type of input array to to_timestamp and test
diff --git a/datafusion/sqllogictest/test_files/timestamps.slt 
b/datafusion/sqllogictest/test_files/timestamps.slt
index bb06c569f0..edafe18caa 100644
--- a/datafusion/sqllogictest/test_files/timestamps.slt
+++ b/datafusion/sqllogictest/test_files/timestamps.slt
@@ -1702,3 +1702,59 @@ SELECT TIMESTAMPTZ '2023-03-11 02:00:00 
America/Los_Angeles' as ts_geo
 # postgresql: accepts
 statement error
 SELECT TIMESTAMPTZ '2023-03-12 02:00:00 America/Los_Angeles' as ts_geo
+
+
+
+##########
+## Timezone column tests
+##########
+
+# create a table with a non-UTC time zone.
+statement ok
+SET TIME ZONE = '+05:00'
+
+statement ok
+CREATE TABLE foo (time TIMESTAMPTZ) AS VALUES
+    ('2020-01-01T00:00:00+05:00'), 
+    ('2020-01-01T01:00:00+05:00'),
+    ('2020-01-01T02:00:00+05:00'),
+    ('2020-01-01T03:00:00+05:00')
+
+statement ok
+SET TIME ZONE = '+00'
+
+# verify column type
+query T
+SELECT arrow_typeof(time) FROM foo LIMIT 1
+----
+Timestamp(Nanosecond, Some("+05:00"))
+
+# check date_trunc
+query P
+SELECT date_trunc('day', time) FROM foo
+----
+2020-01-01T00:00:00+05:00
+2020-01-01T00:00:00+05:00
+2020-01-01T00:00:00+05:00
+2020-01-01T00:00:00+05:00
+
+# verify date_trunc column type
+query T
+SELECT arrow_typeof(date_trunc('day', time)) FROM foo LIMIT 1
+----
+Timestamp(Nanosecond, Some("+05:00"))
+
+# check date_bin
+query P
+SELECT date_bin(INTERVAL '1 day', time, '1970-01-01T00:00:00+05:00') FROM foo
+----
+2020-01-01T00:00:00+05:00
+2020-01-01T00:00:00+05:00
+2020-01-01T00:00:00+05:00
+2020-01-01T00:00:00+05:00
+
+# verify date_trunc column type
+query T
+SELECT arrow_typeof(date_bin(INTERVAL '1 day', time, 
'1970-01-01T00:00:00+05:00')) FROM foo LIMIT 1
+----
+Timestamp(Nanosecond, Some("+05:00"))

Reply via email to