This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new eef654c3b0 Introduce return type for aggregate sum (#8141)
eef654c3b0 is described below

commit eef654c3b0c22b1f845b1441320b8bb718ddd605
Author: Jay Zhan <[email protected]>
AuthorDate: Tue Nov 14 22:29:47 2023 +0800

    Introduce return type for aggregate sum (#8141)
    
    * introduce return type for aggregate sum
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix state field type
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix state field
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion/physical-expr/src/aggregate/sum.rs      | 26 +++++++++++++---------
 .../physical-expr/src/aggregate/sum_distinct.rs    | 11 +++++----
 2 files changed, 22 insertions(+), 15 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/sum.rs 
b/datafusion/physical-expr/src/aggregate/sum.rs
index d6c23d0dfa..03f666cc4e 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -41,7 +41,10 @@ use datafusion_expr::Accumulator;
 #[derive(Debug, Clone)]
 pub struct Sum {
     name: String,
+    // The DataType for the input expression
     data_type: DataType,
+    // The DataType for the final sum
+    return_type: DataType,
     expr: Arc<dyn PhysicalExpr>,
     nullable: bool,
 }
@@ -53,11 +56,12 @@ impl Sum {
         name: impl Into<String>,
         data_type: DataType,
     ) -> Self {
-        let data_type = sum_return_type(&data_type).unwrap();
+        let return_type = sum_return_type(&data_type).unwrap();
         Self {
             name: name.into(),
-            expr,
             data_type,
+            return_type,
+            expr,
             nullable: true,
         }
     }
@@ -70,13 +74,13 @@ impl Sum {
 /// `s` is a `Sum`, `helper` is a macro accepting (ArrowPrimitiveType, 
DataType)
 macro_rules! downcast_sum {
     ($s:ident, $helper:ident) => {
-        match $s.data_type {
-            DataType::UInt64 => $helper!(UInt64Type, $s.data_type),
-            DataType::Int64 => $helper!(Int64Type, $s.data_type),
-            DataType::Float64 => $helper!(Float64Type, $s.data_type),
-            DataType::Decimal128(_, _) => $helper!(Decimal128Type, 
$s.data_type),
-            DataType::Decimal256(_, _) => $helper!(Decimal256Type, 
$s.data_type),
-            _ => not_impl_err!("Sum not supported for {}: {}", $s.name, 
$s.data_type),
+        match $s.return_type {
+            DataType::UInt64 => $helper!(UInt64Type, $s.return_type),
+            DataType::Int64 => $helper!(Int64Type, $s.return_type),
+            DataType::Float64 => $helper!(Float64Type, $s.return_type),
+            DataType::Decimal128(_, _) => $helper!(Decimal128Type, 
$s.return_type),
+            DataType::Decimal256(_, _) => $helper!(Decimal256Type, 
$s.return_type),
+            _ => not_impl_err!("Sum not supported for {}: {}", $s.name, 
$s.return_type),
         }
     };
 }
@@ -91,7 +95,7 @@ impl AggregateExpr for Sum {
     fn field(&self) -> Result<Field> {
         Ok(Field::new(
             &self.name,
-            self.data_type.clone(),
+            self.return_type.clone(),
             self.nullable,
         ))
     }
@@ -108,7 +112,7 @@ impl AggregateExpr for Sum {
     fn state_fields(&self) -> Result<Vec<Field>> {
         Ok(vec![Field::new(
             format_state_name(&self.name, "sum"),
-            self.data_type.clone(),
+            self.return_type.clone(),
             self.nullable,
         )])
     }
diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs 
b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
index ef1bd039a5..0cf4a90ab8 100644
--- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
@@ -40,8 +40,10 @@ use datafusion_expr::Accumulator;
 pub struct DistinctSum {
     /// Column name
     name: String,
-    /// The DataType for the final sum
+    // The DataType for the input expression
     data_type: DataType,
+    // The DataType for the final sum
+    return_type: DataType,
     /// The input arguments, only contains 1 item for sum
     exprs: Vec<Arc<dyn PhysicalExpr>>,
 }
@@ -53,10 +55,11 @@ impl DistinctSum {
         name: String,
         data_type: DataType,
     ) -> Self {
-        let data_type = sum_return_type(&data_type).unwrap();
+        let return_type = sum_return_type(&data_type).unwrap();
         Self {
             name,
             data_type,
+            return_type,
             exprs,
         }
     }
@@ -68,14 +71,14 @@ impl AggregateExpr for DistinctSum {
     }
 
     fn field(&self) -> Result<Field> {
-        Ok(Field::new(&self.name, self.data_type.clone(), true))
+        Ok(Field::new(&self.name, self.return_type.clone(), true))
     }
 
     fn state_fields(&self) -> Result<Vec<Field>> {
         // State field is a List which stores items to rebuild hash set.
         Ok(vec![Field::new_list(
             format_state_name(&self.name, "sum distinct"),
-            Field::new("item", self.data_type.clone(), true),
+            Field::new("item", self.return_type.clone(), true),
             false,
         )])
     }

Reply via email to