This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ecf5323eaa Fix unparser invalid sql for query with order (#11527)
ecf5323eaa is described below

commit ecf5323eaa38869ed2f911b02f98e17aa6db639a
Author: yfu <[email protected]>
AuthorDate: Mon Jul 22 21:04:36 2024 +1000

    Fix unparser invalid sql for query with order (#11527)
    
    * wip
    
    * fix wrong unparsed query for original query that has derived table with 
limit/sort/distinct; fix wrong unparsed query for original query with sort 
column that is not in select
    
    * clippy
    
    * addressed the comments, also fix one issue when selected column is 
aliased - see test
---
 datafusion/sql/src/unparser/plan.rs       | 67 ++++++++++++---------
 datafusion/sql/src/unparser/rewrite.rs    | 80 ++++++++++++++++++++++++-
 datafusion/sql/tests/cases/plan_to_sql.rs | 98 +++++++++++++++++++++++++++++++
 3 files changed, 215 insertions(+), 30 deletions(-)

diff --git a/datafusion/sql/src/unparser/plan.rs 
b/datafusion/sql/src/unparser/plan.rs
index 7f050d8a06..59660f4f04 100644
--- a/datafusion/sql/src/unparser/plan.rs
+++ b/datafusion/sql/src/unparser/plan.rs
@@ -29,6 +29,7 @@ use super::{
         SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
     },
     rewrite::normalize_union_schema,
+    rewrite::rewrite_plan_for_sort_on_non_projected_fields,
     utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
     Unparser,
 };
@@ -199,33 +200,21 @@ impl Unparser<'_> {
         Ok(())
     }
 
-    fn projection_to_sql(
-        &self,
-        plan: &LogicalPlan,
-        p: &Projection,
-        query: &mut Option<QueryBuilder>,
-        select: &mut SelectBuilder,
-        relation: &mut RelationBuilder,
-    ) -> Result<()> {
-        // A second projection implies a derived tablefactor
-        if !select.already_projected() {
-            self.reconstruct_select_statement(plan, p, select)?;
-            self.select_to_sql_recursively(p.input.as_ref(), query, select, 
relation)
-        } else {
-            let mut derived_builder = DerivedRelationBuilder::default();
-            derived_builder.lateral(false).alias(None).subquery({
-                let inner_statement = self.plan_to_sql(plan)?;
-                if let ast::Statement::Query(inner_query) = inner_statement {
-                    inner_query
-                } else {
-                    return internal_err!(
-                        "Subquery must be a Query, but found 
{inner_statement:?}"
-                    );
-                }
-            });
-            relation.derived(derived_builder);
-            Ok(())
-        }
+    fn derive(&self, plan: &LogicalPlan, relation: &mut RelationBuilder) -> 
Result<()> {
+        let mut derived_builder = DerivedRelationBuilder::default();
+        derived_builder.lateral(false).alias(None).subquery({
+            let inner_statement = self.plan_to_sql(plan)?;
+            if let ast::Statement::Query(inner_query) = inner_statement {
+                inner_query
+            } else {
+                return internal_err!(
+                    "Subquery must be a Query, but found {inner_statement:?}"
+                );
+            }
+        });
+        relation.derived(derived_builder);
+
+        Ok(())
     }
 
     fn select_to_sql_recursively(
@@ -256,7 +245,17 @@ impl Unparser<'_> {
                 Ok(())
             }
             LogicalPlan::Projection(p) => {
-                self.projection_to_sql(plan, p, query, select, relation)
+                if let Some(new_plan) = 
rewrite_plan_for_sort_on_non_projected_fields(p) {
+                    return self
+                        .select_to_sql_recursively(&new_plan, query, select, 
relation);
+                }
+
+                // Projection can be top-level plan for derived table
+                if select.already_projected() {
+                    return self.derive(plan, relation);
+                }
+                self.reconstruct_select_statement(plan, p, select)?;
+                self.select_to_sql_recursively(p.input.as_ref(), query, 
select, relation)
             }
             LogicalPlan::Filter(filter) => {
                 if let Some(AggVariant::Aggregate(agg)) =
@@ -278,6 +277,10 @@ impl Unparser<'_> {
                 )
             }
             LogicalPlan::Limit(limit) => {
+                // Limit can be top-level plan for derived table
+                if select.already_projected() {
+                    return self.derive(plan, relation);
+                }
                 if let Some(fetch) = limit.fetch {
                     let Some(query) = query.as_mut() else {
                         return internal_err!(
@@ -298,6 +301,10 @@ impl Unparser<'_> {
                 )
             }
             LogicalPlan::Sort(sort) => {
+                // Sort can be top-level plan for derived table
+                if select.already_projected() {
+                    return self.derive(plan, relation);
+                }
                 if let Some(query_ref) = query {
                     query_ref.order_by(self.sort_to_sql(sort.expr.clone())?);
                 } else {
@@ -323,6 +330,10 @@ impl Unparser<'_> {
                 )
             }
             LogicalPlan::Distinct(distinct) => {
+                // Distinct can be top-level plan for derived table
+                if select.already_projected() {
+                    return self.derive(plan, relation);
+                }
                 let (select_distinct, input) = match distinct {
                     Distinct::All(input) => (ast::Distinct::Distinct, 
input.as_ref()),
                     Distinct::On(on) => {
diff --git a/datafusion/sql/src/unparser/rewrite.rs 
b/datafusion/sql/src/unparser/rewrite.rs
index a73fce30ce..fba95ad48f 100644
--- a/datafusion/sql/src/unparser/rewrite.rs
+++ b/datafusion/sql/src/unparser/rewrite.rs
@@ -15,13 +15,16 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::sync::Arc;
+use std::{
+    collections::{HashMap, HashSet},
+    sync::Arc,
+};
 
 use datafusion_common::{
     tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeIterator},
     Result,
 };
-use datafusion_expr::{Expr, LogicalPlan, Sort};
+use datafusion_expr::{Expr, LogicalPlan, Projection, Sort};
 
 /// Normalize the schema of a union plan to remove qualifiers from the schema 
fields and sort expressions.
 ///
@@ -99,3 +102,76 @@ fn rewrite_sort_expr_for_union(exprs: Vec<Expr>) -> 
Result<Vec<Expr>> {
 
     Ok(sort_exprs)
 }
+
+// Rewrite logic plan for query that order by columns are not in projections
+// Plan before rewrite:
+//
+// Projection: j1.j1_string, j2.j2_string
+//   Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST
+//     Projection: j1.j1_string, j2.j2_string, j1.j1_id, j2.j2_id
+//       Inner Join:  Filter: j1.j1_id = j2.j2_id
+//         TableScan: j1
+//         TableScan: j2
+//
+// Plan after rewrite
+//
+// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST
+//   Projection: j1.j1_string, j2.j2_string
+//     Inner Join:  Filter: j1.j1_id = j2.j2_id
+//       TableScan: j1
+//       TableScan: j2
+//
+// This prevents the original plan generate query with derived table but 
missing alias.
+pub(super) fn rewrite_plan_for_sort_on_non_projected_fields(
+    p: &Projection,
+) -> Option<LogicalPlan> {
+    let LogicalPlan::Sort(sort) = p.input.as_ref() else {
+        return None;
+    };
+
+    let LogicalPlan::Projection(inner_p) = sort.input.as_ref() else {
+        return None;
+    };
+
+    let mut map = HashMap::new();
+    let inner_exprs = inner_p
+        .expr
+        .iter()
+        .map(|f| {
+            if let Expr::Alias(alias) = f {
+                let a = Expr::Column(alias.name.clone().into());
+                map.insert(a.clone(), f.clone());
+                a
+            } else {
+                f.clone()
+            }
+        })
+        .collect::<Vec<_>>();
+
+    let mut collects = p.expr.clone();
+    for expr in &sort.expr {
+        if let Expr::Sort(s) = expr {
+            collects.push(s.expr.as_ref().clone());
+        }
+    }
+
+    if collects.iter().collect::<HashSet<_>>()
+        == inner_exprs.iter().collect::<HashSet<_>>()
+    {
+        let mut sort = sort.clone();
+        let mut inner_p = inner_p.clone();
+
+        let new_exprs = p
+            .expr
+            .iter()
+            .map(|e| map.get(e).unwrap_or(e).clone())
+            .collect::<Vec<_>>();
+
+        inner_p.expr.clone_from(&new_exprs);
+        sort.input = Arc::new(LogicalPlan::Projection(inner_p));
+
+        Some(LogicalPlan::Sort(sort))
+    } else {
+        None
+    }
+}
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs 
b/datafusion/sql/tests/cases/plan_to_sql.rs
index e9c4114353..aada560fd8 100644
--- a/datafusion/sql/tests/cases/plan_to_sql.rs
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -244,6 +244,50 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
             parser_dialect: Box::new(GenericDialect {}),
             unparser_dialect: Box::new(UnparserDefaultDialect {}),
         },
+        // Test query with derived tables that put distinct,sort,limit on the 
wrong level
+        TestStatementWithDialect {
+            sql: "SELECT j1_string from j1 order by j1_id",
+            expected: r#"SELECT j1.j1_string FROM j1 ORDER BY j1.j1_id ASC 
NULLS LAST"#,
+            parser_dialect: Box::new(GenericDialect {}),
+            unparser_dialect: Box::new(UnparserDefaultDialect {}),
+        },
+        TestStatementWithDialect {
+            sql: "SELECT j1_string AS a from j1 order by j1_id",
+            expected: r#"SELECT j1.j1_string AS a FROM j1 ORDER BY j1.j1_id 
ASC NULLS LAST"#,
+            parser_dialect: Box::new(GenericDialect {}),
+            unparser_dialect: Box::new(UnparserDefaultDialect {}),
+        },
+        TestStatementWithDialect {
+            sql: "SELECT j1_string from j1 join j2 on j1.j1_id = j2.j2_id 
order by j1_id",
+            expected: r#"SELECT j1.j1_string FROM j1 JOIN j2 ON (j1.j1_id = 
j2.j2_id) ORDER BY j1.j1_id ASC NULLS LAST"#,
+            parser_dialect: Box::new(GenericDialect {}),
+            unparser_dialect: Box::new(UnparserDefaultDialect {}),
+        },
+        TestStatementWithDialect {
+            sql: "
+                SELECT
+                  j1_string,
+                  j2_string
+                FROM
+                  (
+                    SELECT
+                      distinct j1_id,
+                      j1_string,
+                      j2_string
+                    from
+                      j1
+                      INNER join j2 ON j1.j1_id = j2.j2_id
+                    order by
+                      j1.j1_id desc
+                    limit
+                      10
+                  ) abc
+                ORDER BY
+                  abc.j2_string",
+            expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT 
DISTINCT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = 
j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY 
abc.j2_string ASC NULLS LAST"#,
+            parser_dialect: Box::new(GenericDialect {}),
+            unparser_dialect: Box::new(UnparserDefaultDialect {}),
+        },
         // more tests around subquery/derived table roundtrip
         TestStatementWithDialect {
             sql: "SELECT string_count FROM (
@@ -261,6 +305,60 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
             parser_dialect: Box::new(GenericDialect {}),
             unparser_dialect: Box::new(UnparserDefaultDialect {}),
         },
+        TestStatementWithDialect {
+            sql: "
+                SELECT
+                  j1_string,
+                  j2_string
+                FROM
+                  (
+                    SELECT
+                      j1_id,
+                      j1_string,
+                      j2_string
+                    from
+                      j1
+                      INNER join j2 ON j1.j1_id = j2.j2_id
+                    group by
+                      j1_id,
+                      j1_string,
+                      j2_string
+                    order by
+                      j1.j1_id desc
+                    limit
+                      10
+                  ) abc
+                ORDER BY
+                  abc.j2_string",
+            expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT 
j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) 
GROUP BY j1.j1_id, j1.j1_string, j2.j2_string ORDER BY j1.j1_id DESC NULLS 
FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#,
+            parser_dialect: Box::new(GenericDialect {}),
+            unparser_dialect: Box::new(UnparserDefaultDialect {}),
+        },
+        // Test query that order by columns are not in select columns
+        TestStatementWithDialect {
+            sql: "
+                SELECT
+                  j1_string
+                FROM
+                  (
+                    SELECT
+                      j1_string,
+                      j2_string
+                    from
+                      j1
+                      INNER join j2 ON j1.j1_id = j2.j2_id
+                    order by
+                      j1.j1_id desc,
+                      j2.j2_id desc
+                    limit
+                      10
+                  ) abc
+                ORDER BY
+                  j2_string",
+            expected: r#"SELECT abc.j1_string FROM (SELECT j1.j1_string, 
j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC 
NULLS FIRST, j2.j2_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string 
ASC NULLS LAST"#,
+            parser_dialect: Box::new(GenericDialect {}),
+            unparser_dialect: Box::new(UnparserDefaultDialect {}),
+        },
         TestStatementWithDialect {
             sql: "SELECT id FROM (SELECT j1_id from j1) AS c (id)",
             expected: r#"SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c 
(id)"#,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to