This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new b326637  fix: CometShuffleExchangeExec logical link should be correct 
(#324)
b326637 is described below

commit b326637de6fdd711569413d4f6319ef122b51a81
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Apr 29 23:10:21 2024 -0700

    fix: CometShuffleExchangeExec logical link should be correct (#324)
    
    * fix: CometShuffleExchangeExec logical link should be correct
    
    * Implement equals and hashCode for CometShuffleExchangeExec
    
    * Update plan stability
    
    * Restore plan stability
    
    * Dedup test
    
    * Remove unused import
    
    * Fix test
    
    * Use columnar shuffle
---
 .../apache/comet/CometSparkSessionExtensions.scala | 34 ++++++++++++++++++++--
 .../shuffle/CometShuffleExchangeExec.scala         | 21 +++++++++++++
 .../org/apache/spark/sql/comet/operators.scala     |  8 +++++
 .../comet/shims/ShimCometShuffleExchangeExec.scala |  1 +
 .../org/apache/comet/exec/CometExecSuite.scala     | 31 ++++++++++++++++++--
 .../org/apache/spark/sql/CometTPCHQuerySuite.scala |  1 +
 6 files changed, 91 insertions(+), 5 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 8ef8cb8..1708544 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -741,9 +741,37 @@ class CometSparkSessionExtensions
         }
 
         // Set up logical links
-        newPlan = newPlan.transform { case op: CometExec =>
-          op.originalPlan.logicalLink.foreach(op.setLogicalLink)
-          op
+        newPlan = newPlan.transform {
+          case op: CometExec =>
+            if (op.originalPlan.logicalLink.isEmpty) {
+              op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
+              op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
+            } else {
+              op.originalPlan.logicalLink.foreach(op.setLogicalLink)
+            }
+            op
+          case op: CometShuffleExchangeExec =>
+            // Original Spark shuffle exchange operator might have empty 
logical link.
+            // But the `setLogicalLink` call above on downstream operator of
+            // `CometShuffleExchangeExec` will set its logical link to the 
downstream
+            // operators which cause AQE behavior to be incorrect. So we need 
to unset
+            // the logical link here.
+            if (op.originalPlan.logicalLink.isEmpty) {
+              op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
+              op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
+            } else {
+              op.originalPlan.logicalLink.foreach(op.setLogicalLink)
+            }
+            op
+
+          case op: CometBroadcastExchangeExec =>
+            if (op.originalPlan.logicalLink.isEmpty) {
+              op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
+              op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
+            } else {
+              op.originalPlan.logicalLink.foreach(op.setLogicalLink)
+            }
+            op
         }
 
         // Convert native execution block by linking consecutive native 
operators.
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 232b6bf..fb2f2a2 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -50,6 +50,8 @@ import org.apache.spark.util.MutablePair
 import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, 
RecordComparator}
 import org.apache.spark.util.random.XORShiftRandom
 
+import com.google.common.base.Objects
+
 import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, 
QueryPlanSerde}
 import org.apache.comet.serde.OperatorOuterClass.Operator
 import org.apache.comet.serde.QueryPlanSerde.serializeDataType
@@ -61,6 +63,7 @@ import org.apache.comet.shims.ShimCometShuffleExchangeExec
 case class CometShuffleExchangeExec(
     override val outputPartitioning: Partitioning,
     child: SparkPlan,
+    originalPlan: ShuffleExchangeLike,
     shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS,
     shuffleType: ShuffleType = CometNativeShuffle,
     advisoryPartitionSize: Option[Long] = None)
@@ -192,6 +195,24 @@ case class CometShuffleExchangeExec(
 
   override protected def withNewChildInternal(newChild: SparkPlan): 
CometShuffleExchangeExec =
     copy(child = newChild)
+
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case other: CometShuffleExchangeExec =>
+        this.outputPartitioning == other.outputPartitioning &&
+        this.shuffleOrigin == other.shuffleOrigin && this.child == other.child 
&&
+        this.shuffleType == other.shuffleType &&
+        this.advisoryPartitionSize == other.advisoryPartitionSize
+      case _ =>
+        false
+    }
+  }
+
+  override def hashCode(): Int =
+    Objects.hashCode(outputPartitioning, shuffleOrigin, shuffleType, 
advisoryPartitionSize, child)
+
+  override def stringArgs: Iterator[Any] =
+    Iterator(outputPartitioning, shuffleOrigin, shuffleType, child) ++ 
Iterator(s"[plan_id=$id]")
 }
 
 object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 4e6d997..ad07ff0 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -240,6 +240,7 @@ abstract class CometNativeExec extends CometExec {
         val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find {
           case (_: CometBroadcastExchangeExec, _) => false
           case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), 
_) => false
+          case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => 
false
           case _ => true
         }
 
@@ -264,6 +265,13 @@ abstract class CometNativeExec extends CometExec {
               inputs += 
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
             case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) 
=>
               inputs += 
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
+            case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
+              inputs += 
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
+            case BroadcastQueryStageExec(
+                  _,
+                  ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
+                  _) =>
+              inputs += 
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
             case _ if idx == firstNonBroadcastPlan.get._2 =>
               inputs += firstNonBroadcastPlanRDD
             case _ =>
diff --git 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
 
b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
index f89dbb8..6b4fad9 100644
--- 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
+++ 
b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
@@ -32,6 +32,7 @@ trait ShimCometShuffleExchangeExec {
     CometShuffleExchangeExec(
       s.outputPartitioning,
       s.child,
+      s,
       s.shuffleOrigin,
       shuffleType,
       advisoryPartitionSize)
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 3eb20aa..f371835 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -38,7 +38,7 @@ import 
org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
 import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec, 
CometHashAggregateExec, CometHashJoinExec, CometProjectExec, 
CometRowToColumnarExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, 
CometTakeOrderedAndProjectExec}
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometShuffleExchangeExec}
 import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, 
SQLExecution, UnionExec}
-import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, 
CartesianProductExec, SortMergeJoinExec}
 import org.apache.spark.sql.execution.window.WindowExec
 import org.apache.spark.sql.expressions.Window
@@ -62,6 +62,29 @@ class CometExecSuite extends CometTestBase {
     }
   }
 
+  test("CometShuffleExchangeExec logical link should be correct") {
+    withTempView("v") {
+      spark.sparkContext
+        .parallelize((1 to 4).map(i => TestData(i, i.toString)), 2)
+        .toDF("c1", "c2")
+        .createOrReplaceTempView("v")
+
+      Seq(true, false).foreach { columnarShuffle =>
+        withSQLConf(
+          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+          CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> 
columnarShuffle.toString) {
+          val df = sql("SELECT * FROM v where c1 = 1 order by c1, c2")
+          val shuffle = find(df.queryExecution.executedPlan) {
+            case _: CometShuffleExchangeExec if columnarShuffle => true
+            case _: ShuffleExchangeExec if !columnarShuffle => true
+            case _ => false
+          }.get
+          assert(shuffle.logicalLink.isEmpty)
+        }
+      }
+    }
+  }
+
   test("Ensure that the correct outputPartitioning of CometSort") {
     withTable("test_data") {
       val tableDF = spark.sparkContext
@@ -302,7 +325,8 @@ class CometExecSuite extends CometTestBase {
     withSQLConf(
       CometConf.COMET_EXEC_ENABLED.key -> "true",
       CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true",
-      "spark.sql.autoBroadcastJoinThreshold" -> "0",
+      "spark.sql.adaptive.autoBroadcastJoinThreshold" -> "-1",
+      "spark.sql.autoBroadcastJoinThreshold" -> "-1",
       "spark.sql.join.preferSortMergeJoin" -> "true") {
       withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl1") {
         withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl2") {
@@ -407,6 +431,7 @@ class CometExecSuite extends CometTestBase {
     withSQLConf(
       SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true",
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
       CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
       CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
       withTable(tableName, dim) {
@@ -1340,3 +1365,5 @@ case class BucketedTableTestSpec(
     expectedShuffle: Boolean = true,
     expectedSort: Boolean = true,
     expectedNumOutputPartitions: Option[Int] = None)
+
+case class TestData(key: Int, value: String)
diff --git 
a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
index 6ec25dd..8d7111e 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
@@ -91,6 +91,7 @@ class CometTPCHQuerySuite extends QueryTest with CometTPCBase 
with SQLQueryTestH
     conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true")
     conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true")
     conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true")
+    conf.set(CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key, "true")
     conf.set(MEMORY_OFFHEAP_ENABLED.key, "true")
     conf.set(MEMORY_OFFHEAP_SIZE.key, "2g")
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to