This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 65dd364be4 Improve unparser MySQL compatibility (#11589)
65dd364be4 is described below

commit 65dd364be438c2febb10f68efab0fd7a63586055
Author: Sergei Grebnov <[email protected]>
AuthorDate: Tue Jul 23 10:49:43 2024 -0700

    Improve unparser MySQL compatibility (#11589)
    
    * Configurable date field extraction style for unparsing (#21)
    
    * Add support for IntervalStyle::MySQL (#18)
    
    * Support alternate format for Int64 unparsing (SIGNED for MySQL) (#22)
    
    * Alternate format support for Timestamp casting (DATETIME for MySQL) (#23)
    
    * Improve
    
    * Fix clippy and docs
---
 datafusion/sql/src/unparser/dialect.rs | 155 ++++++++++++-
 datafusion/sql/src/unparser/expr.rs    | 397 ++++++++++++++++++++++++++++++---
 2 files changed, 507 insertions(+), 45 deletions(-)

diff --git a/datafusion/sql/src/unparser/dialect.rs 
b/datafusion/sql/src/unparser/dialect.rs
index ed0cfddc38..7eca326386 100644
--- a/datafusion/sql/src/unparser/dialect.rs
+++ b/datafusion/sql/src/unparser/dialect.rs
@@ -15,8 +15,14 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::sync::Arc;
+
+use arrow_schema::TimeUnit;
 use regex::Regex;
-use sqlparser::{ast, keywords::ALL_KEYWORDS};
+use sqlparser::{
+    ast::{self, Ident, ObjectName, TimezoneInfo},
+    keywords::ALL_KEYWORDS,
+};
 
 /// `Dialect` to use for Unparsing
 ///
@@ -36,8 +42,8 @@ pub trait Dialect: Send + Sync {
         true
     }
 
-    // Does the dialect use TIMESTAMP to represent Date64 rather than DATETIME?
-    // E.g. Trino, Athena and Dremio does not have DATETIME data type
+    /// Does the dialect use TIMESTAMP to represent Date64 rather than 
DATETIME?
+    /// E.g. Trino, Athena and Dremio does not have DATETIME data type
     fn use_timestamp_for_date64(&self) -> bool {
         false
     }
@@ -46,23 +52,50 @@ pub trait Dialect: Send + Sync {
         IntervalStyle::PostgresVerbose
     }
 
-    // Does the dialect use DOUBLE PRECISION to represent Float64 rather than 
DOUBLE?
-    // E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE
+    /// Does the dialect use DOUBLE PRECISION to represent Float64 rather than 
DOUBLE?
+    /// E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE
     fn float64_ast_dtype(&self) -> sqlparser::ast::DataType {
         sqlparser::ast::DataType::Double
     }
 
-    // The SQL type to use for Arrow Utf8 unparsing
-    // Most dialects use VARCHAR, but some, like MySQL, require CHAR
+    /// The SQL type to use for Arrow Utf8 unparsing
+    /// Most dialects use VARCHAR, but some, like MySQL, require CHAR
     fn utf8_cast_dtype(&self) -> ast::DataType {
         ast::DataType::Varchar(None)
     }
 
-    // The SQL type to use for Arrow LargeUtf8 unparsing
-    // Most dialects use TEXT, but some, like MySQL, require CHAR
+    /// The SQL type to use for Arrow LargeUtf8 unparsing
+    /// Most dialects use TEXT, but some, like MySQL, require CHAR
     fn large_utf8_cast_dtype(&self) -> ast::DataType {
         ast::DataType::Text
     }
+
+    /// The date field extract style to use: `DateFieldExtractStyle`
+    fn date_field_extract_style(&self) -> DateFieldExtractStyle {
+        DateFieldExtractStyle::DatePart
+    }
+
+    /// The SQL type to use for Arrow Int64 unparsing
+    /// Most dialects use BigInt, but some, like MySQL, require SIGNED
+    fn int64_cast_dtype(&self) -> ast::DataType {
+        ast::DataType::BigInt(None)
+    }
+
+    /// The SQL type to use for Timestamp unparsing
+    /// Most dialects use Timestamp, but some, like MySQL, require Datetime
+    /// Some dialects like Dremio does not support WithTimeZone and requires 
always Timestamp
+    fn timestamp_cast_dtype(
+        &self,
+        _time_unit: &TimeUnit,
+        tz: &Option<Arc<str>>,
+    ) -> ast::DataType {
+        let tz_info = match tz {
+            Some(_) => TimezoneInfo::WithTimeZone,
+            None => TimezoneInfo::None,
+        };
+
+        ast::DataType::Timestamp(None, tz_info)
+    }
 }
 
 /// `IntervalStyle` to use for unparsing
@@ -80,6 +113,19 @@ pub enum IntervalStyle {
     MySQL,
 }
 
+/// Datetime subfield extraction style for unparsing
+///
+/// 
`<https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT>`
+/// Different DBMSs follow different standards; popular ones are:
+/// date_part('YEAR', date '2001-02-16')
+/// EXTRACT(YEAR from date '2001-02-16')
+/// Some DBMSs, like Postgres, support both, whereas others like MySQL require 
EXTRACT.
+#[derive(Clone, Copy, PartialEq)]
+pub enum DateFieldExtractStyle {
+    DatePart,
+    Extract,
+}
+
 pub struct DefaultDialect {}
 
 impl Dialect for DefaultDialect {
@@ -133,6 +179,22 @@ impl Dialect for MySqlDialect {
     fn large_utf8_cast_dtype(&self) -> ast::DataType {
         ast::DataType::Char(None)
     }
+
+    fn date_field_extract_style(&self) -> DateFieldExtractStyle {
+        DateFieldExtractStyle::Extract
+    }
+
+    fn int64_cast_dtype(&self) -> ast::DataType {
+        ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![])
+    }
+
+    fn timestamp_cast_dtype(
+        &self,
+        _time_unit: &TimeUnit,
+        _tz: &Option<Arc<str>>,
+    ) -> ast::DataType {
+        ast::DataType::Datetime(None)
+    }
 }
 
 pub struct SqliteDialect {}
@@ -151,6 +213,10 @@ pub struct CustomDialect {
     float64_ast_dtype: sqlparser::ast::DataType,
     utf8_cast_dtype: ast::DataType,
     large_utf8_cast_dtype: ast::DataType,
+    date_field_extract_style: DateFieldExtractStyle,
+    int64_cast_dtype: ast::DataType,
+    timestamp_cast_dtype: ast::DataType,
+    timestamp_tz_cast_dtype: ast::DataType,
 }
 
 impl Default for CustomDialect {
@@ -163,6 +229,13 @@ impl Default for CustomDialect {
             float64_ast_dtype: sqlparser::ast::DataType::Double,
             utf8_cast_dtype: ast::DataType::Varchar(None),
             large_utf8_cast_dtype: ast::DataType::Text,
+            date_field_extract_style: DateFieldExtractStyle::DatePart,
+            int64_cast_dtype: ast::DataType::BigInt(None),
+            timestamp_cast_dtype: ast::DataType::Timestamp(None, 
TimezoneInfo::None),
+            timestamp_tz_cast_dtype: ast::DataType::Timestamp(
+                None,
+                TimezoneInfo::WithTimeZone,
+            ),
         }
     }
 }
@@ -206,6 +279,26 @@ impl Dialect for CustomDialect {
     fn large_utf8_cast_dtype(&self) -> ast::DataType {
         self.large_utf8_cast_dtype.clone()
     }
+
+    fn date_field_extract_style(&self) -> DateFieldExtractStyle {
+        self.date_field_extract_style
+    }
+
+    fn int64_cast_dtype(&self) -> ast::DataType {
+        self.int64_cast_dtype.clone()
+    }
+
+    fn timestamp_cast_dtype(
+        &self,
+        _time_unit: &TimeUnit,
+        tz: &Option<Arc<str>>,
+    ) -> ast::DataType {
+        if tz.is_some() {
+            self.timestamp_tz_cast_dtype.clone()
+        } else {
+            self.timestamp_cast_dtype.clone()
+        }
+    }
 }
 
 /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern
@@ -230,6 +323,10 @@ pub struct CustomDialectBuilder {
     float64_ast_dtype: sqlparser::ast::DataType,
     utf8_cast_dtype: ast::DataType,
     large_utf8_cast_dtype: ast::DataType,
+    date_field_extract_style: DateFieldExtractStyle,
+    int64_cast_dtype: ast::DataType,
+    timestamp_cast_dtype: ast::DataType,
+    timestamp_tz_cast_dtype: ast::DataType,
 }
 
 impl Default for CustomDialectBuilder {
@@ -248,6 +345,13 @@ impl CustomDialectBuilder {
             float64_ast_dtype: sqlparser::ast::DataType::Double,
             utf8_cast_dtype: ast::DataType::Varchar(None),
             large_utf8_cast_dtype: ast::DataType::Text,
+            date_field_extract_style: DateFieldExtractStyle::DatePart,
+            int64_cast_dtype: ast::DataType::BigInt(None),
+            timestamp_cast_dtype: ast::DataType::Timestamp(None, 
TimezoneInfo::None),
+            timestamp_tz_cast_dtype: ast::DataType::Timestamp(
+                None,
+                TimezoneInfo::WithTimeZone,
+            ),
         }
     }
 
@@ -260,6 +364,10 @@ impl CustomDialectBuilder {
             float64_ast_dtype: self.float64_ast_dtype,
             utf8_cast_dtype: self.utf8_cast_dtype,
             large_utf8_cast_dtype: self.large_utf8_cast_dtype,
+            date_field_extract_style: self.date_field_extract_style,
+            int64_cast_dtype: self.int64_cast_dtype,
+            timestamp_cast_dtype: self.timestamp_cast_dtype,
+            timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype,
         }
     }
 
@@ -293,6 +401,7 @@ impl CustomDialectBuilder {
         self
     }
 
+    /// Customize the dialect with a specific SQL type for Float64 casting: 
DOUBLE, DOUBLE PRECISION, etc.
     pub fn with_float64_ast_dtype(
         mut self,
         float64_ast_dtype: sqlparser::ast::DataType,
@@ -301,11 +410,13 @@ impl CustomDialectBuilder {
         self
     }
 
+    /// Customize the dialect with a specific SQL type for Utf8 casting: 
VARCHAR, CHAR, etc.
     pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> 
Self {
         self.utf8_cast_dtype = utf8_cast_dtype;
         self
     }
 
+    /// Customize the dialect with a specific SQL type for LargeUtf8 casting: 
TEXT, CHAR, etc.
     pub fn with_large_utf8_cast_dtype(
         mut self,
         large_utf8_cast_dtype: ast::DataType,
@@ -313,4 +424,30 @@ impl CustomDialectBuilder {
         self.large_utf8_cast_dtype = large_utf8_cast_dtype;
         self
     }
+
+    /// Customize the dialect with a specific date field extract style listed 
in `DateFieldExtractStyle`
+    pub fn with_date_field_extract_style(
+        mut self,
+        date_field_extract_style: DateFieldExtractStyle,
+    ) -> Self {
+        self.date_field_extract_style = date_field_extract_style;
+        self
+    }
+
+    /// Customize the dialect with a specific SQL type for Int64 casting: 
BigInt, SIGNED, etc.
+    pub fn with_int64_cast_dtype(mut self, int64_cast_dtype: ast::DataType) -> 
Self {
+        self.int64_cast_dtype = int64_cast_dtype;
+        self
+    }
+
+    /// Customize the dialect with a specific SQL type for Timestamp casting: 
Timestamp, Datetime, etc.
+    pub fn with_timestamp_cast_dtype(
+        mut self,
+        timestamp_cast_dtype: ast::DataType,
+        timestamp_tz_cast_dtype: ast::DataType,
+    ) -> Self {
+        self.timestamp_cast_dtype = timestamp_cast_dtype;
+        self.timestamp_tz_cast_dtype = timestamp_tz_cast_dtype;
+        self
+    }
 }
diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index 2f7854c1a1..f4ea44f37d 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -16,6 +16,13 @@
 // under the License.
 
 use core::fmt;
+
+use datafusion_expr::ScalarUDF;
+use sqlparser::ast::Value::SingleQuotedString;
+use sqlparser::ast::{
+    self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, 
Interval,
+    TimezoneInfo, UnaryOperator,
+};
 use std::sync::Arc;
 use std::{fmt::Display, vec};
 
@@ -28,12 +35,6 @@ use arrow_array::types::{
 };
 use arrow_array::{Date32Array, Date64Array, PrimitiveArray};
 use arrow_schema::DataType;
-use sqlparser::ast::Value::SingleQuotedString;
-use sqlparser::ast::{
-    self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, 
Interval,
-    TimezoneInfo, UnaryOperator,
-};
-
 use datafusion_common::{
     internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, 
Result,
     ScalarValue,
@@ -43,7 +44,7 @@ use datafusion_expr::{
     Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, 
TryCast,
 };
 
-use super::dialect::IntervalStyle;
+use super::dialect::{DateFieldExtractStyle, IntervalStyle};
 use super::Unparser;
 
 /// DataFusion's Exprs can represent either an `Expr` or an `OrderByExpr`
@@ -149,6 +150,12 @@ impl Unparser<'_> {
             Expr::ScalarFunction(ScalarFunction { func, args }) => {
                 let func_name = func.name();
 
+                if let Some(expr) =
+                    self.scalar_function_to_sql_overrides(func_name, func, 
args)
+                {
+                    return Ok(expr);
+                }
+
                 let args = args
                     .iter()
                     .map(|e| {
@@ -545,6 +552,38 @@ impl Unparser<'_> {
         }
     }
 
+    fn scalar_function_to_sql_overrides(
+        &self,
+        func_name: &str,
+        _func: &Arc<ScalarUDF>,
+        args: &[Expr],
+    ) -> Option<ast::Expr> {
+        if func_name.to_lowercase() == "date_part"
+            && self.dialect.date_field_extract_style() == 
DateFieldExtractStyle::Extract
+            && args.len() == 2
+        {
+            let date_expr = self.expr_to_sql(&args[1]).ok()?;
+
+            if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &args[0] {
+                let field = match field.to_lowercase().as_str() {
+                    "year" => ast::DateTimeField::Year,
+                    "month" => ast::DateTimeField::Month,
+                    "day" => ast::DateTimeField::Day,
+                    "hour" => ast::DateTimeField::Hour,
+                    "minute" => ast::DateTimeField::Minute,
+                    "second" => ast::DateTimeField::Second,
+                    _ => return None,
+                };
+
+                return Some(ast::Expr::Extract {
+                    field,
+                    expr: Box::new(date_expr),
+                });
+            }
+        }
+        None
+    }
+
     fn ast_type_for_date64_in_cast(&self) -> ast::DataType {
         if self.dialect.use_timestamp_for_date64() {
             ast::DataType::Timestamp(None, ast::TimezoneInfo::None)
@@ -1105,6 +1144,131 @@ impl Unparser<'_> {
         }
     }
 
+    /// MySQL requires INTERVAL sql to be in the format: INTERVAL 1 YEAR + 
INTERVAL 1 MONTH + INTERVAL 1 DAY etc
+    /// 
`<https://dev.mysql.com/doc/refman/8.4/en/expressions.html#temporal-intervals>`
+    /// Interval sequence can't be wrapped in brackets - (INTERVAL 1 YEAR + 
INTERVAL 1 MONTH ...) so we need to generate
+    /// a single INTERVAL expression so it works correct for interval 
substraction cases
+    /// MySQL supports the DAY_MICROSECOND unit type (format is DAYS 
HOURS:MINUTES:SECONDS.MICROSECONDS), but it is not supported by sqlparser
+    /// so we calculate the best single interval to represent the provided 
duration
+    fn interval_to_mysql_expr(
+        &self,
+        months: i32,
+        days: i32,
+        microseconds: i64,
+    ) -> Result<ast::Expr> {
+        // MONTH only
+        if months != 0 && days == 0 && microseconds == 0 {
+            let interval = Interval {
+                value: Box::new(ast::Expr::Value(ast::Value::Number(
+                    months.to_string(),
+                    false,
+                ))),
+                leading_field: Some(ast::DateTimeField::Month),
+                leading_precision: None,
+                last_field: None,
+                fractional_seconds_precision: None,
+            };
+            return Ok(ast::Expr::Interval(interval));
+        } else if months != 0 {
+            return not_impl_err!("Unsupported Interval scalar with both Month 
and DayTime for IntervalStyle::MySQL");
+        }
+
+        // DAY only
+        if microseconds == 0 {
+            let interval = Interval {
+                value: Box::new(ast::Expr::Value(ast::Value::Number(
+                    days.to_string(),
+                    false,
+                ))),
+                leading_field: Some(ast::DateTimeField::Day),
+                leading_precision: None,
+                last_field: None,
+                fractional_seconds_precision: None,
+            };
+            return Ok(ast::Expr::Interval(interval));
+        }
+
+        // calculate the best single interval to represent the provided days 
and microseconds
+
+        let microseconds = microseconds + (days as i64 * 24 * 60 * 60 * 
1_000_000);
+
+        if microseconds % 1_000_000 != 0 {
+            let interval = Interval {
+                value: Box::new(ast::Expr::Value(ast::Value::Number(
+                    microseconds.to_string(),
+                    false,
+                ))),
+                leading_field: Some(ast::DateTimeField::Microsecond),
+                leading_precision: None,
+                last_field: None,
+                fractional_seconds_precision: None,
+            };
+            return Ok(ast::Expr::Interval(interval));
+        }
+
+        let secs = microseconds / 1_000_000;
+
+        if secs % 60 != 0 {
+            let interval = Interval {
+                value: Box::new(ast::Expr::Value(ast::Value::Number(
+                    secs.to_string(),
+                    false,
+                ))),
+                leading_field: Some(ast::DateTimeField::Second),
+                leading_precision: None,
+                last_field: None,
+                fractional_seconds_precision: None,
+            };
+            return Ok(ast::Expr::Interval(interval));
+        }
+
+        let mins = secs / 60;
+
+        if mins % 60 != 0 {
+            let interval = Interval {
+                value: Box::new(ast::Expr::Value(ast::Value::Number(
+                    mins.to_string(),
+                    false,
+                ))),
+                leading_field: Some(ast::DateTimeField::Minute),
+                leading_precision: None,
+                last_field: None,
+                fractional_seconds_precision: None,
+            };
+            return Ok(ast::Expr::Interval(interval));
+        }
+
+        let hours = mins / 60;
+
+        if hours % 24 != 0 {
+            let interval = Interval {
+                value: Box::new(ast::Expr::Value(ast::Value::Number(
+                    hours.to_string(),
+                    false,
+                ))),
+                leading_field: Some(ast::DateTimeField::Hour),
+                leading_precision: None,
+                last_field: None,
+                fractional_seconds_precision: None,
+            };
+            return Ok(ast::Expr::Interval(interval));
+        }
+
+        let days = hours / 24;
+
+        let interval = Interval {
+            value: Box::new(ast::Expr::Value(ast::Value::Number(
+                days.to_string(),
+                false,
+            ))),
+            leading_field: Some(ast::DateTimeField::Day),
+            leading_precision: None,
+            last_field: None,
+            fractional_seconds_precision: None,
+        };
+        Ok(ast::Expr::Interval(interval))
+    }
+
     fn interval_scalar_to_sql(&self, v: &ScalarValue) -> Result<ast::Expr> {
         match self.dialect.interval_style() {
             IntervalStyle::PostgresVerbose => {
@@ -1127,10 +1291,7 @@ impl Unparser<'_> {
             }
             // If the interval standard is SQLStandard, implement a simple 
unparse logic
             IntervalStyle::SQLStandard => match v {
-                ScalarValue::IntervalYearMonth(v) => {
-                    let Some(v) = v else {
-                        return Ok(ast::Expr::Value(ast::Value::Null));
-                    };
+                ScalarValue::IntervalYearMonth(Some(v)) => {
                     let interval = Interval {
                         value: Box::new(ast::Expr::Value(
                             ast::Value::SingleQuotedString(v.to_string()),
@@ -1142,10 +1303,7 @@ impl Unparser<'_> {
                     };
                     Ok(ast::Expr::Interval(interval))
                 }
-                ScalarValue::IntervalDayTime(v) => {
-                    let Some(v) = v else {
-                        return Ok(ast::Expr::Value(ast::Value::Null));
-                    };
+                ScalarValue::IntervalDayTime(Some(v)) => {
                     let days = v.days;
                     let secs = v.milliseconds / 1_000;
                     let mins = secs / 60;
@@ -1168,11 +1326,7 @@ impl Unparser<'_> {
                     };
                     Ok(ast::Expr::Interval(interval))
                 }
-                ScalarValue::IntervalMonthDayNano(v) => {
-                    let Some(v) = v else {
-                        return Ok(ast::Expr::Value(ast::Value::Null));
-                    };
-
+                ScalarValue::IntervalMonthDayNano(Some(v)) => {
                     if v.months >= 0 && v.days == 0 && v.nanoseconds == 0 {
                         let interval = Interval {
                             value: Box::new(ast::Expr::Value(
@@ -1184,10 +1338,7 @@ impl Unparser<'_> {
                             fractional_seconds_precision: None,
                         };
                         Ok(ast::Expr::Interval(interval))
-                    } else if v.months == 0
-                        && v.days >= 0
-                        && v.nanoseconds % 1_000_000 == 0
-                    {
+                    } else if v.months == 0 && v.nanoseconds % 1_000_000 == 0 {
                         let days = v.days;
                         let secs = v.nanoseconds / 1_000_000_000;
                         let mins = secs / 60;
@@ -1214,11 +1365,29 @@ impl Unparser<'_> {
                         not_impl_err!("Unsupported IntervalMonthDayNano scalar 
with both Month and DayTime for IntervalStyle::SQLStandard")
                     }
                 }
-                _ => Ok(ast::Expr::Value(ast::Value::Null)),
+                _ => not_impl_err!(
+                    "Unsupported ScalarValue for Interval conversion: {v:?}"
+                ),
+            },
+            IntervalStyle::MySQL => match v {
+                ScalarValue::IntervalYearMonth(Some(v)) => {
+                    self.interval_to_mysql_expr(*v, 0, 0)
+                }
+                ScalarValue::IntervalDayTime(Some(v)) => {
+                    self.interval_to_mysql_expr(0, v.days, v.milliseconds as 
i64 * 1_000)
+                }
+                ScalarValue::IntervalMonthDayNano(Some(v)) => {
+                    if v.nanoseconds % 1_000 != 0 {
+                        return not_impl_err!(
+                            "Unsupported IntervalMonthDayNano scalar with 
nanoseconds precision for IntervalStyle::MySQL"
+                        );
+                    }
+                    self.interval_to_mysql_expr(v.months, v.days, 
v.nanoseconds / 1_000)
+                }
+                _ => not_impl_err!(
+                    "Unsupported ScalarValue for Interval conversion: {v:?}"
+                ),
             },
-            IntervalStyle::MySQL => {
-                not_impl_err!("Unsupported interval scalar for 
IntervalStyle::MySQL")
-            }
         }
     }
 
@@ -1231,7 +1400,7 @@ impl Unparser<'_> {
             DataType::Int8 => Ok(ast::DataType::TinyInt(None)),
             DataType::Int16 => Ok(ast::DataType::SmallInt(None)),
             DataType::Int32 => Ok(ast::DataType::Integer(None)),
-            DataType::Int64 => Ok(ast::DataType::BigInt(None)),
+            DataType::Int64 => Ok(self.dialect.int64_cast_dtype()),
             DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)),
             DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)),
             DataType::UInt32 => Ok(ast::DataType::UnsignedInteger(None)),
@@ -1241,13 +1410,8 @@ impl Unparser<'_> {
             }
             DataType::Float32 => Ok(ast::DataType::Float(None)),
             DataType::Float64 => Ok(self.dialect.float64_ast_dtype()),
-            DataType::Timestamp(_, tz) => {
-                let tz_info = match tz {
-                    Some(_) => TimezoneInfo::WithTimeZone,
-                    None => TimezoneInfo::None,
-                };
-
-                Ok(ast::DataType::Timestamp(None, tz_info))
+            DataType::Timestamp(time_unit, tz) => {
+                Ok(self.dialect.timestamp_cast_dtype(time_unit, tz))
             }
             DataType::Date32 => Ok(ast::DataType::Date),
             DataType::Date64 => Ok(self.ast_type_for_date64_in_cast()),
@@ -1335,6 +1499,7 @@ mod tests {
     use arrow::datatypes::TimeUnit;
     use arrow::datatypes::{Field, Schema};
     use arrow_schema::DataType::Int8;
+    use ast::ObjectName;
     use datafusion_common::TableReference;
     use datafusion_expr::{
         case, col, cube, exists, grouping_set, interval_datetime_lit,
@@ -1885,6 +2050,11 @@ mod tests {
                 IntervalStyle::SQLStandard,
                 "INTERVAL '1 12:0:0.000' DAY TO SECOND",
             ),
+            (
+                interval_month_day_nano_lit("-1.5 DAY"),
+                IntervalStyle::SQLStandard,
+                "INTERVAL '-1 -12:0:0.000' DAY TO SECOND",
+            ),
             (
                 interval_month_day_nano_lit("1.51234 DAY"),
                 IntervalStyle::SQLStandard,
@@ -1949,6 +2119,46 @@ mod tests {
                 IntervalStyle::PostgresVerbose,
                 r#"INTERVAL '1 YEARS 7 MONS 0 DAYS 0 HOURS 0 MINS 0.00 SECS'"#,
             ),
+            (
+                interval_year_month_lit("1 YEAR 1 MONTH"),
+                IntervalStyle::MySQL,
+                r#"INTERVAL 13 MONTH"#,
+            ),
+            (
+                interval_month_day_nano_lit("1 YEAR -1 MONTH"),
+                IntervalStyle::MySQL,
+                r#"INTERVAL 11 MONTH"#,
+            ),
+            (
+                interval_month_day_nano_lit("15 DAY"),
+                IntervalStyle::MySQL,
+                r#"INTERVAL 15 DAY"#,
+            ),
+            (
+                interval_month_day_nano_lit("-40 HOURS"),
+                IntervalStyle::MySQL,
+                r#"INTERVAL -40 HOUR"#,
+            ),
+            (
+                interval_datetime_lit("-1.5 DAY 1 HOUR"),
+                IntervalStyle::MySQL,
+                "INTERVAL -35 HOUR",
+            ),
+            (
+                interval_datetime_lit("1000000 DAY 1.5 HOUR 10 MINUTE 20 
SECOND"),
+                IntervalStyle::MySQL,
+                r#"INTERVAL 86400006020 SECOND"#,
+            ),
+            (
+                interval_year_month_lit("0 DAY 0 HOUR"),
+                IntervalStyle::MySQL,
+                r#"INTERVAL 0 DAY"#,
+            ),
+            (
+                interval_month_day_nano_lit("-1296000000 SECOND"),
+                IntervalStyle::MySQL,
+                r#"INTERVAL -15000 DAY"#,
+            ),
         ];
 
         for (value, style, expected) in tests {
@@ -1994,4 +2204,119 @@ mod tests {
         }
         Ok(())
     }
+
+    #[test]
+    fn custom_dialect_with_date_field_extract_style() -> Result<()> {
+        for (extract_style, unit, expected) in [
+            (
+                DateFieldExtractStyle::DatePart,
+                "YEAR",
+                "date_part('YEAR', x)",
+            ),
+            (
+                DateFieldExtractStyle::Extract,
+                "YEAR",
+                "EXTRACT(YEAR FROM x)",
+            ),
+            (
+                DateFieldExtractStyle::DatePart,
+                "MONTH",
+                "date_part('MONTH', x)",
+            ),
+            (
+                DateFieldExtractStyle::Extract,
+                "MONTH",
+                "EXTRACT(MONTH FROM x)",
+            ),
+            (
+                DateFieldExtractStyle::DatePart,
+                "DAY",
+                "date_part('DAY', x)",
+            ),
+            (DateFieldExtractStyle::Extract, "DAY", "EXTRACT(DAY FROM x)"),
+        ] {
+            let dialect = CustomDialectBuilder::new()
+                .with_date_field_extract_style(extract_style)
+                .build();
+
+            let unparser = Unparser::new(&dialect);
+            let expr = ScalarUDF::new_from_impl(
+                datafusion_functions::datetime::date_part::DatePartFunc::new(),
+            )
+            .call(vec![Expr::Literal(ScalarValue::new_utf8(unit)), col("x")]);
+
+            let ast = unparser.expr_to_sql(&expr)?;
+            let actual = format!("{}", ast);
+
+            assert_eq!(actual, expected);
+        }
+        Ok(())
+    }
+
+    #[test]
+    fn custom_dialect_with_int64_cast_dtype() -> Result<()> {
+        let default_dialect = CustomDialectBuilder::new().build();
+        let mysql_dialect = CustomDialectBuilder::new()
+            .with_int64_cast_dtype(ast::DataType::Custom(
+                ObjectName(vec![Ident::new("SIGNED")]),
+                vec![],
+            ))
+            .build();
+
+        for (dialect, identifier) in
+            [(default_dialect, "BIGINT"), (mysql_dialect, "SIGNED")]
+        {
+            let unparser = Unparser::new(&dialect);
+            let expr = Expr::Cast(Cast {
+                expr: Box::new(col("a")),
+                data_type: DataType::Int64,
+            });
+            let ast = unparser.expr_to_sql(&expr)?;
+
+            let actual = format!("{}", ast);
+            let expected = format!(r#"CAST(a AS {identifier})"#);
+
+            assert_eq!(actual, expected);
+        }
+        Ok(())
+    }
+
+    #[test]
+    fn custom_dialect_with_teimstamp_cast_dtype() -> Result<()> {
+        let default_dialect = CustomDialectBuilder::new().build();
+        let mysql_dialect = CustomDialectBuilder::new()
+            .with_timestamp_cast_dtype(
+                ast::DataType::Datetime(None),
+                ast::DataType::Datetime(None),
+            )
+            .build();
+
+        let timestamp = DataType::Timestamp(TimeUnit::Nanosecond, None);
+        let timestamp_with_tz =
+            DataType::Timestamp(TimeUnit::Nanosecond, Some("+08:00".into()));
+
+        for (dialect, data_type, identifier) in [
+            (&default_dialect, &timestamp, "TIMESTAMP"),
+            (
+                &default_dialect,
+                &timestamp_with_tz,
+                "TIMESTAMP WITH TIME ZONE",
+            ),
+            (&mysql_dialect, &timestamp, "DATETIME"),
+            (&mysql_dialect, &timestamp_with_tz, "DATETIME"),
+        ] {
+            let unparser = Unparser::new(dialect);
+            let expr = Expr::Cast(Cast {
+                expr: Box::new(col("a")),
+                data_type: data_type.clone(),
+            });
+            let ast = unparser.expr_to_sql(&expr)?;
+
+            let actual = format!("{}", ast);
+            let expected = format!(r#"CAST(a AS {identifier})"#);
+
+            assert_eq!(actual, expected);
+        }
+        Ok(())
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to