This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new f2328870f feat: Support ANSI mode sum expr (int inputs) (#2600)
f2328870f is described below

commit f2328870fb6a15cfa1b19cba11614c970b0d4171
Author: B Vadlamani <[email protected]>
AuthorDate: Tue Dec 23 14:06:00 2025 -0800

    feat: Support ANSI mode sum expr (int inputs) (#2600)
---
 docs/source/user-guide/latest/compatibility.md     |   3 +-
 native/core/src/execution/planner.rs               |   7 +
 native/spark-expr/src/agg_funcs/mod.rs             |   2 +
 native/spark-expr/src/agg_funcs/sum_int.rs         | 589 +++++++++++++++++++++
 .../scala/org/apache/comet/serde/aggregates.scala  |  15 +-
 .../apache/comet/exec/CometAggregateSuite.scala    | 277 ++++++++--
 .../spark/sql/comet/CometPlanStabilitySuite.scala  |   3 +-
 7 files changed, 840 insertions(+), 56 deletions(-)

diff --git a/docs/source/user-guide/latest/compatibility.md 
b/docs/source/user-guide/latest/compatibility.md
index 58dd8d6ab..3d2c9a7b5 100644
--- a/docs/source/user-guide/latest/compatibility.md
+++ b/docs/source/user-guide/latest/compatibility.md
@@ -32,12 +32,11 @@ Comet has the following limitations when reading Parquet 
files:
 
 ## ANSI Mode
 
-Comet will fall back to Spark for the following expressions when ANSI mode is 
enabled. Thes expressions can be enabled by setting
+Comet will fall back to Spark for the following expressions when ANSI mode is 
enabled. These expressions can be enabled by setting
 `spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is 
the Spark expression class name. See
 the [Comet Supported Expressions Guide](expressions.md) for more information 
on this configuration setting.
 
 - Average
-- Sum
 - Cast (in some cases)
 
 There is an [epic](https://github.com/apache/datafusion-comet/issues/313) 
where we are tracking the work to fully implement ANSI support.
diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index 56de19d67..8e8191dd0 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -71,6 +71,7 @@ use datafusion::{
 use datafusion_comet_spark_expr::{
     create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, 
BinaryOutputStyle,
     BloomFilterAgg, BloomFilterMightContain, EvalMode, SparkHour, SparkMinute, 
SparkSecond,
+    SumInteger,
 };
 use iceberg::expr::Bind;
 
@@ -1813,6 +1814,12 @@ impl PhysicalPlanner {
                             
AggregateUDF::new_from_impl(SumDecimal::try_new(datatype, eval_mode)?);
                         AggregateExprBuilder::new(Arc::new(func), vec![child])
                     }
+                    DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64 => {
+                        let eval_mode = 
from_protobuf_eval_mode(expr.eval_mode)?;
+                        let func =
+                            
AggregateUDF::new_from_impl(SumInteger::try_new(datatype, eval_mode)?);
+                        AggregateExprBuilder::new(Arc::new(func), vec![child])
+                    }
                     _ => {
                         // cast to the result data type of SUM if necessary, 
we should not expect
                         // a cast failure since it should have already been 
checked at Spark side
diff --git a/native/spark-expr/src/agg_funcs/mod.rs 
b/native/spark-expr/src/agg_funcs/mod.rs
index 252da7889..b1027153e 100644
--- a/native/spark-expr/src/agg_funcs/mod.rs
+++ b/native/spark-expr/src/agg_funcs/mod.rs
@@ -21,6 +21,7 @@ mod correlation;
 mod covariance;
 mod stddev;
 mod sum_decimal;
+mod sum_int;
 mod variance;
 
 pub use avg::Avg;
@@ -29,4 +30,5 @@ pub use correlation::Correlation;
 pub use covariance::Covariance;
 pub use stddev::Stddev;
 pub use sum_decimal::SumDecimal;
+pub use sum_int::SumInteger;
 pub use variance::Variance;
diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs 
b/native/spark-expr/src/agg_funcs/sum_int.rs
new file mode 100644
index 000000000..d226c5ede
--- /dev/null
+++ b/native/spark-expr/src/agg_funcs/sum_int.rs
@@ -0,0 +1,589 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::{arithmetic_overflow_error, EvalMode};
+use arrow::array::{
+    as_primitive_array, cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp, 
ArrowPrimitiveType,
+    BooleanArray, Int64Array, PrimitiveArray,
+};
+use arrow::datatypes::{
+    ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type, 
Int64Type, Int8Type,
+};
+use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue};
+use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
+use datafusion::logical_expr::Volatility::Immutable;
+use datafusion::logical_expr::{
+    Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, 
Signature,
+};
+use std::{any::Any, sync::Arc};
+
+#[derive(Debug, PartialEq, Eq, Hash)]
+pub struct SumInteger {
+    signature: Signature,
+    eval_mode: EvalMode,
+}
+
+impl SumInteger {
+    pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult<Self> 
{
+        match data_type {
+            DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64 => Ok(Self {
+                signature: Signature::user_defined(Immutable),
+                eval_mode,
+            }),
+            _ => Err(DataFusionError::Internal(
+                "Invalid data type for SumInteger".into(),
+            )),
+        }
+    }
+}
+
+impl AggregateUDFImpl for SumInteger {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "sum"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
+        Ok(DataType::Int64)
+    }
+
+    fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult<Box<dyn 
Accumulator>> {
+        Ok(Box::new(SumIntegerAccumulator::new(self.eval_mode)))
+    }
+
+    fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<FieldRef>> {
+        if self.eval_mode == EvalMode::Try {
+            Ok(vec![
+                Arc::new(Field::new("sum", DataType::Int64, true)),
+                Arc::new(Field::new("has_all_nulls", DataType::Boolean, 
false)),
+            ])
+        } else {
+            Ok(vec![Arc::new(Field::new("sum", DataType::Int64, true))])
+        }
+    }
+
+    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
+        true
+    }
+
+    fn create_groups_accumulator(
+        &self,
+        _args: AccumulatorArgs,
+    ) -> DFResult<Box<dyn GroupsAccumulator>> {
+        Ok(Box::new(SumIntGroupsAccumulator::new(self.eval_mode)))
+    }
+
+    fn reverse_expr(&self) -> ReversedUDAF {
+        ReversedUDAF::Identical
+    }
+}
+
+#[derive(Debug)]
+struct SumIntegerAccumulator {
+    sum: Option<i64>,
+    eval_mode: EvalMode,
+    has_all_nulls: bool,
+}
+
+impl SumIntegerAccumulator {
+    fn new(eval_mode: EvalMode) -> Self {
+        if eval_mode == EvalMode::Try {
+            Self {
+                // Try mode starts with 0 (because if this is init to None we 
cant say if it is none due to all nulls or due to an overflow)
+                sum: Some(0),
+                has_all_nulls: true,
+                eval_mode,
+            }
+        } else {
+            Self {
+                sum: None,
+                has_all_nulls: false,
+                eval_mode,
+            }
+        }
+    }
+}
+
+impl Accumulator for SumIntegerAccumulator {
+    fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
+        // accumulator internal to add sum and return null sum (and has_nulls 
false) if there is an overflow in Try Eval mode
+        fn update_sum_internal<T>(
+            int_array: &PrimitiveArray<T>,
+            eval_mode: EvalMode,
+            mut sum: i64,
+        ) -> Result<Option<i64>, DataFusionError>
+        where
+            T: ArrowPrimitiveType,
+        {
+            for i in 0..int_array.len() {
+                if !int_array.is_null(i) {
+                    let v = int_array.value(i).to_i64().ok_or_else(|| {
+                        DataFusionError::Internal(format!(
+                            "Failed to convert value {:?} to i64",
+                            int_array.value(i)
+                        ))
+                    })?;
+                    match eval_mode {
+                        EvalMode::Legacy => {
+                            sum = v.add_wrapping(sum);
+                        }
+                        EvalMode::Ansi | EvalMode::Try => {
+                            match v.add_checked(sum) {
+                                Ok(v) => sum = v,
+                                Err(_e) => {
+                                    return if eval_mode == EvalMode::Ansi {
+                                        
Err(DataFusionError::from(arithmetic_overflow_error(
+                                            "integer",
+                                        )))
+                                    } else {
+                                        Ok(None)
+                                    };
+                                }
+                            };
+                        }
+                    }
+                }
+            }
+            Ok(Some(sum))
+        }
+
+        if self.eval_mode == EvalMode::Try && !self.has_all_nulls && 
self.sum.is_none() {
+            // we saw an overflow earlier (Try eval mode). Skip processing
+            return Ok(());
+        }
+        let values = &values[0];
+        if values.len() == values.null_count() {
+            Ok(())
+        } else {
+            // No nulls so there should be a non-null sum / null incase 
overflow in Try eval
+            let running_sum = self.sum.unwrap_or(0);
+            let sum = match values.data_type() {
+                DataType::Int64 => update_sum_internal(
+                    as_primitive_array::<Int64Type>(values),
+                    self.eval_mode,
+                    running_sum,
+                )?,
+                DataType::Int32 => update_sum_internal(
+                    as_primitive_array::<Int32Type>(values),
+                    self.eval_mode,
+                    running_sum,
+                )?,
+                DataType::Int16 => update_sum_internal(
+                    as_primitive_array::<Int16Type>(values),
+                    self.eval_mode,
+                    running_sum,
+                )?,
+                DataType::Int8 => update_sum_internal(
+                    as_primitive_array::<Int8Type>(values),
+                    self.eval_mode,
+                    running_sum,
+                )?,
+                _ => {
+                    return Err(DataFusionError::Internal(format!(
+                        "unsupported data type: {:?}",
+                        values.data_type()
+                    )));
+                }
+            };
+            self.sum = sum;
+            self.has_all_nulls = false;
+            Ok(())
+        }
+    }
+
+    fn evaluate(&mut self) -> DFResult<ScalarValue> {
+        if self.has_all_nulls {
+            Ok(ScalarValue::Int64(None))
+        } else {
+            Ok(ScalarValue::Int64(self.sum))
+        }
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self)
+    }
+
+    fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
+        if self.eval_mode == EvalMode::Try {
+            Ok(vec![
+                ScalarValue::Int64(self.sum),
+                ScalarValue::Boolean(Some(self.has_all_nulls)),
+            ])
+        } else {
+            Ok(vec![ScalarValue::Int64(self.sum)])
+        }
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
+        let expected_state_len = if self.eval_mode == EvalMode::Try {
+            2
+        } else {
+            1
+        };
+        if expected_state_len != states.len() {
+            return Err(DataFusionError::Internal(format!(
+                "Invalid state while merging batch. Expected {} elements but 
found {}",
+                expected_state_len,
+                states.len()
+            )));
+        }
+
+        let that_sum_array = states[0].as_primitive::<Int64Type>();
+        let that_sum = if that_sum_array.is_null(0) {
+            None
+        } else {
+            Some(that_sum_array.value(0))
+        };
+
+        // Check for overflow for early termination
+        if self.eval_mode == EvalMode::Try {
+            let that_has_all_nulls = states[1].as_boolean().value(0);
+            let that_overflowed = !that_has_all_nulls && that_sum.is_none();
+            let this_overflowed = !self.has_all_nulls && self.sum.is_none();
+            if that_overflowed || this_overflowed {
+                self.sum = None;
+                self.has_all_nulls = false;
+                return Ok(());
+            }
+            if that_has_all_nulls {
+                return Ok(());
+            }
+            if self.has_all_nulls {
+                self.sum = that_sum;
+                self.has_all_nulls = false;
+                return Ok(());
+            }
+        } else {
+            if that_sum.is_none() {
+                return Ok(());
+            }
+            if self.sum.is_none() {
+                self.sum = that_sum;
+                return Ok(());
+            }
+        }
+
+        // safe to unwrap (since we checked nulls above) but handling error 
just in case state is corrupt
+        let left = self.sum.ok_or_else(|| {
+            DataFusionError::Internal(
+                "Invalid state in merging batch. Current batch's sum is 
None".to_string(),
+            )
+        })?;
+        let right = that_sum.ok_or_else(|| {
+            DataFusionError::Internal(
+                "Invalid state in merging batch. Incoming sum is 
None".to_string(),
+            )
+        })?;
+
+        match self.eval_mode {
+            EvalMode::Legacy => {
+                self.sum = Some(left.add_wrapping(right));
+            }
+            EvalMode::Ansi | EvalMode::Try => match left.add_checked(right) {
+                Ok(v) => self.sum = Some(v),
+                Err(_) => {
+                    if self.eval_mode == EvalMode::Ansi {
+                        return 
Err(DataFusionError::from(arithmetic_overflow_error("integer")));
+                    } else {
+                        self.sum = None;
+                        self.has_all_nulls = false;
+                    }
+                }
+            },
+        }
+        Ok(())
+    }
+}
+
+struct SumIntGroupsAccumulator {
+    sums: Vec<Option<i64>>,
+    has_all_nulls: Vec<bool>,
+    eval_mode: EvalMode,
+}
+
+impl SumIntGroupsAccumulator {
+    fn new(eval_mode: EvalMode) -> Self {
+        Self {
+            sums: Vec::new(),
+            eval_mode,
+            has_all_nulls: Vec::new(),
+        }
+    }
+
+    fn resize_helper(&mut self, total_num_groups: usize) {
+        if self.eval_mode == EvalMode::Try {
+            self.sums.resize(total_num_groups, Some(0));
+            self.has_all_nulls.resize(total_num_groups, true);
+        } else {
+            self.sums.resize(total_num_groups, None);
+            self.has_all_nulls.resize(total_num_groups, false);
+        }
+    }
+}
+
+impl GroupsAccumulator for SumIntGroupsAccumulator {
+    fn update_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> DFResult<()> {
+        fn update_groups_sum_internal<T>(
+            int_array: &PrimitiveArray<T>,
+            group_indices: &[usize],
+            sums: &mut [Option<i64>],
+            has_all_nulls: &mut [bool],
+            eval_mode: EvalMode,
+        ) -> DFResult<()>
+        where
+            T: ArrowPrimitiveType,
+            T::Native: ArrowNativeType,
+        {
+            for (i, &group_index) in group_indices.iter().enumerate() {
+                if !int_array.is_null(i) {
+                    // there is an overflow in prev group in try eval. Skip 
processing
+                    if eval_mode == EvalMode::Try
+                        && !has_all_nulls[group_index]
+                        && sums[group_index].is_none()
+                    {
+                        continue;
+                    }
+                    let v = int_array.value(i).to_i64().ok_or_else(|| {
+                        DataFusionError::Internal("Failed to convert value to 
i64".to_string())
+                    })?;
+                    match eval_mode {
+                        EvalMode::Legacy => {
+                            sums[group_index] =
+                                
Some(sums[group_index].unwrap_or(0).add_wrapping(v));
+                        }
+                        EvalMode::Ansi | EvalMode::Try => {
+                            match 
sums[group_index].unwrap_or(0).add_checked(v) {
+                                Ok(new_sum) => {
+                                    sums[group_index] = Some(new_sum);
+                                }
+                                Err(_) => {
+                                    if eval_mode == EvalMode::Ansi {
+                                        return Err(DataFusionError::from(
+                                            
arithmetic_overflow_error("integer"),
+                                        ));
+                                    } else {
+                                        sums[group_index] = None;
+                                    }
+                                }
+                            };
+                        }
+                    }
+                    has_all_nulls[group_index] = false
+                }
+            }
+            Ok(())
+        }
+
+        debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+        let values = &values[0];
+        self.resize_helper(total_num_groups);
+
+        match values.data_type() {
+            DataType::Int64 => update_groups_sum_internal(
+                as_primitive_array::<Int64Type>(values),
+                group_indices,
+                &mut self.sums,
+                &mut self.has_all_nulls,
+                self.eval_mode,
+            )?,
+            DataType::Int32 => update_groups_sum_internal(
+                as_primitive_array::<Int32Type>(values),
+                group_indices,
+                &mut self.sums,
+                &mut self.has_all_nulls,
+                self.eval_mode,
+            )?,
+            DataType::Int16 => update_groups_sum_internal(
+                as_primitive_array::<Int16Type>(values),
+                group_indices,
+                &mut self.sums,
+                &mut self.has_all_nulls,
+                self.eval_mode,
+            )?,
+            DataType::Int8 => update_groups_sum_internal(
+                as_primitive_array::<Int8Type>(values),
+                group_indices,
+                &mut self.sums,
+                &mut self.has_all_nulls,
+                self.eval_mode,
+            )?,
+            _ => {
+                return Err(DataFusionError::Internal(format!(
+                    "Unsupported data type for SumIntGroupsAccumulator: {:?}",
+                    values.data_type()
+                )))
+            }
+        };
+        Ok(())
+    }
+
+    fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
+        match emit_to {
+            EmitTo::All => {
+                let result = Arc::new(Int64Array::from_iter(
+                    self.sums
+                        .iter()
+                        .zip(self.has_all_nulls.iter())
+                        .map(|(&sum, &is_null)| if is_null { None } else { sum 
}),
+                )) as ArrayRef;
+
+                self.sums.clear();
+                self.has_all_nulls.clear();
+                Ok(result)
+            }
+            EmitTo::First(n) => {
+                let result = Arc::new(Int64Array::from_iter(
+                    self.sums
+                        .drain(..n)
+                        .zip(self.has_all_nulls.drain(..n))
+                        .map(|(sum, is_null)| if is_null { None } else { sum 
}),
+                )) as ArrayRef;
+                Ok(result)
+            }
+        }
+    }
+
+    fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
+        let sums = emit_to.take_needed(&mut self.sums);
+
+        if self.eval_mode == EvalMode::Try {
+            let has_all_nulls = emit_to.take_needed(&mut self.has_all_nulls);
+            Ok(vec![
+                Arc::new(Int64Array::from(sums)),
+                Arc::new(BooleanArray::from(has_all_nulls)),
+            ])
+        } else {
+            Ok(vec![Arc::new(Int64Array::from(sums))])
+        }
+    }
+
+    fn merge_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> DFResult<()> {
+        debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+
+        let expected_state_len = if self.eval_mode == EvalMode::Try {
+            2
+        } else {
+            1
+        };
+        if expected_state_len != values.len() {
+            return Err(DataFusionError::Internal(format!(
+                "Invalid state while merging batch. Expected {} elements but 
found {}",
+                expected_state_len,
+                values.len()
+            )));
+        }
+        let that_sums = values[0].as_primitive::<Int64Type>();
+
+        self.resize_helper(total_num_groups);
+
+        let that_sums_is_all_nulls = if self.eval_mode == EvalMode::Try {
+            Some(values[1].as_boolean())
+        } else {
+            None
+        };
+
+        for (idx, &group_index) in group_indices.iter().enumerate() {
+            let that_sum = if that_sums.is_null(idx) {
+                None
+            } else {
+                Some(that_sums.value(idx))
+            };
+
+            if self.eval_mode == EvalMode::Try {
+                let that_has_all_nulls = 
that_sums_is_all_nulls.unwrap().value(idx);
+
+                let that_overflowed = !that_has_all_nulls && 
that_sum.is_none();
+                let this_overflowed =
+                    !self.has_all_nulls[group_index] && 
self.sums[group_index].is_none();
+
+                if that_overflowed || this_overflowed {
+                    self.sums[group_index] = None;
+                    self.has_all_nulls[group_index] = false;
+                    continue;
+                }
+
+                if that_has_all_nulls {
+                    continue;
+                }
+
+                if self.has_all_nulls[group_index] {
+                    self.sums[group_index] = that_sum;
+                    self.has_all_nulls[group_index] = false;
+                    continue;
+                }
+            } else {
+                if that_sum.is_none() {
+                    continue;
+                }
+                if self.sums[group_index].is_none() {
+                    self.sums[group_index] = that_sum;
+                    continue;
+                }
+            }
+
+            // Both sides have non-null. Update sums now
+            let left = self.sums[group_index].unwrap();
+            let right = that_sum.unwrap();
+
+            match self.eval_mode {
+                EvalMode::Legacy => {
+                    self.sums[group_index] = Some(left.add_wrapping(right));
+                }
+                EvalMode::Ansi | EvalMode::Try => {
+                    match left.add_checked(right) {
+                        Ok(v) => self.sums[group_index] = Some(v),
+                        Err(_) => {
+                            if self.eval_mode == EvalMode::Ansi {
+                                return 
Err(DataFusionError::from(arithmetic_overflow_error(
+                                    "integer",
+                                )));
+                            } else {
+                                // overflow. update flag accordingly
+                                self.sums[group_index] = None;
+                                self.has_all_nulls[group_index] = false;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+        Ok(())
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self)
+    }
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala 
b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
index 8ab568dc8..a05efaebb 100644
--- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -213,17 +213,6 @@ object CometAverage extends 
CometAggregateExpressionSerde[Average] {
 
 object CometSum extends CometAggregateExpressionSerde[Sum] {
 
-  override def getSupportLevel(sum: Sum): SupportLevel = {
-    sum.evalMode match {
-      case EvalMode.ANSI if !sum.dataType.isInstanceOf[DecimalType] =>
-        Incompatible(Some("ANSI mode for non decimal inputs is not supported"))
-      case EvalMode.TRY if !sum.dataType.isInstanceOf[DecimalType] =>
-        Incompatible(Some("TRY mode for non decimal inputs is not supported"))
-      case _ =>
-        Compatible()
-    }
-  }
-
   override def convert(
       aggExpr: AggregateExpression,
       sum: Sum,
@@ -236,6 +225,8 @@ object CometSum extends CometAggregateExpressionSerde[Sum] {
       return None
     }
 
+    val evalMode = sum.evalMode
+
     val childExpr = exprToProto(sum.child, inputs, binding)
     val dataType = serializeDataType(sum.dataType)
 
@@ -243,7 +234,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] {
       val builder = ExprOuterClass.Sum.newBuilder()
       builder.setChild(childExpr.get)
       builder.setDatatype(dataType.get)
-      
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(sum.evalMode)))
+      
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(evalMode)))
 
       Some(
         ExprOuterClass.AggExpr
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 060579b2b..9b2816c2f 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -24,7 +24,6 @@ import scala.util.Random
 import org.apache.hadoop.fs.Path
 import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
 import org.apache.spark.sql.catalyst.expressions.Cast
-import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
 import org.apache.spark.sql.catalyst.optimizer.EliminateSorts
 import org.apache.spark.sql.comet.CometHashAggregateExec
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -1472,11 +1471,22 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("ANSI support for sum - null test") {
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+        withParquetTable(
+          Seq((null.asInstanceOf[java.lang.Long], "a"), 
(null.asInstanceOf[java.lang.Long], "b")),
+          "null_tbl") {
+          val res = sql("SELECT sum(_1) FROM null_tbl")
+          checkSparkAnswerAndOperator(res)
+        }
+      }
+    }
+  }
+
   test("ANSI support for decimal sum - null test") {
     Seq(true, false).foreach { ansiEnabled =>
-      withSQLConf(
-        SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
-        CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
         withParquetTable(
           Seq(
             (null.asInstanceOf[java.math.BigDecimal], "a"),
@@ -1490,11 +1500,22 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("ANSI support for try_sum - null test") {
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+        withParquetTable(
+          Seq((null.asInstanceOf[java.lang.Long], "a"), 
(null.asInstanceOf[java.lang.Long], "b")),
+          "null_tbl") {
+          val res = sql("SELECT try_sum(_1) FROM null_tbl")
+          checkSparkAnswerAndOperator(res)
+        }
+      }
+    }
+  }
+
   test("ANSI support for try_sum decimal - null test") {
     Seq(true, false).foreach { ansiEnabled =>
-      withSQLConf(
-        SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
-        CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
         withParquetTable(
           Seq(
             (null.asInstanceOf[java.math.BigDecimal], "a"),
@@ -1508,11 +1529,28 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("ANSI support for sum - null test (group by)") {
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+        withParquetTable(
+          Seq(
+            (null.asInstanceOf[java.lang.Long], "a"),
+            (null.asInstanceOf[java.lang.Long], "a"),
+            (null.asInstanceOf[java.lang.Long], "b"),
+            (null.asInstanceOf[java.lang.Long], "b"),
+            (null.asInstanceOf[java.lang.Long], "b")),
+          "tbl") {
+          val res = sql("SELECT _2, sum(_1) FROM tbl group by 1")
+          checkSparkAnswerAndOperator(res)
+          assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), 
Row("b", null)))
+        }
+      }
+    }
+  }
+
   test("ANSI support for decimal sum - null test (group by)") {
     Seq(true, false).foreach { ansiEnabled =>
-      withSQLConf(
-        SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
-        CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
         withParquetTable(
           Seq(
             (null.asInstanceOf[java.math.BigDecimal], "a"),
@@ -1529,11 +1567,27 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("ANSI support for try_sum - null test (group by)") {
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+        withParquetTable(
+          Seq(
+            (null.asInstanceOf[java.lang.Long], "a"),
+            (null.asInstanceOf[java.lang.Long], "a"),
+            (null.asInstanceOf[java.lang.Long], "b"),
+            (null.asInstanceOf[java.lang.Long], "b"),
+            (null.asInstanceOf[java.lang.Long], "b")),
+          "tbl") {
+          val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1")
+          checkSparkAnswerAndOperator(res)
+        }
+      }
+    }
+  }
+
   test("ANSI support for try_sum decimal - null test (group by)") {
     Seq(true, false).foreach { ansiEnabled =>
-      withSQLConf(
-        SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
-        CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
         withParquetTable(
           Seq(
             (null.asInstanceOf[java.math.BigDecimal], "a"),
@@ -1544,7 +1598,6 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
           "tbl") {
           val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1")
           checkSparkAnswerAndOperator(res)
-          assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), 
Row("b", null)))
         }
       }
     }
@@ -1555,11 +1608,64 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     (1 to 50).flatMap(_ => Seq((maxDec38_0, 1)))
   }
 
+  test("ANSI support - SUM function") {
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+        // Test long overflow
+        withParquetTable(Seq((Long.MaxValue, 1L), (100L, 1L)), "tbl") {
+          val res = sql("SELECT SUM(_1) FROM tbl")
+          if (ansiEnabled) {
+            checkSparkAnswerMaybeThrows(res) match {
+              case (Some(sparkExc), Some(cometExc)) =>
+                // make sure that the error message throws overflow exception 
only
+                assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+                assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+              case _ => fail("Exception should be thrown for Long overflow in 
ANSI mode")
+            }
+          } else {
+            checkSparkAnswerAndOperator(res)
+          }
+        }
+        // Test long underflow
+        withParquetTable(Seq((Long.MinValue, 1L), (-100L, 1L)), "tbl") {
+          val res = sql("SELECT SUM(_1) FROM tbl")
+          if (ansiEnabled) {
+            checkSparkAnswerMaybeThrows(res) match {
+              case (Some(sparkExc), Some(cometExc)) =>
+                assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+                assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+              case _ => fail("Exception should be thrown for Long underflow in 
ANSI mode")
+            }
+          } else {
+            checkSparkAnswerAndOperator(res)
+          }
+        }
+        // Test Int SUM (should not overflow)
+        withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 1)), 
"tbl") {
+          val res = sql("SELECT SUM(_1) FROM tbl")
+          checkSparkAnswerAndOperator(res)
+        }
+        // Test Short SUM (should not overflow)
+        withParquetTable(
+          Seq((Short.MaxValue, 1.toShort), (Short.MaxValue, 1.toShort), 
(100.toShort, 1.toShort)),
+          "tbl") {
+          val res = sql("SELECT SUM(_1) FROM tbl")
+          checkSparkAnswerAndOperator(res)
+        }
+        // Test Byte SUM (should not overflow)
+        withParquetTable(
+          Seq((Byte.MaxValue, 1.toByte), (Byte.MaxValue, 1.toByte), 
(10.toByte, 1.toByte)),
+          "tbl") {
+          val res = sql("SELECT SUM(_1) FROM tbl")
+          checkSparkAnswerAndOperator(res)
+        }
+      }
+    }
+  }
+
   test("ANSI support for decimal SUM function") {
     Seq(true, false).foreach { ansiEnabled =>
-      withSQLConf(
-        SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
-        CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
         withParquetTable(generateOverflowDecimalInputs, "tbl") {
           val res = sql("SELECT SUM(_1) FROM tbl")
           if (ansiEnabled) {
@@ -1578,11 +1684,68 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("ANSI support for SUM - GROUP BY") {
+    Seq(true, false).foreach { ansiEnabled =>
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+        withParquetTable(
+          Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (200L, 2)),
+          "tbl") {
+          val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY 
_2").repartition(2)
+          if (ansiEnabled) {
+            checkSparkAnswerMaybeThrows(res) match {
+              case (Some(sparkExc), Some(cometExc)) =>
+                assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+                assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+              case _ =>
+                fail("Exception should be thrown for Long overflow with GROUP 
BY in ANSI mode")
+            }
+          } else {
+            checkSparkAnswerAndOperator(res)
+          }
+        }
+
+        withParquetTable(
+          Seq((Long.MinValue, 1), (-100L, 1), (Long.MinValue, 2), (-200L, 2)),
+          "tbl") {
+          val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2")
+          if (ansiEnabled) {
+            checkSparkAnswerMaybeThrows(res) match {
+              case (Some(sparkExc), Some(cometExc)) =>
+                assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+                assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+              case _ =>
+                fail("Exception should be thrown for Long underflow with GROUP 
BY in ANSI mode")
+            }
+          } else {
+            checkSparkAnswerAndOperator(res)
+          }
+        }
+        // Test Int with GROUP BY
+        withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 2), 
(200, 2)), "tbl") {
+          val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2")
+          checkSparkAnswerAndOperator(res)
+        }
+        // Test Short with GROUP BY
+        withParquetTable(
+          Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), 
(200.toShort, 2)),
+          "tbl") {
+          val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2")
+          checkSparkAnswerAndOperator(res)
+        }
+        // Test Byte with GROUP BY
+        withParquetTable(
+          Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), 
(20.toByte, 2)),
+          "tbl") {
+          val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2")
+          checkSparkAnswerAndOperator(res)
+        }
+      }
+    }
+  }
+
   test("ANSI support for decimal SUM - GROUP BY") {
     Seq(true, false).foreach { ansiEnabled =>
-      withSQLConf(
-        SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
-        CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
         withParquetTable(generateOverflowDecimalInputs, "tbl") {
           val res =
             sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2)
@@ -1602,35 +1765,69 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("try_sum overflow - with GROUP BY") {
+    // Test Long overflow with GROUP BY - some groups overflow while some don't
+    withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (200L, 2), (300L, 2)), 
"tbl") {
+      // repartition to trigger merge batch and state checks
+      val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY 
_2").repartition(2, col("_2"))
+      // first group should return NULL (overflow) and group 2 should return 
500
+      checkSparkAnswerAndOperator(res)
+    }
+
+    // Test Long underflow with GROUP BY
+    withParquetTable(Seq((Long.MinValue, 1), (-100L, 1), (-200L, 2), (-300L, 
2)), "tbl") {
+      val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY 
_2").repartition(2, col("_2"))
+      // first group should return NULL (underflow), second group should 
return neg 500
+      checkSparkAnswerAndOperator(res)
+    }
+
+    // Test all groups overflow
+    withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), 
(100L, 2)), "tbl") {
+      val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY 
_2").repartition(2, col("_2"))
+      // Both groups should return NULL
+      checkSparkAnswerAndOperator(res)
+    }
+
+    // Test Short with GROUP BY (should NOT overflow)
+    withParquetTable(
+      Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2), 
(200.toShort, 2)),
+      "tbl") {
+      val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY 
_2").repartition(2, col("_2"))
+      checkSparkAnswerAndOperator(res)
+    }
+
+    // Test Byte with GROUP BY (no overflow)
+    withParquetTable(
+      Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte, 
2)),
+      "tbl") {
+      val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY 
_2").repartition(2, col("_2"))
+      checkSparkAnswerAndOperator(res)
+    }
+  }
+
   test("try_sum decimal overflow") {
-    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> 
"true") {
-      withParquetTable(generateOverflowDecimalInputs, "tbl") {
-        val res = sql("SELECT try_sum(_1) FROM tbl")
-        checkSparkAnswerAndOperator(res)
-      }
+    withParquetTable(generateOverflowDecimalInputs, "tbl") {
+      val res = sql("SELECT try_sum(_1) FROM tbl")
+      checkSparkAnswerAndOperator(res)
     }
   }
 
   test("try_sum decimal overflow - with GROUP BY") {
-    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> 
"true") {
-      withParquetTable(generateOverflowDecimalInputs, "tbl") {
-        val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY 
_2").repartition(2, col("_2"))
-        checkSparkAnswerAndOperator(res)
-      }
+    withParquetTable(generateOverflowDecimalInputs, "tbl") {
+      val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY 
_2").repartition(2, col("_2"))
+      checkSparkAnswerAndOperator(res)
     }
   }
 
   test("try_sum decimal partial overflow - with GROUP BY") {
-    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> 
"true") {
-      // Group 1 overflows, Group 2 succeeds
-      val data: Seq[(java.math.BigDecimal, Int)] = 
generateOverflowDecimalInputs ++ Seq(
-        (new java.math.BigDecimal(300), 2),
-        (new java.math.BigDecimal(200), 2))
-      withParquetTable(data, "tbl") {
-        val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2")
-        // Group 1 should be NULL, Group 2 should be 500
-        checkSparkAnswerAndOperator(res)
-      }
+    // Group 1 overflows, Group 2 succeeds
+    val data: Seq[(java.math.BigDecimal, Int)] = generateOverflowDecimalInputs 
++ Seq(
+      (new java.math.BigDecimal(300), 2),
+      (new java.math.BigDecimal(200), 2))
+    withParquetTable(data, "tbl") {
+      val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2")
+      // Group 1 should be NULL, Group 2 should be 500
+      checkSparkAnswerAndOperator(res)
     }
   }
 
diff --git 
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala 
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
index cf2b3dcdd..b1848ff51 100644
--- 
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
+++ 
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.SparkContext
 import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, 
MEMORY_OFFHEAP_SIZE}
 import org.apache.spark.sql.TPCDSBase
 import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast}
-import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
+import org.apache.spark.sql.catalyst.expressions.aggregate.Average
 import org.apache.spark.sql.catalyst.util.resourceToString
 import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, 
SparkPlan, SubqueryBroadcastExec, SubqueryExec}
 import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
@@ -228,7 +228,6 @@ trait CometPlanStabilitySuite extends 
DisableAdaptiveExecutionSuite with TPCDSBa
       CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> 
"true",
       // Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI 
support
       CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true",
-      CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true",
       // as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64
       CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to