This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 4639a2477 fix: Make AQE capable of converting Comet shuffled joins to 
Comet broadcast hash joins (#1605)
4639a2477 is described below

commit 4639a24772d29ff8fa45533c9314a4a9eff787c8
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Fri Apr 4 22:08:02 2025 +0800

    fix: Make AQE capable of converting Comet shuffled joins to Comet broadcast 
hash joins (#1605)
    
    * Override outputPartitioning in CometBroadcastExchangeExec to make AQE 
capable of converting comet shuffled joins to comet broadcast hash joins
    
    * Support executeBroadcast for CometColumnarToRow
    
    * Add tests, revert my changes to how CometBroadcastExchange is displayed
    
    * Making newly added test fail before applying this fix
    
    * Remove unused imports
    
    * Fix test failure caused by spark conf pollution
---
 .../execution/arrow}/ArrowReaderIterator.scala     |   2 +-
 .../org/apache/spark/sql/comet/util/Utils.scala    |  25 ++++-
 .../apache/comet/CometSparkSessionExtensions.scala |  27 +++++-
 .../org/apache/comet/serde/QueryPlanSerde.scala    |   2 +
 .../sql/comet/CometBroadcastExchangeExec.scala     |  34 ++-----
 .../spark/sql/comet/CometColumnarToRowExec.scala   | 104 ++++++++++++++++++++-
 .../org/apache/comet/exec/CometExecSuite.scala     |  97 ++++++++++++++++++-
 7 files changed, 258 insertions(+), 33 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowReaderIterator.scala
similarity index 97%
rename from 
spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
rename to 
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowReaderIterator.scala
index 933e0b661..0d0093a10 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
+++ 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowReaderIterator.scala
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.spark.sql.comet.execution.shuffle
+package org.apache.spark.sql.comet.execution.arrow
 
 import java.nio.channels.ReadableByteChannel
 
diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala 
b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
index 207474286..1b4275180 100644
--- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
+++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
@@ -19,7 +19,7 @@
 
 package org.apache.spark.sql.comet.util
 
-import java.io.{DataOutputStream, File}
+import java.io.{DataInputStream, DataOutputStream, File}
 import java.nio.ByteBuffer
 import java.nio.channels.Channels
 
@@ -35,6 +35,7 @@ import org.apache.arrow.vector.types._
 import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
 import org.apache.spark.{SparkEnv, SparkException}
 import org.apache.spark.io.CompressionCodec
+import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.util.io.{ChunkedByteBuffer, 
ChunkedByteBufferOutputStream}
@@ -226,6 +227,28 @@ object Utils {
     }
   }
 
+  /**
+   * Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
+   * @param bytes
+   *   the serialized batches
+   * @param source
+   *   the class that calls this method
+   * @return
+   *   an iterator of ColumnarBatch
+   */
+  def decodeBatches(bytes: ChunkedByteBuffer, source: String): 
Iterator[ColumnarBatch] = {
+    if (bytes.size == 0) {
+      return Iterator.empty
+    }
+
+    // use Spark's compression codec (LZ4 by default) and not Comet's 
compression
+    val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+    val cbbis = bytes.toInputStream()
+    val ins = new DataInputStream(codec.compressedInputStream(cbbis))
+    // batches are in Arrow IPC format
+    new ArrowReaderIterator(Channels.newChannel(ins), source)
+  }
+
   def getBatchFieldVectors(
       batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
     var provider: Option[DictionaryProvider] = None
diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index f3d031795..daa1f19ea 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -21,6 +21,7 @@ package org.apache.comet
 
 import java.nio.ByteOrder
 
+import scala.annotation.tailrec
 import scala.collection.mutable.ListBuffer
 
 import org.apache.spark.SparkConf
@@ -37,7 +38,7 @@ import org.apache.spark.sql.comet._
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager}
 import org.apache.spark.sql.comet.util.Utils
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, 
BroadcastQueryStageExec, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, 
BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, 
HashAggregateExec, ObjectHashAggregateExec}
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
@@ -728,6 +729,18 @@ class CometSparkSessionExtensions
               s
           }
 
+        case s @ BroadcastQueryStageExec(
+              _,
+              ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
+              _) =>
+          val newOp = transform1(s)
+          newOp match {
+            case Some(nativeOp) =>
+              CometSinkPlaceHolder(nativeOp, s, s)
+            case None =>
+              s
+          }
+
         // `CometBroadcastExchangeExec`'s broadcast output is not compatible 
with Spark's broadcast
         // exchange. It is only used for Comet native execution. We only 
transform Spark broadcast
         // exchange to Comet broadcast exchange if its downstream is a Comet 
native plan or if the
@@ -739,7 +752,7 @@ class CometSparkSessionExtensions
                   CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.get(conf) =>
               QueryPlanSerde.operator2Proto(b) match {
                 case Some(nativeOp) =>
-                  val cometOp = CometBroadcastExchangeExec(b, b.output, 
b.child)
+                  val cometOp = CometBroadcastExchangeExec(b, b.output, 
b.mode, b.child)
                   CometSinkPlaceHolder(nativeOp, b, cometOp)
                 case None => b
               }
@@ -1136,7 +1149,7 @@ class CometSparkSessionExtensions
             // and CometSparkToColumnarExec
             sparkToColumnar.child
           }
-        case c @ ColumnarToRowExec(child) if 
child.exists(_.isInstanceOf[CometPlan]) =>
+        case c @ ColumnarToRowExec(child) if hasCometNativeChild(child) =>
           val op = CometColumnarToRowExec(child)
           if (c.logicalLink.isEmpty) {
             op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
@@ -1170,6 +1183,14 @@ class CometSparkSessionExtensions
       }
     }
   }
+
+  @tailrec
+  private def hasCometNativeChild(op: SparkPlan): Boolean = {
+    op match {
+      case c: QueryStageExec => hasCometNativeChild(c.plan)
+      case _ => op.exists(_.isInstanceOf[CometPlan])
+    }
+  }
 }
 
 object CometSparkSessionExtensions extends Logging {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 4f8f8ee86..50b418737 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2765,6 +2765,8 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
       case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: 
CometShuffleExchangeExec), _) => true
       case _: TakeOrderedAndProjectExec => true
       case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
+      case BroadcastQueryStageExec(_, ReusedExchangeExec(_, _: 
CometBroadcastExchangeExec), _) =>
+        true
       case _: BroadcastExchangeExec => true
       case _: WindowExec => true
       case _ => false
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index 3285159be..c17b2f785 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -19,8 +19,6 @@
 
 package org.apache.spark.sql.comet
 
-import java.io.DataInputStream
-import java.nio.channels.Channels
 import java.util.UUID
 import java.util.concurrent.{Future, TimeoutException, TimeUnit}
 
@@ -28,13 +26,13 @@ import scala.concurrent.{ExecutionContext, Promise}
 import scala.concurrent.duration.NANOSECONDS
 import scala.util.control.NonFatal
 
-import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv, 
TaskContext}
-import org.apache.spark.io.CompressionCodec
+import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.Statistics
-import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, 
BroadcastPartitioning, Partitioning}
+import org.apache.spark.sql.comet.util.Utils
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, 
SQLExecution}
 import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, 
ShuffleQueryStageExec}
@@ -63,8 +61,10 @@ import org.apache.comet.CometRuntimeException
 case class CometBroadcastExchangeExec(
     originalPlan: SparkPlan,
     override val output: Seq[Attribute],
+    mode: BroadcastMode,
     override val child: SparkPlan)
-    extends BroadcastExchangeLike {
+    extends BroadcastExchangeLike
+    with CometPlan {
   import CometBroadcastExchangeExec._
 
   override val runId: UUID = UUID.randomUUID
@@ -77,7 +77,7 @@ case class CometBroadcastExchangeExec(
     "broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to 
broadcast"))
 
   override def doCanonicalize(): SparkPlan = {
-    CometBroadcastExchangeExec(null, null, child.canonicalized)
+    CometBroadcastExchangeExec(null, null, mode, child.canonicalized)
   }
 
   override def runtimeStatistics: Statistics = {
@@ -86,6 +86,8 @@ case class CometBroadcastExchangeExec(
     Statistics(dataSize, Some(rowCount))
   }
 
+  override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
+
   @transient
   private lazy val promise = Promise[broadcast.Broadcast[Any]]()
 
@@ -289,23 +291,7 @@ class CometBatchRDD(
   override def compute(split: Partition, context: TaskContext): 
Iterator[ColumnarBatch] = {
     val partition = split.asInstanceOf[CometBatchPartition]
     partition.value.value.toIterator
-      .flatMap(decodeBatches(_, this.getClass.getSimpleName))
-  }
-
-  /**
-   * Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
-   */
-  private def decodeBatches(bytes: ChunkedByteBuffer, source: String): 
Iterator[ColumnarBatch] = {
-    if (bytes.size == 0) {
-      return Iterator.empty
-    }
-
-    // use Spark's compression codec (LZ4 by default) and not Comet's 
compression
-    val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
-    val cbbis = bytes.toInputStream()
-    val ins = new DataInputStream(codec.compressedInputStream(cbbis))
-    // batches are in Arrow IPC format
-    new ArrowReaderIterator(Channels.newChannel(ins), source)
+      .flatMap(Utils.decodeBatches(_, this.getClass.getSimpleName))
   }
 }
 
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala
index 18d95a473..0391a1c3b 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala
@@ -19,20 +19,30 @@
 
 package org.apache.spark.sql.comet
 
+import java.util.UUID
+import java.util.concurrent.{Future, TimeoutException, TimeUnit}
+
 import scala.collection.JavaConverters._
+import scala.concurrent.Promise
+import scala.util.control.NonFatal
 
+import org.apache.spark.{broadcast, SparkException}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, 
UnsafeProjection}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.execution.{CodegenSupport, 
ColumnarToRowTransition, SparkPlan}
+import org.apache.spark.sql.comet.util.{Utils => CometUtils}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.{CodegenSupport, 
ColumnarToRowTransition, SparkPlan, SQLExecution}
+import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector, 
WritableColumnVector}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SparkFatalException, Utils}
+import org.apache.spark.util.io.ChunkedByteBuffer
 
 import org.apache.comet.vector.CometPlainVector
 
@@ -76,6 +86,96 @@ case class CometColumnarToRowExec(child: SparkPlan)
     }
   }
 
+  @transient
+  private lazy val promise = Promise[broadcast.Broadcast[Any]]()
+
+  @transient
+  private val timeout: Long = conf.broadcastTimeout
+
+  private val runId: UUID = UUID.randomUUID
+
+  private lazy val cometBroadcastExchange = findCometBroadcastExchange(child)
+
+  @transient
+  lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
+    SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
+      session,
+      CometBroadcastExchangeExec.executionContext) {
+      try {
+        // Setup a job group here so later it may get cancelled by groupId if 
necessary.
+        sparkContext.setJobGroup(
+          runId.toString,
+          s"CometColumnarToRow broadcast exchange (runId $runId)",
+          interruptOnCancel = true)
+
+        val numOutputRows = longMetric("numOutputRows")
+        val numInputBatches = longMetric("numInputBatches")
+        val localOutput = this.output
+        val broadcastColumnar = child.executeBroadcast()
+        val serializedBatches = 
broadcastColumnar.value.asInstanceOf[Array[ChunkedByteBuffer]]
+        val toUnsafe = UnsafeProjection.create(localOutput, localOutput)
+        val rows = serializedBatches.iterator
+          .flatMap(CometUtils.decodeBatches(_, this.getClass.getSimpleName))
+          .flatMap { batch =>
+            numInputBatches += 1
+            numOutputRows += batch.numRows()
+            batch.rowIterator().asScala.map(toUnsafe)
+          }
+
+        val mode = cometBroadcastExchange.get.mode
+        val relation = mode.transform(rows, Some(numOutputRows.value))
+        val broadcasted = sparkContext.broadcastInternal(relation, 
serializedOnly = true)
+        val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+        SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, 
metrics.values.toSeq)
+        promise.trySuccess(broadcasted)
+        broadcasted
+      } catch {
+        // SPARK-24294: To bypass scala bug: 
https://github.com/scala/bug/issues/9554, we throw
+        // SparkFatalException, which is a subclass of Exception. 
ThreadUtils.awaitResult
+        // will catch this exception and re-throw the wrapped fatal throwable.
+        case oe: OutOfMemoryError =>
+          val ex = new SparkFatalException(oe)
+          promise.tryFailure(ex)
+          throw ex
+        case e if !NonFatal(e) =>
+          val ex = new SparkFatalException(e)
+          promise.tryFailure(ex)
+          throw ex
+        case e: Throwable =>
+          promise.tryFailure(e)
+          throw e
+      }
+    }
+  }
+
+  override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+    if (cometBroadcastExchange.isEmpty) {
+      throw new SparkException(
+        "ColumnarToRowExec only supports doExecuteBroadcast when child 
contains a " +
+          "CometBroadcastExchange, but got " + child)
+    }
+
+    try {
+      relationFuture.get(timeout, 
TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
+    } catch {
+      case ex: TimeoutException =>
+        logError(s"Could not execute broadcast in $timeout secs.", ex)
+        if (!relationFuture.isDone) {
+          sparkContext.cancelJobGroup(runId.toString)
+          relationFuture.cancel(true)
+        }
+        throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, 
Some(ex))
+    }
+  }
+
+  private def findCometBroadcastExchange(op: SparkPlan): 
Option[CometBroadcastExchangeExec] = {
+    op match {
+      case b: CometBroadcastExchangeExec => Some(b)
+      case b: BroadcastQueryStageExec => findCometBroadcastExchange(b.plan)
+      case _ => 
op.children.collectFirst(Function.unlift(findCometBroadcastExchange))
+    }
+  }
+
   /**
    * Generate [[ColumnVector]] expressions for our parent to consume as rows. 
This is called once
    * per [[ColumnVector]] in the batch.
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index f8c1cf90f..3fe300e97 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -36,10 +36,10 @@ import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Bloom
 import org.apache.spark.sql.comet._
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometShuffleExchangeExec}
 import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, 
SQLExecution, UnionExec}
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
BroadcastQueryStageExec}
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, 
CartesianProductExec, SortMergeJoinExec}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
 import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
 import org.apache.spark.sql.execution.window.WindowExec
 import org.apache.spark.sql.expressions.Window
@@ -750,6 +750,99 @@ class CometExecSuite extends CometTestBase {
     }
   }
 
+  test("Comet Shuffled Join should be optimized to CometBroadcastHashJoin by 
AQE") {
+    withSQLConf(
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760",
+      CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+      CometConf.COMET_SHUFFLE_MODE.key -> "native") {
+      withParquetTable((0 until 100).map(i => (i, i + 1)), "tbl_a") {
+        withParquetTable((0 until 100).map(i => (i, i + 2)), "tbl_b") {
+          withParquetTable((0 until 100).map(i => (i, i + 3)), "tbl_c") {
+            val df = sql("""SELECT /*+ BROADCAST(c) */ a1, sum_b2, c._2 FROM (
+                |  SELECT a._1 a1, SUM(b._2) sum_b2 FROM tbl_a a
+                |  JOIN tbl_b b ON a._1 = b._1
+                |  GROUP BY a._1) t
+                |JOIN tbl_c c ON t.a1 = c._1
+                |""".stripMargin)
+            checkSparkAnswerAndOperator(df)
+
+            // Before AQE: 1 broadcast join
+            var broadcastHashJoinExec = 
stripAQEPlan(df.queryExecution.executedPlan).collect {
+              case s: CometBroadcastHashJoinExec => s
+            }
+            assert(broadcastHashJoinExec.length == 1)
+
+            // After AQE: shuffled join optimized to broadcast join
+            df.collect()
+            broadcastHashJoinExec = 
stripAQEPlan(df.queryExecution.executedPlan).collect {
+              case s: CometBroadcastHashJoinExec => s
+            }
+            assert(broadcastHashJoinExec.length == 2)
+          }
+        }
+      }
+    }
+  }
+
+  test("CometBroadcastExchange could be converted to rows using 
CometColumnarToRow") {
+    withSQLConf(
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760",
+      CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+      CometConf.COMET_SHUFFLE_MODE.key -> "auto") {
+      withParquetTable((0 until 100).map(i => (i, i + 1)), "tbl_a") {
+        withParquetTable((0 until 100).map(i => (i, i + 2)), "tbl_b") {
+          withParquetTable((0 until 100).map(i => (i, i + 3)), "tbl_c") {
+            val df = sql("""SELECT /*+ BROADCAST(c) */ a1, sum_b2, c._2 FROM (
+                |  SELECT a._1 a1, SUM(b._2) sum_b2 FROM tbl_a a
+                |  JOIN tbl_b b ON a._1 = b._1
+                |  GROUP BY a._1) t
+                |JOIN tbl_c c ON t.a1 = c._1
+                |""".stripMargin)
+            checkSparkAnswer(df)
+
+            // Before AQE: one CometBroadcastExchange, no CometColumnarToRow
+            var columnarToRowExec = 
stripAQEPlan(df.queryExecution.executedPlan).collect {
+              case s: CometColumnarToRowExec => s
+            }
+            assert(columnarToRowExec.isEmpty)
+
+            // Disable CometExecRule after the initial plan is generated. The 
CometSortMergeJoin and
+            // CometBroadcastHashJoin nodes in the initial plan will be 
converted to Spark BroadcastHashJoin
+            // during AQE. This will make CometBroadcastExchangeExec being 
converted to rows to be used by
+            // Spark BroadcastHashJoin.
+            withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") {
+              df.collect()
+            }
+
+            // After AQE: CometBroadcastExchange has to be converted to rows 
to conform to Spark
+            // BroadcastHashJoin.
+            val plan = stripAQEPlan(df.queryExecution.executedPlan)
+            columnarToRowExec = plan.collect { case s: CometColumnarToRowExec 
=>
+              s
+            }
+            assert(columnarToRowExec.length == 1)
+
+            // This ColumnarToRowExec should be the immediate child of 
BroadcastHashJoinExec
+            val parent = plan.find(_.children.contains(columnarToRowExec.head))
+            assert(parent.get.isInstanceOf[BroadcastHashJoinExec])
+
+            // There should be a CometBroadcastExchangeExec under 
CometColumnarToRowExec
+            val broadcastQueryStage =
+              
columnarToRowExec.head.find(_.isInstanceOf[BroadcastQueryStageExec])
+            assert(broadcastQueryStage.isDefined)
+            assert(
+              broadcastQueryStage.get
+                .asInstanceOf[BroadcastQueryStageExec]
+                .broadcast
+                .isInstanceOf[CometBroadcastExchangeExec])
+          }
+        }
+      }
+    }
+  }
+
   test("expand operator") {
     val data1 = (0 until 1000)
       .map(_ % 5) // reduce value space to trigger dictionary encoding


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to