jcsherin commented on code in PR #11287:
URL: https://github.com/apache/datafusion/pull/11287#discussion_r1667017516
##########
datafusion/functions-aggregate/src/nth_value.rs:
##########
@@ -19,152 +19,150 @@
//! that can evaluated at runtime during query execution
use std::any::Any;
-use std::collections::VecDeque;
+use std::cmp::Ordering;
+use std::collections::{BinaryHeap, VecDeque};
use std::sync::Arc;
-use crate::aggregate::array_agg_ordered::merge_ordered_arrays;
-use crate::aggregate::utils::{down_cast_any_ref, ordering_fields};
-use crate::expressions::{format_state_name, Literal};
-use crate::{
- reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr,
PhysicalSortExpr,
+use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray};
+use arrow_schema::{DataType, Field, Fields, SortOptions};
+
+use datafusion_common::utils::{
+ array_into_list_array_nullable, compare_rows, get_row_at_idx,
+};
+use datafusion_common::{exec_err, internal_err, not_impl_err, Result,
ScalarValue};
+use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
+use datafusion_expr::utils::format_state_name;
+use datafusion_expr::{
+ Accumulator, AggregateUDF, AggregateUDFImpl, Expr, ReversedUDAF, Signature,
+ Volatility,
+};
+use datafusion_physical_expr_common::aggregate::utils::ordering_fields;
+use datafusion_physical_expr_common::sort_expr::{
+ limited_convert_logical_sort_exprs_to_physical, LexOrdering,
PhysicalSortExpr,
};
-use arrow_array::cast::AsArray;
-use arrow_array::{new_empty_array, ArrayRef, StructArray};
-use arrow_schema::{DataType, Field, Fields};
-use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx};
-use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
-use datafusion_expr::utils::AggregateOrderSensitivity;
-use datafusion_expr::Accumulator;
+make_udaf_expr_and_func!(
+ NthValueAgg,
+ nth_value,
+ "Returns the nth value in a group of values.",
+ nth_value_udaf
+);
/// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi
/// partition setting, partial aggregations are computed for every partition,
/// and then their results are merged.
#[derive(Debug)]
pub struct NthValueAgg {
- /// Column name
- name: String,
- /// The `DataType` for the input expression
- input_data_type: DataType,
- /// The input expression
- expr: Arc<dyn PhysicalExpr>,
- /// The `N` value.
- n: i64,
- /// If the input expression can have `NULL`s
- nullable: bool,
- /// Ordering data types
- order_by_data_types: Vec<DataType>,
- /// Ordering requirement
- ordering_req: LexOrdering,
+ signature: Signature,
+ /// Determines whether `N` is relative to the beginning or the end
+ /// of the aggregation. When set to `true`, then `N` is from the end.
+ reversed: bool,
}
impl NthValueAgg {
/// Create a new `NthValueAgg` aggregate function
- pub fn new(
- expr: Arc<dyn PhysicalExpr>,
- n: i64,
- name: impl Into<String>,
- input_data_type: DataType,
- nullable: bool,
- order_by_data_types: Vec<DataType>,
- ordering_req: LexOrdering,
- ) -> Self {
+ pub fn new() -> Self {
Self {
- name: name.into(),
- input_data_type,
- expr,
- n,
- nullable,
- order_by_data_types,
- ordering_req,
+ signature: Signature::any(2, Volatility::Immutable),
+ reversed: false,
}
}
+
+ pub fn with_reversed(mut self, reversed: bool) -> Self {
+ self.reversed = reversed;
+ self
+ }
}
-impl AggregateExpr for NthValueAgg {
+impl Default for NthValueAgg {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl AggregateUDFImpl for NthValueAgg {
fn as_any(&self) -> &dyn Any {
self
}
- fn field(&self) -> Result<Field> {
- Ok(Field::new(&self.name, self.input_data_type.clone(), true))
+ fn name(&self) -> &str {
+ "nth_value"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
}
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- Ok(Box::new(NthValueAccumulator::try_new(
- self.n,
- &self.input_data_type,
- &self.order_by_data_types,
- self.ordering_req.clone(),
- )?))
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ Ok(arg_types[0].clone())
+ }
+
+ fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ let n = match acc_args.input_exprs[1] {
+ Expr::Literal(ScalarValue::Int64(Some(value))) => {
+ if self.reversed {
+ Ok(-value)
+ } else {
+ Ok(value)
+ }
+ }
+ _ => not_impl_err!(
+ "{} not supported for n: {}",
+ self.name(),
+ &acc_args.input_exprs[1]
+ ),
+ }?;
+
+ let ordering_req = limited_convert_logical_sort_exprs_to_physical(
+ acc_args.sort_exprs,
+ acc_args.schema,
+ )?;
+
+ let ordering_dtypes = ordering_req
+ .iter()
+ .map(|e| e.expr.data_type(acc_args.schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ NthValueAccumulator::try_new(
+ n,
+ acc_args.input_type,
+ &ordering_dtypes,
+ ordering_req,
+ )
+ .map(|acc| Box::new(acc) as _)
}
- fn state_fields(&self) -> Result<Vec<Field>> {
+ fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let mut fields = vec![Field::new_list(
- format_state_name(&self.name, "nth_value"),
- Field::new("item", self.input_data_type.clone(), true),
- self.nullable, // This should be the same as field()
+ format_state_name(self.name(), "nth_value"),
+ Field::new("item", args.input_type.clone(), true),
+ false,
Review Comment:
Should `nullable` be configurable? But it is unavailable in
`StateFieldArgs`. I think it is related to #11274 and #11094.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]