This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 18150fb6 chore: Don't transform the HashAggregate to 
CometHashAggregate if Comet shuffle is disabled (#991)
18150fb6 is described below

commit 18150fb60cee0f63e216a31338dee2217c70743d
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Oct 1 16:14:57 2024 -0700

    chore: Don't transform the HashAggregate to CometHashAggregate if Comet 
shuffle is disabled (#991)
---
 .../org/apache/comet/CometSparkSessionExtensions.scala | 18 +++++++++++++-----
 1 file changed, 13 insertions(+), 5 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index d52f31f5..bf09d641 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -439,7 +439,11 @@ class CometSparkSessionExtensions
 
         case op: BaseAggregateExec
             if op.isInstanceOf[HashAggregateExec] ||
-              op.isInstanceOf[ObjectHashAggregateExec] =>
+              op.isInstanceOf[ObjectHashAggregateExec] &&
+              // When Comet shuffle is disabled, we don't want to transform 
the HashAggregate
+              // to CometHashAggregate. Otherwise, we probably get partial 
Comet aggregation
+              // and final Spark aggregation.
+              isCometShuffleEnabled(conf) =>
           val groupingExprs = op.groupingExpressions
           val aggExprs = op.aggregateExpressions
           val resultExpressions = op.resultExpressions
@@ -451,8 +455,10 @@ class CometSparkSessionExtensions
             // Fallback to Spark nevertheless here.
             op
           } else {
+            // For a final mode HashAggregate, we only need to transform the 
HashAggregate
+            // if there is Comet partial aggregation.
             val sparkFinalMode = {
-              !modes.isEmpty && modes.head == Final && 
findPartialAgg(child).isEmpty
+              !modes.isEmpty && modes.head == Final && 
findCometPartialAgg(child).isEmpty
             }
 
             if (sparkFinalMode) {
@@ -995,13 +1001,15 @@ class CometSparkSessionExtensions
      * Find the first Comet partial aggregate in the plan. If it reaches a 
Spark HashAggregate
      * with partial mode, it will return None.
      */
-    def findPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
+    def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = 
{
       plan.collectFirst {
         case agg: CometHashAggregateExec if 
agg.aggregateExpressions.forall(_.mode == Partial) =>
           Some(agg)
         case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode 
== Partial) => None
-        case a: AQEShuffleReadExec => findPartialAgg(a.child)
-        case s: ShuffleQueryStageExec => findPartialAgg(s.plan)
+        case agg: ObjectHashAggregateExec if 
agg.aggregateExpressions.forall(_.mode == Partial) =>
+          None
+        case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
+        case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan)
       }.flatten
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to