This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new f2328870f feat: Support ANSI mode sum expr (int inputs) (#2600)
f2328870f is described below
commit f2328870fb6a15cfa1b19cba11614c970b0d4171
Author: B Vadlamani <[email protected]>
AuthorDate: Tue Dec 23 14:06:00 2025 -0800
feat: Support ANSI mode sum expr (int inputs) (#2600)
---
docs/source/user-guide/latest/compatibility.md | 3 +-
native/core/src/execution/planner.rs | 7 +
native/spark-expr/src/agg_funcs/mod.rs | 2 +
native/spark-expr/src/agg_funcs/sum_int.rs | 589 +++++++++++++++++++++
.../scala/org/apache/comet/serde/aggregates.scala | 15 +-
.../apache/comet/exec/CometAggregateSuite.scala | 277 ++++++++--
.../spark/sql/comet/CometPlanStabilitySuite.scala | 3 +-
7 files changed, 840 insertions(+), 56 deletions(-)
diff --git a/docs/source/user-guide/latest/compatibility.md
b/docs/source/user-guide/latest/compatibility.md
index 58dd8d6ab..3d2c9a7b5 100644
--- a/docs/source/user-guide/latest/compatibility.md
+++ b/docs/source/user-guide/latest/compatibility.md
@@ -32,12 +32,11 @@ Comet has the following limitations when reading Parquet
files:
## ANSI Mode
-Comet will fall back to Spark for the following expressions when ANSI mode is
enabled. Thes expressions can be enabled by setting
+Comet will fall back to Spark for the following expressions when ANSI mode is
enabled. These expressions can be enabled by setting
`spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is
the Spark expression class name. See
the [Comet Supported Expressions Guide](expressions.md) for more information
on this configuration setting.
- Average
-- Sum
- Cast (in some cases)
There is an [epic](https://github.com/apache/datafusion-comet/issues/313)
where we are tracking the work to fully implement ANSI support.
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index 56de19d67..8e8191dd0 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -71,6 +71,7 @@ use datafusion::{
use datafusion_comet_spark_expr::{
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode,
BinaryOutputStyle,
BloomFilterAgg, BloomFilterMightContain, EvalMode, SparkHour, SparkMinute,
SparkSecond,
+ SumInteger,
};
use iceberg::expr::Bind;
@@ -1813,6 +1814,12 @@ impl PhysicalPlanner {
AggregateUDF::new_from_impl(SumDecimal::try_new(datatype, eval_mode)?);
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
+ DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64 => {
+ let eval_mode =
from_protobuf_eval_mode(expr.eval_mode)?;
+ let func =
+
AggregateUDF::new_from_impl(SumInteger::try_new(datatype, eval_mode)?);
+ AggregateExprBuilder::new(Arc::new(func), vec![child])
+ }
_ => {
// cast to the result data type of SUM if necessary,
we should not expect
// a cast failure since it should have already been
checked at Spark side
diff --git a/native/spark-expr/src/agg_funcs/mod.rs
b/native/spark-expr/src/agg_funcs/mod.rs
index 252da7889..b1027153e 100644
--- a/native/spark-expr/src/agg_funcs/mod.rs
+++ b/native/spark-expr/src/agg_funcs/mod.rs
@@ -21,6 +21,7 @@ mod correlation;
mod covariance;
mod stddev;
mod sum_decimal;
+mod sum_int;
mod variance;
pub use avg::Avg;
@@ -29,4 +30,5 @@ pub use correlation::Correlation;
pub use covariance::Covariance;
pub use stddev::Stddev;
pub use sum_decimal::SumDecimal;
+pub use sum_int::SumInteger;
pub use variance::Variance;
diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs
b/native/spark-expr/src/agg_funcs/sum_int.rs
new file mode 100644
index 000000000..d226c5ede
--- /dev/null
+++ b/native/spark-expr/src/agg_funcs/sum_int.rs
@@ -0,0 +1,589 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::{arithmetic_overflow_error, EvalMode};
+use arrow::array::{
+ as_primitive_array, cast::AsArray, Array, ArrayRef, ArrowNativeTypeOp,
ArrowPrimitiveType,
+ BooleanArray, Int64Array, PrimitiveArray,
+};
+use arrow::datatypes::{
+ ArrowNativeType, DataType, Field, FieldRef, Int16Type, Int32Type,
Int64Type, Int8Type,
+};
+use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue};
+use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
+use datafusion::logical_expr::Volatility::Immutable;
+use datafusion::logical_expr::{
+ Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF,
Signature,
+};
+use std::{any::Any, sync::Arc};
+
+#[derive(Debug, PartialEq, Eq, Hash)]
+pub struct SumInteger {
+ signature: Signature,
+ eval_mode: EvalMode,
+}
+
+impl SumInteger {
+ pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult<Self>
{
+ match data_type {
+ DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64 => Ok(Self {
+ signature: Signature::user_defined(Immutable),
+ eval_mode,
+ }),
+ _ => Err(DataFusionError::Internal(
+ "Invalid data type for SumInteger".into(),
+ )),
+ }
+ }
+}
+
+impl AggregateUDFImpl for SumInteger {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "sum"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
+ Ok(DataType::Int64)
+ }
+
+ fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult<Box<dyn
Accumulator>> {
+ Ok(Box::new(SumIntegerAccumulator::new(self.eval_mode)))
+ }
+
+ fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<FieldRef>> {
+ if self.eval_mode == EvalMode::Try {
+ Ok(vec![
+ Arc::new(Field::new("sum", DataType::Int64, true)),
+ Arc::new(Field::new("has_all_nulls", DataType::Boolean,
false)),
+ ])
+ } else {
+ Ok(vec![Arc::new(Field::new("sum", DataType::Int64, true))])
+ }
+ }
+
+ fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
+ true
+ }
+
+ fn create_groups_accumulator(
+ &self,
+ _args: AccumulatorArgs,
+ ) -> DFResult<Box<dyn GroupsAccumulator>> {
+ Ok(Box::new(SumIntGroupsAccumulator::new(self.eval_mode)))
+ }
+
+ fn reverse_expr(&self) -> ReversedUDAF {
+ ReversedUDAF::Identical
+ }
+}
+
+#[derive(Debug)]
+struct SumIntegerAccumulator {
+ sum: Option<i64>,
+ eval_mode: EvalMode,
+ has_all_nulls: bool,
+}
+
+impl SumIntegerAccumulator {
+ fn new(eval_mode: EvalMode) -> Self {
+ if eval_mode == EvalMode::Try {
+ Self {
+ // Try mode starts with 0 (because if this is init to None we
cant say if it is none due to all nulls or due to an overflow)
+ sum: Some(0),
+ has_all_nulls: true,
+ eval_mode,
+ }
+ } else {
+ Self {
+ sum: None,
+ has_all_nulls: false,
+ eval_mode,
+ }
+ }
+ }
+}
+
+impl Accumulator for SumIntegerAccumulator {
+ fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
+ // accumulator internal to add sum and return null sum (and has_nulls
false) if there is an overflow in Try Eval mode
+ fn update_sum_internal<T>(
+ int_array: &PrimitiveArray<T>,
+ eval_mode: EvalMode,
+ mut sum: i64,
+ ) -> Result<Option<i64>, DataFusionError>
+ where
+ T: ArrowPrimitiveType,
+ {
+ for i in 0..int_array.len() {
+ if !int_array.is_null(i) {
+ let v = int_array.value(i).to_i64().ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "Failed to convert value {:?} to i64",
+ int_array.value(i)
+ ))
+ })?;
+ match eval_mode {
+ EvalMode::Legacy => {
+ sum = v.add_wrapping(sum);
+ }
+ EvalMode::Ansi | EvalMode::Try => {
+ match v.add_checked(sum) {
+ Ok(v) => sum = v,
+ Err(_e) => {
+ return if eval_mode == EvalMode::Ansi {
+
Err(DataFusionError::from(arithmetic_overflow_error(
+ "integer",
+ )))
+ } else {
+ Ok(None)
+ };
+ }
+ };
+ }
+ }
+ }
+ }
+ Ok(Some(sum))
+ }
+
+ if self.eval_mode == EvalMode::Try && !self.has_all_nulls &&
self.sum.is_none() {
+ // we saw an overflow earlier (Try eval mode). Skip processing
+ return Ok(());
+ }
+ let values = &values[0];
+ if values.len() == values.null_count() {
+ Ok(())
+ } else {
+ // No nulls so there should be a non-null sum / null incase
overflow in Try eval
+ let running_sum = self.sum.unwrap_or(0);
+ let sum = match values.data_type() {
+ DataType::Int64 => update_sum_internal(
+ as_primitive_array::<Int64Type>(values),
+ self.eval_mode,
+ running_sum,
+ )?,
+ DataType::Int32 => update_sum_internal(
+ as_primitive_array::<Int32Type>(values),
+ self.eval_mode,
+ running_sum,
+ )?,
+ DataType::Int16 => update_sum_internal(
+ as_primitive_array::<Int16Type>(values),
+ self.eval_mode,
+ running_sum,
+ )?,
+ DataType::Int8 => update_sum_internal(
+ as_primitive_array::<Int8Type>(values),
+ self.eval_mode,
+ running_sum,
+ )?,
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "unsupported data type: {:?}",
+ values.data_type()
+ )));
+ }
+ };
+ self.sum = sum;
+ self.has_all_nulls = false;
+ Ok(())
+ }
+ }
+
+ fn evaluate(&mut self) -> DFResult<ScalarValue> {
+ if self.has_all_nulls {
+ Ok(ScalarValue::Int64(None))
+ } else {
+ Ok(ScalarValue::Int64(self.sum))
+ }
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+
+ fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
+ if self.eval_mode == EvalMode::Try {
+ Ok(vec![
+ ScalarValue::Int64(self.sum),
+ ScalarValue::Boolean(Some(self.has_all_nulls)),
+ ])
+ } else {
+ Ok(vec![ScalarValue::Int64(self.sum)])
+ }
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
+ let expected_state_len = if self.eval_mode == EvalMode::Try {
+ 2
+ } else {
+ 1
+ };
+ if expected_state_len != states.len() {
+ return Err(DataFusionError::Internal(format!(
+ "Invalid state while merging batch. Expected {} elements but
found {}",
+ expected_state_len,
+ states.len()
+ )));
+ }
+
+ let that_sum_array = states[0].as_primitive::<Int64Type>();
+ let that_sum = if that_sum_array.is_null(0) {
+ None
+ } else {
+ Some(that_sum_array.value(0))
+ };
+
+ // Check for overflow for early termination
+ if self.eval_mode == EvalMode::Try {
+ let that_has_all_nulls = states[1].as_boolean().value(0);
+ let that_overflowed = !that_has_all_nulls && that_sum.is_none();
+ let this_overflowed = !self.has_all_nulls && self.sum.is_none();
+ if that_overflowed || this_overflowed {
+ self.sum = None;
+ self.has_all_nulls = false;
+ return Ok(());
+ }
+ if that_has_all_nulls {
+ return Ok(());
+ }
+ if self.has_all_nulls {
+ self.sum = that_sum;
+ self.has_all_nulls = false;
+ return Ok(());
+ }
+ } else {
+ if that_sum.is_none() {
+ return Ok(());
+ }
+ if self.sum.is_none() {
+ self.sum = that_sum;
+ return Ok(());
+ }
+ }
+
+ // safe to unwrap (since we checked nulls above) but handling error
just in case state is corrupt
+ let left = self.sum.ok_or_else(|| {
+ DataFusionError::Internal(
+ "Invalid state in merging batch. Current batch's sum is
None".to_string(),
+ )
+ })?;
+ let right = that_sum.ok_or_else(|| {
+ DataFusionError::Internal(
+ "Invalid state in merging batch. Incoming sum is
None".to_string(),
+ )
+ })?;
+
+ match self.eval_mode {
+ EvalMode::Legacy => {
+ self.sum = Some(left.add_wrapping(right));
+ }
+ EvalMode::Ansi | EvalMode::Try => match left.add_checked(right) {
+ Ok(v) => self.sum = Some(v),
+ Err(_) => {
+ if self.eval_mode == EvalMode::Ansi {
+ return
Err(DataFusionError::from(arithmetic_overflow_error("integer")));
+ } else {
+ self.sum = None;
+ self.has_all_nulls = false;
+ }
+ }
+ },
+ }
+ Ok(())
+ }
+}
+
+struct SumIntGroupsAccumulator {
+ sums: Vec<Option<i64>>,
+ has_all_nulls: Vec<bool>,
+ eval_mode: EvalMode,
+}
+
+impl SumIntGroupsAccumulator {
+ fn new(eval_mode: EvalMode) -> Self {
+ Self {
+ sums: Vec::new(),
+ eval_mode,
+ has_all_nulls: Vec::new(),
+ }
+ }
+
+ fn resize_helper(&mut self, total_num_groups: usize) {
+ if self.eval_mode == EvalMode::Try {
+ self.sums.resize(total_num_groups, Some(0));
+ self.has_all_nulls.resize(total_num_groups, true);
+ } else {
+ self.sums.resize(total_num_groups, None);
+ self.has_all_nulls.resize(total_num_groups, false);
+ }
+ }
+}
+
+impl GroupsAccumulator for SumIntGroupsAccumulator {
+ fn update_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> DFResult<()> {
+ fn update_groups_sum_internal<T>(
+ int_array: &PrimitiveArray<T>,
+ group_indices: &[usize],
+ sums: &mut [Option<i64>],
+ has_all_nulls: &mut [bool],
+ eval_mode: EvalMode,
+ ) -> DFResult<()>
+ where
+ T: ArrowPrimitiveType,
+ T::Native: ArrowNativeType,
+ {
+ for (i, &group_index) in group_indices.iter().enumerate() {
+ if !int_array.is_null(i) {
+ // there is an overflow in prev group in try eval. Skip
processing
+ if eval_mode == EvalMode::Try
+ && !has_all_nulls[group_index]
+ && sums[group_index].is_none()
+ {
+ continue;
+ }
+ let v = int_array.value(i).to_i64().ok_or_else(|| {
+ DataFusionError::Internal("Failed to convert value to
i64".to_string())
+ })?;
+ match eval_mode {
+ EvalMode::Legacy => {
+ sums[group_index] =
+
Some(sums[group_index].unwrap_or(0).add_wrapping(v));
+ }
+ EvalMode::Ansi | EvalMode::Try => {
+ match
sums[group_index].unwrap_or(0).add_checked(v) {
+ Ok(new_sum) => {
+ sums[group_index] = Some(new_sum);
+ }
+ Err(_) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(DataFusionError::from(
+
arithmetic_overflow_error("integer"),
+ ));
+ } else {
+ sums[group_index] = None;
+ }
+ }
+ };
+ }
+ }
+ has_all_nulls[group_index] = false
+ }
+ }
+ Ok(())
+ }
+
+ debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+ let values = &values[0];
+ self.resize_helper(total_num_groups);
+
+ match values.data_type() {
+ DataType::Int64 => update_groups_sum_internal(
+ as_primitive_array::<Int64Type>(values),
+ group_indices,
+ &mut self.sums,
+ &mut self.has_all_nulls,
+ self.eval_mode,
+ )?,
+ DataType::Int32 => update_groups_sum_internal(
+ as_primitive_array::<Int32Type>(values),
+ group_indices,
+ &mut self.sums,
+ &mut self.has_all_nulls,
+ self.eval_mode,
+ )?,
+ DataType::Int16 => update_groups_sum_internal(
+ as_primitive_array::<Int16Type>(values),
+ group_indices,
+ &mut self.sums,
+ &mut self.has_all_nulls,
+ self.eval_mode,
+ )?,
+ DataType::Int8 => update_groups_sum_internal(
+ as_primitive_array::<Int8Type>(values),
+ group_indices,
+ &mut self.sums,
+ &mut self.has_all_nulls,
+ self.eval_mode,
+ )?,
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "Unsupported data type for SumIntGroupsAccumulator: {:?}",
+ values.data_type()
+ )))
+ }
+ };
+ Ok(())
+ }
+
+ fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
+ match emit_to {
+ EmitTo::All => {
+ let result = Arc::new(Int64Array::from_iter(
+ self.sums
+ .iter()
+ .zip(self.has_all_nulls.iter())
+ .map(|(&sum, &is_null)| if is_null { None } else { sum
}),
+ )) as ArrayRef;
+
+ self.sums.clear();
+ self.has_all_nulls.clear();
+ Ok(result)
+ }
+ EmitTo::First(n) => {
+ let result = Arc::new(Int64Array::from_iter(
+ self.sums
+ .drain(..n)
+ .zip(self.has_all_nulls.drain(..n))
+ .map(|(sum, is_null)| if is_null { None } else { sum
}),
+ )) as ArrayRef;
+ Ok(result)
+ }
+ }
+ }
+
+ fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
+ let sums = emit_to.take_needed(&mut self.sums);
+
+ if self.eval_mode == EvalMode::Try {
+ let has_all_nulls = emit_to.take_needed(&mut self.has_all_nulls);
+ Ok(vec![
+ Arc::new(Int64Array::from(sums)),
+ Arc::new(BooleanArray::from(has_all_nulls)),
+ ])
+ } else {
+ Ok(vec![Arc::new(Int64Array::from(sums))])
+ }
+ }
+
+ fn merge_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> DFResult<()> {
+ debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+
+ let expected_state_len = if self.eval_mode == EvalMode::Try {
+ 2
+ } else {
+ 1
+ };
+ if expected_state_len != values.len() {
+ return Err(DataFusionError::Internal(format!(
+ "Invalid state while merging batch. Expected {} elements but
found {}",
+ expected_state_len,
+ values.len()
+ )));
+ }
+ let that_sums = values[0].as_primitive::<Int64Type>();
+
+ self.resize_helper(total_num_groups);
+
+ let that_sums_is_all_nulls = if self.eval_mode == EvalMode::Try {
+ Some(values[1].as_boolean())
+ } else {
+ None
+ };
+
+ for (idx, &group_index) in group_indices.iter().enumerate() {
+ let that_sum = if that_sums.is_null(idx) {
+ None
+ } else {
+ Some(that_sums.value(idx))
+ };
+
+ if self.eval_mode == EvalMode::Try {
+ let that_has_all_nulls =
that_sums_is_all_nulls.unwrap().value(idx);
+
+ let that_overflowed = !that_has_all_nulls &&
that_sum.is_none();
+ let this_overflowed =
+ !self.has_all_nulls[group_index] &&
self.sums[group_index].is_none();
+
+ if that_overflowed || this_overflowed {
+ self.sums[group_index] = None;
+ self.has_all_nulls[group_index] = false;
+ continue;
+ }
+
+ if that_has_all_nulls {
+ continue;
+ }
+
+ if self.has_all_nulls[group_index] {
+ self.sums[group_index] = that_sum;
+ self.has_all_nulls[group_index] = false;
+ continue;
+ }
+ } else {
+ if that_sum.is_none() {
+ continue;
+ }
+ if self.sums[group_index].is_none() {
+ self.sums[group_index] = that_sum;
+ continue;
+ }
+ }
+
+ // Both sides have non-null. Update sums now
+ let left = self.sums[group_index].unwrap();
+ let right = that_sum.unwrap();
+
+ match self.eval_mode {
+ EvalMode::Legacy => {
+ self.sums[group_index] = Some(left.add_wrapping(right));
+ }
+ EvalMode::Ansi | EvalMode::Try => {
+ match left.add_checked(right) {
+ Ok(v) => self.sums[group_index] = Some(v),
+ Err(_) => {
+ if self.eval_mode == EvalMode::Ansi {
+ return
Err(DataFusionError::from(arithmetic_overflow_error(
+ "integer",
+ )));
+ } else {
+ // overflow. update flag accordingly
+ self.sums[group_index] = None;
+ self.has_all_nulls[group_index] = false;
+ }
+ }
+ }
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
index 8ab568dc8..a05efaebb 100644
--- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -213,17 +213,6 @@ object CometAverage extends
CometAggregateExpressionSerde[Average] {
object CometSum extends CometAggregateExpressionSerde[Sum] {
- override def getSupportLevel(sum: Sum): SupportLevel = {
- sum.evalMode match {
- case EvalMode.ANSI if !sum.dataType.isInstanceOf[DecimalType] =>
- Incompatible(Some("ANSI mode for non decimal inputs is not supported"))
- case EvalMode.TRY if !sum.dataType.isInstanceOf[DecimalType] =>
- Incompatible(Some("TRY mode for non decimal inputs is not supported"))
- case _ =>
- Compatible()
- }
- }
-
override def convert(
aggExpr: AggregateExpression,
sum: Sum,
@@ -236,6 +225,8 @@ object CometSum extends CometAggregateExpressionSerde[Sum] {
return None
}
+ val evalMode = sum.evalMode
+
val childExpr = exprToProto(sum.child, inputs, binding)
val dataType = serializeDataType(sum.dataType)
@@ -243,7 +234,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] {
val builder = ExprOuterClass.Sum.newBuilder()
builder.setChild(childExpr.get)
builder.setDatatype(dataType.get)
-
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(sum.evalMode)))
+
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(evalMode)))
Some(
ExprOuterClass.AggExpr
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 060579b2b..9b2816c2f 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -24,7 +24,6 @@ import scala.util.Random
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.Cast
-import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.optimizer.EliminateSorts
import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -1472,11 +1471,22 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("ANSI support for sum - null test") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+ withParquetTable(
+ Seq((null.asInstanceOf[java.lang.Long], "a"),
(null.asInstanceOf[java.lang.Long], "b")),
+ "null_tbl") {
+ val res = sql("SELECT sum(_1) FROM null_tbl")
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ }
+ }
+
test("ANSI support for decimal sum - null test") {
Seq(true, false).foreach { ansiEnabled =>
- withSQLConf(
- SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
- CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
withParquetTable(
Seq(
(null.asInstanceOf[java.math.BigDecimal], "a"),
@@ -1490,11 +1500,22 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("ANSI support for try_sum - null test") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+ withParquetTable(
+ Seq((null.asInstanceOf[java.lang.Long], "a"),
(null.asInstanceOf[java.lang.Long], "b")),
+ "null_tbl") {
+ val res = sql("SELECT try_sum(_1) FROM null_tbl")
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ }
+ }
+
test("ANSI support for try_sum decimal - null test") {
Seq(true, false).foreach { ansiEnabled =>
- withSQLConf(
- SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
- CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
withParquetTable(
Seq(
(null.asInstanceOf[java.math.BigDecimal], "a"),
@@ -1508,11 +1529,28 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("ANSI support for sum - null test (group by)") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+ withParquetTable(
+ Seq(
+ (null.asInstanceOf[java.lang.Long], "a"),
+ (null.asInstanceOf[java.lang.Long], "a"),
+ (null.asInstanceOf[java.lang.Long], "b"),
+ (null.asInstanceOf[java.lang.Long], "b"),
+ (null.asInstanceOf[java.lang.Long], "b")),
+ "tbl") {
+ val res = sql("SELECT _2, sum(_1) FROM tbl group by 1")
+ checkSparkAnswerAndOperator(res)
+ assert(res.orderBy(col("_2")).collect() === Array(Row("a", null),
Row("b", null)))
+ }
+ }
+ }
+ }
+
test("ANSI support for decimal sum - null test (group by)") {
Seq(true, false).foreach { ansiEnabled =>
- withSQLConf(
- SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
- CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
withParquetTable(
Seq(
(null.asInstanceOf[java.math.BigDecimal], "a"),
@@ -1529,11 +1567,27 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("ANSI support for try_sum - null test (group by)") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+ withParquetTable(
+ Seq(
+ (null.asInstanceOf[java.lang.Long], "a"),
+ (null.asInstanceOf[java.lang.Long], "a"),
+ (null.asInstanceOf[java.lang.Long], "b"),
+ (null.asInstanceOf[java.lang.Long], "b"),
+ (null.asInstanceOf[java.lang.Long], "b")),
+ "tbl") {
+ val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1")
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ }
+ }
+
test("ANSI support for try_sum decimal - null test (group by)") {
Seq(true, false).foreach { ansiEnabled =>
- withSQLConf(
- SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
- CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
withParquetTable(
Seq(
(null.asInstanceOf[java.math.BigDecimal], "a"),
@@ -1544,7 +1598,6 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
"tbl") {
val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1")
checkSparkAnswerAndOperator(res)
- assert(res.orderBy(col("_2")).collect() === Array(Row("a", null),
Row("b", null)))
}
}
}
@@ -1555,11 +1608,64 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
(1 to 50).flatMap(_ => Seq((maxDec38_0, 1)))
}
+ test("ANSI support - SUM function") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+ // Test long overflow
+ withParquetTable(Seq((Long.MaxValue, 1L), (100L, 1L)), "tbl") {
+ val res = sql("SELECT SUM(_1) FROM tbl")
+ if (ansiEnabled) {
+ checkSparkAnswerMaybeThrows(res) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+ // make sure that the error message throws overflow exception
only
+ assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+ assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+ case _ => fail("Exception should be thrown for Long overflow in
ANSI mode")
+ }
+ } else {
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ // Test long underflow
+ withParquetTable(Seq((Long.MinValue, 1L), (-100L, 1L)), "tbl") {
+ val res = sql("SELECT SUM(_1) FROM tbl")
+ if (ansiEnabled) {
+ checkSparkAnswerMaybeThrows(res) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+ assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+ assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+ case _ => fail("Exception should be thrown for Long underflow in
ANSI mode")
+ }
+ } else {
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ // Test Int SUM (should not overflow)
+ withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 1)),
"tbl") {
+ val res = sql("SELECT SUM(_1) FROM tbl")
+ checkSparkAnswerAndOperator(res)
+ }
+ // Test Short SUM (should not overflow)
+ withParquetTable(
+ Seq((Short.MaxValue, 1.toShort), (Short.MaxValue, 1.toShort),
(100.toShort, 1.toShort)),
+ "tbl") {
+ val res = sql("SELECT SUM(_1) FROM tbl")
+ checkSparkAnswerAndOperator(res)
+ }
+ // Test Byte SUM (should not overflow)
+ withParquetTable(
+ Seq((Byte.MaxValue, 1.toByte), (Byte.MaxValue, 1.toByte),
(10.toByte, 1.toByte)),
+ "tbl") {
+ val res = sql("SELECT SUM(_1) FROM tbl")
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ }
+ }
+
test("ANSI support for decimal SUM function") {
Seq(true, false).foreach { ansiEnabled =>
- withSQLConf(
- SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
- CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
withParquetTable(generateOverflowDecimalInputs, "tbl") {
val res = sql("SELECT SUM(_1) FROM tbl")
if (ansiEnabled) {
@@ -1578,11 +1684,68 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("ANSI support for SUM - GROUP BY") {
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
+ withParquetTable(
+ Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2), (200L, 2)),
+ "tbl") {
+ val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY
_2").repartition(2)
+ if (ansiEnabled) {
+ checkSparkAnswerMaybeThrows(res) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+ assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+ assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+ case _ =>
+ fail("Exception should be thrown for Long overflow with GROUP
BY in ANSI mode")
+ }
+ } else {
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+
+ withParquetTable(
+ Seq((Long.MinValue, 1), (-100L, 1), (Long.MinValue, 2), (-200L, 2)),
+ "tbl") {
+ val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2")
+ if (ansiEnabled) {
+ checkSparkAnswerMaybeThrows(res) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+ assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+ assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW"))
+ case _ =>
+ fail("Exception should be thrown for Long underflow with GROUP
BY in ANSI mode")
+ }
+ } else {
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ // Test Int with GROUP BY
+ withParquetTable(Seq((Int.MaxValue, 1), (Int.MaxValue, 1), (100, 2),
(200, 2)), "tbl") {
+ val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(res)
+ }
+ // Test Short with GROUP BY
+ withParquetTable(
+ Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2),
(200.toShort, 2)),
+ "tbl") {
+ val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(res)
+ }
+ // Test Byte with GROUP BY
+ withParquetTable(
+ Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2),
(20.toByte, 2)),
+ "tbl") {
+ val res = sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+ }
+ }
+
test("ANSI support for decimal SUM - GROUP BY") {
Seq(true, false).foreach { ansiEnabled =>
- withSQLConf(
- SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
- CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") {
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
withParquetTable(generateOverflowDecimalInputs, "tbl") {
val res =
sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2)
@@ -1602,35 +1765,69 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("try_sum overflow - with GROUP BY") {
+ // Test Long overflow with GROUP BY - some groups overflow while some don't
+ withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (200L, 2), (300L, 2)),
"tbl") {
+ // repartition to trigger merge batch and state checks
+ val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY
_2").repartition(2, col("_2"))
+ // first group should return NULL (overflow) and group 2 should return
500
+ checkSparkAnswerAndOperator(res)
+ }
+
+ // Test Long underflow with GROUP BY
+ withParquetTable(Seq((Long.MinValue, 1), (-100L, 1), (-200L, 2), (-300L,
2)), "tbl") {
+ val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY
_2").repartition(2, col("_2"))
+ // first group should return NULL (underflow), second group should
return neg 500
+ checkSparkAnswerAndOperator(res)
+ }
+
+ // Test all groups overflow
+ withParquetTable(Seq((Long.MaxValue, 1), (100L, 1), (Long.MaxValue, 2),
(100L, 2)), "tbl") {
+ val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY
_2").repartition(2, col("_2"))
+ // Both groups should return NULL
+ checkSparkAnswerAndOperator(res)
+ }
+
+ // Test Short with GROUP BY (should NOT overflow)
+ withParquetTable(
+ Seq((Short.MaxValue, 1), (Short.MaxValue, 1), (100.toShort, 2),
(200.toShort, 2)),
+ "tbl") {
+ val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY
_2").repartition(2, col("_2"))
+ checkSparkAnswerAndOperator(res)
+ }
+
+ // Test Byte with GROUP BY (no overflow)
+ withParquetTable(
+ Seq((Byte.MaxValue, 1), (Byte.MaxValue, 1), (10.toByte, 2), (20.toByte,
2)),
+ "tbl") {
+ val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY
_2").repartition(2, col("_2"))
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+
test("try_sum decimal overflow") {
- withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) ->
"true") {
- withParquetTable(generateOverflowDecimalInputs, "tbl") {
- val res = sql("SELECT try_sum(_1) FROM tbl")
- checkSparkAnswerAndOperator(res)
- }
+ withParquetTable(generateOverflowDecimalInputs, "tbl") {
+ val res = sql("SELECT try_sum(_1) FROM tbl")
+ checkSparkAnswerAndOperator(res)
}
}
test("try_sum decimal overflow - with GROUP BY") {
- withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) ->
"true") {
- withParquetTable(generateOverflowDecimalInputs, "tbl") {
- val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY
_2").repartition(2, col("_2"))
- checkSparkAnswerAndOperator(res)
- }
+ withParquetTable(generateOverflowDecimalInputs, "tbl") {
+ val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY
_2").repartition(2, col("_2"))
+ checkSparkAnswerAndOperator(res)
}
}
test("try_sum decimal partial overflow - with GROUP BY") {
- withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) ->
"true") {
- // Group 1 overflows, Group 2 succeeds
- val data: Seq[(java.math.BigDecimal, Int)] =
generateOverflowDecimalInputs ++ Seq(
- (new java.math.BigDecimal(300), 2),
- (new java.math.BigDecimal(200), 2))
- withParquetTable(data, "tbl") {
- val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2")
- // Group 1 should be NULL, Group 2 should be 500
- checkSparkAnswerAndOperator(res)
- }
+ // Group 1 overflows, Group 2 succeeds
+ val data: Seq[(java.math.BigDecimal, Int)] = generateOverflowDecimalInputs
++ Seq(
+ (new java.math.BigDecimal(300), 2),
+ (new java.math.BigDecimal(200), 2))
+ withParquetTable(data, "tbl") {
+ val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2")
+ // Group 1 should be NULL, Group 2 should be 500
+ checkSparkAnswerAndOperator(res)
}
}
diff --git
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
index cf2b3dcdd..b1848ff51 100644
---
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
+++
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED,
MEMORY_OFFHEAP_SIZE}
import org.apache.spark.sql.TPCDSBase
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast}
-import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
+import org.apache.spark.sql.catalyst.expressions.aggregate.Average
import org.apache.spark.sql.catalyst.util.resourceToString
import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec,
SparkPlan, SubqueryBroadcastExec, SubqueryExec}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
@@ -228,7 +228,6 @@ trait CometPlanStabilitySuite extends
DisableAdaptiveExecutionSuite with TPCDSBa
CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key ->
"true",
// Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI
support
CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true",
- CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true",
// as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]