This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 52f340b78a Support `GroupsAccumulator` for Avg duration (#15748)
52f340b78a is described below

commit 52f340b78a07811ba5f2ce24c3463078acf8d4c0
Author: Shruti Sharma <[email protected]>
AuthorDate: Wed May 21 16:55:37 2025 +0530

    Support `GroupsAccumulator` for Avg duration (#15748)
    
    * Support GroupsAccumulator for avg duration
    
    * update test
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../functions-aggregate-common/src/accumulator.rs  |  9 ++++-
 datafusion/functions-aggregate/src/average.rs      | 41 +++++++++++++++++++++-
 datafusion/sqllogictest/test_files/aggregate.slt   | 22 ++++++++++++
 3 files changed, 70 insertions(+), 2 deletions(-)

diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs 
b/datafusion/functions-aggregate-common/src/accumulator.rs
index f67e2f49dc..37bbd1508c 100644
--- a/datafusion/functions-aggregate-common/src/accumulator.rs
+++ b/datafusion/functions-aggregate-common/src/accumulator.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::datatypes::{Field, Schema};
+use arrow::datatypes::{DataType, Field, Schema};
 use datafusion_common::Result;
 use datafusion_expr_common::accumulator::Accumulator;
 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
@@ -71,6 +71,13 @@ pub struct AccumulatorArgs<'a> {
     pub exprs: &'a [Arc<dyn PhysicalExpr>],
 }
 
+impl AccumulatorArgs<'_> {
+    /// Returns the return type of the aggregate function.
+    pub fn return_type(&self) -> &DataType {
+        self.return_field.data_type()
+    }
+}
+
 /// Factory that returns an accumulator for the given aggregate function.
 pub type AccumulatorFactoryFunction =
     Arc<dyn Fn(AccumulatorArgs) -> Result<Box<dyn Accumulator>> + Send + Sync>;
diff --git a/datafusion/functions-aggregate/src/average.rs 
b/datafusion/functions-aggregate/src/average.rs
index 15b5db2d72..3ca39aa315 100644
--- a/datafusion/functions-aggregate/src/average.rs
+++ b/datafusion/functions-aggregate/src/average.rs
@@ -182,7 +182,7 @@ impl AggregateUDFImpl for Avg {
     fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
         matches!(
             args.return_field.data_type(),
-            DataType::Float64 | DataType::Decimal128(_, _)
+            DataType::Float64 | DataType::Decimal128(_, _) | 
DataType::Duration(_)
         )
     }
 
@@ -243,6 +243,45 @@ impl AggregateUDFImpl for Avg {
                 )))
             }
 
+            (Duration(time_unit), Duration(_result_unit)) => {
+                let avg_fn = move |sum: i64, count: u64| Ok(sum / count as 
i64);
+
+                match time_unit {
+                    TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::<
+                        DurationSecondType,
+                        _,
+                    >::new(
+                        &data_type,
+                        args.return_type(),
+                        avg_fn,
+                    ))),
+                    TimeUnit::Millisecond => 
Ok(Box::new(AvgGroupsAccumulator::<
+                        DurationMillisecondType,
+                        _,
+                    >::new(
+                        &data_type,
+                        args.return_type(),
+                        avg_fn,
+                    ))),
+                    TimeUnit::Microsecond => 
Ok(Box::new(AvgGroupsAccumulator::<
+                        DurationMicrosecondType,
+                        _,
+                    >::new(
+                        &data_type,
+                        args.return_type(),
+                        avg_fn,
+                    ))),
+                    TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::<
+                        DurationNanosecondType,
+                        _,
+                    >::new(
+                        &data_type,
+                        args.return_type(),
+                        avg_fn,
+                    ))),
+                }
+            }
+
             _ => not_impl_err!(
                 "AvgGroupsAccumulator for ({} --> {})",
                 &data_type,
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 19f92ed72e..41ce15d794 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -5098,6 +5098,28 @@ FROM d WHERE column1 IS NOT NULL;
 statement ok
 drop table d;
 
+statement ok
+create table dn as values
+  (arrow_cast(10, 'Duration(Second)'), 'a', 1),
+  (arrow_cast(20, 'Duration(Second)'), 'a', 2),
+  (NULL, 'b', 1),
+  (arrow_cast(40, 'Duration(Second)'), 'b', 2),
+  (arrow_cast(50, 'Duration(Second)'), 'c', 1),
+  (NULL, 'c', 2);
+
+query T?I
+SELECT column2, avg(column1), column3 FROM dn GROUP BY column2, column3 ORDER 
BY column2, column3;
+----
+a 0 days 0 hours 0 mins 10 secs 1
+a 0 days 0 hours 0 mins 20 secs 2
+b NULL 1
+b 0 days 0 hours 0 mins 40 secs 2
+c 0 days 0 hours 0 mins 50 secs 1
+c NULL 2
+
+statement ok
+drop table dn;
+
 # Prepare the table with dictionary values for testing
 statement ok
 CREATE TABLE value(x bigint) AS VALUES (1), (2), (3), (1), (3), (4), (5), (2);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to