This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 77e0e3b8b4 Fixes missing `nth_value` UDAF expr function  (#12279)
77e0e3b8b4 is described below

commit 77e0e3b8b4df21eede9f1e3d1f8ee7709681a2d4
Author: jcsherin <[email protected]>
AuthorDate: Mon Sep 2 18:24:14 2024 +0530

    Fixes missing `nth_value` UDAF expr function  (#12279)
    
    * Makes `nth_value` expression API public
    
    * Updates type of `order_by` parameter
---
 datafusion/functions-aggregate/src/lib.rs          |  1 +
 datafusion/functions-aggregate/src/nth_value.rs    | 28 ++++++++++++++++------
 .../proto/tests/cases/roundtrip_logical_plan.rs    | 13 ++++++++++
 3 files changed, 35 insertions(+), 7 deletions(-)

diff --git a/datafusion/functions-aggregate/src/lib.rs 
b/datafusion/functions-aggregate/src/lib.rs
index b54cd181a0..ca0276d326 100644
--- a/datafusion/functions-aggregate/src/lib.rs
+++ b/datafusion/functions-aggregate/src/lib.rs
@@ -113,6 +113,7 @@ pub mod expr_fn {
     pub use super::median::median;
     pub use super::min_max::max;
     pub use super::min_max::min;
+    pub use super::nth_value::nth_value;
     pub use super::regr::regr_avgx;
     pub use super::regr::regr_avgy;
     pub use super::regr::regr_count;
diff --git a/datafusion/functions-aggregate/src/nth_value.rs 
b/datafusion/functions-aggregate/src/nth_value.rs
index 7425bdfa18..bbfe56914c 100644
--- a/datafusion/functions-aggregate/src/nth_value.rs
+++ b/datafusion/functions-aggregate/src/nth_value.rs
@@ -30,19 +30,33 @@ use datafusion_common::{exec_err, internal_err, 
not_impl_err, Result, ScalarValu
 use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
 use datafusion_expr::utils::format_state_name;
 use datafusion_expr::{
-    Accumulator, AggregateUDFImpl, ReversedUDAF, Signature, Volatility,
+    lit, Accumulator, AggregateUDFImpl, ExprFunctionExt, ReversedUDAF, 
Signature,
+    SortExpr, Volatility,
 };
 use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
 use datafusion_functions_aggregate_common::utils::ordering_fields;
 use datafusion_physical_expr::expressions::Literal;
 use datafusion_physical_expr_common::sort_expr::{LexOrdering, 
PhysicalSortExpr};
 
-make_udaf_expr_and_func!(
-    NthValueAgg,
-    nth_value,
-    "Returns the nth value in a group of values.",
-    nth_value_udaf
-);
+create_func!(NthValueAgg, nth_value_udaf);
+
+/// Returns the nth value in a group of values.
+pub fn nth_value(
+    expr: datafusion_expr::Expr,
+    n: i64,
+    order_by: Vec<SortExpr>,
+) -> datafusion_expr::Expr {
+    let args = vec![expr, lit(n)];
+    if !order_by.is_empty() {
+        nth_value_udaf()
+            .call(args)
+            .order_by(order_by)
+            .build()
+            .unwrap()
+    } else {
+        nth_value_udaf().call(args)
+    }
+}
 
 /// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi
 /// partition setting, partial aggregations are computed for every partition,
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index e174d1b507..994ed8ad23 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -71,6 +71,7 @@ use datafusion_expr::{
 use datafusion_functions_aggregate::average::avg_udaf;
 use datafusion_functions_aggregate::expr_fn::{
     approx_distinct, array_agg, avg, bit_and, bit_or, bit_xor, bool_and, 
bool_or, corr,
+    nth_value,
 };
 use datafusion_functions_aggregate::string_agg::string_agg;
 use datafusion_proto::bytes::{
@@ -903,6 +904,18 @@ async fn roundtrip_expr_api() -> Result<()> {
             vec![lit(10), lit(20), lit(30)],
         ),
         row_number(),
+        nth_value(col("b"), 1, vec![]),
+        nth_value(
+            col("b"),
+            1,
+            vec![col("a").sort(false, false), col("b").sort(true, false)],
+        ),
+        nth_value(col("b"), -1, vec![]),
+        nth_value(
+            col("b"),
+            -1,
+            vec![col("a").sort(false, false), col("b").sort(true, false)],
+        ),
     ];
 
     // ensure expressions created with the expr api can be round tripped


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to