This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new b326637 fix: CometShuffleExchangeExec logical link should be correct
(#324)
b326637 is described below
commit b326637de6fdd711569413d4f6319ef122b51a81
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Apr 29 23:10:21 2024 -0700
fix: CometShuffleExchangeExec logical link should be correct (#324)
* fix: CometShuffleExchangeExec logical link should be correct
* Implement equals and hashCode for CometShuffleExchangeExec
* Update plan stability
* Restore plan stability
* Dedup test
* Remove unused import
* Fix test
* Use columnar shuffle
---
.../apache/comet/CometSparkSessionExtensions.scala | 34 ++++++++++++++++++++--
.../shuffle/CometShuffleExchangeExec.scala | 21 +++++++++++++
.../org/apache/spark/sql/comet/operators.scala | 8 +++++
.../comet/shims/ShimCometShuffleExchangeExec.scala | 1 +
.../org/apache/comet/exec/CometExecSuite.scala | 31 ++++++++++++++++++--
.../org/apache/spark/sql/CometTPCHQuerySuite.scala | 1 +
6 files changed, 91 insertions(+), 5 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 8ef8cb8..1708544 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -741,9 +741,37 @@ class CometSparkSessionExtensions
}
// Set up logical links
- newPlan = newPlan.transform { case op: CometExec =>
- op.originalPlan.logicalLink.foreach(op.setLogicalLink)
- op
+ newPlan = newPlan.transform {
+ case op: CometExec =>
+ if (op.originalPlan.logicalLink.isEmpty) {
+ op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
+ op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
+ } else {
+ op.originalPlan.logicalLink.foreach(op.setLogicalLink)
+ }
+ op
+ case op: CometShuffleExchangeExec =>
+ // Original Spark shuffle exchange operator might have empty
logical link.
+ // But the `setLogicalLink` call above on downstream operator of
+ // `CometShuffleExchangeExec` will set its logical link to the
downstream
+ // operators which cause AQE behavior to be incorrect. So we need
to unset
+ // the logical link here.
+ if (op.originalPlan.logicalLink.isEmpty) {
+ op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
+ op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
+ } else {
+ op.originalPlan.logicalLink.foreach(op.setLogicalLink)
+ }
+ op
+
+ case op: CometBroadcastExchangeExec =>
+ if (op.originalPlan.logicalLink.isEmpty) {
+ op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
+ op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
+ } else {
+ op.originalPlan.logicalLink.foreach(op.setLogicalLink)
+ }
+ op
}
// Convert native execution block by linking consecutive native
operators.
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 232b6bf..fb2f2a2 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -50,6 +50,8 @@ import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators,
RecordComparator}
import org.apache.spark.util.random.XORShiftRandom
+import com.google.common.base.Objects
+
import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass,
QueryPlanSerde}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde.serializeDataType
@@ -61,6 +63,7 @@ import org.apache.comet.shims.ShimCometShuffleExchangeExec
case class CometShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
+ originalPlan: ShuffleExchangeLike,
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS,
shuffleType: ShuffleType = CometNativeShuffle,
advisoryPartitionSize: Option[Long] = None)
@@ -192,6 +195,24 @@ case class CometShuffleExchangeExec(
override protected def withNewChildInternal(newChild: SparkPlan):
CometShuffleExchangeExec =
copy(child = newChild)
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case other: CometShuffleExchangeExec =>
+ this.outputPartitioning == other.outputPartitioning &&
+ this.shuffleOrigin == other.shuffleOrigin && this.child == other.child
&&
+ this.shuffleType == other.shuffleType &&
+ this.advisoryPartitionSize == other.advisoryPartitionSize
+ case _ =>
+ false
+ }
+ }
+
+ override def hashCode(): Int =
+ Objects.hashCode(outputPartitioning, shuffleOrigin, shuffleType,
advisoryPartitionSize, child)
+
+ override def stringArgs: Iterator[Any] =
+ Iterator(outputPartitioning, shuffleOrigin, shuffleType, child) ++
Iterator(s"[plan_id=$id]")
}
object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 4e6d997..ad07ff0 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -240,6 +240,7 @@ abstract class CometNativeExec extends CometExec {
val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find {
case (_: CometBroadcastExchangeExec, _) => false
case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _),
_) => false
+ case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) =>
false
case _ => true
}
@@ -264,6 +265,13 @@ abstract class CometNativeExec extends CometExec {
inputs +=
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _)
=>
inputs +=
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
+ case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
+ inputs +=
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
+ case BroadcastQueryStageExec(
+ _,
+ ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
+ _) =>
+ inputs +=
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case _ if idx == firstNonBroadcastPlan.get._2 =>
inputs += firstNonBroadcastPlanRDD
case _ =>
diff --git
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
index f89dbb8..6b4fad9 100644
---
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
+++
b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
@@ -32,6 +32,7 @@ trait ShimCometShuffleExchangeExec {
CometShuffleExchangeExec(
s.outputPartitioning,
s.child,
+ s,
s.shuffleOrigin,
shuffleType,
advisoryPartitionSize)
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 3eb20aa..f371835 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -38,7 +38,7 @@ import
org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec,
CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec,
CometHashAggregateExec, CometHashJoinExec, CometProjectExec,
CometRowToColumnarExec, CometScanExec, CometSortExec, CometSortMergeJoinExec,
CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec,
SQLExecution, UnionExec}
-import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec,
CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.Window
@@ -62,6 +62,29 @@ class CometExecSuite extends CometTestBase {
}
}
+ test("CometShuffleExchangeExec logical link should be correct") {
+ withTempView("v") {
+ spark.sparkContext
+ .parallelize((1 to 4).map(i => TestData(i, i.toString)), 2)
+ .toDF("c1", "c2")
+ .createOrReplaceTempView("v")
+
+ Seq(true, false).foreach { columnarShuffle =>
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key ->
columnarShuffle.toString) {
+ val df = sql("SELECT * FROM v where c1 = 1 order by c1, c2")
+ val shuffle = find(df.queryExecution.executedPlan) {
+ case _: CometShuffleExchangeExec if columnarShuffle => true
+ case _: ShuffleExchangeExec if !columnarShuffle => true
+ case _ => false
+ }.get
+ assert(shuffle.logicalLink.isEmpty)
+ }
+ }
+ }
+ }
+
test("Ensure that the correct outputPartitioning of CometSort") {
withTable("test_data") {
val tableDF = spark.sparkContext
@@ -302,7 +325,8 @@ class CometExecSuite extends CometTestBase {
withSQLConf(
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true",
- "spark.sql.autoBroadcastJoinThreshold" -> "0",
+ "spark.sql.adaptive.autoBroadcastJoinThreshold" -> "-1",
+ "spark.sql.autoBroadcastJoinThreshold" -> "-1",
"spark.sql.join.preferSortMergeJoin" -> "true") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl1") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl2") {
@@ -407,6 +431,7 @@ class CometExecSuite extends CometTestBase {
withSQLConf(
SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
withTable(tableName, dim) {
@@ -1340,3 +1365,5 @@ case class BucketedTableTestSpec(
expectedShuffle: Boolean = true,
expectedSort: Boolean = true,
expectedNumOutputPartitions: Option[Int] = None)
+
+case class TestData(key: Int, value: String)
diff --git
a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
index 6ec25dd..8d7111e 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
@@ -91,6 +91,7 @@ class CometTPCHQuerySuite extends QueryTest with CometTPCBase
with SQLQueryTestH
conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true")
+ conf.set(CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key, "true")
conf.set(MEMORY_OFFHEAP_ENABLED.key, "true")
conf.set(MEMORY_OFFHEAP_SIZE.key, "2g")
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]