This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ecf5323eaa Fix unparser invalid sql for query with order (#11527)
ecf5323eaa is described below
commit ecf5323eaa38869ed2f911b02f98e17aa6db639a
Author: yfu <[email protected]>
AuthorDate: Mon Jul 22 21:04:36 2024 +1000
Fix unparser invalid sql for query with order (#11527)
* wip
* fix wrong unparsed query for original query that has derived table with
limit/sort/distinct; fix wrong unparsed query for original query with sort
column that is not in select
* clippy
* addressed the comments, also fix one issue when selected column is
aliased - see test
---
datafusion/sql/src/unparser/plan.rs | 67 ++++++++++++---------
datafusion/sql/src/unparser/rewrite.rs | 80 ++++++++++++++++++++++++-
datafusion/sql/tests/cases/plan_to_sql.rs | 98 +++++++++++++++++++++++++++++++
3 files changed, 215 insertions(+), 30 deletions(-)
diff --git a/datafusion/sql/src/unparser/plan.rs
b/datafusion/sql/src/unparser/plan.rs
index 7f050d8a06..59660f4f04 100644
--- a/datafusion/sql/src/unparser/plan.rs
+++ b/datafusion/sql/src/unparser/plan.rs
@@ -29,6 +29,7 @@ use super::{
SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
},
rewrite::normalize_union_schema,
+ rewrite::rewrite_plan_for_sort_on_non_projected_fields,
utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
Unparser,
};
@@ -199,33 +200,21 @@ impl Unparser<'_> {
Ok(())
}
- fn projection_to_sql(
- &self,
- plan: &LogicalPlan,
- p: &Projection,
- query: &mut Option<QueryBuilder>,
- select: &mut SelectBuilder,
- relation: &mut RelationBuilder,
- ) -> Result<()> {
- // A second projection implies a derived tablefactor
- if !select.already_projected() {
- self.reconstruct_select_statement(plan, p, select)?;
- self.select_to_sql_recursively(p.input.as_ref(), query, select,
relation)
- } else {
- let mut derived_builder = DerivedRelationBuilder::default();
- derived_builder.lateral(false).alias(None).subquery({
- let inner_statement = self.plan_to_sql(plan)?;
- if let ast::Statement::Query(inner_query) = inner_statement {
- inner_query
- } else {
- return internal_err!(
- "Subquery must be a Query, but found
{inner_statement:?}"
- );
- }
- });
- relation.derived(derived_builder);
- Ok(())
- }
+ fn derive(&self, plan: &LogicalPlan, relation: &mut RelationBuilder) ->
Result<()> {
+ let mut derived_builder = DerivedRelationBuilder::default();
+ derived_builder.lateral(false).alias(None).subquery({
+ let inner_statement = self.plan_to_sql(plan)?;
+ if let ast::Statement::Query(inner_query) = inner_statement {
+ inner_query
+ } else {
+ return internal_err!(
+ "Subquery must be a Query, but found {inner_statement:?}"
+ );
+ }
+ });
+ relation.derived(derived_builder);
+
+ Ok(())
}
fn select_to_sql_recursively(
@@ -256,7 +245,17 @@ impl Unparser<'_> {
Ok(())
}
LogicalPlan::Projection(p) => {
- self.projection_to_sql(plan, p, query, select, relation)
+ if let Some(new_plan) =
rewrite_plan_for_sort_on_non_projected_fields(p) {
+ return self
+ .select_to_sql_recursively(&new_plan, query, select,
relation);
+ }
+
+ // Projection can be top-level plan for derived table
+ if select.already_projected() {
+ return self.derive(plan, relation);
+ }
+ self.reconstruct_select_statement(plan, p, select)?;
+ self.select_to_sql_recursively(p.input.as_ref(), query,
select, relation)
}
LogicalPlan::Filter(filter) => {
if let Some(AggVariant::Aggregate(agg)) =
@@ -278,6 +277,10 @@ impl Unparser<'_> {
)
}
LogicalPlan::Limit(limit) => {
+ // Limit can be top-level plan for derived table
+ if select.already_projected() {
+ return self.derive(plan, relation);
+ }
if let Some(fetch) = limit.fetch {
let Some(query) = query.as_mut() else {
return internal_err!(
@@ -298,6 +301,10 @@ impl Unparser<'_> {
)
}
LogicalPlan::Sort(sort) => {
+ // Sort can be top-level plan for derived table
+ if select.already_projected() {
+ return self.derive(plan, relation);
+ }
if let Some(query_ref) = query {
query_ref.order_by(self.sort_to_sql(sort.expr.clone())?);
} else {
@@ -323,6 +330,10 @@ impl Unparser<'_> {
)
}
LogicalPlan::Distinct(distinct) => {
+ // Distinct can be top-level plan for derived table
+ if select.already_projected() {
+ return self.derive(plan, relation);
+ }
let (select_distinct, input) = match distinct {
Distinct::All(input) => (ast::Distinct::Distinct,
input.as_ref()),
Distinct::On(on) => {
diff --git a/datafusion/sql/src/unparser/rewrite.rs
b/datafusion/sql/src/unparser/rewrite.rs
index a73fce30ce..fba95ad48f 100644
--- a/datafusion/sql/src/unparser/rewrite.rs
+++ b/datafusion/sql/src/unparser/rewrite.rs
@@ -15,13 +15,16 @@
// specific language governing permissions and limitations
// under the License.
-use std::sync::Arc;
+use std::{
+ collections::{HashMap, HashSet},
+ sync::Arc,
+};
use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeIterator},
Result,
};
-use datafusion_expr::{Expr, LogicalPlan, Sort};
+use datafusion_expr::{Expr, LogicalPlan, Projection, Sort};
/// Normalize the schema of a union plan to remove qualifiers from the schema
fields and sort expressions.
///
@@ -99,3 +102,76 @@ fn rewrite_sort_expr_for_union(exprs: Vec<Expr>) ->
Result<Vec<Expr>> {
Ok(sort_exprs)
}
+
+// Rewrite logic plan for query that order by columns are not in projections
+// Plan before rewrite:
+//
+// Projection: j1.j1_string, j2.j2_string
+// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST
+// Projection: j1.j1_string, j2.j2_string, j1.j1_id, j2.j2_id
+// Inner Join: Filter: j1.j1_id = j2.j2_id
+// TableScan: j1
+// TableScan: j2
+//
+// Plan after rewrite
+//
+// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST
+// Projection: j1.j1_string, j2.j2_string
+// Inner Join: Filter: j1.j1_id = j2.j2_id
+// TableScan: j1
+// TableScan: j2
+//
+// This prevents the original plan generate query with derived table but
missing alias.
+pub(super) fn rewrite_plan_for_sort_on_non_projected_fields(
+ p: &Projection,
+) -> Option<LogicalPlan> {
+ let LogicalPlan::Sort(sort) = p.input.as_ref() else {
+ return None;
+ };
+
+ let LogicalPlan::Projection(inner_p) = sort.input.as_ref() else {
+ return None;
+ };
+
+ let mut map = HashMap::new();
+ let inner_exprs = inner_p
+ .expr
+ .iter()
+ .map(|f| {
+ if let Expr::Alias(alias) = f {
+ let a = Expr::Column(alias.name.clone().into());
+ map.insert(a.clone(), f.clone());
+ a
+ } else {
+ f.clone()
+ }
+ })
+ .collect::<Vec<_>>();
+
+ let mut collects = p.expr.clone();
+ for expr in &sort.expr {
+ if let Expr::Sort(s) = expr {
+ collects.push(s.expr.as_ref().clone());
+ }
+ }
+
+ if collects.iter().collect::<HashSet<_>>()
+ == inner_exprs.iter().collect::<HashSet<_>>()
+ {
+ let mut sort = sort.clone();
+ let mut inner_p = inner_p.clone();
+
+ let new_exprs = p
+ .expr
+ .iter()
+ .map(|e| map.get(e).unwrap_or(e).clone())
+ .collect::<Vec<_>>();
+
+ inner_p.expr.clone_from(&new_exprs);
+ sort.input = Arc::new(LogicalPlan::Projection(inner_p));
+
+ Some(LogicalPlan::Sort(sort))
+ } else {
+ None
+ }
+}
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs
b/datafusion/sql/tests/cases/plan_to_sql.rs
index e9c4114353..aada560fd8 100644
--- a/datafusion/sql/tests/cases/plan_to_sql.rs
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -244,6 +244,50 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
+ // Test query with derived tables that put distinct,sort,limit on the
wrong level
+ TestStatementWithDialect {
+ sql: "SELECT j1_string from j1 order by j1_id",
+ expected: r#"SELECT j1.j1_string FROM j1 ORDER BY j1.j1_id ASC
NULLS LAST"#,
+ parser_dialect: Box::new(GenericDialect {}),
+ unparser_dialect: Box::new(UnparserDefaultDialect {}),
+ },
+ TestStatementWithDialect {
+ sql: "SELECT j1_string AS a from j1 order by j1_id",
+ expected: r#"SELECT j1.j1_string AS a FROM j1 ORDER BY j1.j1_id
ASC NULLS LAST"#,
+ parser_dialect: Box::new(GenericDialect {}),
+ unparser_dialect: Box::new(UnparserDefaultDialect {}),
+ },
+ TestStatementWithDialect {
+ sql: "SELECT j1_string from j1 join j2 on j1.j1_id = j2.j2_id
order by j1_id",
+ expected: r#"SELECT j1.j1_string FROM j1 JOIN j2 ON (j1.j1_id =
j2.j2_id) ORDER BY j1.j1_id ASC NULLS LAST"#,
+ parser_dialect: Box::new(GenericDialect {}),
+ unparser_dialect: Box::new(UnparserDefaultDialect {}),
+ },
+ TestStatementWithDialect {
+ sql: "
+ SELECT
+ j1_string,
+ j2_string
+ FROM
+ (
+ SELECT
+ distinct j1_id,
+ j1_string,
+ j2_string
+ from
+ j1
+ INNER join j2 ON j1.j1_id = j2.j2_id
+ order by
+ j1.j1_id desc
+ limit
+ 10
+ ) abc
+ ORDER BY
+ abc.j2_string",
+ expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT
DISTINCT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id =
j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY
abc.j2_string ASC NULLS LAST"#,
+ parser_dialect: Box::new(GenericDialect {}),
+ unparser_dialect: Box::new(UnparserDefaultDialect {}),
+ },
// more tests around subquery/derived table roundtrip
TestStatementWithDialect {
sql: "SELECT string_count FROM (
@@ -261,6 +305,60 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
+ TestStatementWithDialect {
+ sql: "
+ SELECT
+ j1_string,
+ j2_string
+ FROM
+ (
+ SELECT
+ j1_id,
+ j1_string,
+ j2_string
+ from
+ j1
+ INNER join j2 ON j1.j1_id = j2.j2_id
+ group by
+ j1_id,
+ j1_string,
+ j2_string
+ order by
+ j1.j1_id desc
+ limit
+ 10
+ ) abc
+ ORDER BY
+ abc.j2_string",
+ expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT
j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id)
GROUP BY j1.j1_id, j1.j1_string, j2.j2_string ORDER BY j1.j1_id DESC NULLS
FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#,
+ parser_dialect: Box::new(GenericDialect {}),
+ unparser_dialect: Box::new(UnparserDefaultDialect {}),
+ },
+ // Test query that order by columns are not in select columns
+ TestStatementWithDialect {
+ sql: "
+ SELECT
+ j1_string
+ FROM
+ (
+ SELECT
+ j1_string,
+ j2_string
+ from
+ j1
+ INNER join j2 ON j1.j1_id = j2.j2_id
+ order by
+ j1.j1_id desc,
+ j2.j2_id desc
+ limit
+ 10
+ ) abc
+ ORDER BY
+ j2_string",
+ expected: r#"SELECT abc.j1_string FROM (SELECT j1.j1_string,
j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC
NULLS FIRST, j2.j2_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string
ASC NULLS LAST"#,
+ parser_dialect: Box::new(GenericDialect {}),
+ unparser_dialect: Box::new(UnparserDefaultDialect {}),
+ },
TestStatementWithDialect {
sql: "SELECT id FROM (SELECT j1_id from j1) AS c (id)",
expected: r#"SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c
(id)"#,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]