This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 18150fb6 chore: Don't transform the HashAggregate to
CometHashAggregate if Comet shuffle is disabled (#991)
18150fb6 is described below
commit 18150fb60cee0f63e216a31338dee2217c70743d
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Oct 1 16:14:57 2024 -0700
chore: Don't transform the HashAggregate to CometHashAggregate if Comet
shuffle is disabled (#991)
---
.../org/apache/comet/CometSparkSessionExtensions.scala | 18 +++++++++++++-----
1 file changed, 13 insertions(+), 5 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index d52f31f5..bf09d641 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -439,7 +439,11 @@ class CometSparkSessionExtensions
case op: BaseAggregateExec
if op.isInstanceOf[HashAggregateExec] ||
- op.isInstanceOf[ObjectHashAggregateExec] =>
+ op.isInstanceOf[ObjectHashAggregateExec] &&
+ // When Comet shuffle is disabled, we don't want to transform
the HashAggregate
+ // to CometHashAggregate. Otherwise, we probably get partial
Comet aggregation
+ // and final Spark aggregation.
+ isCometShuffleEnabled(conf) =>
val groupingExprs = op.groupingExpressions
val aggExprs = op.aggregateExpressions
val resultExpressions = op.resultExpressions
@@ -451,8 +455,10 @@ class CometSparkSessionExtensions
// Fallback to Spark nevertheless here.
op
} else {
+ // For a final mode HashAggregate, we only need to transform the
HashAggregate
+ // if there is Comet partial aggregation.
val sparkFinalMode = {
- !modes.isEmpty && modes.head == Final &&
findPartialAgg(child).isEmpty
+ !modes.isEmpty && modes.head == Final &&
findCometPartialAgg(child).isEmpty
}
if (sparkFinalMode) {
@@ -995,13 +1001,15 @@ class CometSparkSessionExtensions
* Find the first Comet partial aggregate in the plan. If it reaches a
Spark HashAggregate
* with partial mode, it will return None.
*/
- def findPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
+ def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] =
{
plan.collectFirst {
case agg: CometHashAggregateExec if
agg.aggregateExpressions.forall(_.mode == Partial) =>
Some(agg)
case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode
== Partial) => None
- case a: AQEShuffleReadExec => findPartialAgg(a.child)
- case s: ShuffleQueryStageExec => findPartialAgg(s.plan)
+ case agg: ObjectHashAggregateExec if
agg.aggregateExpressions.forall(_.mode == Partial) =>
+ None
+ case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
+ case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan)
}.flatten
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]