This is an automated email from the ASF dual-hosted git repository.

mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d63bc13c minor: Small refactor for consistent serde for hash 
aggregate (#2764)
0d63bc13c is described below

commit 0d63bc13c20e47570b4f2dc1a994df6920e8bc4e
Author: Andy Grove <[email protected]>
AuthorDate: Wed Nov 12 15:15:39 2025 -0700

    minor: Small refactor for consistent serde for hash aggregate (#2764)
---
 .../org/apache/comet/rules/CometExecRule.scala     | 88 ++++++++--------------
 .../comet/serde/operator/CometAggregate.scala      | 41 +++++++++-
 .../org/apache/spark/sql/comet/operators.scala     | 10 ++-
 3 files changed, 81 insertions(+), 58 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala 
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index 708a48d6a..c10a8b5af 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -23,7 +23,6 @@ import scala.collection.mutable.ListBuffer
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, 
EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, 
GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, 
NamedExpression, Remainder}
-import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
 import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
RangePartitioning, RoundRobinPartitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -32,7 +31,7 @@ import org.apache.spark.sql.comet._
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
-import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, 
HashAggregateExec, ObjectHashAggregateExec}
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec}
 import org.apache.spark.sql.execution.command.ExecutedCommandExec
 import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
@@ -232,44 +231,37 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
           op,
           CometExpandExec(_, op, op.output, op.projections, op.child, 
SerializedPlan(None)))
 
-      // When Comet shuffle is disabled, we don't want to transform the 
HashAggregate
-      // to CometHashAggregate. Otherwise, we probably get partial Comet 
aggregation
-      // and final Spark aggregation.
-      case op: BaseAggregateExec
-          if op.isInstanceOf[HashAggregateExec] ||
-            op.isInstanceOf[ObjectHashAggregateExec] &&
-            isCometShuffleEnabled(conf) =>
-        val modes = op.aggregateExpressions.map(_.mode).distinct
-        // In distinct aggregates there can be a combination of modes
-        val multiMode = modes.size > 1
-        // For a final mode HashAggregate, we only need to transform the 
HashAggregate
-        // if there is Comet partial aggregation.
-        val sparkFinalMode = modes.contains(Final) && 
findCometPartialAgg(op.child).isEmpty
-
-        if (multiMode || sparkFinalMode) {
-          op
-        } else {
-          newPlanWithProto(
-            op,
-            nativeOp => {
-              // The aggExprs could be empty. For example, if the aggregate 
functions only have
-              // distinct aggregate functions or only have group by, the 
aggExprs is empty and
-              // modes is empty too. If aggExprs is not empty, we need to 
verify all the
-              // aggregates have the same mode.
-              assert(modes.length == 1 || modes.isEmpty)
-              CometHashAggregateExec(
-                nativeOp,
-                op,
-                op.output,
-                op.groupingExpressions,
-                op.aggregateExpressions,
-                op.resultExpressions,
-                op.child.output,
-                modes.headOption,
-                op.child,
-                SerializedPlan(None))
-            })
-        }
+      case op: HashAggregateExec =>
+        newPlanWithProto(
+          op,
+          nativeOp => {
+            CometHashAggregateExec(
+              nativeOp,
+              op,
+              op.output,
+              op.groupingExpressions,
+              op.aggregateExpressions,
+              op.resultExpressions,
+              op.child.output,
+              op.child,
+              SerializedPlan(None))
+          })
+
+      case op: ObjectHashAggregateExec =>
+        newPlanWithProto(
+          op,
+          nativeOp => {
+            CometHashAggregateExec(
+              nativeOp,
+              op,
+              op.output,
+              op.groupingExpressions,
+              op.aggregateExpressions,
+              op.resultExpressions,
+              op.child.output,
+              op.child,
+              SerializedPlan(None))
+          })
 
       case op: ShuffledHashJoinExec
           if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
@@ -738,22 +730,6 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
     }
   }
 
-  /**
-   * Find the first Comet partial aggregate in the plan. If it reaches a Spark 
HashAggregate with
-   * partial mode, it will return None.
-   */
-  private def findCometPartialAgg(plan: SparkPlan): 
Option[CometHashAggregateExec] = {
-    plan.collectFirst {
-      case agg: CometHashAggregateExec if 
agg.aggregateExpressions.forall(_.mode == Partial) =>
-        Some(agg)
-      case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == 
Partial) => None
-      case agg: ObjectHashAggregateExec if 
agg.aggregateExpressions.forall(_.mode == Partial) =>
-        None
-      case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
-      case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan)
-    }.flatten
-  }
-
   /**
    * Returns true if a given spark plan is Comet shuffle operator.
    */
diff --git 
a/spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala 
b/spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala
index 93e5d52c8..b0c359f08 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometAggregate.scala
@@ -23,11 +23,14 @@ import scala.jdk.CollectionConverters._
 
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
+import org.apache.spark.sql.comet.CometHashAggregateExec
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, 
ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, 
HashAggregateExec, ObjectHashAggregateExec}
 import org.apache.spark.sql.types.MapType
 
 import org.apache.comet.{CometConf, ConfigEntry}
-import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, 
withInfo}
 import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass}
 import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => 
CometAggregateMode, Operator}
 import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto}
@@ -38,6 +41,18 @@ trait CometBaseAggregate {
       aggregate: BaseAggregateExec,
       builder: Operator.Builder,
       childOp: OperatorOuterClass.Operator*): 
Option[OperatorOuterClass.Operator] = {
+
+    val modes = aggregate.aggregateExpressions.map(_.mode).distinct
+    // In distinct aggregates there can be a combination of modes
+    val multiMode = modes.size > 1
+    // For a final mode HashAggregate, we only need to transform the 
HashAggregate
+    // if there is Comet partial aggregation.
+    val sparkFinalMode = modes.contains(Final) && 
findCometPartialAgg(aggregate.child).isEmpty
+
+    if (multiMode || sparkFinalMode) {
+      return None
+    }
+
     val groupingExpressions = aggregate.groupingExpressions
     val aggregateExpressions = aggregate.aggregateExpressions
     val aggregateAttributes = aggregate.aggregateAttributes
@@ -163,6 +178,22 @@ trait CometBaseAggregate {
 
   }
 
+  /**
+   * Find the first Comet partial aggregate in the plan. If it reaches a Spark 
HashAggregate with
+   * partial mode, it will return None.
+   */
+  private def findCometPartialAgg(plan: SparkPlan): 
Option[CometHashAggregateExec] = {
+    plan.collectFirst {
+      case agg: CometHashAggregateExec if 
agg.aggregateExpressions.forall(_.mode == Partial) =>
+        Some(agg)
+      case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == 
Partial) => None
+      case agg: ObjectHashAggregateExec if 
agg.aggregateExpressions.forall(_.mode == Partial) =>
+        None
+      case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
+      case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan)
+    }.flatten
+  }
+
 }
 
 object CometHashAggregate extends CometOperatorSerde[HashAggregateExec] with 
CometBaseAggregate {
@@ -189,6 +220,14 @@ object CometObjectHashAggregate
       aggregate: ObjectHashAggregateExec,
       builder: Operator.Builder,
       childOp: OperatorOuterClass.Operator*): 
Option[OperatorOuterClass.Operator] = {
+
+    if (!isCometShuffleEnabled(aggregate.conf)) {
+      // When Comet shuffle is disabled, we don't want to transform the 
HashAggregate
+      // to CometHashAggregate. Otherwise, we probably get partial Comet 
aggregation
+      // and final Spark aggregation.
+      return None
+    }
+
     doConvert(aggregate, builder, childOp: _*)
   }
 }
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index d7a743eb2..3d1fccc98 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -739,11 +739,19 @@ case class CometHashAggregateExec(
     aggregateExpressions: Seq[AggregateExpression],
     resultExpressions: Seq[NamedExpression],
     input: Seq[Attribute],
-    mode: Option[AggregateMode],
     child: SparkPlan,
     override val serializedPlanOpt: SerializedPlan)
     extends CometUnaryExec
     with PartitioningPreservingUnaryExecNode {
+
+  // The aggExprs could be empty. For example, if the aggregate functions only 
have
+  // distinct aggregate functions or only have group by, the aggExprs is empty 
and
+  // modes is empty too. If aggExprs is not empty, we need to verify all the
+  // aggregates have the same mode.
+  val modes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct
+  assert(modes.length == 1 || modes.isEmpty)
+  val mode = modes.headOption
+
   override def producedAttributes: AttributeSet = outputSet ++ 
AttributeSet(resultExpressions)
 
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to