This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 2bf283581 feat: Support ANSI mode avg expr (int inputs) (#2817)
2bf283581 is described below

commit 2bf2835812e818da976d192f5bb24e8947828dc4
Author: B Vadlamani <[email protected]>
AuthorDate: Mon Dec 29 07:29:19 2025 -0800

    feat: Support ANSI mode avg expr (int inputs) (#2817)
---
 docs/source/user-guide/latest/compatibility.md     |  2 +-
 native/core/src/execution/planner.rs               | 10 +--
 native/proto/src/proto/expr.proto                  |  2 +-
 native/spark-expr/src/agg_funcs/avg.rs             | 14 ++--
 .../scala/org/apache/comet/serde/aggregates.scala  | 15 +---
 .../apache/comet/exec/CometAggregateSuite.scala    | 83 ++++++++++++++++++++++
 .../spark/sql/comet/CometPlanStabilitySuite.scala  |  3 -
 7 files changed, 99 insertions(+), 30 deletions(-)

diff --git a/docs/source/user-guide/latest/compatibility.md 
b/docs/source/user-guide/latest/compatibility.md
index 35bf09724..31270404c 100644
--- a/docs/source/user-guide/latest/compatibility.md
+++ b/docs/source/user-guide/latest/compatibility.md
@@ -36,7 +36,7 @@ Comet will fall back to Spark for the following expressions 
when ANSI mode is en
 `spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is 
the Spark expression class name. See
 the [Comet Supported Expressions Guide](expressions.md) for more information 
on this configuration setting.
 
-- Average
+- Average (supports all numeric inputs except decimal types)
 - Cast (in some cases)
 
 There is an [epic](https://github.com/apache/datafusion-comet/issues/313) 
where we are tracking the work to fully implement ANSI support.
diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index 8e8191dd0..93fbb59c1 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -1840,6 +1840,7 @@ impl PhysicalPlanner {
                 let child = self.create_expr(expr.child.as_ref().unwrap(), 
Arc::clone(&schema))?;
                 let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
                 let input_datatype = 
to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
+
                 let builder = match datatype {
                     DataType::Decimal128(_, _) => {
                         let func =
@@ -1847,12 +1848,11 @@ impl PhysicalPlanner {
                         AggregateExprBuilder::new(Arc::new(func), vec![child])
                     }
                     _ => {
-                        // cast to the result data type of AVG if the result 
data type is different
-                        // from the input type, e.g. AVG(Int32). We should not 
expect a cast
-                        // failure since it should have already been checked 
at Spark side.
+                        // For all other numeric types (Int8/16/32/64, 
Float32/64):
+                        // Cast to Float64 for accumulation
                         let child: Arc<dyn PhysicalExpr> =
-                            Arc::new(CastExpr::new(Arc::clone(&child), 
datatype.clone(), None));
-                        let func = AggregateUDF::new_from_impl(Avg::new("avg", 
datatype));
+                            Arc::new(CastExpr::new(Arc::clone(&child), 
DataType::Float64, None));
+                        let func = AggregateUDF::new_from_impl(Avg::new("avg", 
DataType::Float64));
                         AggregateExprBuilder::new(Arc::new(func), vec![child])
                     }
                 };
diff --git a/native/proto/src/proto/expr.proto 
b/native/proto/src/proto/expr.proto
index 1c453b633..5f258fd67 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -138,7 +138,7 @@ message Avg {
   Expr child = 1;
   DataType datatype = 2;
   DataType sum_datatype = 3;
-  bool fail_on_error = 4; // currently unused (useful for deciding Ansi vs 
Legacy mode)
+  EvalMode eval_mode = 4;
 }
 
 message First {
diff --git a/native/spark-expr/src/agg_funcs/avg.rs 
b/native/spark-expr/src/agg_funcs/avg.rs
index e746aaf6e..d1d71cca2 100644
--- a/native/spark-expr/src/agg_funcs/avg.rs
+++ b/native/spark-expr/src/agg_funcs/avg.rs
@@ -73,7 +73,7 @@ impl AggregateUDFImpl for Avg {
     }
 
     fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
-        // instantiate specialized accumulator based for the type
+        // All numeric types use Float64 accumulation after casting
         match (&self.input_data_type, &self.result_data_type) {
             (Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
             _ => not_impl_err!(
@@ -115,7 +115,6 @@ impl AggregateUDFImpl for Avg {
         &self,
         _args: AccumulatorArgs,
     ) -> Result<Box<dyn GroupsAccumulator>> {
-        // instantiate specialized accumulator based for the type
         match (&self.input_data_type, &self.result_data_type) {
             (Float64, Float64) => 
Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
                 &self.input_data_type,
@@ -172,7 +171,7 @@ impl Accumulator for AvgAccumulator {
         // counts are summed
         self.count += 
sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
 
-        // sums are summed
+        // sums are summed - no overflow checking in all Eval Modes
         if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
             let v = self.sum.get_or_insert(0.);
             *v += x;
@@ -182,7 +181,7 @@ impl Accumulator for AvgAccumulator {
 
     fn evaluate(&mut self) -> Result<ScalarValue> {
         if self.count == 0 {
-            // If all input are nulls, count will be 0 and we will get null 
after the division.
+            // If all input are nulls, count will be 0, and we will get null 
after the division.
             // This is consistent with Spark Average implementation.
             Ok(ScalarValue::Float64(None))
         } else {
@@ -198,7 +197,8 @@ impl Accumulator for AvgAccumulator {
 }
 
 /// An accumulator to compute the average of `[PrimitiveArray<T>]`.
-/// Stores values as native types, and does overflow checking
+/// Stores values as native types (
+/// no overflow check all eval modes since inf is a perfectly valid value per 
spark impl)
 ///
 /// F: Function that calculates the average value from a sum of
 /// T::Native and a total count
@@ -260,6 +260,7 @@ where
         if values.null_count() == 0 {
             for (&group_index, &value) in iter {
                 let sum = &mut self.sums[group_index];
+                // No overflow checking - Infinity is a valid result
                 *sum = (*sum).add_wrapping(value);
                 self.counts[group_index] += 1;
             }
@@ -296,7 +297,7 @@ where
             self.counts[group_index] += partial_count;
         }
 
-        // update sums
+        // update sums - no overflow checking (in all eval modes)
         self.sums.resize(total_num_groups, T::default_value());
         let iter2 = group_indices.iter().zip(partial_sums.values().iter());
         for (&group_index, &new_value) in iter2 {
@@ -325,7 +326,6 @@ where
         Ok(Arc::new(array))
     }
 
-    // return arrays for sums and counts
     fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
         let counts = emit_to.take_needed(&mut self.counts);
         let counts = Int64Array::new(counts.into(), None);
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala 
b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
index a05efaebb..8e58c0874 100644
--- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -21,7 +21,7 @@ package org.apache.comet.serde
 
 import scala.jdk.CollectionConverters._
 
-import org.apache.spark.sql.catalyst.expressions.{Attribute, EvalMode}
+import org.apache.spark.sql.catalyst.expressions.Attribute
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, 
CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, 
Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, 
IntegerType, LongType, ShortType, StringType}
@@ -151,17 +151,6 @@ object CometCount extends 
CometAggregateExpressionSerde[Count] {
 
 object CometAverage extends CometAggregateExpressionSerde[Average] {
 
-  override def getSupportLevel(avg: Average): SupportLevel = {
-    avg.evalMode match {
-      case EvalMode.ANSI =>
-        Incompatible(Some("ANSI mode is not supported"))
-      case EvalMode.TRY =>
-        Incompatible(Some("TRY mode is not supported"))
-      case _ =>
-        Compatible()
-    }
-  }
-
   override def convert(
       aggExpr: AggregateExpression,
       avg: Average,
@@ -193,7 +182,7 @@ object CometAverage extends 
CometAggregateExpressionSerde[Average] {
       val builder = ExprOuterClass.Avg.newBuilder()
       builder.setChild(childExpr.get)
       builder.setDatatype(dataType.get)
-      builder.setFailOnError(avg.evalMode == EvalMode.ANSI)
+      
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(avg.evalMode)))
       builder.setSumDatatype(sumDataType.get)
 
       Some(
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 9b2816c2f..14b5dc309 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -1471,6 +1471,89 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("AVG and try_avg - basic functionality") {
+    withParquetTable(
+      Seq(
+        (10L, 1),
+        (20L, 1),
+        (null.asInstanceOf[Long], 1),
+        (100L, 2),
+        (200L, 2),
+        (null.asInstanceOf[Long], 3)),
+      "tbl") {
+
+      Seq(true, false).foreach({ ansiMode =>
+        // without GROUP BY
+        withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
+          val res = sql("SELECT avg(_1) FROM tbl")
+          checkSparkAnswerAndOperator(res)
+        }
+
+        // with GROUP BY
+        withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
+          val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
+          checkSparkAnswerAndOperator(res)
+        }
+      })
+
+      // try_avg without GROUP BY
+      val resTry = sql("SELECT try_avg(_1) FROM tbl")
+      checkSparkAnswerAndOperator(resTry)
+
+      // try_avg with GROUP BY
+      val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2")
+      checkSparkAnswerAndOperator(resTryGroup)
+
+    }
+  }
+
+  test("AVG and try_avg - special numbers") {
+
+    val negativeNumbers: Seq[(Long, Int)] = Seq(
+      (-1L, 1),
+      (-123L, 1),
+      (-456L, 1),
+      (Long.MinValue, 1),
+      (Long.MinValue, 1),
+      (Long.MinValue, 2),
+      (Long.MinValue, 2),
+      (null.asInstanceOf[Long], 3))
+
+    val zeroSeq: Seq[(Long, Int)] =
+      Seq((0L, 1), (-0L, 1), (+0L, 2), (+0L, 2), (null.asInstanceOf[Long], 3))
+
+    val highValNumbers: Seq[(Long, Int)] = Seq(
+      (Long.MaxValue, 1),
+      (Long.MaxValue, 1),
+      (Long.MaxValue, 2),
+      (Long.MaxValue, 2),
+      (null.asInstanceOf[Long], 3))
+
+    val inputs = Seq(negativeNumbers, highValNumbers, zeroSeq)
+    inputs.foreach(inputSeq => {
+      withParquetTable(inputSeq, "tbl") {
+        Seq(true, false).foreach({ ansiMode =>
+          // without GROUP BY
+          withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
+            checkSparkAnswerAndOperator("SELECT avg(_1) FROM tbl")
+          }
+
+          // with GROUP BY
+          withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
+            checkSparkAnswerAndOperator("SELECT _2, avg(_1) FROM tbl GROUP BY 
_2")
+          }
+        })
+
+        // try_avg without GROUP BY
+        checkSparkAnswerAndOperator("SELECT try_avg(_1) FROM tbl")
+
+        // try_avg with GROUP BY
+        checkSparkAnswerAndOperator("SELECT _2, try_avg(_1) FROM tbl GROUP BY 
_2")
+
+      }
+    })
+  }
+
   test("ANSI support for sum - null test") {
     Seq(true, false).foreach { ansiEnabled =>
       withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
diff --git 
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala 
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
index b1848ff51..adf74ba54 100644
--- 
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
+++ 
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
@@ -29,7 +29,6 @@ import org.apache.spark.SparkContext
 import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, 
MEMORY_OFFHEAP_SIZE}
 import org.apache.spark.sql.TPCDSBase
 import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast}
-import org.apache.spark.sql.catalyst.expressions.aggregate.Average
 import org.apache.spark.sql.catalyst.util.resourceToString
 import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, 
SparkPlan, SubqueryBroadcastExec, SubqueryExec}
 import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
@@ -226,8 +225,6 @@ trait CometPlanStabilitySuite extends 
DisableAdaptiveExecutionSuite with TPCDSBa
       CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "false",
       CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
       CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> 
"true",
-      // Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI 
support
-      CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true",
       // as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64
       CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to