This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 634ed28fd7 Support Decimal256 on AVG aggregate expression (#7853)
634ed28fd7 is described below
commit 634ed28fd7a7aedcf48d86d8b578b5fe4e19081a
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Oct 18 22:31:33 2023 -0700
Support Decimal256 on AVG aggregate expression (#7853)
* More
* More
* More
* More
* Fix clippy
---
datafusion/physical-expr/src/aggregate/average.rs | 94 +++++++++++++++++------
datafusion/physical-expr/src/aggregate/utils.rs | 59 ++++++++------
datafusion/sqllogictest/test_files/decimal.slt | 4 +-
3 files changed, 109 insertions(+), 48 deletions(-)
diff --git a/datafusion/physical-expr/src/aggregate/average.rs
b/datafusion/physical-expr/src/aggregate/average.rs
index 92c806f76f..91f2fb952d 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -21,6 +21,7 @@ use arrow::array::{AsArray, PrimitiveBuilder};
use log::debug;
use std::any::Any;
+use std::fmt::Debug;
use std::sync::Arc;
use crate::aggregate::groups_accumulator::accumulate::NullState;
@@ -33,15 +34,17 @@ use arrow::{
array::{ArrayRef, UInt64Array},
datatypes::Field,
};
+use arrow_array::types::{Decimal256Type, DecimalType};
use arrow_array::{
Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType,
PrimitiveArray,
};
+use arrow_buffer::{i256, ArrowNativeType};
use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::type_coercion::aggregates::avg_return_type;
use datafusion_expr::Accumulator;
use super::groups_accumulator::EmitTo;
-use super::utils::Decimal128Averager;
+use super::utils::DecimalAverager;
/// AVG aggregate expression
#[derive(Debug, Clone)]
@@ -88,7 +91,19 @@ impl AggregateExpr for Avg {
(
Decimal128(sum_precision, sum_scale),
Decimal128(target_precision, target_scale),
- ) => Ok(Box::new(DecimalAvgAccumulator {
+ ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
+ sum: None,
+ count: 0,
+ sum_scale: *sum_scale,
+ sum_precision: *sum_precision,
+ target_precision: *target_precision,
+ target_scale: *target_scale,
+ })),
+
+ (
+ Decimal256(sum_precision, sum_scale),
+ Decimal256(target_precision, target_scale),
+ ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
sum: None,
count: 0,
sum_scale: *sum_scale,
@@ -156,7 +171,7 @@ impl AggregateExpr for Avg {
Decimal128(_sum_precision, sum_scale),
Decimal128(target_precision, target_scale),
) => {
- let decimal_averager = Decimal128Averager::try_new(
+ let decimal_averager =
DecimalAverager::<Decimal128Type>::try_new(
*sum_scale,
*target_precision,
*target_scale,
@@ -172,6 +187,27 @@ impl AggregateExpr for Avg {
)))
}
+ (
+ Decimal256(_sum_precision, sum_scale),
+ Decimal256(target_precision, target_scale),
+ ) => {
+ let decimal_averager =
DecimalAverager::<Decimal256Type>::try_new(
+ *sum_scale,
+ *target_precision,
+ *target_scale,
+ )?;
+
+ let avg_fn = move |sum: i256, count: u64| {
+ decimal_averager.avg(sum, i256::from_usize(count as
usize).unwrap())
+ };
+
+ Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
+ &self.input_data_type,
+ &self.result_data_type,
+ avg_fn,
+ )))
+ }
+
_ => not_impl_err!(
"AvgGroupsAccumulator for ({} --> {})",
self.input_data_type,
@@ -256,9 +292,8 @@ impl Accumulator for AvgAccumulator {
}
/// An accumulator to compute the average for decimals
-#[derive(Debug)]
-struct DecimalAvgAccumulator {
- sum: Option<i128>,
+struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType> {
+ sum: Option<T::Native>,
count: u64,
sum_scale: i8,
sum_precision: u8,
@@ -266,30 +301,46 @@ struct DecimalAvgAccumulator {
target_scale: i8,
}
-impl Accumulator for DecimalAvgAccumulator {
+impl<T: DecimalType + ArrowNumericType> Debug for DecimalAvgAccumulator<T> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("DecimalAvgAccumulator")
+ .field("sum", &self.sum)
+ .field("count", &self.count)
+ .field("sum_scale", &self.sum_scale)
+ .field("sum_precision", &self.sum_precision)
+ .field("target_precision", &self.target_precision)
+ .field("target_scale", &self.target_scale)
+ .finish()
+ }
+}
+
+impl<T: DecimalType + ArrowNumericType> Accumulator for
DecimalAvgAccumulator<T> {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
- ScalarValue::Decimal128(self.sum, self.sum_precision,
self.sum_scale),
+ ScalarValue::new_primitive::<T>(
+ self.sum,
+ &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
+ )?,
])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let values = values[0].as_primitive::<Decimal128Type>();
+ let values = values[0].as_primitive::<T>();
self.count += (values.len() - values.null_count()) as u64;
if let Some(x) = sum(values) {
- let v = self.sum.get_or_insert(0);
- *v += x;
+ let v = self.sum.get_or_insert(T::Native::default());
+ self.sum = Some(v.add_wrapping(x));
}
Ok(())
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let values = values[0].as_primitive::<Decimal128Type>();
+ let values = values[0].as_primitive::<T>();
self.count -= (values.len() - values.null_count()) as u64;
if let Some(x) = sum(values) {
- self.sum = Some(self.sum.unwrap() - x);
+ self.sum = Some(self.sum.unwrap().sub_wrapping(x));
}
Ok(())
}
@@ -299,9 +350,9 @@ impl Accumulator for DecimalAvgAccumulator {
self.count +=
sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();
// sums are summed
- if let Some(x) = sum(states[1].as_primitive::<Decimal128Type>()) {
- let v = self.sum.get_or_insert(0);
- *v += x;
+ if let Some(x) = sum(states[1].as_primitive::<T>()) {
+ let v = self.sum.get_or_insert(T::Native::default());
+ self.sum = Some(v.add_wrapping(x));
}
Ok(())
}
@@ -310,20 +361,19 @@ impl Accumulator for DecimalAvgAccumulator {
let v = self
.sum
.map(|v| {
- Decimal128Averager::try_new(
+ DecimalAverager::<T>::try_new(
self.sum_scale,
self.target_precision,
self.target_scale,
)?
- .avg(v, self.count as _)
+ .avg(v, T::Native::from_usize(self.count as usize).unwrap())
})
.transpose()?;
- Ok(ScalarValue::Decimal128(
+ ScalarValue::new_primitive::<T>(
v,
- self.target_precision,
- self.target_scale,
- ))
+ &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
+ )
}
fn supports_retract_batch(&self) -> bool {
true
diff --git a/datafusion/physical-expr/src/aggregate/utils.rs
b/datafusion/physical-expr/src/aggregate/utils.rs
index 2f473f7608..420b26eb2d 100644
--- a/datafusion/physical-expr/src/aggregate/utils.rs
+++ b/datafusion/physical-expr/src/aggregate/utils.rs
@@ -19,12 +19,13 @@
use crate::{AggregateExpr, PhysicalSortExpr};
use arrow::array::ArrayRef;
-use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION,
MIN_DECIMAL_FOR_EACH_PRECISION};
use arrow_array::cast::AsArray;
use arrow_array::types::{
- Decimal128Type, TimestampMicrosecondType, TimestampMillisecondType,
+ Decimal128Type, DecimalType, TimestampMicrosecondType,
TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType,
};
+use arrow_array::ArrowNativeTypeOp;
+use arrow_buffer::ArrowNativeType;
use arrow_schema::{DataType, Field};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::Accumulator;
@@ -42,27 +43,25 @@ pub fn get_accum_scalar_values_as_arrays(
.collect::<Vec<_>>())
}
-/// Computes averages for `Decimal128` values, checking for overflow
+/// Computes averages for `Decimal128`/`Decimal256` values, checking for
overflow
///
-/// This is needed because different precisions for Decimal128 can
+/// This is needed because different precisions for Decimal128/Decimal256 can
/// store different ranges of values and thus sum/count may not fit in
/// the target type.
///
/// For example, the precision is 3, the max of value is `999` and the min
/// value is `-999`
-pub(crate) struct Decimal128Averager {
+pub(crate) struct DecimalAverager<T: DecimalType> {
/// scale factor for sum values (10^sum_scale)
- sum_mul: i128,
+ sum_mul: T::Native,
/// scale factor for target (10^target_scale)
- target_mul: i128,
- /// The minimum output value possible to represent with the target
precision
- target_min: i128,
- /// The maximum output value possible to represent with the target
precision
- target_max: i128,
+ target_mul: T::Native,
+ /// the output precision
+ target_precision: u8,
}
-impl Decimal128Averager {
- /// Create a new `Decimal128Averager`:
+impl<T: DecimalType> DecimalAverager<T> {
+ /// Create a new `DecimalAverager`:
///
/// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
/// * target_precision: the output precision
@@ -74,17 +73,23 @@ impl Decimal128Averager {
target_precision: u8,
target_scale: i8,
) -> Result<Self> {
- let sum_mul = 10_i128.pow(sum_scale as u32);
- let target_mul = 10_i128.pow(target_scale as u32);
- let target_min = MIN_DECIMAL_FOR_EACH_PRECISION[target_precision as
usize - 1];
- let target_max = MAX_DECIMAL_FOR_EACH_PRECISION[target_precision as
usize - 1];
+ let sum_mul = T::Native::from_usize(10_usize)
+ .map(|b| b.pow_wrapping(sum_scale as u32))
+ .ok_or(DataFusionError::Internal(
+ "Failed to compute sum_mul in DecimalAverager".to_string(),
+ ))?;
+
+ let target_mul = T::Native::from_usize(10_usize)
+ .map(|b| b.pow_wrapping(target_scale as u32))
+ .ok_or(DataFusionError::Internal(
+ "Failed to compute target_mul in DecimalAverager".to_string(),
+ ))?;
if target_mul >= sum_mul {
Ok(Self {
sum_mul,
target_mul,
- target_min,
- target_max,
+ target_precision,
})
} else {
// can't convert the lit decimal to the returned data type
@@ -92,17 +97,21 @@ impl Decimal128Averager {
}
}
- /// Returns the `sum`/`count` as a i128 Decimal128 with
+ /// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with
/// target_scale and target_precision and reporting overflow.
///
/// * sum: The total sum value stored as Decimal128 with sum_scale
/// (passed to `Self::try_new`)
- /// * count: total count, stored as a i128 (*NOT* a Decimal128 value)
+ /// * count: total count, stored as a i128/i256 (*NOT* a
Decimal128/Decimal256 value)
#[inline(always)]
- pub fn avg(&self, sum: i128, count: i128) -> Result<i128> {
- if let Some(value) = sum.checked_mul(self.target_mul / self.sum_mul) {
- let new_value = value / count;
- if new_value >= self.target_min && new_value <= self.target_max {
+ pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
+ if let Ok(value) =
sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
+ let new_value = value.div_wrapping(count);
+
+ let validate =
+ T::validate_decimal_precision(new_value,
self.target_precision);
+
+ if validate.is_ok() {
Ok(new_value)
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
diff --git a/datafusion/sqllogictest/test_files/decimal.slt
b/datafusion/sqllogictest/test_files/decimal.slt
index d7632138a8..570116b7a2 100644
--- a/datafusion/sqllogictest/test_files/decimal.slt
+++ b/datafusion/sqllogictest/test_files/decimal.slt
@@ -622,8 +622,10 @@ create table t as values (arrow_cast(123,
'Decimal256(5,2)'));
statement ok
set datafusion.execution.target_partitions = 1;
-query error DataFusion error: This feature is not implemented: AvgAccumulator
for \(Decimal256\(5, 2\) --> Decimal256\(9, 6\)\)
+query R
select AVG(column1) from t;
+----
+123
statement ok
drop table t;