This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 4639a2477 fix: Make AQE capable of converting Comet shuffled joins to
Comet broadcast hash joins (#1605)
4639a2477 is described below
commit 4639a24772d29ff8fa45533c9314a4a9eff787c8
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Fri Apr 4 22:08:02 2025 +0800
fix: Make AQE capable of converting Comet shuffled joins to Comet broadcast
hash joins (#1605)
* Override outputPartitioning in CometBroadcastExchangeExec to make AQE
capable of converting comet shuffled joins to comet broadcast hash joins
* Support executeBroadcast for CometColumnarToRow
* Add tests, revert my changes to how CometBroadcastExchange is displayed
* Making newly added test fail before applying this fix
* Remove unused imports
* Fix test failure caused by spark conf pollution
---
.../execution/arrow}/ArrowReaderIterator.scala | 2 +-
.../org/apache/spark/sql/comet/util/Utils.scala | 25 ++++-
.../apache/comet/CometSparkSessionExtensions.scala | 27 +++++-
.../org/apache/comet/serde/QueryPlanSerde.scala | 2 +
.../sql/comet/CometBroadcastExchangeExec.scala | 34 ++-----
.../spark/sql/comet/CometColumnarToRowExec.scala | 104 ++++++++++++++++++++-
.../org/apache/comet/exec/CometExecSuite.scala | 97 ++++++++++++++++++-
7 files changed, 258 insertions(+), 33 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowReaderIterator.scala
similarity index 97%
rename from
spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
rename to
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowReaderIterator.scala
index 933e0b661..0d0093a10 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
+++
b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowReaderIterator.scala
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.spark.sql.comet.execution.shuffle
+package org.apache.spark.sql.comet.execution.arrow
import java.nio.channels.ReadableByteChannel
diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
index 207474286..1b4275180 100644
--- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
+++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
@@ -19,7 +19,7 @@
package org.apache.spark.sql.comet.util
-import java.io.{DataOutputStream, File}
+import java.io.{DataInputStream, DataOutputStream, File}
import java.nio.ByteBuffer
import java.nio.channels.Channels
@@ -35,6 +35,7 @@ import org.apache.arrow.vector.types._
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
+import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.io.{ChunkedByteBuffer,
ChunkedByteBufferOutputStream}
@@ -226,6 +227,28 @@ object Utils {
}
}
+ /**
+ * Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
+ * @param bytes
+ * the serialized batches
+ * @param source
+ * the class that calls this method
+ * @return
+ * an iterator of ColumnarBatch
+ */
+ def decodeBatches(bytes: ChunkedByteBuffer, source: String):
Iterator[ColumnarBatch] = {
+ if (bytes.size == 0) {
+ return Iterator.empty
+ }
+
+ // use Spark's compression codec (LZ4 by default) and not Comet's
compression
+ val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+ val cbbis = bytes.toInputStream()
+ val ins = new DataInputStream(codec.compressedInputStream(cbbis))
+ // batches are in Arrow IPC format
+ new ArrowReaderIterator(Channels.newChannel(ins), source)
+ }
+
def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
var provider: Option[DictionaryProvider] = None
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index f3d031795..daa1f19ea 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -21,6 +21,7 @@ package org.apache.comet
import java.nio.ByteOrder
+import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import org.apache.spark.SparkConf
@@ -37,7 +38,7 @@ import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager}
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
BroadcastQueryStageExec, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
@@ -728,6 +729,18 @@ class CometSparkSessionExtensions
s
}
+ case s @ BroadcastQueryStageExec(
+ _,
+ ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
+ _) =>
+ val newOp = transform1(s)
+ newOp match {
+ case Some(nativeOp) =>
+ CometSinkPlaceHolder(nativeOp, s, s)
+ case None =>
+ s
+ }
+
// `CometBroadcastExchangeExec`'s broadcast output is not compatible
with Spark's broadcast
// exchange. It is only used for Comet native execution. We only
transform Spark broadcast
// exchange to Comet broadcast exchange if its downstream is a Comet
native plan or if the
@@ -739,7 +752,7 @@ class CometSparkSessionExtensions
CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.get(conf) =>
QueryPlanSerde.operator2Proto(b) match {
case Some(nativeOp) =>
- val cometOp = CometBroadcastExchangeExec(b, b.output,
b.child)
+ val cometOp = CometBroadcastExchangeExec(b, b.output,
b.mode, b.child)
CometSinkPlaceHolder(nativeOp, b, cometOp)
case None => b
}
@@ -1136,7 +1149,7 @@ class CometSparkSessionExtensions
// and CometSparkToColumnarExec
sparkToColumnar.child
}
- case c @ ColumnarToRowExec(child) if
child.exists(_.isInstanceOf[CometPlan]) =>
+ case c @ ColumnarToRowExec(child) if hasCometNativeChild(child) =>
val op = CometColumnarToRowExec(child)
if (c.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
@@ -1170,6 +1183,14 @@ class CometSparkSessionExtensions
}
}
}
+
+ @tailrec
+ private def hasCometNativeChild(op: SparkPlan): Boolean = {
+ op match {
+ case c: QueryStageExec => hasCometNativeChild(c.plan)
+ case _ => op.exists(_.isInstanceOf[CometPlan])
+ }
+ }
}
object CometSparkSessionExtensions extends Logging {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 4f8f8ee86..50b418737 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2765,6 +2765,8 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _:
CometShuffleExchangeExec), _) => true
case _: TakeOrderedAndProjectExec => true
case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
+ case BroadcastQueryStageExec(_, ReusedExchangeExec(_, _:
CometBroadcastExchangeExec), _) =>
+ true
case _: BroadcastExchangeExec => true
case _: WindowExec => true
case _ => false
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index 3285159be..c17b2f785 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -19,8 +19,6 @@
package org.apache.spark.sql.comet
-import java.io.DataInputStream
-import java.nio.channels.Channels
import java.util.UUID
import java.util.concurrent.{Future, TimeoutException, TimeUnit}
@@ -28,13 +26,13 @@ import scala.concurrent.{ExecutionContext, Promise}
import scala.concurrent.duration.NANOSECONDS
import scala.util.control.NonFatal
-import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv,
TaskContext}
-import org.apache.spark.io.CompressionCodec
+import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
-import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode,
BroadcastPartitioning, Partitioning}
+import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan,
SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
ShuffleQueryStageExec}
@@ -63,8 +61,10 @@ import org.apache.comet.CometRuntimeException
case class CometBroadcastExchangeExec(
originalPlan: SparkPlan,
override val output: Seq[Attribute],
+ mode: BroadcastMode,
override val child: SparkPlan)
- extends BroadcastExchangeLike {
+ extends BroadcastExchangeLike
+ with CometPlan {
import CometBroadcastExchangeExec._
override val runId: UUID = UUID.randomUUID
@@ -77,7 +77,7 @@ case class CometBroadcastExchangeExec(
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to
broadcast"))
override def doCanonicalize(): SparkPlan = {
- CometBroadcastExchangeExec(null, null, child.canonicalized)
+ CometBroadcastExchangeExec(null, null, mode, child.canonicalized)
}
override def runtimeStatistics: Statistics = {
@@ -86,6 +86,8 @@ case class CometBroadcastExchangeExec(
Statistics(dataSize, Some(rowCount))
}
+ override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
+
@transient
private lazy val promise = Promise[broadcast.Broadcast[Any]]()
@@ -289,23 +291,7 @@ class CometBatchRDD(
override def compute(split: Partition, context: TaskContext):
Iterator[ColumnarBatch] = {
val partition = split.asInstanceOf[CometBatchPartition]
partition.value.value.toIterator
- .flatMap(decodeBatches(_, this.getClass.getSimpleName))
- }
-
- /**
- * Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
- */
- private def decodeBatches(bytes: ChunkedByteBuffer, source: String):
Iterator[ColumnarBatch] = {
- if (bytes.size == 0) {
- return Iterator.empty
- }
-
- // use Spark's compression codec (LZ4 by default) and not Comet's
compression
- val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
- val cbbis = bytes.toInputStream()
- val ins = new DataInputStream(codec.compressedInputStream(cbbis))
- // batches are in Arrow IPC format
- new ArrowReaderIterator(Channels.newChannel(ins), source)
+ .flatMap(Utils.decodeBatches(_, this.getClass.getSimpleName))
}
}
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala
index 18d95a473..0391a1c3b 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala
@@ -19,20 +19,30 @@
package org.apache.spark.sql.comet
+import java.util.UUID
+import java.util.concurrent.{Future, TimeoutException, TimeUnit}
+
import scala.collection.JavaConverters._
+import scala.concurrent.Promise
+import scala.util.control.NonFatal
+import org.apache.spark.{broadcast, SparkException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder,
UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.execution.{CodegenSupport,
ColumnarToRowTransition, SparkPlan}
+import org.apache.spark.sql.comet.util.{Utils => CometUtils}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.{CodegenSupport,
ColumnarToRowTransition, SparkPlan, SQLExecution}
+import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector,
WritableColumnVector}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SparkFatalException, Utils}
+import org.apache.spark.util.io.ChunkedByteBuffer
import org.apache.comet.vector.CometPlainVector
@@ -76,6 +86,96 @@ case class CometColumnarToRowExec(child: SparkPlan)
}
}
+ @transient
+ private lazy val promise = Promise[broadcast.Broadcast[Any]]()
+
+ @transient
+ private val timeout: Long = conf.broadcastTimeout
+
+ private val runId: UUID = UUID.randomUUID
+
+ private lazy val cometBroadcastExchange = findCometBroadcastExchange(child)
+
+ @transient
+ lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
+ SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
+ session,
+ CometBroadcastExchangeExec.executionContext) {
+ try {
+ // Setup a job group here so later it may get cancelled by groupId if
necessary.
+ sparkContext.setJobGroup(
+ runId.toString,
+ s"CometColumnarToRow broadcast exchange (runId $runId)",
+ interruptOnCancel = true)
+
+ val numOutputRows = longMetric("numOutputRows")
+ val numInputBatches = longMetric("numInputBatches")
+ val localOutput = this.output
+ val broadcastColumnar = child.executeBroadcast()
+ val serializedBatches =
broadcastColumnar.value.asInstanceOf[Array[ChunkedByteBuffer]]
+ val toUnsafe = UnsafeProjection.create(localOutput, localOutput)
+ val rows = serializedBatches.iterator
+ .flatMap(CometUtils.decodeBatches(_, this.getClass.getSimpleName))
+ .flatMap { batch =>
+ numInputBatches += 1
+ numOutputRows += batch.numRows()
+ batch.rowIterator().asScala.map(toUnsafe)
+ }
+
+ val mode = cometBroadcastExchange.get.mode
+ val relation = mode.transform(rows, Some(numOutputRows.value))
+ val broadcasted = sparkContext.broadcastInternal(relation,
serializedOnly = true)
+ val executionId =
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ SQLMetrics.postDriverMetricUpdates(sparkContext, executionId,
metrics.values.toSeq)
+ promise.trySuccess(broadcasted)
+ broadcasted
+ } catch {
+ // SPARK-24294: To bypass scala bug:
https://github.com/scala/bug/issues/9554, we throw
+ // SparkFatalException, which is a subclass of Exception.
ThreadUtils.awaitResult
+ // will catch this exception and re-throw the wrapped fatal throwable.
+ case oe: OutOfMemoryError =>
+ val ex = new SparkFatalException(oe)
+ promise.tryFailure(ex)
+ throw ex
+ case e if !NonFatal(e) =>
+ val ex = new SparkFatalException(e)
+ promise.tryFailure(ex)
+ throw ex
+ case e: Throwable =>
+ promise.tryFailure(e)
+ throw e
+ }
+ }
+ }
+
+ override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+ if (cometBroadcastExchange.isEmpty) {
+ throw new SparkException(
+ "ColumnarToRowExec only supports doExecuteBroadcast when child
contains a " +
+ "CometBroadcastExchange, but got " + child)
+ }
+
+ try {
+ relationFuture.get(timeout,
TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
+ } catch {
+ case ex: TimeoutException =>
+ logError(s"Could not execute broadcast in $timeout secs.", ex)
+ if (!relationFuture.isDone) {
+ sparkContext.cancelJobGroup(runId.toString)
+ relationFuture.cancel(true)
+ }
+ throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout,
Some(ex))
+ }
+ }
+
+ private def findCometBroadcastExchange(op: SparkPlan):
Option[CometBroadcastExchangeExec] = {
+ op match {
+ case b: CometBroadcastExchangeExec => Some(b)
+ case b: BroadcastQueryStageExec => findCometBroadcastExchange(b.plan)
+ case _ =>
op.children.collectFirst(Function.unlift(findCometBroadcastExchange))
+ }
+ }
+
/**
* Generate [[ColumnVector]] expressions for our parent to consume as rows.
This is called once
* per [[ColumnVector]] in the batch.
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index f8c1cf90f..3fe300e97 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -36,10 +36,10 @@ import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Bloom
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec,
SQLExecution, UnionExec}
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
BroadcastQueryStageExec}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec,
CartesianProductExec, SortMergeJoinExec}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.Window
@@ -750,6 +750,99 @@ class CometExecSuite extends CometTestBase {
}
}
+ test("Comet Shuffled Join should be optimized to CometBroadcastHashJoin by
AQE") {
+ withSQLConf(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760",
+ CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+ CometConf.COMET_SHUFFLE_MODE.key -> "native") {
+ withParquetTable((0 until 100).map(i => (i, i + 1)), "tbl_a") {
+ withParquetTable((0 until 100).map(i => (i, i + 2)), "tbl_b") {
+ withParquetTable((0 until 100).map(i => (i, i + 3)), "tbl_c") {
+ val df = sql("""SELECT /*+ BROADCAST(c) */ a1, sum_b2, c._2 FROM (
+ | SELECT a._1 a1, SUM(b._2) sum_b2 FROM tbl_a a
+ | JOIN tbl_b b ON a._1 = b._1
+ | GROUP BY a._1) t
+ |JOIN tbl_c c ON t.a1 = c._1
+ |""".stripMargin)
+ checkSparkAnswerAndOperator(df)
+
+ // Before AQE: 1 broadcast join
+ var broadcastHashJoinExec =
stripAQEPlan(df.queryExecution.executedPlan).collect {
+ case s: CometBroadcastHashJoinExec => s
+ }
+ assert(broadcastHashJoinExec.length == 1)
+
+ // After AQE: shuffled join optimized to broadcast join
+ df.collect()
+ broadcastHashJoinExec =
stripAQEPlan(df.queryExecution.executedPlan).collect {
+ case s: CometBroadcastHashJoinExec => s
+ }
+ assert(broadcastHashJoinExec.length == 2)
+ }
+ }
+ }
+ }
+ }
+
+ test("CometBroadcastExchange could be converted to rows using
CometColumnarToRow") {
+ withSQLConf(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760",
+ CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+ CometConf.COMET_SHUFFLE_MODE.key -> "auto") {
+ withParquetTable((0 until 100).map(i => (i, i + 1)), "tbl_a") {
+ withParquetTable((0 until 100).map(i => (i, i + 2)), "tbl_b") {
+ withParquetTable((0 until 100).map(i => (i, i + 3)), "tbl_c") {
+ val df = sql("""SELECT /*+ BROADCAST(c) */ a1, sum_b2, c._2 FROM (
+ | SELECT a._1 a1, SUM(b._2) sum_b2 FROM tbl_a a
+ | JOIN tbl_b b ON a._1 = b._1
+ | GROUP BY a._1) t
+ |JOIN tbl_c c ON t.a1 = c._1
+ |""".stripMargin)
+ checkSparkAnswer(df)
+
+ // Before AQE: one CometBroadcastExchange, no CometColumnarToRow
+ var columnarToRowExec =
stripAQEPlan(df.queryExecution.executedPlan).collect {
+ case s: CometColumnarToRowExec => s
+ }
+ assert(columnarToRowExec.isEmpty)
+
+ // Disable CometExecRule after the initial plan is generated. The
CometSortMergeJoin and
+ // CometBroadcastHashJoin nodes in the initial plan will be
converted to Spark BroadcastHashJoin
+ // during AQE. This will make CometBroadcastExchangeExec being
converted to rows to be used by
+ // Spark BroadcastHashJoin.
+ withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") {
+ df.collect()
+ }
+
+ // After AQE: CometBroadcastExchange has to be converted to rows
to conform to Spark
+ // BroadcastHashJoin.
+ val plan = stripAQEPlan(df.queryExecution.executedPlan)
+ columnarToRowExec = plan.collect { case s: CometColumnarToRowExec
=>
+ s
+ }
+ assert(columnarToRowExec.length == 1)
+
+ // This ColumnarToRowExec should be the immediate child of
BroadcastHashJoinExec
+ val parent = plan.find(_.children.contains(columnarToRowExec.head))
+ assert(parent.get.isInstanceOf[BroadcastHashJoinExec])
+
+ // There should be a CometBroadcastExchangeExec under
CometColumnarToRowExec
+ val broadcastQueryStage =
+
columnarToRowExec.head.find(_.isInstanceOf[BroadcastQueryStageExec])
+ assert(broadcastQueryStage.isDefined)
+ assert(
+ broadcastQueryStage.get
+ .asInstanceOf[BroadcastQueryStageExec]
+ .broadcast
+ .isInstanceOf[CometBroadcastExchangeExec])
+ }
+ }
+ }
+ }
+ }
+
test("expand operator") {
val data1 = (0 until 1000)
.map(_ % 5) // reduce value space to trigger dictionary encoding
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]