This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e46924d80 feat: add `arrow_cast` function to support supports 
arbitrary arrow types (#5166)
e46924d80 is described below

commit e46924d80fddbed0faf35edc85b3bba6f050b344
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Mar 8 23:32:59 2023 +0100

    feat: add `arrow_cast` function to support supports arbitrary arrow types 
(#5166)
    
    * Add `arrow_cast` function
    
    * prettier
    
    * Update datafusion/sql/src/expr/arrow_cast.rs
    
    Co-authored-by: Wei-Ting Kuo <[email protected]>
    
    * Apply suggestions from code review
    
    Co-authored-by: Wei-Ting Kuo <[email protected]>
    
    * Clarify intent of tests
    
    * Add more error tests
    
    * More tests
    
    * fix test
    
    * reuse buffer to avoid an allocation per word
    
    * add ticket link
    
    * allow trailing whitespace, add tests for whitespace
    
    ---------
    
    Co-authored-by: Wei-Ting Kuo <[email protected]>
---
 .../sqllogictests/test_files/arrow_typeof.slt      | 267 +++++++-
 datafusion/proto/src/logical_plan/mod.rs           |   5 +-
 datafusion/sql/src/expr/arrow_cast.rs              | 719 +++++++++++++++++++++
 datafusion/sql/src/expr/function.rs                |  39 +-
 datafusion/sql/src/expr/mod.rs                     |   1 +
 datafusion/sql/src/lib.rs                          |   1 +
 datafusion/sql/tests/integration_test.rs           |  11 +-
 docs/source/user-guide/sql/data_types.md           |  66 +-
 8 files changed, 1056 insertions(+), 53 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt 
b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
index 8f1c00651..fee24740a 100644
--- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
@@ -52,31 +52,242 @@ SELECT arrow_typeof(1.0::float)
 Float32
 
 # arrow_typeof_decimal
-# query T
-# SELECT arrow_typeof(1::Decimal)
-# ----
-# Decimal128(38, 10)
-
-# # arrow_typeof_timestamp
-# query T
-# SELECT arrow_typeof(now()::timestamp)
-# ----
-# Timestamp(Nanosecond, None)
-
-# # arrow_typeof_timestamp_utc
-# query T
-# SELECT arrow_typeof(now())
-# ----
-# Timestamp(Nanosecond, Some(\"+00:00\"))
-
-# # arrow_typeof_timestamp_date32(
-# query T
-# SELECT arrow_typeof(now()::date)
-# ----
-# Date32
-
-# # arrow_typeof_utf8
-# query T
-# SELECT arrow_typeof('1')
-# ----
-# Utf8
+query T
+SELECT arrow_typeof(1::Decimal)
+----
+Decimal128(38, 10)
+
+# arrow_typeof_timestamp
+query T
+SELECT arrow_typeof(now()::timestamp)
+----
+Timestamp(Nanosecond, None)
+
+# arrow_typeof_timestamp_utc
+query T
+SELECT arrow_typeof(now())
+----
+Timestamp(Nanosecond, Some("+00:00"))
+
+# arrow_typeof_timestamp_date32(
+query T
+SELECT arrow_typeof(now()::date)
+----
+Date32
+
+# arrow_typeof_utf8
+query T
+SELECT arrow_typeof('1')
+----
+Utf8
+
+
+#### arrow_cast (in some ways opposite of arrow_typeof)
+
+# Basic tests
+
+query I
+SELECT arrow_cast('1', 'Int16')
+----
+1
+
+# Basic error test
+query error Error during planning: arrow_cast needs 2 arguments, 1 provided
+SELECT arrow_cast('1')
+
+query error Error during planning: arrow_cast requires its second argument to 
be a constant string, got Int64\(43\)
+SELECT arrow_cast('1', 43)
+
+query error Error unrecognized word: unknown
+SELECT arrow_cast('1', 'unknown')
+
+# Round Trip tests:
+query TTTTTTTTTTTTTTTTTTT
+SELECT
+  arrow_typeof(arrow_cast(1, 'Int8')) as col_i8,
+  arrow_typeof(arrow_cast(1, 'Int16')) as col_i16,
+  arrow_typeof(arrow_cast(1, 'Int32')) as col_i32,
+  arrow_typeof(arrow_cast(1, 'Int64')) as col_i64,
+  arrow_typeof(arrow_cast(1, 'UInt8')) as col_u8,
+  arrow_typeof(arrow_cast(1, 'UInt16')) as col_u16,
+  arrow_typeof(arrow_cast(1, 'UInt32')) as col_u32,
+  arrow_typeof(arrow_cast(1, 'UInt64')) as col_u64,
+  -- can't seem to cast to Float16 for some reason
+  -- arrow_typeof(arrow_cast(1, 'Float16')) as col_f16,
+  arrow_typeof(arrow_cast(1, 'Float32')) as col_f32,
+  arrow_typeof(arrow_cast(1, 'Float64')) as col_f64,
+  arrow_typeof(arrow_cast('foo', 'Utf8')) as col_utf8,
+  arrow_typeof(arrow_cast('foo', 'LargeUtf8')) as col_large_utf8,
+  arrow_typeof(arrow_cast('foo', 'Binary')) as col_binary,
+  arrow_typeof(arrow_cast('foo', 'LargeBinary')) as col_large_binary,
+  arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 
'Timestamp(Second, None)')) as col_ts_s,
+  arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 
'Timestamp(Millisecond, None)')) as col_ts_ms,
+  arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 
'Timestamp(Microsecond, None)')) as col_ts_us,
+  arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 
'Timestamp(Nanosecond, None)')) as col_ts_ns,
+  arrow_typeof(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) as col_dict
+----
+Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 Utf8 
LargeUtf8 Binary LargeBinary Timestamp(Second, None) Timestamp(Millisecond, 
None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) 
Dictionary(Int32, Utf8)
+
+
+
+
+## Basic Types: Create a table
+
+statement ok
+create table foo as select
+  arrow_cast(1, 'Int8') as col_i8,
+  arrow_cast(1, 'Int16') as col_i16,
+  arrow_cast(1, 'Int32') as col_i32,
+  arrow_cast(1, 'Int64') as col_i64,
+  arrow_cast(1, 'UInt8') as col_u8,
+  arrow_cast(1, 'UInt16') as col_u16,
+  arrow_cast(1, 'UInt32') as col_u32,
+  arrow_cast(1, 'UInt64') as col_u64,
+  -- can't seem to cast to Float16 for some reason
+  -- arrow_cast(1.0, 'Float16') as col_f16,
+  arrow_cast(1.0, 'Float32') as col_f32,
+  arrow_cast(1.0, 'Float64') as col_f64
+;
+
+## Ensure each column in the table has the expected type
+
+query TTTTTTTTTT
+SELECT
+  arrow_typeof(col_i8),
+  arrow_typeof(col_i16),
+  arrow_typeof(col_i32),
+  arrow_typeof(col_i64),
+  arrow_typeof(col_u8),
+  arrow_typeof(col_u16),
+  arrow_typeof(col_u32),
+  arrow_typeof(col_u64),
+  -- arrow_typeof(col_f16),
+  arrow_typeof(col_f32),
+  arrow_typeof(col_f64)
+  FROM foo;
+----
+Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64
+
+
+statement ok
+drop table foo
+
+## Decimals: Create a table
+
+statement ok
+create table foo as select
+  arrow_cast(100, 'Decimal128(3,2)') as col_d128
+  -- Can't make a decimal 156:
+  -- This feature is not implemented: Can't create a scalar from array of type 
"Decimal256(3, 2)"
+  --arrow_cast(100, 'Decimal256(3,2)') as col_d256
+;
+
+
+## Ensure each column in the table has the expected type
+
+query T
+SELECT
+  arrow_typeof(col_d128)
+  -- arrow_typeof(col_d256),
+  FROM foo;
+----
+Decimal128(3, 2)
+
+
+statement ok
+drop table foo
+
+## Strings, Binary: Create a table
+
+statement ok
+create table foo as select
+  arrow_cast('foo', 'Utf8') as col_utf8,
+  arrow_cast('foo', 'LargeUtf8') as col_large_utf8,
+  arrow_cast('foo', 'Binary') as col_binary,
+  arrow_cast('foo', 'LargeBinary') as col_large_binary
+;
+
+## Ensure each column in the table has the expected type
+
+query TTTT
+SELECT
+  arrow_typeof(col_utf8),
+  arrow_typeof(col_large_utf8),
+  arrow_typeof(col_binary),
+  arrow_typeof(col_large_binary)
+  FROM foo;
+----
+Utf8 LargeUtf8 Binary LargeBinary
+
+
+statement ok
+drop table foo
+
+
+## Timestamps: Create a table
+
+statement ok
+create table foo as select
+  arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 
'Timestamp(Second, None)') as col_ts_s,
+  arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 
'Timestamp(Millisecond, None)') as col_ts_ms,
+  arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 
'Timestamp(Microsecond, None)') as col_ts_us,
+  arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 
'Timestamp(Nanosecond, None)') as col_ts_ns
+;
+
+## Ensure each column in the table has the expected type
+
+query TTTT
+SELECT
+  arrow_typeof(col_ts_s),
+  arrow_typeof(col_ts_ms),
+  arrow_typeof(col_ts_us),
+  arrow_typeof(col_ts_ns)
+  FROM foo;
+----
+Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, 
None) Timestamp(Nanosecond, None)
+
+
+statement ok
+drop table foo
+
+## Dictionaries
+
+statement ok
+create table foo as select
+  arrow_cast('foo', 'Dictionary(Int32, Utf8)') as col_dict_int32_utf8,
+  arrow_cast('foo', 'Dictionary(Int8, LargeUtf8)') as col_dict_int8_largeutf8
+;
+
+## Ensure each column in the table has the expected type
+
+query TT
+SELECT
+  arrow_typeof(col_dict_int32_utf8),
+  arrow_typeof(col_dict_int8_largeutf8)
+  FROM foo;
+----
+Dictionary(Int32, Utf8) Dictionary(Int8, LargeUtf8)
+
+
+statement ok
+drop table foo
+
+
+## Intervals:
+
+query error Cannot automatically convert Interval\(DayTime\) to 
Interval\(MonthDayNano\)
+---
+select arrow_cast(interval '30 minutes', 'Interval(MonthDayNano)');
+
+query error DataFusion error: Error during planning: Cannot automatically 
convert Utf8 to Interval\(MonthDayNano\)
+select arrow_cast('30 minutes', 'Interval(MonthDayNano)');
+
+
+## Duration
+
+query error Cannot automatically convert Interval\(DayTime\) to 
Duration\(Second\)
+---
+select arrow_cast(interval '30 minutes', 'Duration(Second)');
+
+query error DataFusion error: Error during planning: Cannot automatically 
convert Utf8 to Duration\(Second\)
+select arrow_cast('30 minutes', 'Duration(Second)');
diff --git a/datafusion/proto/src/logical_plan/mod.rs 
b/datafusion/proto/src/logical_plan/mod.rs
index 706128259..802242b3e 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -2120,8 +2120,11 @@ mod roundtrip_tests {
             DataType::Float16,
             DataType::Float32,
             DataType::Float64,
-            // Add more timestamp tests
+            DataType::Timestamp(TimeUnit::Second, None),
             DataType::Timestamp(TimeUnit::Millisecond, None),
+            DataType::Timestamp(TimeUnit::Microsecond, None),
+            DataType::Timestamp(TimeUnit::Nanosecond, None),
+            DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
             DataType::Date32,
             DataType::Date64,
             DataType::Time32(TimeUnit::Second),
diff --git a/datafusion/sql/src/expr/arrow_cast.rs 
b/datafusion/sql/src/expr/arrow_cast.rs
new file mode 100644
index 000000000..bc1313e2c
--- /dev/null
+++ b/datafusion/sql/src/expr/arrow_cast.rs
@@ -0,0 +1,719 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Implementation of the `arrow_cast` function that allows
+//! casting to arbitrary arrow types (rather than SQL types)
+
+use std::{fmt::Display, iter::Peekable, str::Chars};
+
+use arrow_schema::{DataType, IntervalUnit, TimeUnit};
+use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
+
+use datafusion_expr::{Expr, ExprSchemable};
+
+pub const ARROW_CAST_NAME: &str = "arrow_cast";
+
+/// Create an [`Expr`] that evaluates the `arrow_cast` function
+///
+/// This function is not a [`BuiltInScalarFunction`] because the
+/// return type of [`BuiltInScalarFunction`] depends only on the
+/// *types* of the arguments. However, the type of `arrow_type` depends on
+/// the *value* of its second argument.
+///
+/// Use the `cast` function to cast to SQL type (which is then mapped
+/// to the corresponding arrow type). For example to cast to `int`
+/// (which is then mapped to the arrow type `Int32`)
+///
+/// ```sql
+/// select cast(column_x as int) ...
+/// ```
+///
+/// Use the `arrow_cast` functiont to cast to a specfic arrow type
+///
+/// For example
+/// ```sql
+/// select arrow_cast(column_x, 'Float64')
+/// ```
+pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) -> 
Result<Expr> {
+    if args.len() != 2 {
+        return Err(DataFusionError::Plan(format!(
+            "arrow_cast needs 2 arguments, {} provided",
+            args.len()
+        )));
+    }
+    let arg1 = args.pop().unwrap();
+    let arg0 = args.pop().unwrap();
+
+    // arg1 must be a stirng
+    let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = 
arg1 {
+        v
+    } else {
+        return Err(DataFusionError::Plan(format!(
+            "arrow_cast requires its second argument to be a constant string, 
got {arg1}"
+        )));
+    };
+
+    // do the actual lookup to the appropriate data type
+    let data_type = parse_data_type(&data_type_string)?;
+
+    arg0.cast_to(&data_type, schema)
+}
+
+/// Parses `str` into a `DataType`.
+///
+/// `parse_data_type` is the the reverse of [`DataType`]'s `Display`
+/// impl, and maintains the invariant that
+/// `parse_data_type(data_type.to_string()) == data_type`
+///
+/// Example:
+/// ```
+/// # use datafusion_sql::parse_data_type;
+/// # use arrow_schema::DataType;
+/// let display_value = "Int32";
+///
+/// // "Int32" is the Display value of `DataType`
+/// assert_eq!(display_value, &format!("{}", DataType::Int32));
+///
+/// // parse_data_type coverts "Int32" back to `DataType`:
+/// let data_type = parse_data_type(display_value).unwrap();
+/// assert_eq!(data_type, DataType::Int32);
+/// ```
+///
+/// Remove if added to arrow: https://github.com/apache/arrow-rs/issues/3821
+pub fn parse_data_type(val: &str) -> Result<DataType> {
+    Parser::new(val).parse()
+}
+
+fn make_error(val: &str, msg: &str) -> DataFusionError {
+    DataFusionError::Plan(
+        format!("Unsupported type '{val}'. Must be a supported arrow type name 
such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" )
+    )
+}
+
+fn make_error_expected(val: &str, expected: &Token, actual: &Token) -> 
DataFusionError {
+    make_error(val, &format!("Expected '{expected}', got '{actual}'"))
+}
+
+#[derive(Debug)]
+/// Implementation of `parse_data_type`, modeled after 
<https://github.com/sqlparser-rs/sqlparser-rs>
+struct Parser<'a> {
+    val: &'a str,
+    tokenizer: Tokenizer<'a>,
+}
+
+impl<'a> Parser<'a> {
+    fn new(val: &'a str) -> Self {
+        Self {
+            val,
+            tokenizer: Tokenizer::new(val),
+        }
+    }
+
+    fn parse(mut self) -> Result<DataType> {
+        let data_type = self.parse_next_type()?;
+        // ensure that there is no trailing content
+        if self.tokenizer.next().is_some() {
+            return Err(make_error(
+                self.val,
+                &format!("checking trailing content after parsing 
'{data_type}'"),
+            ));
+        } else {
+            Ok(data_type)
+        }
+    }
+
+    /// parses the next full DataType
+    fn parse_next_type(&mut self) -> Result<DataType> {
+        match self.next_token()? {
+            Token::SimpleType(data_type) => Ok(data_type),
+            Token::Timestamp => self.parse_timestamp(),
+            Token::Time32 => self.parse_time32(),
+            Token::Time64 => self.parse_time64(),
+            Token::Duration => self.parse_duration(),
+            Token::Interval => self.parse_interval(),
+            Token::FixedSizeBinary => self.parse_fixed_size_binary(),
+            Token::Decimal128 => self.parse_decimal_128(),
+            Token::Decimal256 => self.parse_decimal_256(),
+            Token::Dictionary => self.parse_dictionary(),
+            tok => Err(make_error(
+                self.val,
+                &format!("finding next type, got unexpected '{tok}'"),
+            )),
+        }
+    }
+
+    /// Parses the next timeunit
+    fn parse_time_unit(&mut self, context: &str) -> Result<TimeUnit> {
+        match self.next_token()? {
+            Token::TimeUnit(time_unit) => Ok(time_unit),
+            tok => Err(make_error(
+                self.val,
+                &format!("finding TimeUnit for {context}, got {tok}"),
+            )),
+        }
+    }
+
+    /// Parses the next integer value
+    fn parse_i64(&mut self, context: &str) -> Result<i64> {
+        match self.next_token()? {
+            Token::Integer(v) => Ok(v),
+            tok => Err(make_error(
+                self.val,
+                &format!("finding i64 for {context}, got '{tok}'"),
+            )),
+        }
+    }
+
+    /// Parses the next i32 integer value
+    fn parse_i32(&mut self, context: &str) -> Result<i32> {
+        let length = self.parse_i64(context)?;
+        length.try_into().map_err(|e| {
+            make_error(
+                self.val,
+                &format!("converting {length} into i32 for {context}: {e}"),
+            )
+        })
+    }
+
+    /// Parses the next i8 integer value
+    fn parse_i8(&mut self, context: &str) -> Result<i8> {
+        let length = self.parse_i64(context)?;
+        length.try_into().map_err(|e| {
+            make_error(
+                self.val,
+                &format!("converting {length} into i8 for {context}: {e}"),
+            )
+        })
+    }
+
+    /// Parses the next u8 integer value
+    fn parse_u8(&mut self, context: &str) -> Result<u8> {
+        let length = self.parse_i64(context)?;
+        length.try_into().map_err(|e| {
+            make_error(
+                self.val,
+                &format!("converting {length} into u8 for {context}: {e}"),
+            )
+        })
+    }
+
+    /// Parses the next timestamp (called after `Timestamp` has been consumed)
+    fn parse_timestamp(&mut self) -> Result<DataType> {
+        self.expect_token(Token::LParen)?;
+        let time_unit = self.parse_time_unit("Timestamp")?;
+        self.expect_token(Token::Comma)?;
+        // TODO Support timezones other than None
+        self.expect_token(Token::None)?;
+        let timezone = None;
+
+        self.expect_token(Token::RParen)?;
+        Ok(DataType::Timestamp(time_unit, timezone))
+    }
+
+    /// Parses the next Time32 (called after `Time32` has been consumed)
+    fn parse_time32(&mut self) -> Result<DataType> {
+        self.expect_token(Token::LParen)?;
+        let time_unit = self.parse_time_unit("Time32")?;
+        self.expect_token(Token::RParen)?;
+        Ok(DataType::Time32(time_unit))
+    }
+
+    /// Parses the next Time64 (called after `Time64` has been consumed)
+    fn parse_time64(&mut self) -> Result<DataType> {
+        self.expect_token(Token::LParen)?;
+        let time_unit = self.parse_time_unit("Time64")?;
+        self.expect_token(Token::RParen)?;
+        Ok(DataType::Time64(time_unit))
+    }
+
+    /// Parses the next Duration (called after `Duration` has been consumed)
+    fn parse_duration(&mut self) -> Result<DataType> {
+        self.expect_token(Token::LParen)?;
+        let time_unit = self.parse_time_unit("Duration")?;
+        self.expect_token(Token::RParen)?;
+        Ok(DataType::Duration(time_unit))
+    }
+
+    /// Parses the next Interval (called after `Interval` has been consumed)
+    fn parse_interval(&mut self) -> Result<DataType> {
+        self.expect_token(Token::LParen)?;
+        let interval_unit = match self.next_token()? {
+            Token::IntervalUnit(interval_unit) => interval_unit,
+            tok => {
+                return Err(make_error(
+                    self.val,
+                    &format!("finding IntervalUnit for Interval, got {tok}"),
+                ))
+            }
+        };
+        self.expect_token(Token::RParen)?;
+        Ok(DataType::Interval(interval_unit))
+    }
+
+    /// Parses the next FixedSizeBinary (called after `FixedSizeBinary` has 
been consumed)
+    fn parse_fixed_size_binary(&mut self) -> Result<DataType> {
+        self.expect_token(Token::LParen)?;
+        let length = self.parse_i32("FixedSizeBinary")?;
+        self.expect_token(Token::RParen)?;
+        Ok(DataType::FixedSizeBinary(length))
+    }
+
+    /// Parses the next Decimal128 (called after `Decimal128` has been 
consumed)
+    fn parse_decimal_128(&mut self) -> Result<DataType> {
+        self.expect_token(Token::LParen)?;
+        let precision = self.parse_u8("Decimal128")?;
+        self.expect_token(Token::Comma)?;
+        let scale = self.parse_i8("Decimal128")?;
+        self.expect_token(Token::RParen)?;
+        Ok(DataType::Decimal128(precision, scale))
+    }
+
+    /// Parses the next Decimal256 (called after `Decimal256` has been 
consumed)
+    fn parse_decimal_256(&mut self) -> Result<DataType> {
+        self.expect_token(Token::LParen)?;
+        let precision = self.parse_u8("Decimal256")?;
+        self.expect_token(Token::Comma)?;
+        let scale = self.parse_i8("Decimal256")?;
+        self.expect_token(Token::RParen)?;
+        Ok(DataType::Decimal256(precision, scale))
+    }
+
+    /// Parses the next Dictionary (called after `Dictionary` has been 
consumed)
+    fn parse_dictionary(&mut self) -> Result<DataType> {
+        self.expect_token(Token::LParen)?;
+        let key_type = self.parse_next_type()?;
+        self.expect_token(Token::Comma)?;
+        let value_type = self.parse_next_type()?;
+        self.expect_token(Token::RParen)?;
+        Ok(DataType::Dictionary(
+            Box::new(key_type),
+            Box::new(value_type),
+        ))
+    }
+
+    /// return the next token, or an error if there are none left
+    fn next_token(&mut self) -> Result<Token> {
+        match self.tokenizer.next() {
+            None => Err(make_error(self.val, "finding next token")),
+            Some(token) => token,
+        }
+    }
+
+    /// consume the next token, returning OK(()) if it matches tok, and Err if 
not
+    fn expect_token(&mut self, tok: Token) -> Result<()> {
+        let next_token = self.next_token()?;
+        if next_token == tok {
+            Ok(())
+        } else {
+            Err(make_error_expected(self.val, &tok, &next_token))
+        }
+    }
+}
+
+/// returns true if this character is a separator
+fn is_separator(c: char) -> bool {
+    c == '(' || c == ')' || c == ',' || c == ' '
+}
+
+#[derive(Debug)]
+/// Splits a strings like Dictionary(Int32, Int64) into tokens sutable for 
parsing
+///
+/// For example the string "Timestamp(Nanosecond, None)" would be parsed into:
+///
+/// * Token::Timestamp
+/// * Token::Lparen
+/// * Token::IntervalUnit(IntervalUnit::Nanosecond)
+/// * Token::Comma,
+/// * Token::None,
+/// * Token::Rparen,
+struct Tokenizer<'a> {
+    val: &'a str,
+    chars: Peekable<Chars<'a>>,
+    // temporary buffer for parsing words
+    word: String,
+}
+
+impl<'a> Tokenizer<'a> {
+    fn new(val: &'a str) -> Self {
+        Self {
+            val,
+            chars: val.chars().peekable(),
+            word: String::new(),
+        }
+    }
+
+    /// returns the next char, without consuming it
+    fn peek_next_char(&mut self) -> Option<char> {
+        self.chars.peek().copied()
+    }
+
+    /// returns the next char, and consuming it
+    fn next_char(&mut self) -> Option<char> {
+        self.chars.next()
+    }
+
+    /// parse the characters in val starting at pos, until the next
+    /// `,`, `(`, or `)` or end of line
+    fn parse_word(&mut self) -> Result<Token> {
+        // reset temp space
+        self.word.clear();
+        loop {
+            match self.peek_next_char() {
+                None => break,
+                Some(c) if is_separator(c) => break,
+                Some(c) => {
+                    self.next_char();
+                    self.word.push(c);
+                }
+            }
+        }
+
+        // if it started with a number, try parsing it as an integer
+        if let Some(c) = self.word.chars().next() {
+            if c == '-' || c.is_numeric() {
+                let val: i64 = self.word.parse().map_err(|e| {
+                    make_error(
+                        self.val,
+                        &format!("parsing {} as integer: {e}", self.word),
+                    )
+                })?;
+                return Ok(Token::Integer(val));
+            }
+        }
+
+        // figure out what the word was
+        let token = match self.word.as_str() {
+            "Null" => Token::SimpleType(DataType::Null),
+            "Boolean" => Token::SimpleType(DataType::Boolean),
+
+            "Int8" => Token::SimpleType(DataType::Int8),
+            "Int16" => Token::SimpleType(DataType::Int16),
+            "Int32" => Token::SimpleType(DataType::Int32),
+            "Int64" => Token::SimpleType(DataType::Int64),
+
+            "UInt8" => Token::SimpleType(DataType::UInt8),
+            "UInt16" => Token::SimpleType(DataType::UInt16),
+            "UInt32" => Token::SimpleType(DataType::UInt32),
+            "UInt64" => Token::SimpleType(DataType::UInt64),
+
+            "Utf8" => Token::SimpleType(DataType::Utf8),
+            "LargeUtf8" => Token::SimpleType(DataType::LargeUtf8),
+            "Binary" => Token::SimpleType(DataType::Binary),
+            "LargeBinary" => Token::SimpleType(DataType::LargeBinary),
+
+            "Float16" => Token::SimpleType(DataType::Float16),
+            "Float32" => Token::SimpleType(DataType::Float32),
+            "Float64" => Token::SimpleType(DataType::Float64),
+
+            "Date32" => Token::SimpleType(DataType::Date32),
+            "Date64" => Token::SimpleType(DataType::Date64),
+
+            "Second" => Token::TimeUnit(TimeUnit::Second),
+            "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond),
+            "Microsecond" => Token::TimeUnit(TimeUnit::Microsecond),
+            "Nanosecond" => Token::TimeUnit(TimeUnit::Nanosecond),
+
+            "Timestamp" => Token::Timestamp,
+            "Time32" => Token::Time32,
+            "Time64" => Token::Time64,
+            "Duration" => Token::Duration,
+            "Interval" => Token::Interval,
+            "Dictionary" => Token::Dictionary,
+
+            "FixedSizeBinary" => Token::FixedSizeBinary,
+            "Decimal128" => Token::Decimal128,
+            "Decimal256" => Token::Decimal256,
+
+            "YearMonth" => Token::IntervalUnit(IntervalUnit::YearMonth),
+            "DayTime" => Token::IntervalUnit(IntervalUnit::DayTime),
+            "MonthDayNano" => Token::IntervalUnit(IntervalUnit::MonthDayNano),
+
+            "None" => Token::None,
+
+            _ => {
+                return Err(make_error(
+                    self.val,
+                    &format!("unrecognized word: {}", self.word),
+                ))
+            }
+        };
+        Ok(token)
+    }
+}
+
+impl<'a> Iterator for Tokenizer<'a> {
+    type Item = Result<Token>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        loop {
+            match self.peek_next_char()? {
+                ' ' => {
+                    // skip whitespace
+                    self.next_char();
+                    continue;
+                }
+                '(' => {
+                    self.next_char();
+                    return Some(Ok(Token::LParen));
+                }
+                ')' => {
+                    self.next_char();
+                    return Some(Ok(Token::RParen));
+                }
+                ',' => {
+                    self.next_char();
+                    return Some(Ok(Token::Comma));
+                }
+                _ => return Some(self.parse_word()),
+            }
+        }
+    }
+}
+
+/// Grammar is
+///
+#[derive(Debug, PartialEq)]
+enum Token {
+    // Null, or Int32
+    SimpleType(DataType),
+    Timestamp,
+    Time32,
+    Time64,
+    Duration,
+    Interval,
+    FixedSizeBinary,
+    Decimal128,
+    Decimal256,
+    Dictionary,
+    TimeUnit(TimeUnit),
+    IntervalUnit(IntervalUnit),
+    LParen,
+    RParen,
+    Comma,
+    None,
+    Integer(i64),
+}
+
+impl Display for Token {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            Token::SimpleType(t) => write!(f, "{t}"),
+            Token::Timestamp => write!(f, "Timestamp"),
+            Token::Time32 => write!(f, "Time32"),
+            Token::Time64 => write!(f, "Time64"),
+            Token::Duration => write!(f, "Duration"),
+            Token::Interval => write!(f, "Interval"),
+            Token::TimeUnit(u) => write!(f, "TimeUnit({u:?})"),
+            Token::IntervalUnit(u) => write!(f, "IntervalUnit({u:?})"),
+            Token::LParen => write!(f, "("),
+            Token::RParen => write!(f, ")"),
+            Token::Comma => write!(f, ","),
+            Token::None => write!(f, "None"),
+            Token::FixedSizeBinary => write!(f, "FixedSizeBinary"),
+            Token::Decimal128 => write!(f, "Decimal128"),
+            Token::Decimal256 => write!(f, "Decimal256"),
+            Token::Dictionary => write!(f, "Dictionary"),
+            Token::Integer(v) => write!(f, "Integer({v})"),
+        }
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use arrow_schema::{IntervalUnit, TimeUnit};
+
+    use super::*;
+
+    #[test]
+    fn test_parse_data_type() {
+        // this ensures types can be parsed correctly from their string 
representations
+        for dt in list_datatypes() {
+            round_trip(dt)
+        }
+    }
+
+    /// convert data_type to a string, and then parse it as a type
+    /// verifying it is the same
+    fn round_trip(data_type: DataType) {
+        let data_type_string = data_type.to_string();
+        println!("Input '{data_type_string}' ({data_type:?})");
+        let parsed_type = parse_data_type(&data_type_string).unwrap();
+        assert_eq!(
+            data_type, parsed_type,
+            "Mismatch parsing {data_type_string}"
+        );
+    }
+
+    fn list_datatypes() -> Vec<DataType> {
+        vec![
+            // ---------
+            // Non Nested types
+            // ---------
+            DataType::Null,
+            DataType::Boolean,
+            DataType::Int8,
+            DataType::Int16,
+            DataType::Int32,
+            DataType::Int64,
+            DataType::UInt8,
+            DataType::UInt16,
+            DataType::UInt32,
+            DataType::UInt64,
+            DataType::Float16,
+            DataType::Float32,
+            DataType::Float64,
+            DataType::Timestamp(TimeUnit::Second, None),
+            DataType::Timestamp(TimeUnit::Millisecond, None),
+            DataType::Timestamp(TimeUnit::Microsecond, None),
+            DataType::Timestamp(TimeUnit::Nanosecond, None),
+            // TODO support timezones
+            //DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
+            DataType::Date32,
+            DataType::Date64,
+            DataType::Time32(TimeUnit::Second),
+            DataType::Time32(TimeUnit::Millisecond),
+            DataType::Time32(TimeUnit::Microsecond),
+            DataType::Time32(TimeUnit::Nanosecond),
+            DataType::Time64(TimeUnit::Second),
+            DataType::Time64(TimeUnit::Millisecond),
+            DataType::Time64(TimeUnit::Microsecond),
+            DataType::Time64(TimeUnit::Nanosecond),
+            DataType::Duration(TimeUnit::Second),
+            DataType::Duration(TimeUnit::Millisecond),
+            DataType::Duration(TimeUnit::Microsecond),
+            DataType::Duration(TimeUnit::Nanosecond),
+            DataType::Interval(IntervalUnit::YearMonth),
+            DataType::Interval(IntervalUnit::DayTime),
+            DataType::Interval(IntervalUnit::MonthDayNano),
+            DataType::Binary,
+            DataType::FixedSizeBinary(0),
+            DataType::FixedSizeBinary(1234),
+            DataType::FixedSizeBinary(-432),
+            DataType::LargeBinary,
+            DataType::Utf8,
+            DataType::LargeUtf8,
+            DataType::Decimal128(7, 12),
+            DataType::Decimal256(6, 13),
+            // ---------
+            // Nested types
+            // ---------
+            DataType::Dictionary(Box::new(DataType::Int32), 
Box::new(DataType::Utf8)),
+            DataType::Dictionary(Box::new(DataType::Int8), 
Box::new(DataType::Utf8)),
+            DataType::Dictionary(
+                Box::new(DataType::Int8),
+                Box::new(DataType::Timestamp(TimeUnit::Nanosecond, None)),
+            ),
+            DataType::Dictionary(
+                Box::new(DataType::Int8),
+                Box::new(DataType::FixedSizeBinary(23)),
+            ),
+            DataType::Dictionary(
+                Box::new(DataType::Int8),
+                Box::new(
+                    // nested dictionaries are probably a bad idea but they 
are possible
+                    DataType::Dictionary(
+                        Box::new(DataType::Int8),
+                        Box::new(DataType::Utf8),
+                    ),
+                ),
+            ),
+            // TODO support more structured types (List, LargeList, Struct, 
Union, Map, RunEndEncoded, etc)
+        ]
+    }
+
+    #[test]
+    fn test_parse_data_type_whitespace_tolerance() {
+        // (string to parse, expected DataType)
+        let cases = [
+            ("Int8", DataType::Int8),
+            (
+                "Timestamp        (Nanosecond,      None)",
+                DataType::Timestamp(TimeUnit::Nanosecond, None),
+            ),
+            (
+                "Timestamp        (Nanosecond,      None)  ",
+                DataType::Timestamp(TimeUnit::Nanosecond, None),
+            ),
+            (
+                "          Timestamp        (Nanosecond,      None             
  )",
+                DataType::Timestamp(TimeUnit::Nanosecond, None),
+            ),
+            (
+                "Timestamp        (Nanosecond,      None               )  ",
+                DataType::Timestamp(TimeUnit::Nanosecond, None),
+            ),
+        ];
+
+        for (data_type_string, expected_data_type) in cases {
+            println!("Parsing '{data_type_string}', expecting 
'{expected_data_type:?}'");
+            let parsed_data_type = parse_data_type(data_type_string).unwrap();
+            assert_eq!(parsed_data_type, expected_data_type);
+        }
+    }
+
+    #[test]
+    fn parse_data_type_errors() {
+        // (string to parse, expected error message)
+        let cases = [
+            ("", "Unsupported type ''"),
+            ("", "Error finding next token"),
+            ("null", "Unsupported type 'null'"),
+            ("Nu", "Unsupported type 'Nu'"),
+            // TODO support timezones
+            (
+                r#"Timestamp(Nanosecond, Some("UTC"))"#,
+                "Error unrecognized word: Some",
+            ),
+            ("Timestamp(Nanosecond, ", "Error finding next token"),
+            (
+                "Float32 Float32",
+                "trailing content after parsing 'Float32'",
+            ),
+            ("Int32, ", "trailing content after parsing 'Int32'"),
+            ("Int32(3), ", "trailing content after parsing 'Int32'"),
+            ("FixedSizeBinary(Int32), ", "Error finding i64 for 
FixedSizeBinary, got 'Int32'"),
+            ("FixedSizeBinary(3.0), ", "Error parsing 3.0 as integer: invalid 
digit found in string"),
+            // too large for i32
+            ("FixedSizeBinary(4000000000), ", "Error converting 4000000000 
into i32 for FixedSizeBinary: out of range integral type conversion attempted"),
+            // can't have negative precision
+            ("Decimal128(-3, 5)", "Error converting -3 into u8 for Decimal128: 
out of range integral type conversion attempted"),
+            ("Decimal256(-3, 5)", "Error converting -3 into u8 for Decimal256: 
out of range integral type conversion attempted"),
+            ("Decimal128(3, 500)", "Error converting 500 into i8 for 
Decimal128: out of range integral type conversion attempted"),
+            ("Decimal256(3, 500)", "Error converting 500 into i8 for 
Decimal256: out of range integral type conversion attempted"),
+
+        ];
+
+        for (data_type_string, expected_message) in cases {
+            print!("Parsing '{data_type_string}', expecting 
'{expected_message}'");
+            match parse_data_type(data_type_string) {
+                Ok(d) => panic!(
+                    "Expected error while parsing '{data_type_string}', but 
got '{d}'"
+                ),
+                Err(e) => {
+                    let message = e.to_string();
+                    assert!(
+                        message.contains(expected_message),
+                        "\n\ndid not find expected in actual.\n\nexpected: 
{expected_message}\nactual:{message}\n"
+                    );
+                    // errors should also contain  a help message
+                    assert!(message.contains("Must be a supported arrow type 
name such as 'Int32' or 'Timestamp(Nanosecond, None)'"));
+                }
+            }
+            println!(" Ok");
+        }
+    }
+}
diff --git a/datafusion/sql/src/expr/function.rs 
b/datafusion/sql/src/expr/function.rs
index c5f23213a..68a5df054 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -29,6 +29,8 @@ use sqlparser::ast::{
 };
 use std::str::FromStr;
 
+use super::arrow_cast::ARROW_CAST_NAME;
+
 impl<'a, S: ContextProvider> SqlToRel<'a, S> {
     pub(super) fn sql_function_to_expr(
         &self,
@@ -110,24 +112,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         };
 
         // finally, user-defined functions (UDF) and UDAF
-        match self.schema_provider.get_function_meta(&name) {
-            Some(fm) => {
-                let args = self.function_args_to_expr(function.args, schema)?;
+        if let Some(fm) = self.schema_provider.get_function_meta(&name) {
+            let args = self.function_args_to_expr(function.args, schema)?;
+            return Ok(Expr::ScalarUDF { fun: fm, args });
+        }
 
-                Ok(Expr::ScalarUDF { fun: fm, args })
-            }
-            None => match self.schema_provider.get_aggregate_meta(&name) {
-                Some(fm) => {
-                    let args = self.function_args_to_expr(function.args, 
schema)?;
-                    Ok(Expr::AggregateUDF {
-                        fun: fm,
-                        args,
-                        filter: None,
-                    })
-                }
-                _ => Err(DataFusionError::Plan(format!("Invalid function 
'{name}'"))),
-            },
+        // User defined aggregate functions
+        if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) {
+            let args = self.function_args_to_expr(function.args, schema)?;
+            return Ok(Expr::AggregateUDF {
+                fun: fm,
+                args,
+                filter: None,
+            });
         }
+
+        // Special case arrow_cast (as its type is dependent on its argument 
value)
+        if name == ARROW_CAST_NAME {
+            let args = self.function_args_to_expr(function.args, schema)?;
+            return super::arrow_cast::create_arrow_cast(args, schema);
+        }
+
+        // Could not find the relevant function, so return an error
+        Err(DataFusionError::Plan(format!("Invalid function '{name}'")))
     }
 
     pub(super) fn sql_named_function_to_expr(
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index f22692451..ad05fbcc1 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+pub(crate) mod arrow_cast;
 mod binary_op;
 mod function;
 mod grouping_set;
diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs
index efe239f45..c0c1a4ac9 100644
--- a/datafusion/sql/src/lib.rs
+++ b/datafusion/sql/src/lib.rs
@@ -30,4 +30,5 @@ pub mod utils;
 mod values;
 
 pub use datafusion_common::{ResolvedTableReference, TableReference};
+pub use expr::arrow_cast::parse_data_type;
 pub use sqlparser;
diff --git a/datafusion/sql/tests/integration_test.rs 
b/datafusion/sql/tests/integration_test.rs
index 71f5bf05e..660959907 100644
--- a/datafusion/sql/tests/integration_test.rs
+++ b/datafusion/sql/tests/integration_test.rs
@@ -2311,6 +2311,15 @@ fn approx_median_window() {
     quick_test(sql, expected);
 }
 
+#[test]
+fn select_arrow_cast() {
+    let sql = "SELECT arrow_cast(1234, 'Float64'), arrow_cast('foo', 
'LargeUtf8')";
+    let expected = "\
+    Projection: CAST(Int64(1234) AS Float64), CAST(Utf8(\"foo\") AS LargeUtf8)\
+    \n  EmptyRelation";
+    quick_test(sql, expected);
+}
+
 #[test]
 fn select_typed_date_string() {
     let sql = "SELECT date '2020-12-10' AS date";
@@ -2534,7 +2543,7 @@ impl ContextProvider for MockContextProvider {
     }
 
     fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {
-        unimplemented!()
+        None
     }
 
     fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
diff --git a/docs/source/user-guide/sql/data_types.md 
b/docs/source/user-guide/sql/data_types.md
index 968dcda53..9f0ca8f89 100644
--- a/docs/source/user-guide/sql/data_types.md
+++ b/docs/source/user-guide/sql/data_types.md
@@ -37,6 +37,18 @@ the `arrow_typeof` function. For example:
 +-------------------------------------+
 ```
 
+You can cast a SQL expression to a specific Arrow type using the `arrow_cast` 
function
+For example, to cast the output of `now()` to a `Timestamp` with second 
precision rather:
+
+```sql
+❯ select arrow_cast(now(), 'Timestamp(Second, None)');
++---------------------+
+| now()               |
++---------------------+
+| 2023-03-03T17:19:21 |
++---------------------+
+```
+
 ## Character Types
 
 | SQL DataType | Arrow DataType |
@@ -65,12 +77,12 @@ the `arrow_typeof` function. For example:
 
 ## Date/Time Types
 
-| SQL DataType | Arrow DataType                                                
           |
-| ------------ | 
:----------------------------------------------------------------------- |
-| `DATE`       | `Date32`                                                      
           |
-| `TIME`       | `Time64(TimeUnit::Nanosecond)`                                
           |
-| `TIMESTAMP`  | `Timestamp(TimeUnit::Nanosecond, None)`                       
           |
-| `INTERVAL`   | `Interval(IntervalUnit::YearMonth)` or 
`Interval(IntervalUnit::DayTime)` |
+| SQL DataType | Arrow DataType                                  |
+| ------------ | :---------------------------------------------- |
+| `DATE`       | `Date32`                                        |
+| `TIME`       | `Time64(Nanosecond)`                            |
+| `TIMESTAMP`  | `Timestamp(Nanosecond, None)`                   |
+| `INTERVAL`   | `Interval(IntervalUnit)` or `Interval(DayTime)` |
 
 ## Boolean Types
 
@@ -84,7 +96,7 @@ the `arrow_typeof` function. For example:
 | ------------ | :------------- |
 | `BYTEA`      | `Binary`       |
 
-## Unsupported Types
+## Unsupported SQL Types
 
 | SQL Data Type | Arrow DataType      |
 | ------------- | :------------------ |
@@ -100,3 +112,43 @@ the `arrow_typeof` function. For example:
 | `ENUM`        | _Not yet supported_ |
 | `SET`         | _Not yet supported_ |
 | `DATETIME`    | _Not yet supported_ |
+
+## Supported Arrow Types
+
+The following types are supported by the `arrow_typeof` function:
+
+| Arrow Type                                                  |
+| ----------------------------------------------------------- |
+| `Null`                                                      |
+| `Boolean`                                                   |
+| `Int8`                                                      |
+| `Int16`                                                     |
+| `Int32`                                                     |
+| `Int64`                                                     |
+| `UInt8`                                                     |
+| `UInt16`                                                    |
+| `UInt32`                                                    |
+| `UInt64`                                                    |
+| `Float16`                                                   |
+| `Float32`                                                   |
+| `Float64`                                                   |
+| `Utf8`                                                      |
+| `LargeUtf8`                                                 |
+| `Binary`                                                    |
+| `Timestamp(Second, None)`                                   |
+| `Timestamp(Millisecond, None)`                              |
+| `Timestamp(Microsecond, None)`                              |
+| `Timestamp(Nanosecond, None)`                               |
+| `Time32`                                                    |
+| `Time64`                                                    |
+| `Duration(Second)`                                          |
+| `Duration(Millisecond)`                                     |
+| `Duration(Microsecond)`                                     |
+| `Duration(Nanosecond)`                                      |
+| `Interval(YearMonth)`                                       |
+| `Interval(DayTime)`                                         |
+| `Interval(MonthDayNano)`                                    |
+| `Interval(MonthDayNano)`                                    |
+| `FixedSizeBinary(<len>)` (e.g. `FixedSizeBinary(16)`)       |
+| `Decimal128(<precision>, <scale>)` e.g. `Decimal128(3, 10)` |
+| `Decimal256(<precision>, <scale>)` e.g. `Decimal256(3, 10)` |


Reply via email to