This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 634ed28fd7 Support Decimal256 on AVG aggregate expression (#7853)
634ed28fd7 is described below

commit 634ed28fd7a7aedcf48d86d8b578b5fe4e19081a
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Oct 18 22:31:33 2023 -0700

    Support Decimal256 on AVG aggregate expression (#7853)
    
    * More
    
    * More
    
    * More
    
    * More
    
    * Fix clippy
---
 datafusion/physical-expr/src/aggregate/average.rs | 94 +++++++++++++++++------
 datafusion/physical-expr/src/aggregate/utils.rs   | 59 ++++++++------
 datafusion/sqllogictest/test_files/decimal.slt    |  4 +-
 3 files changed, 109 insertions(+), 48 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/average.rs 
b/datafusion/physical-expr/src/aggregate/average.rs
index 92c806f76f..91f2fb952d 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -21,6 +21,7 @@ use arrow::array::{AsArray, PrimitiveBuilder};
 use log::debug;
 
 use std::any::Any;
+use std::fmt::Debug;
 use std::sync::Arc;
 
 use crate::aggregate::groups_accumulator::accumulate::NullState;
@@ -33,15 +34,17 @@ use arrow::{
     array::{ArrayRef, UInt64Array},
     datatypes::Field,
 };
+use arrow_array::types::{Decimal256Type, DecimalType};
 use arrow_array::{
     Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, 
PrimitiveArray,
 };
+use arrow_buffer::{i256, ArrowNativeType};
 use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
 use datafusion_expr::type_coercion::aggregates::avg_return_type;
 use datafusion_expr::Accumulator;
 
 use super::groups_accumulator::EmitTo;
-use super::utils::Decimal128Averager;
+use super::utils::DecimalAverager;
 
 /// AVG aggregate expression
 #[derive(Debug, Clone)]
@@ -88,7 +91,19 @@ impl AggregateExpr for Avg {
             (
                 Decimal128(sum_precision, sum_scale),
                 Decimal128(target_precision, target_scale),
-            ) => Ok(Box::new(DecimalAvgAccumulator {
+            ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
+                sum: None,
+                count: 0,
+                sum_scale: *sum_scale,
+                sum_precision: *sum_precision,
+                target_precision: *target_precision,
+                target_scale: *target_scale,
+            })),
+
+            (
+                Decimal256(sum_precision, sum_scale),
+                Decimal256(target_precision, target_scale),
+            ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
                 sum: None,
                 count: 0,
                 sum_scale: *sum_scale,
@@ -156,7 +171,7 @@ impl AggregateExpr for Avg {
                 Decimal128(_sum_precision, sum_scale),
                 Decimal128(target_precision, target_scale),
             ) => {
-                let decimal_averager = Decimal128Averager::try_new(
+                let decimal_averager = 
DecimalAverager::<Decimal128Type>::try_new(
                     *sum_scale,
                     *target_precision,
                     *target_scale,
@@ -172,6 +187,27 @@ impl AggregateExpr for Avg {
                 )))
             }
 
+            (
+                Decimal256(_sum_precision, sum_scale),
+                Decimal256(target_precision, target_scale),
+            ) => {
+                let decimal_averager = 
DecimalAverager::<Decimal256Type>::try_new(
+                    *sum_scale,
+                    *target_precision,
+                    *target_scale,
+                )?;
+
+                let avg_fn = move |sum: i256, count: u64| {
+                    decimal_averager.avg(sum, i256::from_usize(count as 
usize).unwrap())
+                };
+
+                Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
+                    &self.input_data_type,
+                    &self.result_data_type,
+                    avg_fn,
+                )))
+            }
+
             _ => not_impl_err!(
                 "AvgGroupsAccumulator for ({} --> {})",
                 self.input_data_type,
@@ -256,9 +292,8 @@ impl Accumulator for AvgAccumulator {
 }
 
 /// An accumulator to compute the average for decimals
-#[derive(Debug)]
-struct DecimalAvgAccumulator {
-    sum: Option<i128>,
+struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType> {
+    sum: Option<T::Native>,
     count: u64,
     sum_scale: i8,
     sum_precision: u8,
@@ -266,30 +301,46 @@ struct DecimalAvgAccumulator {
     target_scale: i8,
 }
 
-impl Accumulator for DecimalAvgAccumulator {
+impl<T: DecimalType + ArrowNumericType> Debug for DecimalAvgAccumulator<T> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("DecimalAvgAccumulator")
+            .field("sum", &self.sum)
+            .field("count", &self.count)
+            .field("sum_scale", &self.sum_scale)
+            .field("sum_precision", &self.sum_precision)
+            .field("target_precision", &self.target_precision)
+            .field("target_scale", &self.target_scale)
+            .finish()
+    }
+}
+
+impl<T: DecimalType + ArrowNumericType> Accumulator for 
DecimalAvgAccumulator<T> {
     fn state(&self) -> Result<Vec<ScalarValue>> {
         Ok(vec![
             ScalarValue::from(self.count),
-            ScalarValue::Decimal128(self.sum, self.sum_precision, 
self.sum_scale),
+            ScalarValue::new_primitive::<T>(
+                self.sum,
+                &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
+            )?,
         ])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let values = values[0].as_primitive::<Decimal128Type>();
+        let values = values[0].as_primitive::<T>();
 
         self.count += (values.len() - values.null_count()) as u64;
         if let Some(x) = sum(values) {
-            let v = self.sum.get_or_insert(0);
-            *v += x;
+            let v = self.sum.get_or_insert(T::Native::default());
+            self.sum = Some(v.add_wrapping(x));
         }
         Ok(())
     }
 
     fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let values = values[0].as_primitive::<Decimal128Type>();
+        let values = values[0].as_primitive::<T>();
         self.count -= (values.len() - values.null_count()) as u64;
         if let Some(x) = sum(values) {
-            self.sum = Some(self.sum.unwrap() - x);
+            self.sum = Some(self.sum.unwrap().sub_wrapping(x));
         }
         Ok(())
     }
@@ -299,9 +350,9 @@ impl Accumulator for DecimalAvgAccumulator {
         self.count += 
sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
 
         // sums are summed
-        if let Some(x) = sum(states[1].as_primitive::<Decimal128Type>()) {
-            let v = self.sum.get_or_insert(0);
-            *v += x;
+        if let Some(x) = sum(states[1].as_primitive::<T>()) {
+            let v = self.sum.get_or_insert(T::Native::default());
+            self.sum = Some(v.add_wrapping(x));
         }
         Ok(())
     }
@@ -310,20 +361,19 @@ impl Accumulator for DecimalAvgAccumulator {
         let v = self
             .sum
             .map(|v| {
-                Decimal128Averager::try_new(
+                DecimalAverager::<T>::try_new(
                     self.sum_scale,
                     self.target_precision,
                     self.target_scale,
                 )?
-                .avg(v, self.count as _)
+                .avg(v, T::Native::from_usize(self.count as usize).unwrap())
             })
             .transpose()?;
 
-        Ok(ScalarValue::Decimal128(
+        ScalarValue::new_primitive::<T>(
             v,
-            self.target_precision,
-            self.target_scale,
-        ))
+            &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
+        )
     }
     fn supports_retract_batch(&self) -> bool {
         true
diff --git a/datafusion/physical-expr/src/aggregate/utils.rs 
b/datafusion/physical-expr/src/aggregate/utils.rs
index 2f473f7608..420b26eb2d 100644
--- a/datafusion/physical-expr/src/aggregate/utils.rs
+++ b/datafusion/physical-expr/src/aggregate/utils.rs
@@ -19,12 +19,13 @@
 
 use crate::{AggregateExpr, PhysicalSortExpr};
 use arrow::array::ArrayRef;
-use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, 
MIN_DECIMAL_FOR_EACH_PRECISION};
 use arrow_array::cast::AsArray;
 use arrow_array::types::{
-    Decimal128Type, TimestampMicrosecondType, TimestampMillisecondType,
+    Decimal128Type, DecimalType, TimestampMicrosecondType, 
TimestampMillisecondType,
     TimestampNanosecondType, TimestampSecondType,
 };
+use arrow_array::ArrowNativeTypeOp;
+use arrow_buffer::ArrowNativeType;
 use arrow_schema::{DataType, Field};
 use datafusion_common::{exec_err, DataFusionError, Result};
 use datafusion_expr::Accumulator;
@@ -42,27 +43,25 @@ pub fn get_accum_scalar_values_as_arrays(
         .collect::<Vec<_>>())
 }
 
-/// Computes averages for `Decimal128` values, checking for overflow
+/// Computes averages for `Decimal128`/`Decimal256` values, checking for 
overflow
 ///
-/// This is needed because different precisions for Decimal128 can
+/// This is needed because different precisions for Decimal128/Decimal256 can
 /// store different ranges of values and thus sum/count may not fit in
 /// the target type.
 ///
 /// For example, the precision is 3, the max of value is `999` and the min
 /// value is `-999`
-pub(crate) struct Decimal128Averager {
+pub(crate) struct DecimalAverager<T: DecimalType> {
     /// scale factor for sum values (10^sum_scale)
-    sum_mul: i128,
+    sum_mul: T::Native,
     /// scale factor for target (10^target_scale)
-    target_mul: i128,
-    /// The minimum output value possible to represent with the target 
precision
-    target_min: i128,
-    /// The maximum output value possible to represent with the target 
precision
-    target_max: i128,
+    target_mul: T::Native,
+    /// the output precision
+    target_precision: u8,
 }
 
-impl Decimal128Averager {
-    /// Create a new `Decimal128Averager`:
+impl<T: DecimalType> DecimalAverager<T> {
+    /// Create a new `DecimalAverager`:
     ///
     /// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
     /// * target_precision: the output precision
@@ -74,17 +73,23 @@ impl Decimal128Averager {
         target_precision: u8,
         target_scale: i8,
     ) -> Result<Self> {
-        let sum_mul = 10_i128.pow(sum_scale as u32);
-        let target_mul = 10_i128.pow(target_scale as u32);
-        let target_min = MIN_DECIMAL_FOR_EACH_PRECISION[target_precision as 
usize - 1];
-        let target_max = MAX_DECIMAL_FOR_EACH_PRECISION[target_precision as 
usize - 1];
+        let sum_mul = T::Native::from_usize(10_usize)
+            .map(|b| b.pow_wrapping(sum_scale as u32))
+            .ok_or(DataFusionError::Internal(
+                "Failed to compute sum_mul in DecimalAverager".to_string(),
+            ))?;
+
+        let target_mul = T::Native::from_usize(10_usize)
+            .map(|b| b.pow_wrapping(target_scale as u32))
+            .ok_or(DataFusionError::Internal(
+                "Failed to compute target_mul in DecimalAverager".to_string(),
+            ))?;
 
         if target_mul >= sum_mul {
             Ok(Self {
                 sum_mul,
                 target_mul,
-                target_min,
-                target_max,
+                target_precision,
             })
         } else {
             // can't convert the lit decimal to the returned data type
@@ -92,17 +97,21 @@ impl Decimal128Averager {
         }
     }
 
-    /// Returns the `sum`/`count` as a i128 Decimal128 with
+    /// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with
     /// target_scale and target_precision and reporting overflow.
     ///
     /// * sum: The total sum value stored as Decimal128 with sum_scale
     /// (passed to `Self::try_new`)
-    /// * count: total count, stored as a i128 (*NOT* a Decimal128 value)
+    /// * count: total count, stored as a i128/i256 (*NOT* a 
Decimal128/Decimal256 value)
     #[inline(always)]
-    pub fn avg(&self, sum: i128, count: i128) -> Result<i128> {
-        if let Some(value) = sum.checked_mul(self.target_mul / self.sum_mul) {
-            let new_value = value / count;
-            if new_value >= self.target_min && new_value <= self.target_max {
+    pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
+        if let Ok(value) = 
sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
+            let new_value = value.div_wrapping(count);
+
+            let validate =
+                T::validate_decimal_precision(new_value, 
self.target_precision);
+
+            if validate.is_ok() {
                 Ok(new_value)
             } else {
                 exec_err!("Arithmetic Overflow in AvgAccumulator")
diff --git a/datafusion/sqllogictest/test_files/decimal.slt 
b/datafusion/sqllogictest/test_files/decimal.slt
index d7632138a8..570116b7a2 100644
--- a/datafusion/sqllogictest/test_files/decimal.slt
+++ b/datafusion/sqllogictest/test_files/decimal.slt
@@ -622,8 +622,10 @@ create table t as values (arrow_cast(123, 
'Decimal256(5,2)'));
 statement ok
 set datafusion.execution.target_partitions = 1;
 
-query error DataFusion error: This feature is not implemented: AvgAccumulator 
for \(Decimal256\(5, 2\) --> Decimal256\(9, 6\)\)
+query R
 select AVG(column1) from t;
+----
+123
 
 statement ok
 drop table t;

Reply via email to