This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 02f7e1f713 feat: Introduce convert Expr to SQL string API and basic
feature (#9517)
02f7e1f713 is described below
commit 02f7e1f7132d6bf938724accc8aabccffb92f476
Author: Michiel De Backker <[email protected]>
AuthorDate: Mon Mar 11 22:03:34 2024 +0100
feat: Introduce convert Expr to SQL string API and basic feature (#9517)
* feat: convert Expr to SQL string
* fix: add license headers
* fix: make Unparser and Dialect public
* Update datafusion/sql/src/unparser/dialect.rs
* fmt
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/sql/src/lib.rs | 1 +
datafusion/sql/src/unparser/dialect.rs | 73 +++++
datafusion/sql/src/unparser/expr.rs | 355 +++++++++++++++++++++++++
datafusion/sql/src/{lib.rs => unparser/mod.rs} | 47 ++--
4 files changed, 451 insertions(+), 25 deletions(-)
diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs
index d805f61397..da66ee197a 100644
--- a/datafusion/sql/src/lib.rs
+++ b/datafusion/sql/src/lib.rs
@@ -36,6 +36,7 @@ mod relation;
mod select;
mod set_expr;
mod statement;
+pub mod unparser;
pub mod utils;
mod values;
diff --git a/datafusion/sql/src/unparser/dialect.rs
b/datafusion/sql/src/unparser/dialect.rs
new file mode 100644
index 0000000000..3af33ad0af
--- /dev/null
+++ b/datafusion/sql/src/unparser/dialect.rs
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Dialect is used to capture dialect specific syntax.
+/// Note: this trait will eventually be replaced by the Dialect in the
SQLparser package
+///
+/// See <https://github.com/sqlparser-rs/sqlparser-rs/pull/1170>
+pub trait Dialect {
+ fn identifier_quote_style(&self) -> Option<char>;
+}
+pub struct DefaultDialect {}
+
+impl Dialect for DefaultDialect {
+ fn identifier_quote_style(&self) -> Option<char> {
+ None
+ }
+}
+
+pub struct PostgreSqlDialect {}
+
+impl Dialect for PostgreSqlDialect {
+ fn identifier_quote_style(&self) -> Option<char> {
+ Some('"')
+ }
+}
+
+pub struct MySqlDialect {}
+
+impl Dialect for MySqlDialect {
+ fn identifier_quote_style(&self) -> Option<char> {
+ Some('`')
+ }
+}
+
+pub struct SqliteDialect {}
+
+impl Dialect for SqliteDialect {
+ fn identifier_quote_style(&self) -> Option<char> {
+ Some('`')
+ }
+}
+
+pub struct CustomDialect {
+ identifier_quote_style: Option<char>,
+}
+
+impl CustomDialect {
+ pub fn new(identifier_quote_style: Option<char>) -> Self {
+ Self {
+ identifier_quote_style,
+ }
+ }
+}
+
+impl Dialect for CustomDialect {
+ fn identifier_quote_style(&self) -> Option<char> {
+ self.identifier_quote_style
+ }
+}
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
new file mode 100644
index 0000000000..bb14c8a707
--- /dev/null
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -0,0 +1,355 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_common::{not_impl_err, Column, Result, ScalarValue};
+use datafusion_expr::{
+ expr::{Alias, InList, ScalarFunction, WindowFunction},
+ Between, BinaryExpr, Case, Cast, Expr, Like, Operator,
+};
+use sqlparser::ast;
+
+use super::Unparser;
+
+/// Convert a DataFusion [`Expr`] to `sqlparser::ast::Expr`
+///
+/// This function is the opposite of `SqlToRel::sql_to_expr` and can
+/// be used to, among other things, convert `Expr`s to strings.
+///
+/// # Example
+/// ```
+/// use datafusion_expr::{col, lit};
+/// use datafusion_sql::unparser::expr_to_sql;
+/// let expr = col("a").gt(lit(4));
+/// let sql = expr_to_sql(&expr).unwrap();
+///
+/// assert_eq!(format!("{}", sql), "a > 4")
+/// ```
+pub fn expr_to_sql(expr: &Expr) -> Result<ast::Expr> {
+ let unparser = Unparser::default();
+ unparser.expr_to_sql(expr)
+}
+
+impl Unparser<'_> {
+ pub fn expr_to_sql(&self, expr: &Expr) -> Result<ast::Expr> {
+ match expr {
+ Expr::InList(InList {
+ expr,
+ list: _,
+ negated: _,
+ }) => {
+ not_impl_err!("Unsupported expression: {expr:?}")
+ }
+ Expr::ScalarFunction(ScalarFunction { .. }) => {
+ not_impl_err!("Unsupported expression: {expr:?}")
+ }
+ Expr::Between(Between {
+ expr,
+ negated: _,
+ low: _,
+ high: _,
+ }) => {
+ not_impl_err!("Unsupported expression: {expr:?}")
+ }
+ Expr::Column(col) => self.col_to_sql(col),
+ Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
+ let l = self.expr_to_sql(left.as_ref())?;
+ let r = self.expr_to_sql(right.as_ref())?;
+ let op = self.op_to_sql(op)?;
+
+ Ok(self.binary_op_to_sql(l, r, op))
+ }
+ Expr::Case(Case {
+ expr,
+ when_then_expr: _,
+ else_expr: _,
+ }) => {
+ not_impl_err!("Unsupported expression: {expr:?}")
+ }
+ Expr::Cast(Cast { expr, data_type: _ }) => {
+ not_impl_err!("Unsupported expression: {expr:?}")
+ }
+ Expr::Literal(value) =>
Ok(ast::Expr::Value(self.scalar_to_sql(value)?)),
+ Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr),
+ Expr::WindowFunction(WindowFunction {
+ fun: _,
+ args: _,
+ partition_by: _,
+ order_by: _,
+ window_frame: _,
+ null_treatment: _,
+ }) => {
+ not_impl_err!("Unsupported expression: {expr:?}")
+ }
+ Expr::Like(Like {
+ negated: _,
+ expr,
+ pattern: _,
+ escape_char: _,
+ case_insensitive: _,
+ }) => {
+ not_impl_err!("Unsupported expression: {expr:?}")
+ }
+ _ => not_impl_err!("Unsupported expression: {expr:?}"),
+ }
+ }
+
+ fn col_to_sql(&self, col: &Column) -> Result<ast::Expr> {
+ if let Some(table_ref) = &col.relation {
+ let mut id = table_ref.to_vec();
+ id.push(col.name.to_string());
+ return Ok(ast::Expr::CompoundIdentifier(
+ id.iter().map(|i| self.new_ident(i.to_string())).collect(),
+ ));
+ }
+ Ok(ast::Expr::Identifier(self.new_ident(col.name.to_string())))
+ }
+
+ fn new_ident(&self, str: String) -> ast::Ident {
+ ast::Ident {
+ value: str,
+ quote_style: self.dialect.identifier_quote_style(),
+ }
+ }
+
+ fn binary_op_to_sql(
+ &self,
+ lhs: ast::Expr,
+ rhs: ast::Expr,
+ op: ast::BinaryOperator,
+ ) -> ast::Expr {
+ ast::Expr::BinaryOp {
+ left: Box::new(lhs),
+ op,
+ right: Box::new(rhs),
+ }
+ }
+
+ fn op_to_sql(&self, op: &Operator) -> Result<ast::BinaryOperator> {
+ match op {
+ Operator::Eq => Ok(ast::BinaryOperator::Eq),
+ Operator::NotEq => Ok(ast::BinaryOperator::NotEq),
+ Operator::Lt => Ok(ast::BinaryOperator::Lt),
+ Operator::LtEq => Ok(ast::BinaryOperator::LtEq),
+ Operator::Gt => Ok(ast::BinaryOperator::Gt),
+ Operator::GtEq => Ok(ast::BinaryOperator::GtEq),
+ Operator::Plus => Ok(ast::BinaryOperator::Plus),
+ Operator::Minus => Ok(ast::BinaryOperator::Minus),
+ Operator::Multiply => Ok(ast::BinaryOperator::Multiply),
+ Operator::Divide => Ok(ast::BinaryOperator::Divide),
+ Operator::Modulo => Ok(ast::BinaryOperator::Modulo),
+ Operator::And => Ok(ast::BinaryOperator::And),
+ Operator::Or => Ok(ast::BinaryOperator::Or),
+ Operator::IsDistinctFrom => not_impl_err!("unsupported operation:
{op:?}"),
+ Operator::IsNotDistinctFrom => not_impl_err!("unsupported
operation: {op:?}"),
+ Operator::RegexMatch => Ok(ast::BinaryOperator::PGRegexMatch),
+ Operator::RegexIMatch => Ok(ast::BinaryOperator::PGRegexIMatch),
+ Operator::RegexNotMatch =>
Ok(ast::BinaryOperator::PGRegexNotMatch),
+ Operator::RegexNotIMatch =>
Ok(ast::BinaryOperator::PGRegexNotIMatch),
+ Operator::ILikeMatch => Ok(ast::BinaryOperator::PGILikeMatch),
+ Operator::NotLikeMatch => Ok(ast::BinaryOperator::PGNotLikeMatch),
+ Operator::LikeMatch => Ok(ast::BinaryOperator::PGLikeMatch),
+ Operator::NotILikeMatch =>
Ok(ast::BinaryOperator::PGNotILikeMatch),
+ Operator::BitwiseAnd => Ok(ast::BinaryOperator::BitwiseAnd),
+ Operator::BitwiseOr => Ok(ast::BinaryOperator::BitwiseOr),
+ Operator::BitwiseXor => Ok(ast::BinaryOperator::BitwiseXor),
+ Operator::BitwiseShiftRight =>
Ok(ast::BinaryOperator::PGBitwiseShiftRight),
+ Operator::BitwiseShiftLeft =>
Ok(ast::BinaryOperator::PGBitwiseShiftLeft),
+ Operator::StringConcat => Ok(ast::BinaryOperator::StringConcat),
+ Operator::AtArrow => not_impl_err!("unsupported operation:
{op:?}"),
+ Operator::ArrowAt => not_impl_err!("unsupported operation:
{op:?}"),
+ }
+ }
+
+ fn scalar_to_sql(&self, v: &ScalarValue) -> Result<ast::Value> {
+ match v {
+ ScalarValue::Null => Ok(ast::Value::Null),
+ ScalarValue::Boolean(Some(b)) =>
Ok(ast::Value::Boolean(b.to_owned())),
+ ScalarValue::Boolean(None) => Ok(ast::Value::Null),
+ ScalarValue::Float32(Some(f)) =>
Ok(ast::Value::Number(f.to_string(), false)),
+ ScalarValue::Float32(None) => Ok(ast::Value::Null),
+ ScalarValue::Float64(Some(f)) =>
Ok(ast::Value::Number(f.to_string(), false)),
+ ScalarValue::Float64(None) => Ok(ast::Value::Null),
+ ScalarValue::Decimal128(Some(_), ..) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::Decimal128(None, ..) => Ok(ast::Value::Null),
+ ScalarValue::Decimal256(Some(_), ..) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::Decimal256(None, ..) => Ok(ast::Value::Null),
+ ScalarValue::Int8(Some(i)) => Ok(ast::Value::Number(i.to_string(),
false)),
+ ScalarValue::Int8(None) => Ok(ast::Value::Null),
+ ScalarValue::Int16(Some(i)) =>
Ok(ast::Value::Number(i.to_string(), false)),
+ ScalarValue::Int16(None) => Ok(ast::Value::Null),
+ ScalarValue::Int32(Some(i)) =>
Ok(ast::Value::Number(i.to_string(), false)),
+ ScalarValue::Int32(None) => Ok(ast::Value::Null),
+ ScalarValue::Int64(Some(i)) =>
Ok(ast::Value::Number(i.to_string(), false)),
+ ScalarValue::Int64(None) => Ok(ast::Value::Null),
+ ScalarValue::UInt8(Some(ui)) =>
Ok(ast::Value::Number(ui.to_string(), false)),
+ ScalarValue::UInt8(None) => Ok(ast::Value::Null),
+ ScalarValue::UInt16(Some(ui)) => {
+ Ok(ast::Value::Number(ui.to_string(), false))
+ }
+ ScalarValue::UInt16(None) => Ok(ast::Value::Null),
+ ScalarValue::UInt32(Some(ui)) => {
+ Ok(ast::Value::Number(ui.to_string(), false))
+ }
+ ScalarValue::UInt32(None) => Ok(ast::Value::Null),
+ ScalarValue::UInt64(Some(ui)) => {
+ Ok(ast::Value::Number(ui.to_string(), false))
+ }
+ ScalarValue::UInt64(None) => Ok(ast::Value::Null),
+ ScalarValue::Utf8(Some(str)) => {
+ Ok(ast::Value::SingleQuotedString(str.to_string()))
+ }
+ ScalarValue::Utf8(None) => Ok(ast::Value::Null),
+ ScalarValue::LargeUtf8(Some(str)) => {
+ Ok(ast::Value::SingleQuotedString(str.to_string()))
+ }
+ ScalarValue::LargeUtf8(None) => Ok(ast::Value::Null),
+ ScalarValue::Binary(Some(_)) => not_impl_err!("Unsupported scalar:
{v:?}"),
+ ScalarValue::Binary(None) => Ok(ast::Value::Null),
+ ScalarValue::FixedSizeBinary(..) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::LargeBinary(Some(_)) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::LargeBinary(None) => Ok(ast::Value::Null),
+ ScalarValue::FixedSizeList(_a) => not_impl_err!("Unsupported
scalar: {v:?}"),
+ ScalarValue::List(_a) => not_impl_err!("Unsupported scalar:
{v:?}"),
+ ScalarValue::LargeList(_a) => not_impl_err!("Unsupported scalar:
{v:?}"),
+ ScalarValue::Date32(Some(_d)) => not_impl_err!("Unsupported
scalar: {v:?}"),
+ ScalarValue::Date32(None) => Ok(ast::Value::Null),
+ ScalarValue::Date64(Some(_d)) => not_impl_err!("Unsupported
scalar: {v:?}"),
+ ScalarValue::Date64(None) => Ok(ast::Value::Null),
+ ScalarValue::Time32Second(Some(_t)) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::Time32Second(None) => Ok(ast::Value::Null),
+ ScalarValue::Time32Millisecond(Some(_t)) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::Time32Millisecond(None) => Ok(ast::Value::Null),
+ ScalarValue::Time64Microsecond(Some(_t)) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::Time64Microsecond(None) => Ok(ast::Value::Null),
+ ScalarValue::Time64Nanosecond(Some(_t)) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::Time64Nanosecond(None) => Ok(ast::Value::Null),
+ ScalarValue::TimestampSecond(Some(_ts), _) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::TimestampSecond(None, _) => Ok(ast::Value::Null),
+ ScalarValue::TimestampMillisecond(Some(_ts), _) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::TimestampMillisecond(None, _) => Ok(ast::Value::Null),
+ ScalarValue::TimestampMicrosecond(Some(_ts), _) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::TimestampMicrosecond(None, _) => Ok(ast::Value::Null),
+ ScalarValue::TimestampNanosecond(Some(_ts), _) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::TimestampNanosecond(None, _) => Ok(ast::Value::Null),
+ ScalarValue::IntervalYearMonth(Some(_i)) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::IntervalYearMonth(None) => Ok(ast::Value::Null),
+ ScalarValue::IntervalDayTime(Some(_i)) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::IntervalDayTime(None) => Ok(ast::Value::Null),
+ ScalarValue::IntervalMonthDayNano(Some(_i)) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::IntervalMonthDayNano(None) => Ok(ast::Value::Null),
+ ScalarValue::DurationSecond(Some(_d)) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::DurationSecond(None) => Ok(ast::Value::Null),
+ ScalarValue::DurationMillisecond(Some(_d)) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::DurationMillisecond(None) => Ok(ast::Value::Null),
+ ScalarValue::DurationMicrosecond(Some(_d)) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::DurationMicrosecond(None) => Ok(ast::Value::Null),
+ ScalarValue::DurationNanosecond(Some(_d)) => {
+ not_impl_err!("Unsupported scalar: {v:?}")
+ }
+ ScalarValue::DurationNanosecond(None) => Ok(ast::Value::Null),
+ ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar:
{v:?}"),
+ ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar:
{v:?}"),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use datafusion_common::TableReference;
+ use datafusion_expr::{col, lit};
+
+ use crate::unparser::dialect::CustomDialect;
+
+ use super::*;
+
+ #[test]
+ fn expr_to_sql_ok() -> Result<()> {
+ let tests: Vec<(Expr, &str)> = vec![
+ (col("a").gt(lit(4)), r#"a > 4"#),
+ (
+ Expr::Column(Column {
+ relation: Some(TableReference::partial("a", "b")),
+ name: "c".to_string(),
+ })
+ .gt(lit(4)),
+ r#"a.b.c > 4"#,
+ ),
+ ];
+
+ for (expr, expected) in tests {
+ let ast = expr_to_sql(&expr)?;
+
+ let actual = format!("{}", ast);
+
+ assert_eq!(actual, expected);
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn custom_dialect() -> Result<()> {
+ let dialect = CustomDialect::new(Some('\''));
+ let unparser = Unparser::new(&dialect);
+
+ let expr = col("a").gt(lit(4));
+ let ast = unparser.expr_to_sql(&expr)?;
+
+ let actual = format!("{}", ast);
+
+ let expected = r#"'a' > 4"#;
+ assert_eq!(actual, expected);
+
+ Ok(())
+ }
+}
diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/unparser/mod.rs
similarity index 54%
copy from datafusion/sql/src/lib.rs
copy to datafusion/sql/src/unparser/mod.rs
index d805f61397..77a9de0975 100644
--- a/datafusion/sql/src/lib.rs
+++ b/datafusion/sql/src/unparser/mod.rs
@@ -15,30 +15,27 @@
// specific language governing permissions and limitations
// under the License.
-//! This module provides:
-//!
-//! 1. A SQL parser, [`DFParser`], that translates SQL query text into
-//! an abstract syntax tree (AST), [`Statement`].
-//!
-//! 2. A SQL query planner [`SqlToRel`] that creates [`LogicalPlan`]s
-//! from [`Statement`]s.
-//!
-//! [`DFParser`]: parser::DFParser
-//! [`Statement`]: parser::Statement
-//! [`SqlToRel`]: planner::SqlToRel
-//! [`LogicalPlan`]: datafusion_expr::logical_plan::LogicalPlan
-
mod expr;
-pub mod parser;
-pub mod planner;
-mod query;
-mod relation;
-mod select;
-mod set_expr;
-mod statement;
-pub mod utils;
-mod values;
-pub use datafusion_common::{ResolvedTableReference, TableReference};
-pub use expr::arrow_cast::parse_data_type;
-pub use sqlparser;
+pub use expr::expr_to_sql;
+
+use self::dialect::{DefaultDialect, Dialect};
+pub mod dialect;
+
+pub struct Unparser<'a> {
+ dialect: &'a dyn Dialect,
+}
+
+impl<'a> Unparser<'a> {
+ pub fn new(dialect: &'a dyn Dialect) -> Self {
+ Self { dialect }
+ }
+}
+
+impl<'a> Default for Unparser<'a> {
+ fn default() -> Self {
+ Self {
+ dialect: &DefaultDialect {},
+ }
+ }
+}