This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 256ea91b1 Allow sorting by aggregated groups (#3280)
256ea91b1 is described below
commit 256ea91b1c0864449f9b41520c808695aa00460b
Author: Batuhan Taskaya <[email protected]>
AuthorDate: Tue Aug 30 23:51:05 2022 +0300
Allow sorting by aggregated groups (#3280)
* Add the test for mix of order by/group by on a complex expr
* Allow sorting by aggregated groups
* Prevent duplicate sort expressions with mismatched alias to be included
---
datafusion/core/tests/sql/group_by.rs | 72 +++++++++++++++++++++++++++++
datafusion/expr/src/expr_rewriter.rs | 16 +++++--
datafusion/expr/src/logical_plan/builder.rs | 11 +++--
3 files changed, 91 insertions(+), 8 deletions(-)
diff --git a/datafusion/core/tests/sql/group_by.rs
b/datafusion/core/tests/sql/group_by.rs
index e3da1b021..2e1007be8 100644
--- a/datafusion/core/tests/sql/group_by.rs
+++ b/datafusion/core/tests/sql/group_by.rs
@@ -681,3 +681,75 @@ async fn group_by_dictionary() {
run_test_case::<UInt32Type>().await;
run_test_case::<UInt64Type>().await;
}
+
+#[tokio::test]
+async fn csv_query_group_by_order_by_substr() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT substr(c1, 1, 1), avg(c12) \
+ FROM aggregate_test_100 \
+ GROUP BY substr(c1, 1, 1) \
+ ORDER BY substr(c1, 1, 1)";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+
"+-------------------------------------------------+-----------------------------+",
+ "| substr(aggregate_test_100.c1,Int64(1),Int64(1)) |
AVG(aggregate_test_100.c12) |",
+
"+-------------------------------------------------+-----------------------------+",
+ "| a |
0.48754517466109415 |",
+ "| b |
0.41040709263815384 |",
+ "| c |
0.6600456536439784 |",
+ "| d |
0.48855379387549824 |",
+ "| e |
0.48600669271341534 |",
+
"+-------------------------------------------------+-----------------------------+",
+ ];
+ assert_batches_sorted_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_group_by_order_by_substr_aliased_projection() -> Result<()>
{
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT substr(c1, 1, 1) as name, avg(c12) as average \
+ FROM aggregate_test_100 \
+ GROUP BY substr(c1, 1, 1) \
+ ORDER BY substr(c1, 1, 1)";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+------+---------------------+",
+ "| name | average |",
+ "+------+---------------------+",
+ "| a | 0.48754517466109415 |",
+ "| b | 0.41040709263815384 |",
+ "| c | 0.6600456536439784 |",
+ "| d | 0.48855379387549824 |",
+ "| e | 0.48600669271341534 |",
+ "+------+---------------------+",
+ ];
+ assert_batches_sorted_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT substr(c1, 1, 1) as name, avg(c12) as average \
+ FROM aggregate_test_100 \
+ GROUP BY substr(c1, 1, 1) \
+ ORDER BY avg(c12)";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+------+---------------------+",
+ "| name | average |",
+ "+------+---------------------+",
+ "| b | 0.41040709263815384 |",
+ "| e | 0.48600669271341534 |",
+ "| a | 0.48754517466109415 |",
+ "| d | 0.48855379387549824 |",
+ "| c | 0.6600456536439784 |",
+ "+------+---------------------+",
+ ];
+ assert_batches_sorted_eq!(expected, &actual);
+ Ok(())
+}
diff --git a/datafusion/expr/src/expr_rewriter.rs
b/datafusion/expr/src/expr_rewriter.rs
index e8cf049dd..9e8fa8a7e 100644
--- a/datafusion/expr/src/expr_rewriter.rs
+++ b/datafusion/expr/src/expr_rewriter.rs
@@ -19,6 +19,7 @@
use crate::expr::GroupingSet;
use crate::logical_plan::Aggregate;
+use crate::utils::grouping_set_to_exprlist;
use crate::{Expr, ExprSchemable, LogicalPlan};
use datafusion_common::Result;
use datafusion_common::{Column, DFSchema};
@@ -325,12 +326,16 @@ pub fn rewrite_sort_cols_by_aggs(
fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
match plan {
LogicalPlan::Aggregate(Aggregate {
- input, aggr_expr, ..
+ input,
+ aggr_expr,
+ group_expr,
+ ..
}) => {
struct Rewriter<'a> {
plan: &'a LogicalPlan,
input: &'a LogicalPlan,
aggr_expr: &'a Vec<Expr>,
+ distinct_group_exprs: &'a Vec<Expr>,
}
impl<'a> ExprRewriter for Rewriter<'a> {
@@ -341,8 +346,11 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan:
&LogicalPlan) -> Result<Expr> {
return Ok(expr);
}
let normalized_expr = normalized_expr.unwrap();
- if let Some(found_agg) =
- self.aggr_expr.iter().find(|a| (**a) ==
normalized_expr)
+ if let Some(found_agg) = self
+ .aggr_expr
+ .iter()
+ .chain(self.distinct_group_exprs)
+ .find(|a| (**a) == normalized_expr)
{
let agg = normalize_col(found_agg.clone(), self.plan)?;
let col = Expr::Column(
@@ -356,10 +364,12 @@ fn rewrite_sort_col_by_aggs(expr: Expr, plan:
&LogicalPlan) -> Result<Expr> {
}
}
+ let distinct_group_exprs =
grouping_set_to_exprlist(group_expr.as_slice())?;
expr.rewrite(&mut Rewriter {
plan,
input,
aggr_expr,
+ distinct_group_exprs: &distinct_group_exprs,
})
}
LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr,
plan.inputs()[0]),
diff --git a/datafusion/expr/src/logical_plan/builder.rs
b/datafusion/expr/src/logical_plan/builder.rs
index 9eb379142..2946a74af 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -335,16 +335,17 @@ impl LogicalPlanBuilder {
.iter()
.all(|c| input.schema().field_from_column(c).is_ok()) =>
{
- let missing_exprs = missing_cols
+ let mut missing_exprs = missing_cols
.iter()
.map(|c| normalize_col(Expr::Column(c.clone()), &input))
.collect::<Result<Vec<_>>>()?;
+ // Do not let duplicate columns to be added, some of the
+ // missing_cols may be already present but without the new
+ // projected alias.
+ missing_exprs.retain(|e| !expr.contains(e));
expr.extend(missing_exprs);
-
- Ok(LogicalPlan::Projection(Projection::try_new(
- expr, input, alias,
- )?))
+ Ok(project_with_alias((*input).clone(), expr, alias)?)
}
_ => {
let new_inputs = curr_plan