This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 1ed6e5138f feat: support named arguments for aggregate and window udfs
(#18389)
1ed6e5138f is described below
commit 1ed6e5138f6e18f2325739af7f03af6fc9611e53
Author: bubulalabu <[email protected]>
AuthorDate: Wed Nov 5 02:46:09 2025 +0100
feat: support named arguments for aggregate and window udfs (#18389)
## Which issue does this PR close?
Addresses portions of https://github.com/apache/datafusion/issues/17379.
## Rationale for this change
Add support for aggregate and window UDFs in the same way as we did it
for scalar UDFs here: https://github.com/apache/datafusion/pull/18019
## Are these changes tested?
Yes
## Are there any user-facing changes?
Yes, the changes are user-facing, documented, purely additive and
non-breaking.
---
datafusion/functions-aggregate/src/correlation.rs | 4 +-
.../functions-aggregate/src/percentile_cont.rs | 4 +-
datafusion/functions-window/src/lead_lag.rs | 8 +-
datafusion/sql/src/expr/function.rs | 56 ++++++++-
.../sqllogictest/test_files/named_arguments.slt | 132 +++++++++++++++++++++
.../library-user-guide/functions/adding-udfs.md | 48 +++-----
6 files changed, 210 insertions(+), 42 deletions(-)
diff --git a/datafusion/functions-aggregate/src/correlation.rs
b/datafusion/functions-aggregate/src/correlation.rs
index 20f23662ca..f2a464de41 100644
--- a/datafusion/functions-aggregate/src/correlation.rs
+++ b/datafusion/functions-aggregate/src/correlation.rs
@@ -88,7 +88,9 @@ impl Correlation {
signature: Signature::exact(
vec![DataType::Float64, DataType::Float64],
Volatility::Immutable,
- ),
+ )
+ .with_parameter_names(vec!["y".to_string(), "x".to_string()])
+ .expect("valid parameter names for corr"),
}
}
}
diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs
b/datafusion/functions-aggregate/src/percentile_cont.rs
index 7ef0f8baf0..1e06461e56 100644
--- a/datafusion/functions-aggregate/src/percentile_cont.rs
+++ b/datafusion/functions-aggregate/src/percentile_cont.rs
@@ -146,7 +146,9 @@ impl PercentileCont {
variants.push(TypeSignature::Exact(vec![num.clone(),
DataType::Float64]));
}
Self {
- signature: Signature::one_of(variants, Volatility::Immutable),
+ signature: Signature::one_of(variants, Volatility::Immutable)
+ .with_parameter_names(vec!["expr".to_string(),
"percentile".to_string()])
+ .expect("valid parameter names for percentile_cont"),
aliases: vec![String::from("quantile_cont")],
}
}
diff --git a/datafusion/functions-window/src/lead_lag.rs
b/datafusion/functions-window/src/lead_lag.rs
index 3910a0be57..02d7fc290b 100644
--- a/datafusion/functions-window/src/lead_lag.rs
+++ b/datafusion/functions-window/src/lead_lag.rs
@@ -137,7 +137,13 @@ impl WindowShift {
TypeSignature::Any(3),
],
Volatility::Immutable,
- ),
+ )
+ .with_parameter_names(vec![
+ "expr".to_string(),
+ "offset".to_string(),
+ "default".to_string(),
+ ])
+ .expect("valid parameter names for lead/lag"),
kind,
}
}
diff --git a/datafusion/sql/src/expr/function.rs
b/datafusion/sql/src/expr/function.rs
index 2d20aaf523..50e479af36 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -386,7 +386,30 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
};
if let Ok(fun) = self.find_window_func(&name) {
- let args = self.function_args_to_expr(args, schema,
planner_context)?;
+ let (args, arg_names) =
+ self.function_args_to_expr_with_names(args, schema,
planner_context)?;
+
+ let resolved_args = if arg_names.iter().any(|name|
name.is_some()) {
+ let signature = match &fun {
+ WindowFunctionDefinition::AggregateUDF(udaf) =>
udaf.signature(),
+ WindowFunctionDefinition::WindowUDF(udwf) =>
udwf.signature(),
+ };
+
+ if let Some(param_names) = &signature.parameter_names {
+ datafusion_expr::arguments::resolve_function_arguments(
+ param_names,
+ args,
+ arg_names,
+ )?
+ } else {
+ return plan_err!(
+ "Window function '{}' does not support named
arguments",
+ name
+ );
+ }
+ } else {
+ args
+ };
// Plan FILTER clause if present
let filter = filter
@@ -396,7 +419,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
let mut window_expr = RawWindowExpr {
func_def: fun,
- args,
+ args: resolved_args,
partition_by,
order_by,
window_frame,
@@ -464,8 +487,8 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
);
}
- let mut args =
- self.function_args_to_expr(args, schema, planner_context)?;
+ let (mut args, mut arg_names) =
+ self.function_args_to_expr_with_names(args, schema,
planner_context)?;
let order_by = if fm.supports_within_group_clause() {
let within_group = self.order_by_to_sort_expr(
@@ -479,6 +502,12 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
// Add the WITHIN GROUP ordering expressions to the front
of the argument list
// So function(arg) WITHIN GROUP (ORDER BY x) becomes
function(x, arg)
if !within_group.is_empty() {
+ // Prepend None arg names for each WITHIN GROUP
expression
+ let within_group_count = within_group.len();
+ arg_names = std::iter::repeat_n(None,
within_group_count)
+ .chain(arg_names)
+ .collect();
+
args = within_group
.iter()
.map(|sort| sort.expr.clone())
@@ -506,9 +535,26 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
.transpose()?
.map(Box::new);
+ let resolved_args = if arg_names.iter().any(|name|
name.is_some()) {
+ if let Some(param_names) = &fm.signature().parameter_names
{
+ datafusion_expr::arguments::resolve_function_arguments(
+ param_names,
+ args,
+ arg_names,
+ )?
+ } else {
+ return plan_err!(
+ "Aggregate function '{}' does not support named
arguments",
+ fm.name()
+ );
+ }
+ } else {
+ args
+ };
+
let mut aggregate_expr = RawAggregateExpr {
func: fm,
- args,
+ args: resolved_args,
distinct,
filter,
order_by,
diff --git a/datafusion/sqllogictest/test_files/named_arguments.slt
b/datafusion/sqllogictest/test_files/named_arguments.slt
index c93da7e7a8..4eab799fd2 100644
--- a/datafusion/sqllogictest/test_files/named_arguments.slt
+++ b/datafusion/sqllogictest/test_files/named_arguments.slt
@@ -137,3 +137,135 @@ SELECT substr(str => 'hello world', start_pos => 7,
length => 5);
# Reset to default dialect
statement ok
set datafusion.sql_parser.dialect = 'Generic';
+
+#############
+## Aggregate UDF Tests - using corr(y, x) function
+#############
+
+# Setup test data
+statement ok
+CREATE TABLE correlation_test(col1 DOUBLE, col2 DOUBLE) AS VALUES
+ (1.0, 2.0),
+ (2.0, 4.0),
+ (3.0, 6.0),
+ (4.0, 8.0);
+
+# Test positional arguments (baseline)
+query R
+SELECT corr(col1, col2) FROM correlation_test;
+----
+1
+
+# Test named arguments out of order (proves named args work for aggregates)
+query R
+SELECT corr(x => col2, y => col1) FROM correlation_test;
+----
+1
+
+# Error: function doesn't support named arguments (count has no parameter
names)
+query error DataFusion error: Error during planning: Aggregate function
'count' does not support named arguments
+SELECT count(value => col1) FROM correlation_test;
+
+# Cleanup
+statement ok
+DROP TABLE correlation_test;
+
+#############
+## Aggregate UDF with WITHIN GROUP Tests - using percentile_cont(expression,
percentile)
+## This tests the special handling where WITHIN GROUP ORDER BY expressions are
prepended to args
+#############
+
+# Setup test data
+statement ok
+CREATE TABLE percentile_test(salary DOUBLE) AS VALUES
+ (50000.0),
+ (60000.0),
+ (70000.0),
+ (80000.0),
+ (90000.0);
+
+# Test positional arguments (baseline) - standard call without WITHIN GROUP
+query R
+SELECT percentile_cont(salary, 0.5) FROM percentile_test;
+----
+70000
+
+# Test WITHIN GROUP with positional argument
+query R
+SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY salary) FROM
percentile_test;
+----
+70000
+
+# Test WITHIN GROUP with named argument for percentile
+# The ORDER BY expression (salary) is prepended internally, becoming:
percentile_cont(salary, 0.5)
+# We use named argument for percentile, which should work correctly
+query R
+SELECT percentile_cont(percentile => 0.5) WITHIN GROUP (ORDER BY salary) FROM
percentile_test;
+----
+70000
+
+# Verify the WITHIN GROUP prepending logic with different percentile value
+query R
+SELECT percentile_cont(percentile => 0.25) WITHIN GROUP (ORDER BY salary) FROM
percentile_test;
+----
+60000
+
+# Cleanup
+statement ok
+DROP TABLE percentile_test;
+
+#############
+## Window UDF Tests - using lead(expression, offset, default) function
+#############
+
+# Setup test data
+statement ok
+CREATE TABLE window_test(id INT, value INT) AS VALUES
+ (1, 10),
+ (2, 20),
+ (3, 30),
+ (4, 40);
+
+# Test positional arguments (baseline)
+query II
+SELECT id, lead(value, 1, 0) OVER (ORDER BY id) FROM window_test ORDER BY id;
+----
+1 20
+2 30
+3 40
+4 0
+
+# Test named arguments out of order (proves named args work for window
functions)
+query II
+SELECT id, lead(default => 0, offset => 1, expr => value) OVER (ORDER BY id)
FROM window_test ORDER BY id;
+----
+1 20
+2 30
+3 40
+4 0
+
+# Test with 1 argument (offset and default use defaults)
+query II
+SELECT id, lead(expr => value) OVER (ORDER BY id) FROM window_test ORDER BY id;
+----
+1 20
+2 30
+3 40
+4 NULL
+
+# Test with 2 arguments (default uses default)
+query II
+SELECT id, lead(expr => value, offset => 2) OVER (ORDER BY id) FROM
window_test ORDER BY id;
+----
+1 30
+2 40
+3 NULL
+4 NULL
+
+# Error: function doesn't support named arguments (row_number has no parameter
names)
+query error DataFusion error: Error during planning: Window function
'row_number' does not support named arguments
+SELECT row_number(value => 1) OVER (ORDER BY id) FROM window_test;
+
+# Cleanup
+statement ok
+DROP TABLE window_test;
diff --git a/docs/source/library-user-guide/functions/adding-udfs.md
b/docs/source/library-user-guide/functions/adding-udfs.md
index 7581d8b650..e56790a4b7 100644
--- a/docs/source/library-user-guide/functions/adding-udfs.md
+++ b/docs/source/library-user-guide/functions/adding-udfs.md
@@ -588,10 +588,17 @@ For async UDF implementation details, see
[`async_udf.rs`](https://github.com/ap
## Named Arguments
-DataFusion supports PostgreSQL-style named arguments for scalar functions,
allowing you to pass arguments by parameter name:
+DataFusion supports named arguments for Scalar, Window, and Aggregate UDFs,
allowing you to pass arguments by parameter name:
```sql
+-- Scalar function
SELECT substr(str => 'hello', start_pos => 2, length => 3);
+
+-- Window function
+SELECT lead(expr => value, offset => 1) OVER (ORDER BY id) FROM table;
+
+-- Aggregate function
+SELECT corr(y => col1, x => col2) FROM table;
```
Named arguments can be mixed with positional arguments, but positional
arguments must come first:
@@ -602,38 +609,7 @@ SELECT substr('hello', start_pos => 2, length => 3); --
Valid
### Implementing Functions with Named Arguments
-To support named arguments in your UDF, add parameter names to your function's
signature using `.with_parameter_names()`:
-
-```rust
-# use arrow::datatypes::DataType;
-# use datafusion_expr::{Signature, Volatility};
-#
-# #[derive(Debug)]
-# struct MyFunction {
-# signature: Signature,
-# }
-#
-impl MyFunction {
- fn new() -> Self {
- Self {
- signature: Signature::uniform(
- 2,
- vec![DataType::Float64],
- Volatility::Immutable
- )
- .with_parameter_names(vec![
- "base".to_string(),
- "exponent".to_string()
- ])
- .expect("valid parameter names"),
- }
- }
-}
-```
-
-The parameter names should match the order of arguments in your function's
signature. DataFusion automatically resolves named arguments to the correct
positional order before invoking your function.
-
-### Example
+To support named arguments in your UDF, add parameter names to your function's
signature using `.with_parameter_names()`. This works the same way for Scalar,
Window, and Aggregate UDFs:
```rust
# use std::sync::Arc;
@@ -681,10 +657,14 @@ impl ScalarUDFImpl for PowerFunction {
}
```
-Once registered, users can call your function with named arguments:
+The parameter names should match the order of arguments in your function's
signature. DataFusion automatically resolves named arguments to the correct
positional order before invoking your function.
+
+Once registered, users can call your functions with named arguments in any
order:
```sql
+-- All equivalent
SELECT power(base => 2.0, exponent => 3.0);
+SELECT power(exponent => 3.0, base => 2.0);
SELECT power(2.0, exponent => 3.0);
```
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]