jayzhan211 commented on code in PR #11287:
URL: https://github.com/apache/datafusion/pull/11287#discussion_r1667223626
##########
datafusion/proto/tests/cases/roundtrip_physical_plan.rs:
##########
@@ -362,15 +363,17 @@ fn rountrip_aggregate() -> Result<()> {
false,
)?],
// NTH_VALUE
- vec![Arc::new(NthValueAgg::new(
- col("b", &schema)?,
- 1,
- "NTH_VALUE(b, 1)".to_string(),
- DataType::Int64,
+ vec![udaf::create_aggregate_expr(
+ &AggregateUDF::new_from_impl(NthValueAgg::default()),
+ &[col("b", &schema)?, lit(ScalarValue::UInt64(Some(1)))],
Review Comment:
I guess `lit(1u64)` works too
##########
datafusion/functions-aggregate/src/nth_value.rs:
##########
@@ -430,3 +428,176 @@ impl NthValueAccumulator {
Ok(())
}
}
+
+/// This is a wrapper struct to be able to correctly merge `ARRAY_AGG` data
from
Review Comment:
Alternative way is move this to physical-common, after all the function are
cleanup, we can move it back to functions-aggregate crate
##########
datafusion/proto/tests/cases/roundtrip_physical_plan.rs:
##########
@@ -362,15 +363,17 @@ fn rountrip_aggregate() -> Result<()> {
false,
)?],
// NTH_VALUE
- vec![Arc::new(NthValueAgg::new(
- col("b", &schema)?,
- 1,
- "NTH_VALUE(b, 1)".to_string(),
- DataType::Int64,
+ vec![udaf::create_aggregate_expr(
+ &AggregateUDF::new_from_impl(NthValueAgg::default()),
Review Comment:
we can just use `nth_value_udaf`
##########
datafusion/sql/src/expr/function.rs:
##########
@@ -415,9 +415,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
) -> Result<WindowFunctionDefinition> {
// check udaf first
let udaf = self.context_provider.get_aggregate_meta(name);
- // Skip first value and last value, since we expect window builtin
first/last value not udaf version
+ // Use the builtin window function instead of the user-defined
aggregate function
if udaf.as_ref().is_some_and(|udaf| {
- udaf.name() != "first_value" && udaf.name() != "last_value"
+ udaf.name() != "first_value"
+ && udaf.name() != "last_value"
+ && udaf.name() != "nth_value"
Review Comment:
We may want to replace with UDAF when we convert builtin-window to UDWF #8709
##########
datafusion/functions-aggregate/src/nth_value.rs:
##########
@@ -19,152 +19,150 @@
//! that can evaluated at runtime during query execution
use std::any::Any;
-use std::collections::VecDeque;
+use std::cmp::Ordering;
+use std::collections::{BinaryHeap, VecDeque};
use std::sync::Arc;
-use crate::aggregate::array_agg_ordered::merge_ordered_arrays;
-use crate::aggregate::utils::{down_cast_any_ref, ordering_fields};
-use crate::expressions::{format_state_name, Literal};
-use crate::{
- reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr,
PhysicalSortExpr,
+use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray};
+use arrow_schema::{DataType, Field, Fields, SortOptions};
+
+use datafusion_common::utils::{
+ array_into_list_array_nullable, compare_rows, get_row_at_idx,
+};
+use datafusion_common::{exec_err, internal_err, not_impl_err, Result,
ScalarValue};
+use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
+use datafusion_expr::utils::format_state_name;
+use datafusion_expr::{
+ Accumulator, AggregateUDF, AggregateUDFImpl, Expr, ReversedUDAF, Signature,
+ Volatility,
+};
+use datafusion_physical_expr_common::aggregate::utils::ordering_fields;
+use datafusion_physical_expr_common::sort_expr::{
+ limited_convert_logical_sort_exprs_to_physical, LexOrdering,
PhysicalSortExpr,
};
-use arrow_array::cast::AsArray;
-use arrow_array::{new_empty_array, ArrayRef, StructArray};
-use arrow_schema::{DataType, Field, Fields};
-use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx};
-use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
-use datafusion_expr::utils::AggregateOrderSensitivity;
-use datafusion_expr::Accumulator;
+make_udaf_expr_and_func!(
+ NthValueAgg,
+ nth_value,
+ "Returns the nth value in a group of values.",
+ nth_value_udaf
+);
/// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi
/// partition setting, partial aggregations are computed for every partition,
/// and then their results are merged.
#[derive(Debug)]
pub struct NthValueAgg {
- /// Column name
- name: String,
- /// The `DataType` for the input expression
- input_data_type: DataType,
- /// The input expression
- expr: Arc<dyn PhysicalExpr>,
- /// The `N` value.
- n: i64,
- /// If the input expression can have `NULL`s
- nullable: bool,
- /// Ordering data types
- order_by_data_types: Vec<DataType>,
- /// Ordering requirement
- ordering_req: LexOrdering,
+ signature: Signature,
+ /// Determines whether `N` is relative to the beginning or the end
+ /// of the aggregation. When set to `true`, then `N` is from the end.
+ reversed: bool,
}
impl NthValueAgg {
/// Create a new `NthValueAgg` aggregate function
- pub fn new(
- expr: Arc<dyn PhysicalExpr>,
- n: i64,
- name: impl Into<String>,
- input_data_type: DataType,
- nullable: bool,
- order_by_data_types: Vec<DataType>,
- ordering_req: LexOrdering,
- ) -> Self {
+ pub fn new() -> Self {
Self {
- name: name.into(),
- input_data_type,
- expr,
- n,
- nullable,
- order_by_data_types,
- ordering_req,
+ signature: Signature::any(2, Volatility::Immutable),
+ reversed: false,
}
}
+
+ pub fn with_reversed(mut self, reversed: bool) -> Self {
+ self.reversed = reversed;
+ self
+ }
}
-impl AggregateExpr for NthValueAgg {
+impl Default for NthValueAgg {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl AggregateUDFImpl for NthValueAgg {
fn as_any(&self) -> &dyn Any {
self
}
- fn field(&self) -> Result<Field> {
- Ok(Field::new(&self.name, self.input_data_type.clone(), true))
+ fn name(&self) -> &str {
+ "nth_value"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
}
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- Ok(Box::new(NthValueAccumulator::try_new(
- self.n,
- &self.input_data_type,
- &self.order_by_data_types,
- self.ordering_req.clone(),
- )?))
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ Ok(arg_types[0].clone())
+ }
+
+ fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ let n = match acc_args.input_exprs[1] {
+ Expr::Literal(ScalarValue::Int64(Some(value))) => {
+ if self.reversed {
+ Ok(-value)
+ } else {
+ Ok(value)
+ }
+ }
+ _ => not_impl_err!(
+ "{} not supported for n: {}",
+ self.name(),
+ &acc_args.input_exprs[1]
+ ),
+ }?;
+
+ let ordering_req = limited_convert_logical_sort_exprs_to_physical(
+ acc_args.sort_exprs,
+ acc_args.schema,
+ )?;
+
+ let ordering_dtypes = ordering_req
+ .iter()
+ .map(|e| e.expr.data_type(acc_args.schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ NthValueAccumulator::try_new(
+ n,
+ acc_args.input_type,
+ &ordering_dtypes,
+ ordering_req,
+ )
+ .map(|acc| Box::new(acc) as _)
}
- fn state_fields(&self) -> Result<Vec<Field>> {
+ fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let mut fields = vec![Field::new_list(
- format_state_name(&self.name, "nth_value"),
- Field::new("item", self.input_data_type.clone(), true),
- self.nullable, // This should be the same as field()
+ format_state_name(self.name(), "nth_value"),
+ Field::new("item", args.input_type.clone(), true),
+ false,
Review Comment:
I think given the existing nth function, we should let nullable
configurable. And, the nullability is actually for the list element. We should
add nullable in `StateFieldArgs`.
```rust
let mut fields = vec![Field::new_list(
format_state_name(self.name(), "nth_value"),
Field::new("item", args.input_type.clone(), self.nullable),
false)]
```
@eejbyfeldt is working on it in #11063
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]