This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 2bf283581 feat: Support ANSI mode avg expr (int inputs) (#2817)
2bf283581 is described below
commit 2bf2835812e818da976d192f5bb24e8947828dc4
Author: B Vadlamani <[email protected]>
AuthorDate: Mon Dec 29 07:29:19 2025 -0800
feat: Support ANSI mode avg expr (int inputs) (#2817)
---
docs/source/user-guide/latest/compatibility.md | 2 +-
native/core/src/execution/planner.rs | 10 +--
native/proto/src/proto/expr.proto | 2 +-
native/spark-expr/src/agg_funcs/avg.rs | 14 ++--
.../scala/org/apache/comet/serde/aggregates.scala | 15 +---
.../apache/comet/exec/CometAggregateSuite.scala | 83 ++++++++++++++++++++++
.../spark/sql/comet/CometPlanStabilitySuite.scala | 3 -
7 files changed, 99 insertions(+), 30 deletions(-)
diff --git a/docs/source/user-guide/latest/compatibility.md
b/docs/source/user-guide/latest/compatibility.md
index 35bf09724..31270404c 100644
--- a/docs/source/user-guide/latest/compatibility.md
+++ b/docs/source/user-guide/latest/compatibility.md
@@ -36,7 +36,7 @@ Comet will fall back to Spark for the following expressions
when ANSI mode is en
`spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is
the Spark expression class name. See
the [Comet Supported Expressions Guide](expressions.md) for more information
on this configuration setting.
-- Average
+- Average (supports all numeric inputs except decimal types)
- Cast (in some cases)
There is an [epic](https://github.com/apache/datafusion-comet/issues/313)
where we are tracking the work to fully implement ANSI support.
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index 8e8191dd0..93fbb59c1 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -1840,6 +1840,7 @@ impl PhysicalPlanner {
let child = self.create_expr(expr.child.as_ref().unwrap(),
Arc::clone(&schema))?;
let datatype =
to_arrow_datatype(expr.datatype.as_ref().unwrap());
let input_datatype =
to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
+
let builder = match datatype {
DataType::Decimal128(_, _) => {
let func =
@@ -1847,12 +1848,11 @@ impl PhysicalPlanner {
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
_ => {
- // cast to the result data type of AVG if the result
data type is different
- // from the input type, e.g. AVG(Int32). We should not
expect a cast
- // failure since it should have already been checked
at Spark side.
+ // For all other numeric types (Int8/16/32/64,
Float32/64):
+ // Cast to Float64 for accumulation
let child: Arc<dyn PhysicalExpr> =
- Arc::new(CastExpr::new(Arc::clone(&child),
datatype.clone(), None));
- let func = AggregateUDF::new_from_impl(Avg::new("avg",
datatype));
+ Arc::new(CastExpr::new(Arc::clone(&child),
DataType::Float64, None));
+ let func = AggregateUDF::new_from_impl(Avg::new("avg",
DataType::Float64));
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
};
diff --git a/native/proto/src/proto/expr.proto
b/native/proto/src/proto/expr.proto
index 1c453b633..5f258fd67 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -138,7 +138,7 @@ message Avg {
Expr child = 1;
DataType datatype = 2;
DataType sum_datatype = 3;
- bool fail_on_error = 4; // currently unused (useful for deciding Ansi vs
Legacy mode)
+ EvalMode eval_mode = 4;
}
message First {
diff --git a/native/spark-expr/src/agg_funcs/avg.rs
b/native/spark-expr/src/agg_funcs/avg.rs
index e746aaf6e..d1d71cca2 100644
--- a/native/spark-expr/src/agg_funcs/avg.rs
+++ b/native/spark-expr/src/agg_funcs/avg.rs
@@ -73,7 +73,7 @@ impl AggregateUDFImpl for Avg {
}
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
- // instantiate specialized accumulator based for the type
+ // All numeric types use Float64 accumulation after casting
match (&self.input_data_type, &self.result_data_type) {
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
_ => not_impl_err!(
@@ -115,7 +115,6 @@ impl AggregateUDFImpl for Avg {
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
- // instantiate specialized accumulator based for the type
match (&self.input_data_type, &self.result_data_type) {
(Float64, Float64) =>
Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
&self.input_data_type,
@@ -172,7 +171,7 @@ impl Accumulator for AvgAccumulator {
// counts are summed
self.count +=
sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();
- // sums are summed
+ // sums are summed - no overflow checking in all Eval Modes
if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
let v = self.sum.get_or_insert(0.);
*v += x;
@@ -182,7 +181,7 @@ impl Accumulator for AvgAccumulator {
fn evaluate(&mut self) -> Result<ScalarValue> {
if self.count == 0 {
- // If all input are nulls, count will be 0 and we will get null
after the division.
+ // If all input are nulls, count will be 0, and we will get null
after the division.
// This is consistent with Spark Average implementation.
Ok(ScalarValue::Float64(None))
} else {
@@ -198,7 +197,8 @@ impl Accumulator for AvgAccumulator {
}
/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
-/// Stores values as native types, and does overflow checking
+/// Stores values as native types (
+/// no overflow check all eval modes since inf is a perfectly valid value per
spark impl)
///
/// F: Function that calculates the average value from a sum of
/// T::Native and a total count
@@ -260,6 +260,7 @@ where
if values.null_count() == 0 {
for (&group_index, &value) in iter {
let sum = &mut self.sums[group_index];
+ // No overflow checking - Infinity is a valid result
*sum = (*sum).add_wrapping(value);
self.counts[group_index] += 1;
}
@@ -296,7 +297,7 @@ where
self.counts[group_index] += partial_count;
}
- // update sums
+ // update sums - no overflow checking (in all eval modes)
self.sums.resize(total_num_groups, T::default_value());
let iter2 = group_indices.iter().zip(partial_sums.values().iter());
for (&group_index, &new_value) in iter2 {
@@ -325,7 +326,6 @@ where
Ok(Arc::new(array))
}
- // return arrays for sums and counts
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let counts = emit_to.take_needed(&mut self.counts);
let counts = Int64Array::new(counts.into(), None);
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
index a05efaebb..8e58c0874 100644
--- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -21,7 +21,7 @@ package org.apache.comet.serde
import scala.jdk.CollectionConverters._
-import org.apache.spark.sql.catalyst.expressions.{Attribute, EvalMode}
+import org.apache.spark.sql.catalyst.expressions.Attribute
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate,
CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First,
Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType,
IntegerType, LongType, ShortType, StringType}
@@ -151,17 +151,6 @@ object CometCount extends
CometAggregateExpressionSerde[Count] {
object CometAverage extends CometAggregateExpressionSerde[Average] {
- override def getSupportLevel(avg: Average): SupportLevel = {
- avg.evalMode match {
- case EvalMode.ANSI =>
- Incompatible(Some("ANSI mode is not supported"))
- case EvalMode.TRY =>
- Incompatible(Some("TRY mode is not supported"))
- case _ =>
- Compatible()
- }
- }
-
override def convert(
aggExpr: AggregateExpression,
avg: Average,
@@ -193,7 +182,7 @@ object CometAverage extends
CometAggregateExpressionSerde[Average] {
val builder = ExprOuterClass.Avg.newBuilder()
builder.setChild(childExpr.get)
builder.setDatatype(dataType.get)
- builder.setFailOnError(avg.evalMode == EvalMode.ANSI)
+
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(avg.evalMode)))
builder.setSumDatatype(sumDataType.get)
Some(
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 9b2816c2f..14b5dc309 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -1471,6 +1471,89 @@ class CometAggregateSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("AVG and try_avg - basic functionality") {
+ withParquetTable(
+ Seq(
+ (10L, 1),
+ (20L, 1),
+ (null.asInstanceOf[Long], 1),
+ (100L, 2),
+ (200L, 2),
+ (null.asInstanceOf[Long], 3)),
+ "tbl") {
+
+ Seq(true, false).foreach({ ansiMode =>
+ // without GROUP BY
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
+ val res = sql("SELECT avg(_1) FROM tbl")
+ checkSparkAnswerAndOperator(res)
+ }
+
+ // with GROUP BY
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
+ val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(res)
+ }
+ })
+
+ // try_avg without GROUP BY
+ val resTry = sql("SELECT try_avg(_1) FROM tbl")
+ checkSparkAnswerAndOperator(resTry)
+
+ // try_avg with GROUP BY
+ val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2")
+ checkSparkAnswerAndOperator(resTryGroup)
+
+ }
+ }
+
+ test("AVG and try_avg - special numbers") {
+
+ val negativeNumbers: Seq[(Long, Int)] = Seq(
+ (-1L, 1),
+ (-123L, 1),
+ (-456L, 1),
+ (Long.MinValue, 1),
+ (Long.MinValue, 1),
+ (Long.MinValue, 2),
+ (Long.MinValue, 2),
+ (null.asInstanceOf[Long], 3))
+
+ val zeroSeq: Seq[(Long, Int)] =
+ Seq((0L, 1), (-0L, 1), (+0L, 2), (+0L, 2), (null.asInstanceOf[Long], 3))
+
+ val highValNumbers: Seq[(Long, Int)] = Seq(
+ (Long.MaxValue, 1),
+ (Long.MaxValue, 1),
+ (Long.MaxValue, 2),
+ (Long.MaxValue, 2),
+ (null.asInstanceOf[Long], 3))
+
+ val inputs = Seq(negativeNumbers, highValNumbers, zeroSeq)
+ inputs.foreach(inputSeq => {
+ withParquetTable(inputSeq, "tbl") {
+ Seq(true, false).foreach({ ansiMode =>
+ // without GROUP BY
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
+ checkSparkAnswerAndOperator("SELECT avg(_1) FROM tbl")
+ }
+
+ // with GROUP BY
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
+ checkSparkAnswerAndOperator("SELECT _2, avg(_1) FROM tbl GROUP BY
_2")
+ }
+ })
+
+ // try_avg without GROUP BY
+ checkSparkAnswerAndOperator("SELECT try_avg(_1) FROM tbl")
+
+ // try_avg with GROUP BY
+ checkSparkAnswerAndOperator("SELECT _2, try_avg(_1) FROM tbl GROUP BY
_2")
+
+ }
+ })
+ }
+
test("ANSI support for sum - null test") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
diff --git
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
index b1848ff51..adf74ba54 100644
---
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
+++
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
@@ -29,7 +29,6 @@ import org.apache.spark.SparkContext
import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED,
MEMORY_OFFHEAP_SIZE}
import org.apache.spark.sql.TPCDSBase
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast}
-import org.apache.spark.sql.catalyst.expressions.aggregate.Average
import org.apache.spark.sql.catalyst.util.resourceToString
import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec,
SparkPlan, SubqueryBroadcastExec, SubqueryExec}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
@@ -226,8 +225,6 @@ trait CometPlanStabilitySuite extends
DisableAdaptiveExecutionSuite with TPCDSBa
CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "false",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key ->
"true",
- // Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI
support
- CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true",
// as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]