This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new eef654c3b0 Introduce return type for aggregate sum (#8141)
eef654c3b0 is described below
commit eef654c3b0c22b1f845b1441320b8bb718ddd605
Author: Jay Zhan <[email protected]>
AuthorDate: Tue Nov 14 22:29:47 2023 +0800
Introduce return type for aggregate sum (#8141)
* introduce return type for aggregate sum
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* fix state field type
Signed-off-by: jayzhan211 <[email protected]>
* fix state field
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/physical-expr/src/aggregate/sum.rs | 26 +++++++++++++---------
.../physical-expr/src/aggregate/sum_distinct.rs | 11 +++++----
2 files changed, 22 insertions(+), 15 deletions(-)
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs
b/datafusion/physical-expr/src/aggregate/sum.rs
index d6c23d0dfa..03f666cc4e 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -41,7 +41,10 @@ use datafusion_expr::Accumulator;
#[derive(Debug, Clone)]
pub struct Sum {
name: String,
+ // The DataType for the input expression
data_type: DataType,
+ // The DataType for the final sum
+ return_type: DataType,
expr: Arc<dyn PhysicalExpr>,
nullable: bool,
}
@@ -53,11 +56,12 @@ impl Sum {
name: impl Into<String>,
data_type: DataType,
) -> Self {
- let data_type = sum_return_type(&data_type).unwrap();
+ let return_type = sum_return_type(&data_type).unwrap();
Self {
name: name.into(),
- expr,
data_type,
+ return_type,
+ expr,
nullable: true,
}
}
@@ -70,13 +74,13 @@ impl Sum {
/// `s` is a `Sum`, `helper` is a macro accepting (ArrowPrimitiveType,
DataType)
macro_rules! downcast_sum {
($s:ident, $helper:ident) => {
- match $s.data_type {
- DataType::UInt64 => $helper!(UInt64Type, $s.data_type),
- DataType::Int64 => $helper!(Int64Type, $s.data_type),
- DataType::Float64 => $helper!(Float64Type, $s.data_type),
- DataType::Decimal128(_, _) => $helper!(Decimal128Type,
$s.data_type),
- DataType::Decimal256(_, _) => $helper!(Decimal256Type,
$s.data_type),
- _ => not_impl_err!("Sum not supported for {}: {}", $s.name,
$s.data_type),
+ match $s.return_type {
+ DataType::UInt64 => $helper!(UInt64Type, $s.return_type),
+ DataType::Int64 => $helper!(Int64Type, $s.return_type),
+ DataType::Float64 => $helper!(Float64Type, $s.return_type),
+ DataType::Decimal128(_, _) => $helper!(Decimal128Type,
$s.return_type),
+ DataType::Decimal256(_, _) => $helper!(Decimal256Type,
$s.return_type),
+ _ => not_impl_err!("Sum not supported for {}: {}", $s.name,
$s.return_type),
}
};
}
@@ -91,7 +95,7 @@ impl AggregateExpr for Sum {
fn field(&self) -> Result<Field> {
Ok(Field::new(
&self.name,
- self.data_type.clone(),
+ self.return_type.clone(),
self.nullable,
))
}
@@ -108,7 +112,7 @@ impl AggregateExpr for Sum {
fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new(
format_state_name(&self.name, "sum"),
- self.data_type.clone(),
+ self.return_type.clone(),
self.nullable,
)])
}
diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs
b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
index ef1bd039a5..0cf4a90ab8 100644
--- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
@@ -40,8 +40,10 @@ use datafusion_expr::Accumulator;
pub struct DistinctSum {
/// Column name
name: String,
- /// The DataType for the final sum
+ // The DataType for the input expression
data_type: DataType,
+ // The DataType for the final sum
+ return_type: DataType,
/// The input arguments, only contains 1 item for sum
exprs: Vec<Arc<dyn PhysicalExpr>>,
}
@@ -53,10 +55,11 @@ impl DistinctSum {
name: String,
data_type: DataType,
) -> Self {
- let data_type = sum_return_type(&data_type).unwrap();
+ let return_type = sum_return_type(&data_type).unwrap();
Self {
name,
data_type,
+ return_type,
exprs,
}
}
@@ -68,14 +71,14 @@ impl AggregateExpr for DistinctSum {
}
fn field(&self) -> Result<Field> {
- Ok(Field::new(&self.name, self.data_type.clone(), true))
+ Ok(Field::new(&self.name, self.return_type.clone(), true))
}
fn state_fields(&self) -> Result<Vec<Field>> {
// State field is a List which stores items to rebuild hash set.
Ok(vec![Field::new_list(
format_state_name(&self.name, "sum distinct"),
- Field::new("item", self.data_type.clone(), true),
+ Field::new("item", self.return_type.clone(), true),
false,
)])
}