alamb commented on code in PR #23248: URL: https://github.com/apache/datafusion/pull/23248#discussion_r3501128887
########## datafusion-examples/examples/udf/struct_returning_udaf.rs: ########## @@ -0,0 +1,338 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example shows how an extension can return window metadata from an +//! aggregate without using zero-argument aggregate functions. + +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, Float64Array, IntervalMonthDayNanoArray, StructArray, + TimestampNanosecondArray, UInt64Array, +}; +use arrow::datatypes::{ + DataType, Field, Fields, IntervalMonthDayNano, IntervalUnit, Schema, TimeUnit, +}; +use arrow::record_batch::RecordBatch; +use datafusion::assert_batches_eq; +use datafusion::common::{cast::as_primitive_array, exec_err}; +use datafusion::datasource::MemTable; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::{ + AccumulatorFactoryFunction, ColumnarValue, Volatility, create_udaf, create_udf, +}; +use datafusion::physical_plan::Accumulator; +use datafusion::prelude::*; +use datafusion::scalar::ScalarValue; + +pub async fn struct_returning_udaf() -> Result<()> { + let ctx = create_context()?; + + register_session_window(&ctx); + register_augmented_avg(&ctx); + + let sql = " + SELECT + augmented_avg(time, value)['window_start'] AS window_start, + augmented_avg(time, value)['window_end'] AS window_end, + augmented_avg(time, value)['window_duration'] AS window_duration, + augmented_avg(time, value)['avg_value'] AS avg_value + FROM t + GROUP BY session_window(time, INTERVAL '5 microseconds') Review Comment: I think you can change this to use `date_bin` (the built in function) and simplify the example rather than defining a new urf ########## datafusion-examples/examples/udf/struct_returning_udaf.rs: ########## @@ -0,0 +1,338 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example shows how an extension can return window metadata from an +//! aggregate without using zero-argument aggregate functions. + +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, Float64Array, IntervalMonthDayNanoArray, StructArray, + TimestampNanosecondArray, UInt64Array, +}; +use arrow::datatypes::{ + DataType, Field, Fields, IntervalMonthDayNano, IntervalUnit, Schema, TimeUnit, +}; +use arrow::record_batch::RecordBatch; +use datafusion::assert_batches_eq; +use datafusion::common::{cast::as_primitive_array, exec_err}; +use datafusion::datasource::MemTable; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::{ + AccumulatorFactoryFunction, ColumnarValue, Volatility, create_udaf, create_udf, +}; +use datafusion::physical_plan::Accumulator; +use datafusion::prelude::*; +use datafusion::scalar::ScalarValue; + +pub async fn struct_returning_udaf() -> Result<()> { + let ctx = create_context()?; + + register_session_window(&ctx); + register_augmented_avg(&ctx); + + let sql = " + SELECT + augmented_avg(time, value)['window_start'] AS window_start, + augmented_avg(time, value)['window_end'] AS window_end, + augmented_avg(time, value)['window_duration'] AS window_duration, + augmented_avg(time, value)['avg_value'] AS avg_value + FROM t + GROUP BY session_window(time, INTERVAL '5 microseconds') + ORDER BY window_start + "; + + let results = ctx.sql(sql).await?.collect().await?; + let expected = [ + "+----------------------------+----------------------------+-----------------+-----------+", + "| window_start | window_end | window_duration | avg_value |", + "+----------------------------+----------------------------+-----------------+-----------+", + "| 1970-01-01T00:00:00.000001 | 1970-01-01T00:00:00.000002 | 1000 | 15.0 |", + "| 1970-01-01T00:00:00.000005 | 1970-01-01T00:00:00.000009 | 4000 | 3.0 |", + "+----------------------------+----------------------------+-----------------+-----------+", + ]; + assert_batches_eq!(expected, &results); + + println!("Struct-returning aggregate produced window metadata:"); + ctx.sql(sql).await?.show().await?; + + Ok(()) +} + +fn create_context() -> Result<SessionContext> { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("value", DataType::Float64, false), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(TimestampNanosecondArray::from(vec![ + 1000, 2000, 5000, 7000, 9000, + ])) as ArrayRef, + Arc::new(Float64Array::from(vec![10.0, 20.0, 1.0, 3.0, 5.0])), + ], + )?; + + let ctx = SessionContext::new(); + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +fn register_session_window(ctx: &SessionContext) { + // Minimal stand-in for extension-specific window assignment. Real + // extensions can replace this with session, hopping, or other grouping + // logic while keeping the aggregate shape shown below. + let session_window = create_udf( + "session_window", + vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Interval(IntervalUnit::MonthDayNano), + ], + DataType::UInt64, + Volatility::Immutable, + Arc::new(|args: &[ColumnarValue]| { + let [ColumnarValue::Array(times), width] = args else { + return exec_err!( + "session_window expects timestamp array and interval width" + ); + }; + let times = + as_primitive_array::<arrow::datatypes::TimestampNanosecondType>(times)?; + let width = match width { + ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some(width))) => { + interval_width_nanos(*width)? + } + ColumnarValue::Array(widths) => { + let widths = widths + .as_any() + .downcast_ref::<IntervalMonthDayNanoArray>() + .ok_or_else(|| { + DataFusionError::Execution( + "Expected IntervalMonthDayNanoArray".to_string(), + ) + })?; + interval_width_nanos(widths.value(0))? + } + other => { + return exec_err!( + "session_window expected MonthDayNano interval width, got {other:?}" + ); + } + }; + + let window_ids = times + .iter() + .map(|time| time.map(|time| (time as u64 / width) + 1)) + .collect::<UInt64Array>(); + + Ok(ColumnarValue::Array(Arc::new(window_ids) as ArrayRef)) + }), + ); + + ctx.register_udf(session_window); +} + +fn interval_width_nanos(width: IntervalMonthDayNano) -> Result<u64> { + if width.months != 0 || width.days != 0 || width.nanoseconds <= 0 { + return exec_err!( + "session_window expected a positive sub-day interval, got {width:?}" + ); + } + Ok(width.nanoseconds as u64) +} + +fn register_augmented_avg(ctx: &SessionContext) { + let accumulator: AccumulatorFactoryFunction = + Arc::new(|_| Ok(Box::new(AugmentedAvg::new()))); + + let augmented_avg = create_udaf( + "augmented_avg", + vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Float64, + ], + Arc::new(AugmentedAvg::output_datatype()), + Volatility::Immutable, + accumulator, + Arc::new(AugmentedAvg::state_datatypes()), + ); + + ctx.register_udaf(augmented_avg); +} + +#[derive(Debug, Clone)] +struct AugmentedAvg { Review Comment: a few comments to help readers read this function and understand what it is doing would be nice ########## datafusion-examples/examples/udf/struct_returning_udaf.rs: ########## @@ -0,0 +1,338 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example shows how an extension can return window metadata from an +//! aggregate without using zero-argument aggregate functions. + +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, Float64Array, IntervalMonthDayNanoArray, StructArray, + TimestampNanosecondArray, UInt64Array, +}; +use arrow::datatypes::{ + DataType, Field, Fields, IntervalMonthDayNano, IntervalUnit, Schema, TimeUnit, +}; +use arrow::record_batch::RecordBatch; +use datafusion::assert_batches_eq; +use datafusion::common::{cast::as_primitive_array, exec_err}; +use datafusion::datasource::MemTable; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::{ + AccumulatorFactoryFunction, ColumnarValue, Volatility, create_udaf, create_udf, +}; +use datafusion::physical_plan::Accumulator; +use datafusion::prelude::*; +use datafusion::scalar::ScalarValue; + +pub async fn struct_returning_udaf() -> Result<()> { + let ctx = create_context()?; + + register_session_window(&ctx); + register_augmented_avg(&ctx); + + let sql = " + SELECT + augmented_avg(time, value)['window_start'] AS window_start, Review Comment: nice! ########## datafusion/core/tests/user_defined/user_defined_aggregates.rs: ########## @@ -203,6 +204,97 @@ async fn test_udaf_returning_struct_subquery() { "); } +/// Demonstrates the alternative shape suggested in #16453: pass real input Review Comment: since this is already run in the example, I don't think there is extra value in copying the same code in the integrateion tests (in other words please remove this copy) ########## datafusion-examples/examples/udf/struct_returning_udaf.rs: ########## @@ -0,0 +1,338 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example shows how an extension can return window metadata from an +//! aggregate without using zero-argument aggregate functions. + +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, Float64Array, IntervalMonthDayNanoArray, StructArray, + TimestampNanosecondArray, UInt64Array, +}; +use arrow::datatypes::{ + DataType, Field, Fields, IntervalMonthDayNano, IntervalUnit, Schema, TimeUnit, +}; +use arrow::record_batch::RecordBatch; +use datafusion::assert_batches_eq; +use datafusion::common::{cast::as_primitive_array, exec_err}; +use datafusion::datasource::MemTable; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::{ + AccumulatorFactoryFunction, ColumnarValue, Volatility, create_udaf, create_udf, +}; +use datafusion::physical_plan::Accumulator; +use datafusion::prelude::*; +use datafusion::scalar::ScalarValue; + +pub async fn struct_returning_udaf() -> Result<()> { + let ctx = create_context()?; + + register_session_window(&ctx); + register_augmented_avg(&ctx); + + let sql = " + SELECT + augmented_avg(time, value)['window_start'] AS window_start, + augmented_avg(time, value)['window_end'] AS window_end, + augmented_avg(time, value)['window_duration'] AS window_duration, + augmented_avg(time, value)['avg_value'] AS avg_value + FROM t + GROUP BY session_window(time, INTERVAL '5 microseconds') + ORDER BY window_start + "; + + let results = ctx.sql(sql).await?.collect().await?; + let expected = [ + "+----------------------------+----------------------------+-----------------+-----------+", + "| window_start | window_end | window_duration | avg_value |", + "+----------------------------+----------------------------+-----------------+-----------+", + "| 1970-01-01T00:00:00.000001 | 1970-01-01T00:00:00.000002 | 1000 | 15.0 |", + "| 1970-01-01T00:00:00.000005 | 1970-01-01T00:00:00.000009 | 4000 | 3.0 |", + "+----------------------------+----------------------------+-----------------+-----------+", + ]; + assert_batches_eq!(expected, &results); + + println!("Struct-returning aggregate produced window metadata:"); + ctx.sql(sql).await?.show().await?; + + Ok(()) +} + +fn create_context() -> Result<SessionContext> { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("value", DataType::Float64, false), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(TimestampNanosecondArray::from(vec![ + 1000, 2000, 5000, 7000, 9000, + ])) as ArrayRef, + Arc::new(Float64Array::from(vec![10.0, 20.0, 1.0, 3.0, 5.0])), + ], + )?; + + let ctx = SessionContext::new(); + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +fn register_session_window(ctx: &SessionContext) { Review Comment: See comment below -- I think we can avoid defining a `session-window` function and use `date_bin` (that is already provided in DataFusion) directly ########## datafusion-examples/examples/udf/struct_returning_udaf.rs: ########## @@ -0,0 +1,338 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example shows how an extension can return window metadata from an +//! aggregate without using zero-argument aggregate functions. + +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, Float64Array, IntervalMonthDayNanoArray, StructArray, + TimestampNanosecondArray, UInt64Array, +}; +use arrow::datatypes::{ + DataType, Field, Fields, IntervalMonthDayNano, IntervalUnit, Schema, TimeUnit, +}; +use arrow::record_batch::RecordBatch; +use datafusion::assert_batches_eq; +use datafusion::common::{cast::as_primitive_array, exec_err}; +use datafusion::datasource::MemTable; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::{ + AccumulatorFactoryFunction, ColumnarValue, Volatility, create_udaf, create_udf, +}; +use datafusion::physical_plan::Accumulator; +use datafusion::prelude::*; +use datafusion::scalar::ScalarValue; + +pub async fn struct_returning_udaf() -> Result<()> { + let ctx = create_context()?; + + register_session_window(&ctx); + register_augmented_avg(&ctx); + + let sql = " Review Comment: It might help to give some context as a comment here -- for example ```rust // The `augmented_avg` window function returns both the average as // well as the information about the window from which it was computed ``` ########## docs/source/library-user-guide/functions/adding-udfs.md: ########## @@ -1229,6 +1229,37 @@ The `create_udaf` has six arguments to check: - The fifth argument is the function implementation. This is the function that we defined above. - The sixth argument is the description of the state, which will by passed between execution stages. +### Returning multiple values from an Aggregate UDF + +An aggregate UDF can return a `DataType::Struct` when one aggregate result needs +to carry multiple values. This is useful for selector-style functions and for +time-windowing extensions that need to return metadata such as the window start, +window end, and the aggregate value together. + +Prefer passing the relevant input columns to the aggregate instead of using a +zero-argument aggregate. The input columns give the accumulator enough Review Comment: since you can't make a zero column aggregate this is a strange wording (of course you have to prefer the struct approach as the zero arg aggregate will fail) ########## docs/source/library-user-guide/functions/adding-udfs.md: ########## @@ -1229,6 +1229,37 @@ The `create_udaf` has six arguments to check: - The fifth argument is the function implementation. This is the function that we defined above. - The sixth argument is the description of the state, which will by passed between execution stages. +### Returning multiple values from an Aggregate UDF + +An aggregate UDF can return a `DataType::Struct` when one aggregate result needs +to carry multiple values. This is useful for selector-style functions and for Review Comment: I don't think "selector style" functions is a common term -- so maybe either 1. link to definition like https://docs.influxdata.com/influxdb/v2/query-data/influxql/functions/selectors/ 2. (my preference) jremove the reference to selector-style and just say "useful for windowing extensions ...." ########## docs/source/library-user-guide/functions/adding-udfs.md: ########## @@ -1229,6 +1229,37 @@ The `create_udaf` has six arguments to check: - The fifth argument is the function implementation. This is the function that we defined above. - The sixth argument is the description of the state, which will by passed between execution stages. +### Returning multiple values from an Aggregate UDF + +An aggregate UDF can return a `DataType::Struct` when one aggregate result needs +to carry multiple values. This is useful for selector-style functions and for +time-windowing extensions that need to return metadata such as the window start, +window end, and the aggregate value together. + +Prefer passing the relevant input columns to the aggregate instead of using a +zero-argument aggregate. The input columns give the accumulator enough +information to update and merge state normally in multi-stage aggregate plans. +For example, a windowing extension can group rows with a scalar UDF and return +metadata from a struct-returning aggregate. In this example, `session_window` Review Comment: as above, let's simplify by removing session_window and replacing with date_bin -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
