This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new ec8da307 fix: Reuse CometBroadcastExchangeExec with Spark
ReuseExchangeAndSubquery rule (#441)
ec8da307 is described below
commit ec8da3072911ad96ced59d02a06352f896b969c3
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sat May 18 09:26:11 2024 -0700
fix: Reuse CometBroadcastExchangeExec with Spark ReuseExchangeAndSubquery
rule (#441)
---
.../apache/comet/CometSparkSessionExtensions.scala | 6 ++--
.../sql/comet/CometBroadcastExchangeExec.scala | 4 +--
.../org/apache/spark/sql/comet/operators.scala | 4 +--
.../org/apache/comet/exec/CometExecSuite.scala | 37 +++++++++++++++++++++-
4 files changed, 44 insertions(+), 7 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 7ddc950e..7c269c41 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -46,7 +46,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.comet.CometConf._
-import org.apache.comet.CometSparkSessionExtensions.{createMessage,
isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled,
isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan,
isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported,
shouldApplyRowToColumnar, withInfo, withInfos}
+import org.apache.comet.CometSparkSessionExtensions.{createMessage,
isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled,
isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan,
isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus,
shouldApplyRowToColumnar, withInfo, withInfos}
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde
@@ -576,11 +576,13 @@ class CometSparkSessionExtensions
// exchange. It is only used for Comet native execution. We only
transform Spark broadcast
// exchange to Comet broadcast exchange if its downstream is a Comet
native plan or if the
// broadcast exchange is forced to be enabled by Comet config.
+ // Note that `CometBroadcastExchangeExec` is only supported for Spark
3.4+.
case plan if
plan.children.exists(_.isInstanceOf[BroadcastExchangeExec]) =>
val newChildren = plan.children.map {
case b: BroadcastExchangeExec
if isCometNative(b.child) &&
- isCometOperatorEnabled(conf, "broadcastExchangeExec") =>
+ isCometOperatorEnabled(conf, "broadcastExchangeExec") &&
+ isSpark34Plus => // Spark 3.4+ only
QueryPlanSerde.operator2Proto(b) match {
case Some(nativeOp) =>
val cometOp = CometBroadcastExchangeExec(b, b.child)
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index 06c5898f..7bd34deb 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -238,13 +238,13 @@ case class CometBroadcastExchangeExec(originalPlan:
SparkPlan, child: SparkPlan)
obj match {
case other: CometBroadcastExchangeExec =>
this.originalPlan == other.originalPlan &&
- this.output == other.output && this.child == other.child
+ this.child == other.child
case _ =>
false
}
}
- override def hashCode(): Int = Objects.hashCode(output, child)
+ override def hashCode(): Int = Objects.hashCode(child)
override def stringArgs: Iterator[Any] = Iterator(output, child)
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index ad07ff0e..63587af3 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -406,14 +406,14 @@ case class CometProjectExec(
obj match {
case other: CometProjectExec =>
this.projectList == other.projectList &&
- this.output == other.output && this.child == other.child &&
+ this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}
- override def hashCode(): Int = Objects.hashCode(projectList, output, child)
+ override def hashCode(): Int = Objects.hashCode(projectList, child)
override protected def outputExpressions: Seq[NamedExpression] = projectList
}
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 8f022988..2e144428 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -38,8 +38,9 @@ import
org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec,
CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec,
CometHashAggregateExec, CometHashJoinExec, CometProjectExec,
CometRowToColumnarExec, CometScanExec, CometSortExec, CometSortMergeJoinExec,
CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec,
SQLExecution, UnionExec}
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec,
CartesianProductExec, SortMergeJoinExec}
+import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, date_add, expr, lead, sum}
@@ -62,6 +63,40 @@ class CometExecSuite extends CometTestBase {
}
}
+ test("ReusedExchangeExec should work on CometBroadcastExchangeExec") {
+ assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark
3.4+")
+ withSQLConf(
+ CometConf.COMET_EXEC_BROADCAST_FORCE_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+ SQLConf.USE_V1_SOURCE_LIST.key -> "") {
+ withTempPath { path =>
+ spark
+ .range(5)
+ .withColumn("p", $"id" % 2)
+ .write
+ .mode("overwrite")
+ .partitionBy("p")
+ .parquet(path.toString)
+ withTempView("t") {
+ spark.read.parquet(path.toString).createOrReplaceTempView("t")
+ val df = sql("""
+ |SELECT t1.id, t2.id, t3.id
+ |FROM t AS t1
+ |JOIN t AS t2 ON t2.id = t1.id
+ |JOIN t AS t3 ON t3.id = t2.id
+ |WHERE t1.p = 1 AND t2.p = 1 AND t3.p = 1
+ |""".stripMargin)
+ val reusedPlan =
ReuseExchangeAndSubquery.apply(df.queryExecution.executedPlan)
+ val reusedExchanges = collect(reusedPlan) { case r:
ReusedExchangeExec =>
+ r
+ }
+ assert(reusedExchanges.size == 1)
+
assert(reusedExchanges.head.child.isInstanceOf[CometBroadcastExchangeExec])
+ }
+ }
+ }
+ }
+
test("CometShuffleExchangeExec logical link should be correct") {
withTempView("v") {
spark.sparkContext
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]