This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new cdbd964346 support window function sql2expr (#10243)
cdbd964346 is described below
commit cdbd96434676c8c34e742a6cfea6bc7499e97cde
Author: Junhao Liu <[email protected]>
AuthorDate: Fri Apr 26 09:18:16 2024 -0600
support window function sql2expr (#10243)
---
datafusion/expr/src/built_in_window_function.rs | 2 +-
datafusion/expr/src/expr.rs | 10 ++
datafusion/sql/src/unparser/expr.rs | 175 +++++++++++++++++++-----
3 files changed, 151 insertions(+), 36 deletions(-)
diff --git a/datafusion/expr/src/built_in_window_function.rs
b/datafusion/expr/src/built_in_window_function.rs
index 1001bbb015..18a888ae8b 100644
--- a/datafusion/expr/src/built_in_window_function.rs
+++ b/datafusion/expr/src/built_in_window_function.rs
@@ -71,7 +71,7 @@ pub enum BuiltInWindowFunction {
}
impl BuiltInWindowFunction {
- fn name(&self) -> &str {
+ pub fn name(&self) -> &str {
use BuiltInWindowFunction::*;
match self {
RowNumber => "ROW_NUMBER",
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 0d8e8d816b..e310eaa7e4 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -669,6 +669,16 @@ impl WindowFunctionDefinition {
WindowFunctionDefinition::WindowUDF(fun) =>
fun.signature().clone(),
}
}
+
+ /// Function's name for display
+ pub fn name(&self) -> &str {
+ match self {
+ WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.name(),
+ WindowFunctionDefinition::WindowUDF(fun) => fun.name(),
+ WindowFunctionDefinition::AggregateFunction(fun) => fun.name(),
+ WindowFunctionDefinition::AggregateUDF(fun) => fun.name(),
+ }
+ }
}
impl fmt::Display for WindowFunctionDefinition {
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
index d091fbe14d..7194b0a7d8 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -21,10 +21,7 @@ use datafusion_common::{
internal_datafusion_err, not_impl_err, plan_err, Column, Result,
ScalarValue,
};
use datafusion_expr::{
- expr::{
- AggregateFunctionDefinition, Alias, Exists, InList, ScalarFunction,
Sort,
- WindowFunction,
- },
+ expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction},
Between, BinaryExpr, Case, Cast, Expr, Like, Operator,
};
use sqlparser::ast::{
@@ -170,14 +167,56 @@ impl Unparser<'_> {
Expr::Literal(value) => Ok(self.scalar_to_sql(value)?),
Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr),
Expr::WindowFunction(WindowFunction {
- fun: _,
- args: _,
- partition_by: _,
+ fun,
+ args,
+ partition_by,
order_by: _,
- window_frame: _,
+ window_frame,
null_treatment: _,
}) => {
- not_impl_err!("Unsupported expression: {expr:?}")
+ let func_name = fun.name();
+
+ let args = self.function_args_to_sql(args)?;
+
+ let units = match window_frame.units {
+ datafusion_expr::window_frame::WindowFrameUnits::Rows => {
+ ast::WindowFrameUnits::Rows
+ }
+ datafusion_expr::window_frame::WindowFrameUnits::Range => {
+ ast::WindowFrameUnits::Range
+ }
+ datafusion_expr::window_frame::WindowFrameUnits::Groups =>
{
+ ast::WindowFrameUnits::Groups
+ }
+ };
+ let start_bound =
self.convert_bound(&window_frame.start_bound);
+ let end_bound = self.convert_bound(&window_frame.end_bound);
+ let over = Some(ast::WindowType::WindowSpec(ast::WindowSpec {
+ window_name: None,
+ partition_by: partition_by
+ .iter()
+ .map(|e| self.expr_to_sql(e))
+ .collect::<Result<Vec<_>>>()?,
+ order_by: vec![],
+ window_frame: Some(ast::WindowFrame {
+ units,
+ start_bound,
+ end_bound: Option::from(end_bound),
+ }),
+ }));
+ Ok(ast::Expr::Function(Function {
+ name: ast::ObjectName(vec![Ident {
+ value: func_name.to_string(),
+ quote_style: None,
+ }]),
+ args,
+ filter: None,
+ null_treatment: None,
+ over,
+ distinct: false,
+ special: false,
+ order_by: vec![],
+ }))
}
Expr::SimilarTo(Like {
negated,
@@ -199,37 +238,20 @@ impl Unparser<'_> {
escape_char: *escape_char,
}),
Expr::AggregateFunction(agg) => {
- let func_name = if let
AggregateFunctionDefinition::BuiltIn(built_in) =
- &agg.func_def
- {
- built_in.name()
- } else {
- return not_impl_err!(
- "Only built in agg functions are supported, got
{agg:?}"
- );
- };
-
- let args = agg
- .args
- .iter()
- .map(|e| {
- if matches!(e, Expr::Wildcard { qualifier: None }) {
-
Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard))
- } else {
- self.expr_to_sql(e).map(|e| {
-
FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))
- })
- }
- })
- .collect::<Result<Vec<_>>>()?;
+ let func_name = agg.func_def.name();
+ let args = self.function_args_to_sql(&agg.args)?;
+ let filter = match &agg.filter {
+ Some(filter) => Some(Box::new(self.expr_to_sql(filter)?)),
+ None => None,
+ };
Ok(ast::Expr::Function(Function {
name: ast::ObjectName(vec![Ident {
value: func_name.to_string(),
quote_style: None,
}]),
args,
- filter: None,
+ filter,
null_treatment: None,
over: None,
distinct: agg.distinct,
@@ -355,6 +377,40 @@ impl Unparser<'_> {
Ok(ast::Expr::Identifier(self.new_ident(col.name.to_string())))
}
+ fn convert_bound(
+ &self,
+ bound: &datafusion_expr::window_frame::WindowFrameBound,
+ ) -> ast::WindowFrameBound {
+ match bound {
+ datafusion_expr::window_frame::WindowFrameBound::Preceding(val) =>
{
+ ast::WindowFrameBound::Preceding(
+ self.scalar_to_sql(val).map(Box::new).ok(),
+ )
+ }
+ datafusion_expr::window_frame::WindowFrameBound::Following(val) =>
{
+ ast::WindowFrameBound::Following(
+ self.scalar_to_sql(val).map(Box::new).ok(),
+ )
+ }
+ datafusion_expr::window_frame::WindowFrameBound::CurrentRow => {
+ ast::WindowFrameBound::CurrentRow
+ }
+ }
+ }
+
+ fn function_args_to_sql(&self, args: &[Expr]) ->
Result<Vec<ast::FunctionArg>> {
+ args.iter()
+ .map(|e| {
+ if matches!(e, Expr::Wildcard { qualifier: None }) {
+
Ok(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard))
+ } else {
+ self.expr_to_sql(e)
+ .map(|e|
ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)))
+ }
+ })
+ .collect::<Result<Vec<_>>>()
+ }
+
pub(super) fn new_ident(&self, str: String) -> ast::Ident {
ast::Ident {
value: str,
@@ -735,8 +791,10 @@ mod tests {
use arrow::datatypes::{Field, Schema};
use datafusion_common::TableReference;
use datafusion_expr::{
- case, col, exists, expr::AggregateFunction, lit, not, not_exists,
table_scan,
- ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
+ case, col, exists,
+ expr::{AggregateFunction, AggregateFunctionDefinition},
+ lit, not, not_exists, table_scan, wildcard, ColumnarValue, ScalarUDF,
+ ScalarUDFImpl, Signature, Volatility, WindowFrame,
WindowFunctionDefinition,
};
use crate::unparser::dialect::CustomDialect;
@@ -901,6 +959,53 @@ mod tests {
}),
"COUNT(DISTINCT *)",
),
+ (
+ Expr::AggregateFunction(AggregateFunction {
+ func_def: AggregateFunctionDefinition::BuiltIn(
+ datafusion_expr::AggregateFunction::Count,
+ ),
+ args: vec![Expr::Wildcard { qualifier: None }],
+ distinct: false,
+ filter: Some(Box::new(lit(true))),
+ order_by: None,
+ null_treatment: None,
+ }),
+ "COUNT(*) FILTER (WHERE true)",
+ ),
+ (
+ Expr::WindowFunction(WindowFunction {
+ fun: WindowFunctionDefinition::BuiltInWindowFunction(
+ datafusion_expr::BuiltInWindowFunction::RowNumber,
+ ),
+ args: vec![col("col")],
+ partition_by: vec![],
+ order_by: vec![],
+ window_frame: WindowFrame::new(None),
+ null_treatment: None,
+ }),
+ r#"ROW_NUMBER("col") OVER (ROWS BETWEEN NULL PRECEDING AND
NULL FOLLOWING)"#,
+ ),
+ (
+ Expr::WindowFunction(WindowFunction {
+ fun: WindowFunctionDefinition::AggregateFunction(
+ datafusion_expr::AggregateFunction::Count,
+ ),
+ args: vec![wildcard()],
+ partition_by: vec![],
+ order_by: vec![],
+ window_frame: WindowFrame::new_bounds(
+ datafusion_expr::WindowFrameUnits::Range,
+ datafusion_expr::WindowFrameBound::Preceding(
+ ScalarValue::UInt32(Some(6)),
+ ),
+ datafusion_expr::WindowFrameBound::Following(
+ ScalarValue::UInt32(Some(2)),
+ ),
+ ),
+ null_treatment: None,
+ }),
+ r#"COUNT(*) OVER (RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#,
+ ),
(col("a").is_not_null(), r#""a" IS NOT NULL"#),
(
(col("a") + col("b")).gt(lit(4)).is_true(),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]