This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 02f7e1f713 feat: Introduce convert Expr to SQL string API and basic 
feature (#9517)
02f7e1f713 is described below

commit 02f7e1f7132d6bf938724accc8aabccffb92f476
Author: Michiel De Backker <[email protected]>
AuthorDate: Mon Mar 11 22:03:34 2024 +0100

    feat: Introduce convert Expr to SQL string API and basic feature (#9517)
    
    * feat: convert Expr to SQL string
    
    * fix: add license headers
    
    * fix: make Unparser and Dialect public
    
    * Update datafusion/sql/src/unparser/dialect.rs
    
    * fmt
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/sql/src/lib.rs                      |   1 +
 datafusion/sql/src/unparser/dialect.rs         |  73 +++++
 datafusion/sql/src/unparser/expr.rs            | 355 +++++++++++++++++++++++++
 datafusion/sql/src/{lib.rs => unparser/mod.rs} |  47 ++--
 4 files changed, 451 insertions(+), 25 deletions(-)

diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs
index d805f61397..da66ee197a 100644
--- a/datafusion/sql/src/lib.rs
+++ b/datafusion/sql/src/lib.rs
@@ -36,6 +36,7 @@ mod relation;
 mod select;
 mod set_expr;
 mod statement;
+pub mod unparser;
 pub mod utils;
 mod values;
 
diff --git a/datafusion/sql/src/unparser/dialect.rs 
b/datafusion/sql/src/unparser/dialect.rs
new file mode 100644
index 0000000000..3af33ad0af
--- /dev/null
+++ b/datafusion/sql/src/unparser/dialect.rs
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Dialect is used to capture dialect specific syntax.
+/// Note: this trait will eventually be replaced by the Dialect in the 
SQLparser package
+///
+/// See <https://github.com/sqlparser-rs/sqlparser-rs/pull/1170>
+pub trait Dialect {
+    fn identifier_quote_style(&self) -> Option<char>;
+}
+pub struct DefaultDialect {}
+
+impl Dialect for DefaultDialect {
+    fn identifier_quote_style(&self) -> Option<char> {
+        None
+    }
+}
+
+pub struct PostgreSqlDialect {}
+
+impl Dialect for PostgreSqlDialect {
+    fn identifier_quote_style(&self) -> Option<char> {
+        Some('"')
+    }
+}
+
+pub struct MySqlDialect {}
+
+impl Dialect for MySqlDialect {
+    fn identifier_quote_style(&self) -> Option<char> {
+        Some('`')
+    }
+}
+
+pub struct SqliteDialect {}
+
+impl Dialect for SqliteDialect {
+    fn identifier_quote_style(&self) -> Option<char> {
+        Some('`')
+    }
+}
+
+pub struct CustomDialect {
+    identifier_quote_style: Option<char>,
+}
+
+impl CustomDialect {
+    pub fn new(identifier_quote_style: Option<char>) -> Self {
+        Self {
+            identifier_quote_style,
+        }
+    }
+}
+
+impl Dialect for CustomDialect {
+    fn identifier_quote_style(&self) -> Option<char> {
+        self.identifier_quote_style
+    }
+}
diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
new file mode 100644
index 0000000000..bb14c8a707
--- /dev/null
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -0,0 +1,355 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_common::{not_impl_err, Column, Result, ScalarValue};
+use datafusion_expr::{
+    expr::{Alias, InList, ScalarFunction, WindowFunction},
+    Between, BinaryExpr, Case, Cast, Expr, Like, Operator,
+};
+use sqlparser::ast;
+
+use super::Unparser;
+
+/// Convert a DataFusion [`Expr`] to `sqlparser::ast::Expr`
+///
+/// This function is the opposite of `SqlToRel::sql_to_expr` and can
+/// be used to, among other things, convert `Expr`s to strings.
+///
+/// # Example
+/// ```
+/// use datafusion_expr::{col, lit};
+/// use datafusion_sql::unparser::expr_to_sql;
+/// let expr = col("a").gt(lit(4));
+/// let sql = expr_to_sql(&expr).unwrap();
+///
+/// assert_eq!(format!("{}", sql), "a > 4")
+/// ```
+pub fn expr_to_sql(expr: &Expr) -> Result<ast::Expr> {
+    let unparser = Unparser::default();
+    unparser.expr_to_sql(expr)
+}
+
+impl Unparser<'_> {
+    pub fn expr_to_sql(&self, expr: &Expr) -> Result<ast::Expr> {
+        match expr {
+            Expr::InList(InList {
+                expr,
+                list: _,
+                negated: _,
+            }) => {
+                not_impl_err!("Unsupported expression: {expr:?}")
+            }
+            Expr::ScalarFunction(ScalarFunction { .. }) => {
+                not_impl_err!("Unsupported expression: {expr:?}")
+            }
+            Expr::Between(Between {
+                expr,
+                negated: _,
+                low: _,
+                high: _,
+            }) => {
+                not_impl_err!("Unsupported expression: {expr:?}")
+            }
+            Expr::Column(col) => self.col_to_sql(col),
+            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
+                let l = self.expr_to_sql(left.as_ref())?;
+                let r = self.expr_to_sql(right.as_ref())?;
+                let op = self.op_to_sql(op)?;
+
+                Ok(self.binary_op_to_sql(l, r, op))
+            }
+            Expr::Case(Case {
+                expr,
+                when_then_expr: _,
+                else_expr: _,
+            }) => {
+                not_impl_err!("Unsupported expression: {expr:?}")
+            }
+            Expr::Cast(Cast { expr, data_type: _ }) => {
+                not_impl_err!("Unsupported expression: {expr:?}")
+            }
+            Expr::Literal(value) => 
Ok(ast::Expr::Value(self.scalar_to_sql(value)?)),
+            Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr),
+            Expr::WindowFunction(WindowFunction {
+                fun: _,
+                args: _,
+                partition_by: _,
+                order_by: _,
+                window_frame: _,
+                null_treatment: _,
+            }) => {
+                not_impl_err!("Unsupported expression: {expr:?}")
+            }
+            Expr::Like(Like {
+                negated: _,
+                expr,
+                pattern: _,
+                escape_char: _,
+                case_insensitive: _,
+            }) => {
+                not_impl_err!("Unsupported expression: {expr:?}")
+            }
+            _ => not_impl_err!("Unsupported expression: {expr:?}"),
+        }
+    }
+
+    fn col_to_sql(&self, col: &Column) -> Result<ast::Expr> {
+        if let Some(table_ref) = &col.relation {
+            let mut id = table_ref.to_vec();
+            id.push(col.name.to_string());
+            return Ok(ast::Expr::CompoundIdentifier(
+                id.iter().map(|i| self.new_ident(i.to_string())).collect(),
+            ));
+        }
+        Ok(ast::Expr::Identifier(self.new_ident(col.name.to_string())))
+    }
+
+    fn new_ident(&self, str: String) -> ast::Ident {
+        ast::Ident {
+            value: str,
+            quote_style: self.dialect.identifier_quote_style(),
+        }
+    }
+
+    fn binary_op_to_sql(
+        &self,
+        lhs: ast::Expr,
+        rhs: ast::Expr,
+        op: ast::BinaryOperator,
+    ) -> ast::Expr {
+        ast::Expr::BinaryOp {
+            left: Box::new(lhs),
+            op,
+            right: Box::new(rhs),
+        }
+    }
+
+    fn op_to_sql(&self, op: &Operator) -> Result<ast::BinaryOperator> {
+        match op {
+            Operator::Eq => Ok(ast::BinaryOperator::Eq),
+            Operator::NotEq => Ok(ast::BinaryOperator::NotEq),
+            Operator::Lt => Ok(ast::BinaryOperator::Lt),
+            Operator::LtEq => Ok(ast::BinaryOperator::LtEq),
+            Operator::Gt => Ok(ast::BinaryOperator::Gt),
+            Operator::GtEq => Ok(ast::BinaryOperator::GtEq),
+            Operator::Plus => Ok(ast::BinaryOperator::Plus),
+            Operator::Minus => Ok(ast::BinaryOperator::Minus),
+            Operator::Multiply => Ok(ast::BinaryOperator::Multiply),
+            Operator::Divide => Ok(ast::BinaryOperator::Divide),
+            Operator::Modulo => Ok(ast::BinaryOperator::Modulo),
+            Operator::And => Ok(ast::BinaryOperator::And),
+            Operator::Or => Ok(ast::BinaryOperator::Or),
+            Operator::IsDistinctFrom => not_impl_err!("unsupported operation: 
{op:?}"),
+            Operator::IsNotDistinctFrom => not_impl_err!("unsupported 
operation: {op:?}"),
+            Operator::RegexMatch => Ok(ast::BinaryOperator::PGRegexMatch),
+            Operator::RegexIMatch => Ok(ast::BinaryOperator::PGRegexIMatch),
+            Operator::RegexNotMatch => 
Ok(ast::BinaryOperator::PGRegexNotMatch),
+            Operator::RegexNotIMatch => 
Ok(ast::BinaryOperator::PGRegexNotIMatch),
+            Operator::ILikeMatch => Ok(ast::BinaryOperator::PGILikeMatch),
+            Operator::NotLikeMatch => Ok(ast::BinaryOperator::PGNotLikeMatch),
+            Operator::LikeMatch => Ok(ast::BinaryOperator::PGLikeMatch),
+            Operator::NotILikeMatch => 
Ok(ast::BinaryOperator::PGNotILikeMatch),
+            Operator::BitwiseAnd => Ok(ast::BinaryOperator::BitwiseAnd),
+            Operator::BitwiseOr => Ok(ast::BinaryOperator::BitwiseOr),
+            Operator::BitwiseXor => Ok(ast::BinaryOperator::BitwiseXor),
+            Operator::BitwiseShiftRight => 
Ok(ast::BinaryOperator::PGBitwiseShiftRight),
+            Operator::BitwiseShiftLeft => 
Ok(ast::BinaryOperator::PGBitwiseShiftLeft),
+            Operator::StringConcat => Ok(ast::BinaryOperator::StringConcat),
+            Operator::AtArrow => not_impl_err!("unsupported operation: 
{op:?}"),
+            Operator::ArrowAt => not_impl_err!("unsupported operation: 
{op:?}"),
+        }
+    }
+
+    fn scalar_to_sql(&self, v: &ScalarValue) -> Result<ast::Value> {
+        match v {
+            ScalarValue::Null => Ok(ast::Value::Null),
+            ScalarValue::Boolean(Some(b)) => 
Ok(ast::Value::Boolean(b.to_owned())),
+            ScalarValue::Boolean(None) => Ok(ast::Value::Null),
+            ScalarValue::Float32(Some(f)) => 
Ok(ast::Value::Number(f.to_string(), false)),
+            ScalarValue::Float32(None) => Ok(ast::Value::Null),
+            ScalarValue::Float64(Some(f)) => 
Ok(ast::Value::Number(f.to_string(), false)),
+            ScalarValue::Float64(None) => Ok(ast::Value::Null),
+            ScalarValue::Decimal128(Some(_), ..) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::Decimal128(None, ..) => Ok(ast::Value::Null),
+            ScalarValue::Decimal256(Some(_), ..) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::Decimal256(None, ..) => Ok(ast::Value::Null),
+            ScalarValue::Int8(Some(i)) => Ok(ast::Value::Number(i.to_string(), 
false)),
+            ScalarValue::Int8(None) => Ok(ast::Value::Null),
+            ScalarValue::Int16(Some(i)) => 
Ok(ast::Value::Number(i.to_string(), false)),
+            ScalarValue::Int16(None) => Ok(ast::Value::Null),
+            ScalarValue::Int32(Some(i)) => 
Ok(ast::Value::Number(i.to_string(), false)),
+            ScalarValue::Int32(None) => Ok(ast::Value::Null),
+            ScalarValue::Int64(Some(i)) => 
Ok(ast::Value::Number(i.to_string(), false)),
+            ScalarValue::Int64(None) => Ok(ast::Value::Null),
+            ScalarValue::UInt8(Some(ui)) => 
Ok(ast::Value::Number(ui.to_string(), false)),
+            ScalarValue::UInt8(None) => Ok(ast::Value::Null),
+            ScalarValue::UInt16(Some(ui)) => {
+                Ok(ast::Value::Number(ui.to_string(), false))
+            }
+            ScalarValue::UInt16(None) => Ok(ast::Value::Null),
+            ScalarValue::UInt32(Some(ui)) => {
+                Ok(ast::Value::Number(ui.to_string(), false))
+            }
+            ScalarValue::UInt32(None) => Ok(ast::Value::Null),
+            ScalarValue::UInt64(Some(ui)) => {
+                Ok(ast::Value::Number(ui.to_string(), false))
+            }
+            ScalarValue::UInt64(None) => Ok(ast::Value::Null),
+            ScalarValue::Utf8(Some(str)) => {
+                Ok(ast::Value::SingleQuotedString(str.to_string()))
+            }
+            ScalarValue::Utf8(None) => Ok(ast::Value::Null),
+            ScalarValue::LargeUtf8(Some(str)) => {
+                Ok(ast::Value::SingleQuotedString(str.to_string()))
+            }
+            ScalarValue::LargeUtf8(None) => Ok(ast::Value::Null),
+            ScalarValue::Binary(Some(_)) => not_impl_err!("Unsupported scalar: 
{v:?}"),
+            ScalarValue::Binary(None) => Ok(ast::Value::Null),
+            ScalarValue::FixedSizeBinary(..) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::LargeBinary(Some(_)) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::LargeBinary(None) => Ok(ast::Value::Null),
+            ScalarValue::FixedSizeList(_a) => not_impl_err!("Unsupported 
scalar: {v:?}"),
+            ScalarValue::List(_a) => not_impl_err!("Unsupported scalar: 
{v:?}"),
+            ScalarValue::LargeList(_a) => not_impl_err!("Unsupported scalar: 
{v:?}"),
+            ScalarValue::Date32(Some(_d)) => not_impl_err!("Unsupported 
scalar: {v:?}"),
+            ScalarValue::Date32(None) => Ok(ast::Value::Null),
+            ScalarValue::Date64(Some(_d)) => not_impl_err!("Unsupported 
scalar: {v:?}"),
+            ScalarValue::Date64(None) => Ok(ast::Value::Null),
+            ScalarValue::Time32Second(Some(_t)) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::Time32Second(None) => Ok(ast::Value::Null),
+            ScalarValue::Time32Millisecond(Some(_t)) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::Time32Millisecond(None) => Ok(ast::Value::Null),
+            ScalarValue::Time64Microsecond(Some(_t)) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::Time64Microsecond(None) => Ok(ast::Value::Null),
+            ScalarValue::Time64Nanosecond(Some(_t)) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::Time64Nanosecond(None) => Ok(ast::Value::Null),
+            ScalarValue::TimestampSecond(Some(_ts), _) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::TimestampSecond(None, _) => Ok(ast::Value::Null),
+            ScalarValue::TimestampMillisecond(Some(_ts), _) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::TimestampMillisecond(None, _) => Ok(ast::Value::Null),
+            ScalarValue::TimestampMicrosecond(Some(_ts), _) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::TimestampMicrosecond(None, _) => Ok(ast::Value::Null),
+            ScalarValue::TimestampNanosecond(Some(_ts), _) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::TimestampNanosecond(None, _) => Ok(ast::Value::Null),
+            ScalarValue::IntervalYearMonth(Some(_i)) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::IntervalYearMonth(None) => Ok(ast::Value::Null),
+            ScalarValue::IntervalDayTime(Some(_i)) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::IntervalDayTime(None) => Ok(ast::Value::Null),
+            ScalarValue::IntervalMonthDayNano(Some(_i)) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::IntervalMonthDayNano(None) => Ok(ast::Value::Null),
+            ScalarValue::DurationSecond(Some(_d)) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::DurationSecond(None) => Ok(ast::Value::Null),
+            ScalarValue::DurationMillisecond(Some(_d)) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::DurationMillisecond(None) => Ok(ast::Value::Null),
+            ScalarValue::DurationMicrosecond(Some(_d)) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::DurationMicrosecond(None) => Ok(ast::Value::Null),
+            ScalarValue::DurationNanosecond(Some(_d)) => {
+                not_impl_err!("Unsupported scalar: {v:?}")
+            }
+            ScalarValue::DurationNanosecond(None) => Ok(ast::Value::Null),
+            ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: 
{v:?}"),
+            ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: 
{v:?}"),
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use datafusion_common::TableReference;
+    use datafusion_expr::{col, lit};
+
+    use crate::unparser::dialect::CustomDialect;
+
+    use super::*;
+
+    #[test]
+    fn expr_to_sql_ok() -> Result<()> {
+        let tests: Vec<(Expr, &str)> = vec![
+            (col("a").gt(lit(4)), r#"a > 4"#),
+            (
+                Expr::Column(Column {
+                    relation: Some(TableReference::partial("a", "b")),
+                    name: "c".to_string(),
+                })
+                .gt(lit(4)),
+                r#"a.b.c > 4"#,
+            ),
+        ];
+
+        for (expr, expected) in tests {
+            let ast = expr_to_sql(&expr)?;
+
+            let actual = format!("{}", ast);
+
+            assert_eq!(actual, expected);
+        }
+
+        Ok(())
+    }
+
+    #[test]
+    fn custom_dialect() -> Result<()> {
+        let dialect = CustomDialect::new(Some('\''));
+        let unparser = Unparser::new(&dialect);
+
+        let expr = col("a").gt(lit(4));
+        let ast = unparser.expr_to_sql(&expr)?;
+
+        let actual = format!("{}", ast);
+
+        let expected = r#"'a' > 4"#;
+        assert_eq!(actual, expected);
+
+        Ok(())
+    }
+}
diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/unparser/mod.rs
similarity index 54%
copy from datafusion/sql/src/lib.rs
copy to datafusion/sql/src/unparser/mod.rs
index d805f61397..77a9de0975 100644
--- a/datafusion/sql/src/lib.rs
+++ b/datafusion/sql/src/unparser/mod.rs
@@ -15,30 +15,27 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! This module provides:
-//!
-//! 1. A SQL parser, [`DFParser`], that translates SQL query text into
-//! an abstract syntax tree (AST), [`Statement`].
-//!
-//! 2. A SQL query planner [`SqlToRel`] that creates [`LogicalPlan`]s
-//! from [`Statement`]s.
-//!
-//! [`DFParser`]: parser::DFParser
-//! [`Statement`]: parser::Statement
-//! [`SqlToRel`]: planner::SqlToRel
-//! [`LogicalPlan`]: datafusion_expr::logical_plan::LogicalPlan
-
 mod expr;
-pub mod parser;
-pub mod planner;
-mod query;
-mod relation;
-mod select;
-mod set_expr;
-mod statement;
-pub mod utils;
-mod values;
 
-pub use datafusion_common::{ResolvedTableReference, TableReference};
-pub use expr::arrow_cast::parse_data_type;
-pub use sqlparser;
+pub use expr::expr_to_sql;
+
+use self::dialect::{DefaultDialect, Dialect};
+pub mod dialect;
+
+pub struct Unparser<'a> {
+    dialect: &'a dyn Dialect,
+}
+
+impl<'a> Unparser<'a> {
+    pub fn new(dialect: &'a dyn Dialect) -> Self {
+        Self { dialect }
+    }
+}
+
+impl<'a> Default for Unparser<'a> {
+    fn default() -> Self {
+        Self {
+            dialect: &DefaultDialect {},
+        }
+    }
+}

Reply via email to