This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 30e8ceff8 Add logical plan support for aggregate expressions with
filters (and upgrade to sqlparser 0.23) (#3405)
30e8ceff8 is described below
commit 30e8ceff80dcb5d95d8f399917ac6c846986bdf7
Author: Andy Grove <[email protected]>
AuthorDate: Mon Sep 12 12:19:29 2022 -0600
Add logical plan support for aggregate expressions with filters (and
upgrade to sqlparser 0.23) (#3405)
* Use sqlparser-0.23
* Add filter to aggregate expressions
* clippy
* implement protobuf serde
* clippy
* fix error message
* Update datafusion/expr/src/expr.rs
Co-authored-by: Andrew Lamb <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/src/physical_plan/planner.rs | 7 +++-
datafusion/expr/src/expr.rs | 46 ++++++++++++++++++----
datafusion/expr/src/expr_fn.rs | 10 +++++
datafusion/expr/src/expr_rewriter.rs | 5 ++-
datafusion/expr/src/udaf.rs | 1 +
.../optimizer/src/single_distinct_to_groupby.rs | 6 ++-
datafusion/proto/proto/datafusion.proto | 2 +
datafusion/proto/src/from_proto.rs | 9 +++--
datafusion/proto/src/lib.rs | 4 ++
datafusion/proto/src/to_proto.rs | 35 ++++++++++------
datafusion/sql/src/planner.rs | 29 ++++++++++++--
datafusion/sql/src/utils.rs | 5 ++-
12 files changed, 130 insertions(+), 29 deletions(-)
diff --git a/datafusion/core/src/physical_plan/planner.rs
b/datafusion/core/src/physical_plan/planner.rs
index 34cadd5b4..9f1e488ef 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -194,7 +194,12 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) ->
Result<String> {
args,
..
} => create_function_physical_name(&fun.to_string(), *distinct, args),
- Expr::AggregateUDF { fun, args } => {
+ Expr::AggregateUDF { fun, args, filter } => {
+ if filter.is_some() {
+ return Err(DataFusionError::Execution(
+ "aggregate expression with filter is not
supported".to_string(),
+ ));
+ }
let mut names = Vec::with_capacity(args.len());
for e in args {
names.push(create_physical_name(e, false)?);
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index ab45dd67d..8b90fb9e4 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -231,6 +231,8 @@ pub enum Expr {
args: Vec<Expr>,
/// Whether this is a DISTINCT aggregation or not
distinct: bool,
+ /// Optional filter
+ filter: Option<Box<Expr>>,
},
/// Represents the call of a window function with arguments.
WindowFunction {
@@ -251,6 +253,8 @@ pub enum Expr {
fun: Arc<AggregateUDF>,
/// List of expressions to feed to the functions as arguments
args: Vec<Expr>,
+ /// Optional filter applied prior to aggregating
+ filter: Option<Box<Expr>>,
},
/// Returns whether the list contains the expr value.
InList {
@@ -668,10 +672,26 @@ impl fmt::Debug for Expr {
fun,
distinct,
ref args,
+ filter,
..
- } => fmt_function(f, &fun.to_string(), *distinct, args, true),
- Expr::AggregateUDF { fun, ref args, .. } => {
- fmt_function(f, &fun.name, false, args, false)
+ } => {
+ fmt_function(f, &fun.to_string(), *distinct, args, true)?;
+ if let Some(fe) = filter {
+ write!(f, " FILTER (WHERE {})", fe)?;
+ }
+ Ok(())
+ }
+ Expr::AggregateUDF {
+ fun,
+ ref args,
+ filter,
+ ..
+ } => {
+ fmt_function(f, &fun.name, false, args, false)?;
+ if let Some(fe) = filter {
+ write!(f, " FILTER (WHERE {})", fe)?;
+ }
+ Ok(())
}
Expr::Between {
expr,
@@ -1010,14 +1030,26 @@ fn create_name(e: &Expr) -> Result<String> {
fun,
distinct,
args,
- ..
- } => create_function_name(&fun.to_string(), *distinct, args),
- Expr::AggregateUDF { fun, args } => {
+ filter,
+ } => {
+ let name = create_function_name(&fun.to_string(), *distinct,
args)?;
+ if let Some(fe) = filter {
+ Ok(format!("{} FILTER (WHERE {})", name, fe))
+ } else {
+ Ok(name)
+ }
+ }
+ Expr::AggregateUDF { fun, args, filter } => {
let mut names = Vec::with_capacity(args.len());
for e in args {
names.push(create_name(e)?);
}
- Ok(format!("{}({})", fun.name, names.join(",")))
+ let filter = if let Some(fe) = filter {
+ format!(" FILTER (WHERE {})", fe)
+ } else {
+ "".to_string()
+ };
+ Ok(format!("{}({}){}", fun.name, names.join(","), filter))
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index f7eaec39b..6c5cc0ecc 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -66,6 +66,7 @@ pub fn min(expr: Expr) -> Expr {
fun: aggregate_function::AggregateFunction::Min,
distinct: false,
args: vec![expr],
+ filter: None,
}
}
@@ -75,6 +76,7 @@ pub fn max(expr: Expr) -> Expr {
fun: aggregate_function::AggregateFunction::Max,
distinct: false,
args: vec![expr],
+ filter: None,
}
}
@@ -84,6 +86,7 @@ pub fn sum(expr: Expr) -> Expr {
fun: aggregate_function::AggregateFunction::Sum,
distinct: false,
args: vec![expr],
+ filter: None,
}
}
@@ -93,6 +96,7 @@ pub fn avg(expr: Expr) -> Expr {
fun: aggregate_function::AggregateFunction::Avg,
distinct: false,
args: vec![expr],
+ filter: None,
}
}
@@ -102,6 +106,7 @@ pub fn count(expr: Expr) -> Expr {
fun: aggregate_function::AggregateFunction::Count,
distinct: false,
args: vec![expr],
+ filter: None,
}
}
@@ -111,6 +116,7 @@ pub fn count_distinct(expr: Expr) -> Expr {
fun: aggregate_function::AggregateFunction::Count,
distinct: true,
args: vec![expr],
+ filter: None,
}
}
@@ -163,6 +169,7 @@ pub fn approx_distinct(expr: Expr) -> Expr {
fun: aggregate_function::AggregateFunction::ApproxDistinct,
distinct: false,
args: vec![expr],
+ filter: None,
}
}
@@ -172,6 +179,7 @@ pub fn approx_median(expr: Expr) -> Expr {
fun: aggregate_function::AggregateFunction::ApproxMedian,
distinct: false,
args: vec![expr],
+ filter: None,
}
}
@@ -181,6 +189,7 @@ pub fn approx_percentile_cont(expr: Expr, percentile: Expr)
-> Expr {
fun: aggregate_function::AggregateFunction::ApproxPercentileCont,
distinct: false,
args: vec![expr, percentile],
+ filter: None,
}
}
@@ -194,6 +203,7 @@ pub fn approx_percentile_cont_with_weight(
fun:
aggregate_function::AggregateFunction::ApproxPercentileContWithWeight,
distinct: false,
args: vec![expr, weight_expr, percentile],
+ filter: None,
}
}
diff --git a/datafusion/expr/src/expr_rewriter.rs
b/datafusion/expr/src/expr_rewriter.rs
index b8b9fced9..533f31ce1 100644
--- a/datafusion/expr/src/expr_rewriter.rs
+++ b/datafusion/expr/src/expr_rewriter.rs
@@ -250,10 +250,12 @@ impl ExprRewritable for Expr {
args,
fun,
distinct,
+ filter,
} => Expr::AggregateFunction {
args: rewrite_vec(args, rewriter)?,
fun,
distinct,
+ filter,
},
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
@@ -271,9 +273,10 @@ impl ExprRewritable for Expr {
))
}
},
- Expr::AggregateUDF { args, fun } => Expr::AggregateUDF {
+ Expr::AggregateUDF { args, fun, filter } => Expr::AggregateUDF {
args: rewrite_vec(args, rewriter)?,
fun,
+ filter,
},
Expr::InList {
expr,
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index 00f48dda2..0ecb5280a 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -89,6 +89,7 @@ impl AggregateUDF {
Expr::AggregateUDF {
fun: Arc::new(self.clone()),
args,
+ filter: None,
}
}
}
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index 656d3967e..f1982bcf1 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -87,7 +87,9 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let new_aggr_exprs = aggr_expr
.iter()
.map(|aggr_expr| match aggr_expr {
- Expr::AggregateFunction { fun, args, .. } => {
+ Expr::AggregateFunction {
+ fun, args, filter, ..
+ } => {
// is_single_distinct_agg ensure args.len=1
if group_fields_set.insert(args[0].name()?) {
inner_group_exprs
@@ -97,6 +99,7 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
fun: fun.clone(),
args: vec![col(SINGLE_DISTINCT_ALIAS)],
distinct: false, // intentional to remove
distinct here
+ filter: filter.clone(),
})
}
_ => Ok(aggr_expr.clone()),
@@ -402,6 +405,7 @@ mod tests {
fun: AggregateFunction::Max,
distinct: true,
args: vec![col("b")],
+ filter: None,
},
],
)?
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 8d4da0250..baabc04cf 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -504,11 +504,13 @@ message AggregateExprNode {
AggregateFunction aggr_function = 1;
repeated LogicalExprNode expr = 2;
bool distinct = 3;
+ LogicalExprNode filter = 4;
}
message AggregateUDFExprNode {
string fun_name = 1;
repeated LogicalExprNode args = 2;
+ LogicalExprNode filter = 3;
}
message ScalarUDFExprNode {
diff --git a/datafusion/proto/src/from_proto.rs
b/datafusion/proto/src/from_proto.rs
index c93d3a877..5402b03ce 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -891,6 +891,7 @@ pub fn parse_expr(
.map(|e| parse_expr(e, registry))
.collect::<Result<Vec<_>, _>>()?,
distinct: expr.distinct,
+ filter: parse_optional_expr(&expr.filter,
registry)?.map(Box::new),
})
}
ExprType::Alias(alias) => Ok(Expr::Alias(
@@ -1194,15 +1195,17 @@ pub fn parse_expr(
.collect::<Result<Vec<_>, Error>>()?,
})
}
- ExprType::AggregateUdfExpr(protobuf::AggregateUdfExprNode { fun_name,
args }) => {
- let agg_fn = registry.udaf(fun_name.as_str())?;
+ ExprType::AggregateUdfExpr(pb) => {
+ let agg_fn = registry.udaf(pb.fun_name.as_str())?;
Ok(Expr::AggregateUDF {
fun: agg_fn,
- args: args
+ args: pb
+ .args
.iter()
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, Error>>()?,
+ filter: parse_optional_expr(&pb.filter,
registry)?.map(Box::new),
})
}
diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs
index 8e9475329..cce778be2 100644
--- a/datafusion/proto/src/lib.rs
+++ b/datafusion/proto/src/lib.rs
@@ -1023,6 +1023,7 @@ mod roundtrip_tests {
fun: AggregateFunction::Count,
args: vec![col("bananas")],
distinct: false,
+ filter: None,
};
let ctx = SessionContext::new();
roundtrip_expr_test(test_expr, ctx);
@@ -1034,6 +1035,7 @@ mod roundtrip_tests {
fun: AggregateFunction::Count,
args: vec![col("bananas")],
distinct: true,
+ filter: None,
};
let ctx = SessionContext::new();
roundtrip_expr_test(test_expr, ctx);
@@ -1045,6 +1047,7 @@ mod roundtrip_tests {
fun: AggregateFunction::ApproxPercentileCont,
args: vec![col("bananas"), lit(0.42_f32)],
distinct: false,
+ filter: None,
};
let ctx = SessionContext::new();
@@ -1097,6 +1100,7 @@ mod roundtrip_tests {
let test_expr = Expr::AggregateUDF {
fun: Arc::new(dummy_agg.clone()),
args: vec![lit(1.0_f64)],
+ filter: Some(Box::new(lit(true))),
};
let mut ctx = SessionContext::new();
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 8c43f876e..43d649029 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -585,6 +585,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
ref fun,
ref args,
ref distinct,
+ ref filter
} => {
let aggr_function = match fun {
AggregateFunction::ApproxDistinct => {
@@ -633,9 +634,13 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
.map(|v| v.try_into())
.collect::<Result<Vec<_>, _>>()?,
distinct: *distinct,
+ filter: match filter {
+ Some(e) => Some(Box::new(e.as_ref().try_into()?)),
+ None => None,
+ }
};
Self {
- expr_type: Some(ExprType::AggregateExpr(aggregate_expr)),
+ expr_type:
Some(ExprType::AggregateExpr(Box::new(aggregate_expr))),
}
}
Expr::ScalarVariable(_, _) => return Err(Error::General("Proto
serialization error: Scalar Variable not supported".to_string())),
@@ -663,17 +668,23 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
.collect::<Result<Vec<_>, Error>>()?,
})),
},
- Expr::AggregateUDF { fun, args } => Self {
- expr_type: Some(ExprType::AggregateUdfExpr(
- protobuf::AggregateUdfExprNode {
- fun_name: fun.name.clone(),
- args: args.iter().map(|expr|
expr.try_into()).collect::<Result<
- Vec<_>,
- Error,
- >>(
- )?,
- },
- )),
+ Expr::AggregateUDF { fun, args, filter } => {
+ Self {
+ expr_type: Some(ExprType::AggregateUdfExpr(
+ Box::new(protobuf::AggregateUdfExprNode {
+ fun_name: fun.name.clone(),
+ args: args.iter().map(|expr|
expr.try_into()).collect::<Result<
+ Vec<_>,
+ Error,
+ >>(
+ )?,
+ filter: match filter {
+ Some(e) =>
Some(Box::new(e.as_ref().try_into()?)),
+ None => None,
+ }
+ },
+ ))),
+ }
},
Expr::Not(expr) => {
let expr = Box::new(protobuf::Not {
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 881de1ebe..5d30b670f 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -2089,6 +2089,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Ok(Expr::ScalarFunction { fun, args })
}
+ SQLExpr::AggregateExpressionWithFilter { expr, filter } => {
+ match self.sql_expr_to_logical_expr(*expr, schema, ctes)? {
+ Expr::AggregateFunction {
+ fun, args, distinct, ..
+ } => Ok(Expr::AggregateFunction { fun, args, distinct,
filter: Some(Box::new(self.sql_expr_to_logical_expr(*filter, schema, ctes)?))
}),
+ _ =>
Err(DataFusionError::Internal("AggregateExpressionWithFilter expression was not
an AggregateFunction".to_string()))
+ }
+
+ }
+
SQLExpr::Function(mut function) => {
let name = if function.name.0.len() > 1 {
// DF doesn't handle compound identifiers
@@ -2185,6 +2195,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
fun,
distinct,
args,
+ filter: None
});
};
@@ -2198,7 +2209,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
None => match
self.schema_provider.get_aggregate_meta(&name) {
Some(fm) => {
let args =
self.function_args_to_expr(function.args, schema)?;
- Ok(Expr::AggregateUDF { fun: fm, args })
+ Ok(Expr::AggregateUDF { fun: fm, args, filter:
None })
}
_ => Err(DataFusionError::Plan(format!(
"Invalid function '{}'",
@@ -2217,7 +2228,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::Subquery(subquery) =>
self.parse_scalar_subquery(&subquery, schema, ctes),
_ => Err(DataFusionError::NotImplemented(format!(
- "Unsupported ast node {:?} in sqltorel",
+ "Unsupported ast node in sqltorel: {:?}",
sql
))),
}
@@ -2731,7 +2742,7 @@ fn parse_sql_number(n: &str) -> Result<Expr> {
mod tests {
use super::*;
use crate::assert_contains;
- use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect};
+ use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect,
MySqlDialect};
use std::any::Any;
#[test]
@@ -4966,6 +4977,18 @@ mod tests {
quick_test(sql, expected);
}
+ #[test]
+ fn hive_aggregate_with_filter() -> Result<()> {
+ let dialect = &HiveDialect {};
+ let sql = "SELECT SUM(age) FILTER (WHERE age > 4) FROM person";
+ let plan = logical_plan_with_dialect(sql, dialect)?;
+ let expected = "Projection: #SUM(person.age) FILTER (WHERE #age >
Int64(4))\
+ \n Aggregate: groupBy=[[]], aggr=[[SUM(#person.age) FILTER (WHERE
#age > Int64(4))]]\
+ \n TableScan: person".to_string();
+ assert_eq!(expected, format!("{}", plan.display_indent()));
+ Ok(())
+ }
+
#[test]
fn order_by_unaliased_name() {
// https://github.com/apache/arrow-datafusion/issues/3160
diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs
index 25f5c549a..eb58509d0 100644
--- a/datafusion/sql/src/utils.rs
+++ b/datafusion/sql/src/utils.rs
@@ -163,6 +163,7 @@ where
fun,
args,
distinct,
+ filter,
} => Ok(Expr::AggregateFunction {
fun: fun.clone(),
args: args
@@ -170,6 +171,7 @@ where
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
distinct: *distinct,
+ filter: filter.clone(),
}),
Expr::WindowFunction {
fun,
@@ -193,12 +195,13 @@ where
.collect::<Result<Vec<_>>>()?,
window_frame: *window_frame,
}),
- Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF {
+ Expr::AggregateUDF { fun, args, filter } => Ok(Expr::AggregateUDF {
fun: fun.clone(),
args: args
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
+ filter: filter.clone(),
}),
Expr::Alias(nested_expr, alias_name) => Ok(Expr::Alias(
Box::new(clone_with_replacement(nested_expr, replacement_fn)?),