This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 93b3d9cbfa Handle alias when parsing sql(parse_sql_expr) (#12939)
93b3d9cbfa is described below

commit 93b3d9cbfa1ba8ed237ca59f686c32b94ee4bc0a
Author: Eason <[email protected]>
AuthorDate: Wed Dec 11 20:21:17 2024 +0800

    Handle alias when parsing sql(parse_sql_expr) (#12939)
    
    * fix: Fix parse_sql_expr not handling alias
    
    * cargo fmt
    
    * fix parse_sql_expr example(remove alias)
    
    * add testing
    
    * add SUM udaf to TestContextProvider and modify 
test_sql_to_expr_with_alias for function
    
    * revert change on example `parse_sql_expr`
---
 datafusion-examples/examples/parse_sql_expr.rs | 10 ++---
 datafusion/core/src/execution/session_state.rs | 21 ++++++---
 datafusion/sql/src/expr/mod.rs                 | 60 ++++++++++++++++++++++++--
 datafusion/sql/src/parser.rs                   |  9 ++--
 4 files changed, 82 insertions(+), 18 deletions(-)

diff --git a/datafusion-examples/examples/parse_sql_expr.rs 
b/datafusion-examples/examples/parse_sql_expr.rs
index e23e5accae..d8f0778e19 100644
--- a/datafusion-examples/examples/parse_sql_expr.rs
+++ b/datafusion-examples/examples/parse_sql_expr.rs
@@ -121,11 +121,11 @@ async fn query_parquet_demo() -> Result<()> {
 
     assert_batches_eq!(
         &[
-            "+------------+----------------------+",
-            "| double_col | sum(?table?.int_col) |",
-            "+------------+----------------------+",
-            "| 10.1       | 4                    |",
-            "+------------+----------------------+",
+            "+------------+-------------+",
+            "| double_col | sum_int_col |",
+            "+------------+-------------+",
+            "| 10.1       | 4           |",
+            "+------------+-------------+",
         ],
         &result
     );
diff --git a/datafusion/core/src/execution/session_state.rs 
b/datafusion/core/src/execution/session_state.rs
index 4ccad5ffd3..cef5d4c1ee 100644
--- a/datafusion/core/src/execution/session_state.rs
+++ b/datafusion/core/src/execution/session_state.rs
@@ -68,7 +68,7 @@ use datafusion_sql::planner::{ContextProvider, ParserOptions, 
PlannerContext, Sq
 use itertools::Itertools;
 use log::{debug, info};
 use object_store::ObjectStore;
-use sqlparser::ast::Expr as SQLExpr;
+use sqlparser::ast::{Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias};
 use sqlparser::dialect::dialect_from_str;
 use std::any::Any;
 use std::collections::hash_map::Entry;
@@ -500,11 +500,22 @@ impl SessionState {
         sql: &str,
         dialect: &str,
     ) -> datafusion_common::Result<SQLExpr> {
+        self.sql_to_expr_with_alias(sql, dialect).map(|x| x.expr)
+    }
+
+    /// parse a sql string into a sqlparser-rs AST [`SQLExprWithAlias`].
+    ///
+    /// See [`Self::create_logical_expr`] for parsing sql to [`Expr`].
+    pub fn sql_to_expr_with_alias(
+        &self,
+        sql: &str,
+        dialect: &str,
+    ) -> datafusion_common::Result<SQLExprWithAlias> {
         let dialect = dialect_from_str(dialect).ok_or_else(|| {
             plan_datafusion_err!(
                 "Unsupported SQL dialect: {dialect}. Available dialects: \
-                     Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, 
Redshift, \
-                     MsSQL, ClickHouse, BigQuery, Ansi."
+                         Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, 
Redshift, \
+                         MsSQL, ClickHouse, BigQuery, Ansi."
             )
         })?;
 
@@ -603,7 +614,7 @@ impl SessionState {
     ) -> datafusion_common::Result<Expr> {
         let dialect = self.config.options().sql_parser.dialect.as_str();
 
-        let sql_expr = self.sql_to_expr(sql, dialect)?;
+        let sql_expr = self.sql_to_expr_with_alias(sql, dialect)?;
 
         let provider = SessionContextProvider {
             state: self,
@@ -611,7 +622,7 @@ impl SessionState {
         };
 
         let query = SqlToRel::new_with_options(&provider, 
self.get_parser_options());
-        query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new())
+        query.sql_to_expr_with_alias(sql_expr, df_schema, &mut 
PlannerContext::new())
     }
 
     /// Returns the [`Analyzer`] for this session
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index 57ac96951f..e8ec8d7b7d 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -23,7 +23,8 @@ use datafusion_expr::planner::{
 use recursive::recursive;
 use sqlparser::ast::{
     BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, 
DictionaryField,
-    Expr as SQLExpr, MapEntry, StructField, Subscript, TrimWhereField, Value,
+    Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, StructField, 
Subscript,
+    TrimWhereField, Value,
 };
 
 use datafusion_common::{
@@ -50,6 +51,19 @@ mod unary_op;
 mod value;
 
 impl<S: ContextProvider> SqlToRel<'_, S> {
+    pub(crate) fn sql_expr_to_logical_expr_with_alias(
+        &self,
+        sql: SQLExprWithAlias,
+        schema: &DFSchema,
+        planner_context: &mut PlannerContext,
+    ) -> Result<Expr> {
+        let mut expr =
+            self.sql_expr_to_logical_expr(sql.expr, schema, planner_context)?;
+        if let Some(alias) = sql.alias {
+            expr = expr.alias(alias.value);
+        }
+        Ok(expr)
+    }
     pub(crate) fn sql_expr_to_logical_expr(
         &self,
         sql: SQLExpr,
@@ -131,6 +145,20 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
         )))
     }
 
+    pub fn sql_to_expr_with_alias(
+        &self,
+        sql: SQLExprWithAlias,
+        schema: &DFSchema,
+        planner_context: &mut PlannerContext,
+    ) -> Result<Expr> {
+        let mut expr =
+            self.sql_expr_to_logical_expr_with_alias(sql, schema, 
planner_context)?;
+        expr = self.rewrite_partial_qualifier(expr, schema);
+        self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?;
+        let (expr, _) = expr.infer_placeholder_types(schema)?;
+        Ok(expr)
+    }
+
     /// Generate a relational expression from a SQL expression
     pub fn sql_to_expr(
         &self,
@@ -1091,8 +1119,11 @@ mod tests {
             None
         }
 
-        fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> 
{
-            None
+        fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
+            match name {
+                "sum" => Some(datafusion_functions_aggregate::sum::sum_udaf()),
+                _ => None,
+            }
         }
 
         fn get_variable_type(&self, _variable_names: &[String]) -> 
Option<DataType> {
@@ -1112,7 +1143,7 @@ mod tests {
         }
 
         fn udaf_names(&self) -> Vec<String> {
-            Vec::new()
+            vec!["sum".to_string()]
         }
 
         fn udwf_names(&self) -> Vec<String> {
@@ -1167,4 +1198,25 @@ mod tests {
     test_stack_overflow!(2048);
     test_stack_overflow!(4096);
     test_stack_overflow!(8192);
+    #[test]
+    fn test_sql_to_expr_with_alias() {
+        let schema = DFSchema::empty();
+        let mut planner_context = PlannerContext::default();
+
+        let expr_str = "SUM(int_col) as sum_int_col";
+
+        let dialect = GenericDialect {};
+        let mut parser = Parser::new(&dialect).try_with_sql(expr_str).unwrap();
+        // from sqlparser
+        let sql_expr = parser.parse_expr_with_alias().unwrap();
+
+        let context_provider = TestContextProvider::new();
+        let sql_to_rel = SqlToRel::new(&context_provider);
+
+        let expr = sql_to_rel
+            .sql_expr_to_logical_expr_with_alias(sql_expr, &schema, &mut 
planner_context)
+            .unwrap();
+
+        assert!(matches!(expr, Expr::Alias(_)));
+    }
 }
diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs
index bd1ed3145e..efec602064 100644
--- a/datafusion/sql/src/parser.rs
+++ b/datafusion/sql/src/parser.rs
@@ -20,9 +20,10 @@
 use std::collections::VecDeque;
 use std::fmt;
 
+use sqlparser::ast::ExprWithAlias;
 use sqlparser::{
     ast::{
-        ColumnDef, ColumnOptionDef, Expr, ObjectName, OrderByExpr, Query,
+        ColumnDef, ColumnOptionDef, ObjectName, OrderByExpr, Query,
         Statement as SQLStatement, TableConstraint, Value,
     },
     dialect::{keywords::Keyword, Dialect, GenericDialect},
@@ -328,7 +329,7 @@ impl<'a> DFParser<'a> {
     pub fn parse_sql_into_expr_with_dialect(
         sql: &str,
         dialect: &dyn Dialect,
-    ) -> Result<Expr, ParserError> {
+    ) -> Result<ExprWithAlias, ParserError> {
         let mut parser = DFParser::new_with_dialect(sql, dialect)?;
         parser.parse_expr()
     }
@@ -377,7 +378,7 @@ impl<'a> DFParser<'a> {
         }
     }
 
-    pub fn parse_expr(&mut self) -> Result<Expr, ParserError> {
+    pub fn parse_expr(&mut self) -> Result<ExprWithAlias, ParserError> {
         if let Token::Word(w) = self.parser.peek_token().token {
             match w.keyword {
                 Keyword::CREATE | Keyword::COPY | Keyword::EXPLAIN => {
@@ -387,7 +388,7 @@ impl<'a> DFParser<'a> {
             }
         }
 
-        self.parser.parse_expr()
+        self.parser.parse_expr_with_alias()
     }
 
     /// Parse a SQL `COPY TO` statement


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to