This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new efca99108 fix: Making shuffle files generated in native shuffle mode 
reclaimable (#1568)
efca99108 is described below

commit efca99108b1cd8a648795995456b6074b57dd25b
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Sat Apr 5 02:15:23 2025 +0800

    fix: Making shuffle files generated in native shuffle mode reclaimable 
(#1568)
    
    * Making shuffle files generated in native shuffle mode reclaimable
    
    * Add a unit test
    
    * Use eventually in unit test
    
    * Address review comments
---
 .../shuffle/CometNativeShuffleWriter.scala         | 221 +++++++++++++++++++++
 .../execution/shuffle/CometShuffleDependency.scala |   8 +-
 .../shuffle/CometShuffleExchangeExec.scala         | 198 +-----------------
 .../execution/shuffle/CometShuffleManager.scala    | 101 +++++++---
 .../comet/exec/CometNativeShuffleSuite.scala       |  17 ++
 5 files changed, 321 insertions(+), 224 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala
new file mode 100644
index 000000000..b67b267fc
--- /dev/null
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.shuffle
+
+import java.nio.{ByteBuffer, ByteOrder}
+import java.nio.file.{Files, Paths}
+
+import scala.collection.JavaConverters.asJavaIterableConverter
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.shuffle.{IndexShuffleBlockResolver, 
ShuffleWriteMetricsReporter, ShuffleWriter}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, SinglePartition}
+import org.apache.spark.sql.comet.{CometExec, CometMetricNode}
+import org.apache.spark.sql.execution.metric.{SQLMetric, 
SQLShuffleWriteMetricsReporter}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import org.apache.comet.CometConf
+import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, 
QueryPlanSerde}
+import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator}
+import org.apache.comet.serde.QueryPlanSerde.serializeDataType
+
+/**
+ * A [[ShuffleWriter]] that will delegate shuffle write to native shuffle.
+ */
+class CometNativeShuffleWriter[K, V](
+    outputPartitioning: Partitioning,
+    outputAttributes: Seq[Attribute],
+    metrics: Map[String, SQLMetric],
+    numParts: Int,
+    shuffleId: Int,
+    mapId: Long,
+    context: TaskContext,
+    metricsReporter: ShuffleWriteMetricsReporter)
+    extends ShuffleWriter[K, V] {
+
+  private val OFFSET_LENGTH = 8
+
+  var partitionLengths: Array[Long] = _
+  var mapStatus: MapStatus = _
+
+  override def write(inputs: Iterator[Product2[K, V]]): Unit = {
+    val shuffleBlockResolver =
+      
SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver]
+    val dataFile = shuffleBlockResolver.getDataFile(shuffleId, mapId)
+    val indexFile = shuffleBlockResolver.getIndexFile(shuffleId, mapId)
+    val tempDataFilename = dataFile.getPath.replace(".data", ".data.tmp")
+    val tempIndexFilename = indexFile.getPath.replace(".index", ".index.tmp")
+    val tempDataFilePath = Paths.get(tempDataFilename)
+    val tempIndexFilePath = Paths.get(tempIndexFilename)
+
+    // Call native shuffle write
+    val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename)
+
+    val detailedMetrics = Seq(
+      "elapsed_compute",
+      "encode_time",
+      "repart_time",
+      "mempool_time",
+      "input_batches",
+      "spill_count",
+      "spilled_bytes")
+
+    // Maps native metrics to SQL metrics
+    val nativeSQLMetrics = Map(
+      "output_rows" -> 
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN),
+      "data_size" -> metrics("dataSize"),
+      "write_time" -> 
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME)) ++
+      metrics.filterKeys(detailedMetrics.contains)
+    val nativeMetrics = CometMetricNode(nativeSQLMetrics)
+
+    // Getting rid of the fake partitionId
+    val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, 
Any]]].map(_._2)
+
+    val cometIter = CometExec.getCometIterator(
+      Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
+      outputAttributes.length,
+      nativePlan,
+      nativeMetrics,
+      numParts,
+      context.partitionId())
+
+    while (cometIter.hasNext) {
+      cometIter.next()
+    }
+    cometIter.close()
+
+    // get partition lengths from shuffle write output index file
+    var offset = 0L
+    partitionLengths = Files
+      .readAllBytes(tempIndexFilePath)
+      .grouped(OFFSET_LENGTH)
+      .drop(1) // first partition offset is always 0
+      .map(indexBytes => {
+        val partitionOffset =
+          ByteBuffer.wrap(indexBytes).order(ByteOrder.LITTLE_ENDIAN).getLong
+        val partitionLength = partitionOffset - offset
+        offset = partitionOffset
+        partitionLength
+      })
+      .toArray
+    Files.delete(tempIndexFilePath)
+
+    // Total written bytes at native
+    metricsReporter.incBytesWritten(Files.size(tempDataFilePath))
+
+    // commit
+    shuffleBlockResolver.writeMetadataFileAndCommit(
+      shuffleId,
+      mapId,
+      partitionLengths,
+      Array.empty, // TODO: add checksums
+      tempDataFilePath.toFile)
+    mapStatus =
+      MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId, 
partitionLengths, mapId)
+  }
+
+  private def getNativePlan(dataFile: String, indexFile: String): Operator = {
+    val scanBuilder = 
OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput")
+    val opBuilder = OperatorOuterClass.Operator.newBuilder()
+
+    val scanTypes = outputAttributes.flatten { attr =>
+      serializeDataType(attr.dataType)
+    }
+
+    if (scanTypes.length == outputAttributes.length) {
+      scanBuilder.addAllFields(scanTypes.asJava)
+
+      val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder()
+      shuffleWriterBuilder.setOutputDataFile(dataFile)
+      shuffleWriterBuilder.setOutputIndexFile(indexFile)
+      shuffleWriterBuilder.setEnableFastEncoding(
+        CometConf.COMET_SHUFFLE_ENABLE_FAST_ENCODING.get())
+
+      if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) {
+        val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match 
{
+          case "zstd" => CompressionCodec.Zstd
+          case "lz4" => CompressionCodec.Lz4
+          case "snappy" => CompressionCodec.Snappy
+          case other => throw new UnsupportedOperationException(s"invalid 
codec: $other")
+        }
+        shuffleWriterBuilder.setCodec(codec)
+      } else {
+        shuffleWriterBuilder.setCodec(CompressionCodec.None)
+      }
+      shuffleWriterBuilder.setCompressionLevel(
+        CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get)
+
+      outputPartitioning match {
+        case _: HashPartitioning =>
+          val hashPartitioning = 
outputPartitioning.asInstanceOf[HashPartitioning]
+
+          val partitioning = 
PartitioningOuterClass.HashRepartition.newBuilder()
+          partitioning.setNumPartitions(outputPartitioning.numPartitions)
+
+          val partitionExprs = hashPartitioning.expressions
+            .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes))
+
+          if (partitionExprs.length != hashPartitioning.expressions.length) {
+            throw new UnsupportedOperationException(
+              s"Partitioning $hashPartitioning is not supported.")
+          }
+
+          partitioning.addAllHashExpression(partitionExprs.asJava)
+
+          val partitioningBuilder = 
PartitioningOuterClass.Partitioning.newBuilder()
+          shuffleWriterBuilder.setPartitioning(
+            partitioningBuilder.setHashPartition(partitioning).build())
+
+        case SinglePartition =>
+          val partitioning = 
PartitioningOuterClass.SinglePartition.newBuilder()
+
+          val partitioningBuilder = 
PartitioningOuterClass.Partitioning.newBuilder()
+          shuffleWriterBuilder.setPartitioning(
+            partitioningBuilder.setSinglePartition(partitioning).build())
+
+        case _ =>
+          throw new UnsupportedOperationException(
+            s"Partitioning $outputPartitioning is not supported.")
+      }
+
+      val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder()
+      shuffleWriterOpBuilder
+        .setShuffleWriter(shuffleWriterBuilder)
+        .addChildren(opBuilder.setScan(scanBuilder).build())
+        .build()
+    } else {
+      // There are unsupported scan type
+      throw new UnsupportedOperationException(
+        s"$outputAttributes contains unsupported data types for 
CometShuffleExchangeExec.")
+    }
+  }
+
+  override def stop(success: Boolean): Option[MapStatus] = {
+    if (success) {
+      Some(mapStatus)
+    } else {
+      None
+    }
+  }
+
+  override def getPartitionLengths(): Array[Long] = partitionLengths
+}
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
index 8c8aed28e..ff35b10eb 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
@@ -25,6 +25,8 @@ import org.apache.spark.{Aggregator, Partitioner, 
ShuffleDependency, SparkEnv}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.ShuffleWriteProcessor
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.StructType
 
@@ -41,7 +43,11 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
     override val shuffleWriterProcessor: ShuffleWriteProcessor = new 
ShuffleWriteProcessor,
     val shuffleType: ShuffleType = CometNativeShuffle,
     val schema: Option[StructType] = None,
-    val decodeTime: SQLMetric)
+    val decodeTime: SQLMetric,
+    val outputPartitioning: Option[Partitioning] = None,
+    val outputAttributes: Seq[Attribute] = Seq.empty,
+    val shuffleWriteMetrics: Map[String, SQLMetric] = Map.empty,
+    val numParts: Int = 0)
     extends ShuffleDependency[K, V, C](
       _rdd,
       partitioner,
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index f2af4402d..d121bd6d5 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -19,27 +19,21 @@
 
 package org.apache.spark.sql.comet.execution.shuffle
 
-import java.nio.{ByteBuffer, ByteOrder}
-import java.nio.file.{Files, Paths}
 import java.util.function.Supplier
 
-import scala.collection.JavaConverters.asJavaIterableConverter
 import scala.concurrent.Future
 
 import org.apache.spark._
 import org.apache.spark.internal.config
 import org.apache.spark.rdd.RDD
-import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{IndexShuffleBlockResolver, 
ShuffleWriteMetricsReporter}
 import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
UnsafeProjection, UnsafeRow}
 import 
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
 import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan}
-import org.apache.spark.sql.comet.shims.ShimCometShuffleWriteProcessor
+import org.apache.spark.sql.comet.{CometMetricNode, CometPlan}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, 
ShuffleExchangeLike, ShuffleOrigin}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -52,10 +46,6 @@ import org.apache.spark.util.random.XORShiftRandom
 
 import com.google.common.base.Objects
 
-import org.apache.comet.CometConf
-import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, 
QueryPlanSerde}
-import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator}
-import org.apache.comet.serde.QueryPlanSerde.serializeDataType
 import org.apache.comet.shims.ShimCometShuffleExchangeExec
 
 /**
@@ -232,20 +222,23 @@ object CometShuffleExchangeExec extends 
ShimCometShuffleExchangeExec {
         (0, _)
       ), // adding fake partitionId that is always 0 because ShuffleDependency 
requires it
       serializer = serializer,
-      shuffleWriterProcessor =
-        new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, 
metrics, numParts),
+      shuffleWriterProcessor = 
ShuffleExchangeExec.createShuffleWriteProcessor(metrics),
       shuffleType = CometNativeShuffle,
       partitioner = new Partitioner {
         override def numPartitions: Int = outputPartitioning.numPartitions
         override def getPartition(key: Any): Int = key.asInstanceOf[Int]
       },
-      decodeTime = metrics("decode_time"))
+      decodeTime = metrics("decode_time"),
+      outputPartitioning = Some(outputPartitioning),
+      outputAttributes = outputAttributes,
+      shuffleWriteMetrics = metrics,
+      numParts = numParts)
     dependency
   }
 
   /**
    * This is copied from Spark 
`ShuffleExchangeExec.needToCopyObjectsBeforeShuffle`. The only
-   * difference is that we use `BosonShuffleManager` instead of 
`SortShuffleManager`.
+   * difference is that we use `CometShuffleManager` instead of 
`SortShuffleManager`.
    */
   private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): 
Boolean = {
     // Note: even though we only use the partitioner's `numPartitions` field, 
we require it to be
@@ -442,178 +435,3 @@ object CometShuffleExchangeExec extends 
ShimCometShuffleExchangeExec {
     dependency
   }
 }
-
-/**
- * A [[ShuffleWriteProcessor]] that will delegate shuffle write to native 
shuffle.
- * @param metrics
- *   metrics to report
- */
-class CometShuffleWriteProcessor(
-    outputPartitioning: Partitioning,
-    outputAttributes: Seq[Attribute],
-    metrics: Map[String, SQLMetric],
-    numParts: Int)
-    extends ShimCometShuffleWriteProcessor {
-
-  private val OFFSET_LENGTH = 8
-
-  override protected def createMetricsReporter(
-      context: TaskContext): ShuffleWriteMetricsReporter = {
-    new 
SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, 
metrics)
-  }
-
-  override def write(
-      inputs: Iterator[_],
-      dep: ShuffleDependency[_, _, _],
-      mapId: Long,
-      mapIndex: Int,
-      context: TaskContext): MapStatus = {
-    val metricsReporter = createMetricsReporter(context)
-    val shuffleBlockResolver =
-      
SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver]
-    val dataFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
-    val indexFile = shuffleBlockResolver.getIndexFile(dep.shuffleId, mapId)
-    val tempDataFilename = dataFile.getPath.replace(".data", ".data.tmp")
-    val tempIndexFilename = indexFile.getPath.replace(".index", ".index.tmp")
-    val tempDataFilePath = Paths.get(tempDataFilename)
-    val tempIndexFilePath = Paths.get(tempIndexFilename)
-
-    // Call native shuffle write
-    val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename)
-
-    val detailedMetrics = Seq(
-      "elapsed_compute",
-      "encode_time",
-      "repart_time",
-      "mempool_time",
-      "input_batches",
-      "spill_count",
-      "spilled_bytes")
-
-    // Maps native metrics to SQL metrics
-    val nativeSQLMetrics = Map(
-      "output_rows" -> 
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN),
-      "data_size" -> metrics("dataSize"),
-      "write_time" -> 
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME)) ++
-      metrics.filterKeys(detailedMetrics.contains)
-    val nativeMetrics = CometMetricNode(nativeSQLMetrics)
-
-    // Getting rid of the fake partitionId
-    val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, 
Any]]].map(_._2)
-
-    val cometIter = CometExec.getCometIterator(
-      Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
-      outputAttributes.length,
-      nativePlan,
-      nativeMetrics,
-      numParts,
-      context.partitionId())
-
-    while (cometIter.hasNext) {
-      cometIter.next()
-    }
-
-    // get partition lengths from shuffle write output index file
-    var offset = 0L
-    val partitionLengths = Files
-      .readAllBytes(tempIndexFilePath)
-      .grouped(OFFSET_LENGTH)
-      .drop(1) // first partition offset is always 0
-      .map(indexBytes => {
-        val partitionOffset =
-          ByteBuffer.wrap(indexBytes).order(ByteOrder.LITTLE_ENDIAN).getLong
-        val partitionLength = partitionOffset - offset
-        offset = partitionOffset
-        partitionLength
-      })
-      .toArray
-
-    // Total written bytes at native
-    metricsReporter.incBytesWritten(Files.size(tempDataFilePath))
-
-    // commit
-    shuffleBlockResolver.writeMetadataFileAndCommit(
-      dep.shuffleId,
-      mapId,
-      partitionLengths,
-      Array.empty, // TODO: add checksums
-      tempDataFilePath.toFile)
-    MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId, 
partitionLengths, mapId)
-  }
-
-  def getNativePlan(dataFile: String, indexFile: String): Operator = {
-    val scanBuilder = 
OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput")
-    val opBuilder = OperatorOuterClass.Operator.newBuilder()
-
-    val scanTypes = outputAttributes.flatten { attr =>
-      serializeDataType(attr.dataType)
-    }
-
-    if (scanTypes.length == outputAttributes.length) {
-      scanBuilder.addAllFields(scanTypes.asJava)
-
-      val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder()
-      shuffleWriterBuilder.setOutputDataFile(dataFile)
-      shuffleWriterBuilder.setOutputIndexFile(indexFile)
-      shuffleWriterBuilder.setEnableFastEncoding(
-        CometConf.COMET_SHUFFLE_ENABLE_FAST_ENCODING.get())
-
-      if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) {
-        val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match 
{
-          case "zstd" => CompressionCodec.Zstd
-          case "lz4" => CompressionCodec.Lz4
-          case "snappy" => CompressionCodec.Snappy
-          case other => throw new UnsupportedOperationException(s"invalid 
codec: $other")
-        }
-        shuffleWriterBuilder.setCodec(codec)
-      } else {
-        shuffleWriterBuilder.setCodec(CompressionCodec.None)
-      }
-      shuffleWriterBuilder.setCompressionLevel(
-        CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get)
-
-      outputPartitioning match {
-        case _: HashPartitioning =>
-          val hashPartitioning = 
outputPartitioning.asInstanceOf[HashPartitioning]
-
-          val partitioning = 
PartitioningOuterClass.HashRepartition.newBuilder()
-          partitioning.setNumPartitions(outputPartitioning.numPartitions)
-
-          val partitionExprs = hashPartitioning.expressions
-            .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes))
-
-          if (partitionExprs.length != hashPartitioning.expressions.length) {
-            throw new UnsupportedOperationException(
-              s"Partitioning $hashPartitioning is not supported.")
-          }
-
-          partitioning.addAllHashExpression(partitionExprs.asJava)
-
-          val partitioningBuilder = 
PartitioningOuterClass.Partitioning.newBuilder()
-          shuffleWriterBuilder.setPartitioning(
-            partitioningBuilder.setHashPartition(partitioning).build())
-
-        case SinglePartition =>
-          val partitioning = 
PartitioningOuterClass.SinglePartition.newBuilder()
-
-          val partitioningBuilder = 
PartitioningOuterClass.Partitioning.newBuilder()
-          shuffleWriterBuilder.setPartitioning(
-            partitioningBuilder.setSinglePartition(partitioning).build())
-
-        case _ =>
-          throw new UnsupportedOperationException(
-            s"Partitioning $outputPartitioning is not supported.")
-      }
-
-      val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder()
-      shuffleWriterOpBuilder
-        .setShuffleWriter(shuffleWriterBuilder)
-        .addChildren(opBuilder.setScan(scanBuilder).build())
-        .build()
-    } else {
-      // There are unsupported scan type
-      throw new UnsupportedOperationException(
-        s"$outputAttributes contains unsupported data types for 
CometShuffleExchangeExec.")
-    }
-  }
-}
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
index b2cc2c2ba..1142c6af1 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
@@ -86,38 +86,52 @@ class CometShuffleManager(conf: SparkConf) extends 
ShuffleManager with Logging {
   def registerShuffle[K, V, C](
       shuffleId: Int,
       dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
-    if (dependency.isInstanceOf[CometShuffleDependency[_, _, _]]) {
-      // Comet shuffle dependency, which comes from `CometShuffleExchangeExec`.
-      if (shouldBypassMergeSort(conf, dependency) ||
-        !SortShuffleManager.canUseSerializedShuffle(dependency)) {
-        new CometBypassMergeSortShuffleHandle(
-          shuffleId,
-          dependency.asInstanceOf[ShuffleDependency[K, V, V]])
-      } else {
-        new CometSerializedShuffleHandle(
-          shuffleId,
-          dependency.asInstanceOf[ShuffleDependency[K, V, V]])
-      }
-    } else {
-      // It is a Spark shuffle dependency, so we use Spark Sort Shuffle 
Manager.
-      if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
-        // If there are fewer than spark.shuffle.sort.bypassMergeThreshold 
partitions and we don't
-        // need map-side aggregation, then write numPartitions files directly 
and just concatenate
-        // them at the end. This avoids doing serialization and 
deserialization twice to merge
-        // together the spilled files, which would happen with the normal code 
path. The downside is
-        // having multiple files open at a time and thus more memory allocated 
to buffers.
-        new BypassMergeSortShuffleHandle[K, V](
-          shuffleId,
-          dependency.asInstanceOf[ShuffleDependency[K, V, V]])
-      } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
-        // Otherwise, try to buffer map outputs in a serialized form, since 
this is more efficient:
-        new SerializedShuffleHandle[K, V](
-          shuffleId,
-          dependency.asInstanceOf[ShuffleDependency[K, V, V]])
-      } else {
-        // Otherwise, buffer map outputs in a deserialized form:
-        new BaseShuffleHandle(shuffleId, dependency)
-      }
+    dependency match {
+      case cometShuffleDependency: CometShuffleDependency[_, _, _] =>
+        // Comet shuffle dependency, which comes from 
`CometShuffleExchangeExec`.
+        cometShuffleDependency.shuffleType match {
+          case CometColumnarShuffle =>
+            // Comet columnar shuffle, which uses Arrow format to shuffle data.
+            if (shouldBypassMergeSort(conf, dependency) ||
+              !SortShuffleManager.canUseSerializedShuffle(dependency)) {
+              new CometBypassMergeSortShuffleHandle(
+                shuffleId,
+                dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+            } else {
+              new CometSerializedShuffleHandle(
+                shuffleId,
+                dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+            }
+          case CometNativeShuffle =>
+            new CometNativeShuffleHandle(
+              shuffleId,
+              dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+          case _ =>
+            // Unsupported shuffle type.
+            throw new UnsupportedOperationException(
+              s"Unsupported shuffle type: 
${cometShuffleDependency.shuffleType}")
+        }
+      case _ =>
+        // It is a Spark shuffle dependency, so we use Spark Sort Shuffle 
Manager.
+        if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
+          // If there are fewer than spark.shuffle.sort.bypassMergeThreshold 
partitions and we don't
+          // need map-side aggregation, then write numPartitions files 
directly and just concatenate
+          // them at the end. This avoids doing serialization and 
deserialization twice to merge
+          // together the spilled files, which would happen with the normal 
code path. The downside
+          // is having multiple files open at a time and thus more memory 
allocated to buffers.
+          new BypassMergeSortShuffleHandle[K, V](
+            shuffleId,
+            dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+        } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
+          // Otherwise, try to buffer map outputs in a serialized form, since 
this is more
+          // efficient:
+          new SerializedShuffleHandle[K, V](
+            shuffleId,
+            dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+        } else {
+          // Otherwise, buffer map outputs in a deserialized form:
+          new BaseShuffleHandle(shuffleId, dependency)
+        }
     }
   }
 
@@ -150,7 +164,8 @@ class CometShuffleManager(conf: SparkConf) extends 
ShuffleManager with Logging {
       }
 
     if (handle.isInstanceOf[CometBypassMergeSortShuffleHandle[_, _]] ||
-      handle.isInstanceOf[CometSerializedShuffleHandle[_, _]]) {
+      handle.isInstanceOf[CometSerializedShuffleHandle[_, _]] ||
+      handle.isInstanceOf[CometNativeShuffleHandle[_, _]]) {
       new CometBlockStoreShuffleReader(
         handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
         blocksByAddress,
@@ -184,6 +199,17 @@ class CometShuffleManager(conf: SparkConf) extends 
ShuffleManager with Logging {
     }
     val env = SparkEnv.get
     handle match {
+      case cometShuffleHandle: CometNativeShuffleHandle[K @unchecked, V 
@unchecked] =>
+        val dep = 
cometShuffleHandle.dependency.asInstanceOf[CometShuffleDependency[_, _, _]]
+        new CometNativeShuffleWriter(
+          dep.outputPartitioning.get,
+          dep.outputAttributes,
+          dep.shuffleWriteMetrics,
+          dep.numParts,
+          dep.shuffleId,
+          mapId,
+          context,
+          metrics)
       case bypassMergeSortHandle: CometBypassMergeSortShuffleHandle[K 
@unchecked, V @unchecked] =>
         new CometBypassMergeSortShuffleWriter(
           env.blockManager,
@@ -295,3 +321,12 @@ private[spark] class CometSerializedShuffleHandle[K, V](
     shuffleId: Int,
     dependency: ShuffleDependency[K, V, V])
     extends BaseShuffleHandle(shuffleId, dependency) {}
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to 
use the native shuffle
+ * writer.
+ */
+private[spark] class CometNativeShuffleHandle[K, V](
+    shuffleId: Int,
+    dependency: ShuffleDependency[K, V, V])
+    extends BaseShuffleHandle(shuffleId, dependency) {}
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
index 9d40a0cdf..e1cbd7406 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
@@ -19,10 +19,13 @@
 
 package org.apache.comet.exec
 
+import scala.concurrent.duration.DurationInt
+
 import org.scalactic.source.Position
 import org.scalatest.Tag
 
 import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkEnv
 import org.apache.spark.sql.{CometTestBase, DataFrame}
 import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -201,6 +204,20 @@ class CometNativeShuffleSuite extends CometTestBase with 
AdaptiveSparkPlanHelper
     }
   }
 
+  test("fix: Comet native shuffle deletes shuffle files after query") {
+    withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") {
+      var df = sql("SELECT count(_2), sum(_2) FROM tbl GROUP BY _1")
+      df.collect()
+      val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+      assert(diskBlockManager.getAllFiles().nonEmpty)
+      df = null
+      eventually(timeout(30.seconds), interval(1.seconds)) {
+        System.gc()
+        assert(diskBlockManager.getAllFiles().isEmpty)
+      }
+    }
+  }
+
   /**
    * Checks that `df` produces the same answer as Spark does, and has the 
`expectedNum` Comet
    * exchange operators. When `checkNativeOperators` is true, this also checks 
that all operators


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to