tustvold commented on code in PR #5166: URL: https://github.com/apache/arrow-datafusion/pull/5166#discussion_r1129692579
########## datafusion/sql/src/expr/arrow_cast.rs: ########## @@ -0,0 +1,673 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Implementation of the `arrow_cast` function that allows +//! casting to arbitrary arrow types (rather than SQL types) + +use std::{fmt::Display, iter::Peekable, str::Chars}; + +use arrow_schema::{DataType, IntervalUnit, TimeUnit}; +use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; + +use datafusion_expr::{Expr, ExprSchemable}; + +pub const ARROW_CAST_NAME: &str = "arrow_cast"; + +/// Create an [`Expr`] that evaluates the `arrow_cast` function +/// +/// This function is not a [`BuiltInScalarFunction`] because the +/// return type of [`BuiltInScalarFunction`] depends only on the +/// *types* of the arguments. However, the type of `arrow_type` depends on +/// the *value* of its second argument. +/// +/// Use the `cast` function to cast to SQL type (which is then mapped +/// to the corresponding arrow type). For example to cast to `int` +/// (which is then mapped to the arrow type `Int32`) +/// +/// ```sql +/// select cast(column_x as int) ... +/// ``` +/// +/// Use the `arrow_cast` functiont to cast to a specfic arrow type +/// +/// For example +/// ```sql +/// select arrow_cast(column_x, 'Float64') +/// ``` +pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) -> Result<Expr> { + if args.len() != 2 { + return Err(DataFusionError::Plan(format!( + "arrow_cast needs 2 arguments, {} provided", + args.len() + ))); + } + let arg1 = args.pop().unwrap(); + let arg0 = args.pop().unwrap(); + + // arg1 must be a stirng + let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { + v + } else { + return Err(DataFusionError::Plan(format!( + "arrow_cast requires its second argument to be a constant string, got {arg1}" + ))); + }; + + // do the actual lookup to the appropriate data type + let data_type = parse_data_type(&data_type_string)?; + + arg0.cast_to(&data_type, schema) +} + +/// Parses `str` into a `DataType`. +/// +/// `parse_data_type` is the the reverse of [`DataType`]'s `Display` +/// impl, and maintains the invariant that +/// `parse_data_type(data_type.to_string()) == data_type` +/// +/// Example: +/// ``` +/// # use datafusion_sql::parse_data_type; +/// # use arrow_schema::DataType; +/// let display_value = "Int32"; +/// +/// // "Int32" is the Display value of `DataType` +/// assert_eq!(display_value, &format!("{}", DataType::Int32)); +/// +/// // parse_data_type coverts "Int32" back to `DataType`: +/// let data_type = parse_data_type(display_value).unwrap(); +/// assert_eq!(data_type, DataType::Int32); +/// ``` +/// +/// TODO file a ticket about bringing this into arrow possibly Review Comment: I think having FromStr and Display implementations for DataType would be very compelling :+1: ########## datafusion/sql/src/expr/arrow_cast.rs: ########## @@ -0,0 +1,673 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Implementation of the `arrow_cast` function that allows +//! casting to arbitrary arrow types (rather than SQL types) + +use std::{fmt::Display, iter::Peekable, str::Chars}; + +use arrow_schema::{DataType, IntervalUnit, TimeUnit}; +use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; + +use datafusion_expr::{Expr, ExprSchemable}; + +pub const ARROW_CAST_NAME: &str = "arrow_cast"; + +/// Create an [`Expr`] that evaluates the `arrow_cast` function +/// +/// This function is not a [`BuiltInScalarFunction`] because the +/// return type of [`BuiltInScalarFunction`] depends only on the +/// *types* of the arguments. However, the type of `arrow_type` depends on +/// the *value* of its second argument. +/// +/// Use the `cast` function to cast to SQL type (which is then mapped +/// to the corresponding arrow type). For example to cast to `int` +/// (which is then mapped to the arrow type `Int32`) +/// +/// ```sql +/// select cast(column_x as int) ... +/// ``` +/// +/// Use the `arrow_cast` functiont to cast to a specfic arrow type +/// +/// For example +/// ```sql +/// select arrow_cast(column_x, 'Float64') +/// ``` +pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) -> Result<Expr> { + if args.len() != 2 { + return Err(DataFusionError::Plan(format!( + "arrow_cast needs 2 arguments, {} provided", + args.len() + ))); + } + let arg1 = args.pop().unwrap(); + let arg0 = args.pop().unwrap(); + + // arg1 must be a stirng + let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { + v + } else { + return Err(DataFusionError::Plan(format!( + "arrow_cast requires its second argument to be a constant string, got {arg1}" + ))); + }; + + // do the actual lookup to the appropriate data type + let data_type = parse_data_type(&data_type_string)?; + + arg0.cast_to(&data_type, schema) +} + +/// Parses `str` into a `DataType`. +/// +/// `parse_data_type` is the the reverse of [`DataType`]'s `Display` +/// impl, and maintains the invariant that +/// `parse_data_type(data_type.to_string()) == data_type` +/// +/// Example: +/// ``` +/// # use datafusion_sql::parse_data_type; +/// # use arrow_schema::DataType; +/// let display_value = "Int32"; +/// +/// // "Int32" is the Display value of `DataType` +/// assert_eq!(display_value, &format!("{}", DataType::Int32)); +/// +/// // parse_data_type coverts "Int32" back to `DataType`: +/// let data_type = parse_data_type(display_value).unwrap(); +/// assert_eq!(data_type, DataType::Int32); +/// ``` +/// +/// TODO file a ticket about bringing this into arrow possibly +pub fn parse_data_type(val: &str) -> Result<DataType> { + Parser::new(val).parse() +} + +fn make_error(val: &str, msg: &str) -> DataFusionError { + DataFusionError::Plan( + format!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" ) + ) +} + +fn make_error_expected(val: &str, expected: &Token, actual: &Token) -> DataFusionError { + make_error(val, &format!("Expected '{expected}', got '{actual}'")) +} + +#[derive(Debug)] +/// Implementation of `parse_data_type`, modeled after <https://github.com/sqlparser-rs/sqlparser-rs> +struct Parser<'a> { + val: &'a str, + tokenizer: Tokenizer<'a>, +} + +impl<'a> Parser<'a> { + fn new(val: &'a str) -> Self { + Self { + val, + tokenizer: Tokenizer::new(val), + } + } + + fn parse(mut self) -> Result<DataType> { + let data_type = self.parse_next_type()?; + // ensure that there is no trailing content + if self.tokenizer.peek_next_char().is_some() { + return Err(make_error( + self.val, + &format!("checking trailing content after parsing '{data_type}'"), + )); + } else { + Ok(data_type) + } + } + + /// parses the next full DataType + fn parse_next_type(&mut self) -> Result<DataType> { + match self.next_token()? { + Token::SimpleType(data_type) => Ok(data_type), + Token::Timestamp => self.parse_timestamp(), + Token::Time32 => self.parse_time32(), + Token::Time64 => self.parse_time64(), + Token::Duration => self.parse_duration(), + Token::Interval => self.parse_interval(), + Token::FixedSizeBinary => self.parse_fixed_size_binary(), + Token::Decimal128 => self.parse_decimal_128(), + Token::Decimal256 => self.parse_decimal_256(), + Token::Dictionary => self.parse_dictionary(), + tok => Err(make_error( + self.val, + &format!("finding next type, got unexpected '{tok}'"), + )), + } + } + + /// Parses the next timeunit + fn parse_time_unit(&mut self, context: &str) -> Result<TimeUnit> { + match self.next_token()? { + Token::TimeUnit(time_unit) => Ok(time_unit), + tok => Err(make_error( + self.val, + &format!("finding TimeUnit for {context}, got {tok}"), + )), + } + } + + /// Parses the next integer value + fn parse_i64(&mut self, context: &str) -> Result<i64> { + match self.next_token()? { + Token::Integer(v) => Ok(v), + tok => Err(make_error( + self.val, + &format!("finding i64 for {context}, got '{tok}'"), + )), + } + } + + /// Parses the next i32 integer value + fn parse_i32(&mut self, context: &str) -> Result<i32> { + let length = self.parse_i64(context)?; + length.try_into().map_err(|e| { + make_error( + self.val, + &format!("converting {length} into i32 for {context}: {e}"), + ) + }) + } + + /// Parses the next i8 integer value + fn parse_i8(&mut self, context: &str) -> Result<i8> { + let length = self.parse_i64(context)?; + length.try_into().map_err(|e| { + make_error( + self.val, + &format!("converting {length} into i8 for {context}: {e}"), + ) + }) + } + + /// Parses the next u8 integer value + fn parse_u8(&mut self, context: &str) -> Result<u8> { + let length = self.parse_i64(context)?; + length.try_into().map_err(|e| { + make_error( + self.val, + &format!("converting {length} into u8 for {context}: {e}"), + ) + }) + } + + /// Parses the next timestamp (called after `Timestamp` has been consumed) + fn parse_timestamp(&mut self) -> Result<DataType> { + self.expect_token(Token::LParen)?; + let time_unit = self.parse_time_unit("Timestamp")?; + self.expect_token(Token::Comma)?; + // TODO Support timezones other than None + self.expect_token(Token::None)?; + let timezone = None; + + self.expect_token(Token::RParen)?; + Ok(DataType::Timestamp(time_unit, timezone)) + } + + /// Parses the next Time32 (called after `Time32` has been consumed) + fn parse_time32(&mut self) -> Result<DataType> { + self.expect_token(Token::LParen)?; + let time_unit = self.parse_time_unit("Time32")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Time32(time_unit)) + } + + /// Parses the next Time64 (called after `Time64` has been consumed) + fn parse_time64(&mut self) -> Result<DataType> { + self.expect_token(Token::LParen)?; + let time_unit = self.parse_time_unit("Time64")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Time64(time_unit)) + } + + /// Parses the next Duration (called after `Duration` has been consumed) + fn parse_duration(&mut self) -> Result<DataType> { + self.expect_token(Token::LParen)?; + let time_unit = self.parse_time_unit("Duration")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Duration(time_unit)) + } + + /// Parses the next Interval (called after `Interval` has been consumed) + fn parse_interval(&mut self) -> Result<DataType> { + self.expect_token(Token::LParen)?; + let interval_unit = match self.next_token()? { + Token::IntervalUnit(interval_unit) => interval_unit, + tok => { + return Err(make_error( + self.val, + &format!("finding IntervalUnit for Interval, got {tok}"), + )) + } + }; + self.expect_token(Token::RParen)?; + Ok(DataType::Interval(interval_unit)) + } + + /// Parses the next FixedSizeBinary (called after `FixedSizeBinary` has been consumed) + fn parse_fixed_size_binary(&mut self) -> Result<DataType> { + self.expect_token(Token::LParen)?; + let length = self.parse_i32("FixedSizeBinary")?; + self.expect_token(Token::RParen)?; + Ok(DataType::FixedSizeBinary(length)) + } + + /// Parses the next Decimal128 (called after `Decimal128` has been consumed) + fn parse_decimal_128(&mut self) -> Result<DataType> { + self.expect_token(Token::LParen)?; + let precision = self.parse_u8("Decimal128")?; + self.expect_token(Token::Comma)?; + let scale = self.parse_i8("Decimal128")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Decimal128(precision, scale)) + } + + /// Parses the next Decimal256 (called after `Decimal256` has been consumed) + fn parse_decimal_256(&mut self) -> Result<DataType> { + self.expect_token(Token::LParen)?; + let precision = self.parse_u8("Decimal256")?; + self.expect_token(Token::Comma)?; + let scale = self.parse_i8("Decimal256")?; + self.expect_token(Token::RParen)?; + Ok(DataType::Decimal256(precision, scale)) + } + + /// Parses the next Dictionary (called after `Dictionary` has been consumed) + fn parse_dictionary(&mut self) -> Result<DataType> { + self.expect_token(Token::LParen)?; + let key_type = self.parse_next_type()?; + self.expect_token(Token::Comma)?; + let value_type = self.parse_next_type()?; + self.expect_token(Token::RParen)?; + Ok(DataType::Dictionary( + Box::new(key_type), + Box::new(value_type), + )) + } + + /// return the next token, or an error if there are none left + fn next_token(&mut self) -> Result<Token> { + match self.tokenizer.next() { + None => Err(make_error(self.val, "finding next token")), + Some(token) => token, + } + } + + /// consume the next token, returning OK(()) if it matches tok, and Err if not + fn expect_token(&mut self, tok: Token) -> Result<()> { + let next_token = self.next_token()?; + if next_token == tok { + Ok(()) + } else { + Err(make_error_expected(self.val, &tok, &next_token)) + } + } +} + +/// returns true if this character is a separator +fn is_separator(c: char) -> bool { + c == '(' || c == ')' || c == ',' || c == ' ' +} + +#[derive(Debug)] +/// Splits a strings like Dictionary(Int32, Int64) into tokens sutable for parsing +/// +/// For example the string "Timestamp(Nanosecond, None)" would be parsed into: +/// +/// * Token::Timestamp +/// * Token::Lparen +/// * Token::IntervalUnit(IntervalUnit::Nanosecond) +/// * Token::Comma, +/// * Token::None, +/// * Token::Rparen, +struct Tokenizer<'a> { + val: &'a str, + chars: Peekable<Chars<'a>>, Review Comment: Unless I am mistaken, all tokens are ASCII, and so I think this could use bytes directly without needing to worry about UTF-8 shenanigans... -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org