This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new aabd7ad416 Performance: Use a specialized sum accumulator for
retractable aggregates (#6888)
aabd7ad416 is described below
commit aabd7ad416a18deaa0e426000789cc9ba7cd1c56
Author: Andrew Lamb <[email protected]>
AuthorDate: Sun Jul 9 03:00:59 2023 -0400
Performance: Use a specialized sum accumulator for retractable aggregates
(#6888)
---
datafusion/physical-expr/src/aggregate/sum.rs | 73 +++++++++++++++++++++------
1 file changed, 57 insertions(+), 16 deletions(-)
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs
b/datafusion/physical-expr/src/aggregate/sum.rs
index efa55f0602..29996eaf5c 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-//! Defines physical expressions that can evaluated at runtime during query
execution
+//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators
use std::any::Any;
use std::convert::TryFrom;
@@ -105,18 +105,11 @@ impl AggregateExpr for Sum {
}
fn state_fields(&self) -> Result<Vec<Field>> {
- Ok(vec![
- Field::new(
- format_state_name(&self.name, "sum"),
- self.data_type.clone(),
- self.nullable,
- ),
- Field::new(
- format_state_name(&self.name, "count"),
- DataType::UInt64,
- self.nullable,
- ),
- ])
+ Ok(vec![Field::new(
+ format_state_name(&self.name, "sum"),
+ self.data_type.clone(),
+ self.nullable,
+ )])
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
@@ -146,7 +139,7 @@ impl AggregateExpr for Sum {
}
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- Ok(Box::new(SumAccumulator::try_new(&self.data_type)?))
+ Ok(Box::new(SlidingSumAccumulator::try_new(&self.data_type)?))
}
}
@@ -164,10 +157,10 @@ impl PartialEq<dyn Any> for Sum {
}
}
+/// This accumulator computes SUM incrementally
#[derive(Debug)]
struct SumAccumulator {
sum: ScalarValue,
- count: u64,
}
impl SumAccumulator {
@@ -175,12 +168,32 @@ impl SumAccumulator {
pub fn try_new(data_type: &DataType) -> Result<Self> {
Ok(Self {
sum: ScalarValue::try_from(data_type)?,
+ })
+ }
+}
+
+/// This accumulator incrementally computes sums over a sliding window
+#[derive(Debug)]
+struct SlidingSumAccumulator {
+ sum: ScalarValue,
+ count: u64,
+}
+
+impl SlidingSumAccumulator {
+ /// new sum accumulator
+ pub fn try_new(data_type: &DataType) -> Result<Self> {
+ Ok(Self {
+ // start at zero
+ sum: ScalarValue::try_from(data_type)?,
count: 0,
})
}
}
-// returns the new value after sum with the new values, taking nullability
into account
+/// Sums the contents of the `$VALUES` array using the arrow compute
+/// kernel, and return a `ScalarValue::$SCALAR`.
+///
+/// Handles nullability
macro_rules! typed_sum_delta_batch {
($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{
let array = downcast_value!($VALUES, $ARRAYTYPE);
@@ -322,6 +335,34 @@ pub(crate) fn update_avg_to_row(
}
impl Accumulator for SumAccumulator {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![self.sum.clone()])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values = &values[0];
+ let delta = sum_batch(values, &self.sum.get_datatype())?;
+ self.sum = self.sum.add(&delta)?;
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ // sum(sum1, sum2, sum3, ...) = sum1 + sum2 + sum3 + ...
+ self.update_batch(states)
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ // TODO: add the checker for overflow
+ // For the decimal(precision,_) data type, the absolute of value must
be less than 10^precision.
+ Ok(self.sum.clone())
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) +
self.sum.size()
+ }
+}
+
+impl Accumulator for SlidingSumAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.sum.clone(), ScalarValue::from(self.count)])
}