This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new efca99108 fix: Making shuffle files generated in native shuffle mode
reclaimable (#1568)
efca99108 is described below
commit efca99108b1cd8a648795995456b6074b57dd25b
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Sat Apr 5 02:15:23 2025 +0800
fix: Making shuffle files generated in native shuffle mode reclaimable
(#1568)
* Making shuffle files generated in native shuffle mode reclaimable
* Add a unit test
* Use eventually in unit test
* Address review comments
---
.../shuffle/CometNativeShuffleWriter.scala | 221 +++++++++++++++++++++
.../execution/shuffle/CometShuffleDependency.scala | 8 +-
.../shuffle/CometShuffleExchangeExec.scala | 198 +-----------------
.../execution/shuffle/CometShuffleManager.scala | 101 +++++++---
.../comet/exec/CometNativeShuffleSuite.scala | 17 ++
5 files changed, 321 insertions(+), 224 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala
new file mode 100644
index 000000000..b67b267fc
--- /dev/null
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.shuffle
+
+import java.nio.{ByteBuffer, ByteOrder}
+import java.nio.file.{Files, Paths}
+
+import scala.collection.JavaConverters.asJavaIterableConverter
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.shuffle.{IndexShuffleBlockResolver,
ShuffleWriteMetricsReporter, ShuffleWriter}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, SinglePartition}
+import org.apache.spark.sql.comet.{CometExec, CometMetricNode}
+import org.apache.spark.sql.execution.metric.{SQLMetric,
SQLShuffleWriteMetricsReporter}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import org.apache.comet.CometConf
+import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass,
QueryPlanSerde}
+import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator}
+import org.apache.comet.serde.QueryPlanSerde.serializeDataType
+
+/**
+ * A [[ShuffleWriter]] that will delegate shuffle write to native shuffle.
+ */
+class CometNativeShuffleWriter[K, V](
+ outputPartitioning: Partitioning,
+ outputAttributes: Seq[Attribute],
+ metrics: Map[String, SQLMetric],
+ numParts: Int,
+ shuffleId: Int,
+ mapId: Long,
+ context: TaskContext,
+ metricsReporter: ShuffleWriteMetricsReporter)
+ extends ShuffleWriter[K, V] {
+
+ private val OFFSET_LENGTH = 8
+
+ var partitionLengths: Array[Long] = _
+ var mapStatus: MapStatus = _
+
+ override def write(inputs: Iterator[Product2[K, V]]): Unit = {
+ val shuffleBlockResolver =
+
SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver]
+ val dataFile = shuffleBlockResolver.getDataFile(shuffleId, mapId)
+ val indexFile = shuffleBlockResolver.getIndexFile(shuffleId, mapId)
+ val tempDataFilename = dataFile.getPath.replace(".data", ".data.tmp")
+ val tempIndexFilename = indexFile.getPath.replace(".index", ".index.tmp")
+ val tempDataFilePath = Paths.get(tempDataFilename)
+ val tempIndexFilePath = Paths.get(tempIndexFilename)
+
+ // Call native shuffle write
+ val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename)
+
+ val detailedMetrics = Seq(
+ "elapsed_compute",
+ "encode_time",
+ "repart_time",
+ "mempool_time",
+ "input_batches",
+ "spill_count",
+ "spilled_bytes")
+
+ // Maps native metrics to SQL metrics
+ val nativeSQLMetrics = Map(
+ "output_rows" ->
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN),
+ "data_size" -> metrics("dataSize"),
+ "write_time" ->
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME)) ++
+ metrics.filterKeys(detailedMetrics.contains)
+ val nativeMetrics = CometMetricNode(nativeSQLMetrics)
+
+ // Getting rid of the fake partitionId
+ val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any,
Any]]].map(_._2)
+
+ val cometIter = CometExec.getCometIterator(
+ Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
+ outputAttributes.length,
+ nativePlan,
+ nativeMetrics,
+ numParts,
+ context.partitionId())
+
+ while (cometIter.hasNext) {
+ cometIter.next()
+ }
+ cometIter.close()
+
+ // get partition lengths from shuffle write output index file
+ var offset = 0L
+ partitionLengths = Files
+ .readAllBytes(tempIndexFilePath)
+ .grouped(OFFSET_LENGTH)
+ .drop(1) // first partition offset is always 0
+ .map(indexBytes => {
+ val partitionOffset =
+ ByteBuffer.wrap(indexBytes).order(ByteOrder.LITTLE_ENDIAN).getLong
+ val partitionLength = partitionOffset - offset
+ offset = partitionOffset
+ partitionLength
+ })
+ .toArray
+ Files.delete(tempIndexFilePath)
+
+ // Total written bytes at native
+ metricsReporter.incBytesWritten(Files.size(tempDataFilePath))
+
+ // commit
+ shuffleBlockResolver.writeMetadataFileAndCommit(
+ shuffleId,
+ mapId,
+ partitionLengths,
+ Array.empty, // TODO: add checksums
+ tempDataFilePath.toFile)
+ mapStatus =
+ MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId,
partitionLengths, mapId)
+ }
+
+ private def getNativePlan(dataFile: String, indexFile: String): Operator = {
+ val scanBuilder =
OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput")
+ val opBuilder = OperatorOuterClass.Operator.newBuilder()
+
+ val scanTypes = outputAttributes.flatten { attr =>
+ serializeDataType(attr.dataType)
+ }
+
+ if (scanTypes.length == outputAttributes.length) {
+ scanBuilder.addAllFields(scanTypes.asJava)
+
+ val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder()
+ shuffleWriterBuilder.setOutputDataFile(dataFile)
+ shuffleWriterBuilder.setOutputIndexFile(indexFile)
+ shuffleWriterBuilder.setEnableFastEncoding(
+ CometConf.COMET_SHUFFLE_ENABLE_FAST_ENCODING.get())
+
+ if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) {
+ val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match
{
+ case "zstd" => CompressionCodec.Zstd
+ case "lz4" => CompressionCodec.Lz4
+ case "snappy" => CompressionCodec.Snappy
+ case other => throw new UnsupportedOperationException(s"invalid
codec: $other")
+ }
+ shuffleWriterBuilder.setCodec(codec)
+ } else {
+ shuffleWriterBuilder.setCodec(CompressionCodec.None)
+ }
+ shuffleWriterBuilder.setCompressionLevel(
+ CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get)
+
+ outputPartitioning match {
+ case _: HashPartitioning =>
+ val hashPartitioning =
outputPartitioning.asInstanceOf[HashPartitioning]
+
+ val partitioning =
PartitioningOuterClass.HashRepartition.newBuilder()
+ partitioning.setNumPartitions(outputPartitioning.numPartitions)
+
+ val partitionExprs = hashPartitioning.expressions
+ .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes))
+
+ if (partitionExprs.length != hashPartitioning.expressions.length) {
+ throw new UnsupportedOperationException(
+ s"Partitioning $hashPartitioning is not supported.")
+ }
+
+ partitioning.addAllHashExpression(partitionExprs.asJava)
+
+ val partitioningBuilder =
PartitioningOuterClass.Partitioning.newBuilder()
+ shuffleWriterBuilder.setPartitioning(
+ partitioningBuilder.setHashPartition(partitioning).build())
+
+ case SinglePartition =>
+ val partitioning =
PartitioningOuterClass.SinglePartition.newBuilder()
+
+ val partitioningBuilder =
PartitioningOuterClass.Partitioning.newBuilder()
+ shuffleWriterBuilder.setPartitioning(
+ partitioningBuilder.setSinglePartition(partitioning).build())
+
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Partitioning $outputPartitioning is not supported.")
+ }
+
+ val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder()
+ shuffleWriterOpBuilder
+ .setShuffleWriter(shuffleWriterBuilder)
+ .addChildren(opBuilder.setScan(scanBuilder).build())
+ .build()
+ } else {
+ // There are unsupported scan type
+ throw new UnsupportedOperationException(
+ s"$outputAttributes contains unsupported data types for
CometShuffleExchangeExec.")
+ }
+ }
+
+ override def stop(success: Boolean): Option[MapStatus] = {
+ if (success) {
+ Some(mapStatus)
+ } else {
+ None
+ }
+ }
+
+ override def getPartitionLengths(): Array[Long] = partitionLengths
+}
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
index 8c8aed28e..ff35b10eb 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala
@@ -25,6 +25,8 @@ import org.apache.spark.{Aggregator, Partitioner,
ShuffleDependency, SparkEnv}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleWriteProcessor
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
@@ -41,7 +43,11 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C:
ClassTag](
override val shuffleWriterProcessor: ShuffleWriteProcessor = new
ShuffleWriteProcessor,
val shuffleType: ShuffleType = CometNativeShuffle,
val schema: Option[StructType] = None,
- val decodeTime: SQLMetric)
+ val decodeTime: SQLMetric,
+ val outputPartitioning: Option[Partitioning] = None,
+ val outputAttributes: Seq[Attribute] = Seq.empty,
+ val shuffleWriteMetrics: Map[String, SQLMetric] = Map.empty,
+ val numParts: Int = 0)
extends ShuffleDependency[K, V, C](
_rdd,
partitioner,
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index f2af4402d..d121bd6d5 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -19,27 +19,21 @@
package org.apache.spark.sql.comet.execution.shuffle
-import java.nio.{ByteBuffer, ByteOrder}
-import java.nio.file.{Files, Paths}
import java.util.function.Supplier
-import scala.collection.JavaConverters.asJavaIterableConverter
import scala.concurrent.Future
import org.apache.spark._
import org.apache.spark.internal.config
import org.apache.spark.rdd.RDD
-import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{IndexShuffleBlockResolver,
ShuffleWriteMetricsReporter}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference,
UnsafeProjection, UnsafeRow}
import
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan}
-import org.apache.spark.sql.comet.shims.ShimCometShuffleWriteProcessor
+import org.apache.spark.sql.comet.{CometMetricNode, CometPlan}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS,
ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -52,10 +46,6 @@ import org.apache.spark.util.random.XORShiftRandom
import com.google.common.base.Objects
-import org.apache.comet.CometConf
-import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass,
QueryPlanSerde}
-import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator}
-import org.apache.comet.serde.QueryPlanSerde.serializeDataType
import org.apache.comet.shims.ShimCometShuffleExchangeExec
/**
@@ -232,20 +222,23 @@ object CometShuffleExchangeExec extends
ShimCometShuffleExchangeExec {
(0, _)
), // adding fake partitionId that is always 0 because ShuffleDependency
requires it
serializer = serializer,
- shuffleWriterProcessor =
- new CometShuffleWriteProcessor(outputPartitioning, outputAttributes,
metrics, numParts),
+ shuffleWriterProcessor =
ShuffleExchangeExec.createShuffleWriteProcessor(metrics),
shuffleType = CometNativeShuffle,
partitioner = new Partitioner {
override def numPartitions: Int = outputPartitioning.numPartitions
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
},
- decodeTime = metrics("decode_time"))
+ decodeTime = metrics("decode_time"),
+ outputPartitioning = Some(outputPartitioning),
+ outputAttributes = outputAttributes,
+ shuffleWriteMetrics = metrics,
+ numParts = numParts)
dependency
}
/**
* This is copied from Spark
`ShuffleExchangeExec.needToCopyObjectsBeforeShuffle`. The only
- * difference is that we use `BosonShuffleManager` instead of
`SortShuffleManager`.
+ * difference is that we use `CometShuffleManager` instead of
`SortShuffleManager`.
*/
private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner):
Boolean = {
// Note: even though we only use the partitioner's `numPartitions` field,
we require it to be
@@ -442,178 +435,3 @@ object CometShuffleExchangeExec extends
ShimCometShuffleExchangeExec {
dependency
}
}
-
-/**
- * A [[ShuffleWriteProcessor]] that will delegate shuffle write to native
shuffle.
- * @param metrics
- * metrics to report
- */
-class CometShuffleWriteProcessor(
- outputPartitioning: Partitioning,
- outputAttributes: Seq[Attribute],
- metrics: Map[String, SQLMetric],
- numParts: Int)
- extends ShimCometShuffleWriteProcessor {
-
- private val OFFSET_LENGTH = 8
-
- override protected def createMetricsReporter(
- context: TaskContext): ShuffleWriteMetricsReporter = {
- new
SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics,
metrics)
- }
-
- override def write(
- inputs: Iterator[_],
- dep: ShuffleDependency[_, _, _],
- mapId: Long,
- mapIndex: Int,
- context: TaskContext): MapStatus = {
- val metricsReporter = createMetricsReporter(context)
- val shuffleBlockResolver =
-
SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver]
- val dataFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
- val indexFile = shuffleBlockResolver.getIndexFile(dep.shuffleId, mapId)
- val tempDataFilename = dataFile.getPath.replace(".data", ".data.tmp")
- val tempIndexFilename = indexFile.getPath.replace(".index", ".index.tmp")
- val tempDataFilePath = Paths.get(tempDataFilename)
- val tempIndexFilePath = Paths.get(tempIndexFilename)
-
- // Call native shuffle write
- val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename)
-
- val detailedMetrics = Seq(
- "elapsed_compute",
- "encode_time",
- "repart_time",
- "mempool_time",
- "input_batches",
- "spill_count",
- "spilled_bytes")
-
- // Maps native metrics to SQL metrics
- val nativeSQLMetrics = Map(
- "output_rows" ->
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN),
- "data_size" -> metrics("dataSize"),
- "write_time" ->
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME)) ++
- metrics.filterKeys(detailedMetrics.contains)
- val nativeMetrics = CometMetricNode(nativeSQLMetrics)
-
- // Getting rid of the fake partitionId
- val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any,
Any]]].map(_._2)
-
- val cometIter = CometExec.getCometIterator(
- Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
- outputAttributes.length,
- nativePlan,
- nativeMetrics,
- numParts,
- context.partitionId())
-
- while (cometIter.hasNext) {
- cometIter.next()
- }
-
- // get partition lengths from shuffle write output index file
- var offset = 0L
- val partitionLengths = Files
- .readAllBytes(tempIndexFilePath)
- .grouped(OFFSET_LENGTH)
- .drop(1) // first partition offset is always 0
- .map(indexBytes => {
- val partitionOffset =
- ByteBuffer.wrap(indexBytes).order(ByteOrder.LITTLE_ENDIAN).getLong
- val partitionLength = partitionOffset - offset
- offset = partitionOffset
- partitionLength
- })
- .toArray
-
- // Total written bytes at native
- metricsReporter.incBytesWritten(Files.size(tempDataFilePath))
-
- // commit
- shuffleBlockResolver.writeMetadataFileAndCommit(
- dep.shuffleId,
- mapId,
- partitionLengths,
- Array.empty, // TODO: add checksums
- tempDataFilePath.toFile)
- MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId,
partitionLengths, mapId)
- }
-
- def getNativePlan(dataFile: String, indexFile: String): Operator = {
- val scanBuilder =
OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput")
- val opBuilder = OperatorOuterClass.Operator.newBuilder()
-
- val scanTypes = outputAttributes.flatten { attr =>
- serializeDataType(attr.dataType)
- }
-
- if (scanTypes.length == outputAttributes.length) {
- scanBuilder.addAllFields(scanTypes.asJava)
-
- val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder()
- shuffleWriterBuilder.setOutputDataFile(dataFile)
- shuffleWriterBuilder.setOutputIndexFile(indexFile)
- shuffleWriterBuilder.setEnableFastEncoding(
- CometConf.COMET_SHUFFLE_ENABLE_FAST_ENCODING.get())
-
- if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) {
- val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match
{
- case "zstd" => CompressionCodec.Zstd
- case "lz4" => CompressionCodec.Lz4
- case "snappy" => CompressionCodec.Snappy
- case other => throw new UnsupportedOperationException(s"invalid
codec: $other")
- }
- shuffleWriterBuilder.setCodec(codec)
- } else {
- shuffleWriterBuilder.setCodec(CompressionCodec.None)
- }
- shuffleWriterBuilder.setCompressionLevel(
- CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get)
-
- outputPartitioning match {
- case _: HashPartitioning =>
- val hashPartitioning =
outputPartitioning.asInstanceOf[HashPartitioning]
-
- val partitioning =
PartitioningOuterClass.HashRepartition.newBuilder()
- partitioning.setNumPartitions(outputPartitioning.numPartitions)
-
- val partitionExprs = hashPartitioning.expressions
- .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes))
-
- if (partitionExprs.length != hashPartitioning.expressions.length) {
- throw new UnsupportedOperationException(
- s"Partitioning $hashPartitioning is not supported.")
- }
-
- partitioning.addAllHashExpression(partitionExprs.asJava)
-
- val partitioningBuilder =
PartitioningOuterClass.Partitioning.newBuilder()
- shuffleWriterBuilder.setPartitioning(
- partitioningBuilder.setHashPartition(partitioning).build())
-
- case SinglePartition =>
- val partitioning =
PartitioningOuterClass.SinglePartition.newBuilder()
-
- val partitioningBuilder =
PartitioningOuterClass.Partitioning.newBuilder()
- shuffleWriterBuilder.setPartitioning(
- partitioningBuilder.setSinglePartition(partitioning).build())
-
- case _ =>
- throw new UnsupportedOperationException(
- s"Partitioning $outputPartitioning is not supported.")
- }
-
- val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder()
- shuffleWriterOpBuilder
- .setShuffleWriter(shuffleWriterBuilder)
- .addChildren(opBuilder.setScan(scanBuilder).build())
- .build()
- } else {
- // There are unsupported scan type
- throw new UnsupportedOperationException(
- s"$outputAttributes contains unsupported data types for
CometShuffleExchangeExec.")
- }
- }
-}
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
index b2cc2c2ba..1142c6af1 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala
@@ -86,38 +86,52 @@ class CometShuffleManager(conf: SparkConf) extends
ShuffleManager with Logging {
def registerShuffle[K, V, C](
shuffleId: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
- if (dependency.isInstanceOf[CometShuffleDependency[_, _, _]]) {
- // Comet shuffle dependency, which comes from `CometShuffleExchangeExec`.
- if (shouldBypassMergeSort(conf, dependency) ||
- !SortShuffleManager.canUseSerializedShuffle(dependency)) {
- new CometBypassMergeSortShuffleHandle(
- shuffleId,
- dependency.asInstanceOf[ShuffleDependency[K, V, V]])
- } else {
- new CometSerializedShuffleHandle(
- shuffleId,
- dependency.asInstanceOf[ShuffleDependency[K, V, V]])
- }
- } else {
- // It is a Spark shuffle dependency, so we use Spark Sort Shuffle
Manager.
- if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
- // If there are fewer than spark.shuffle.sort.bypassMergeThreshold
partitions and we don't
- // need map-side aggregation, then write numPartitions files directly
and just concatenate
- // them at the end. This avoids doing serialization and
deserialization twice to merge
- // together the spilled files, which would happen with the normal code
path. The downside is
- // having multiple files open at a time and thus more memory allocated
to buffers.
- new BypassMergeSortShuffleHandle[K, V](
- shuffleId,
- dependency.asInstanceOf[ShuffleDependency[K, V, V]])
- } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
- // Otherwise, try to buffer map outputs in a serialized form, since
this is more efficient:
- new SerializedShuffleHandle[K, V](
- shuffleId,
- dependency.asInstanceOf[ShuffleDependency[K, V, V]])
- } else {
- // Otherwise, buffer map outputs in a deserialized form:
- new BaseShuffleHandle(shuffleId, dependency)
- }
+ dependency match {
+ case cometShuffleDependency: CometShuffleDependency[_, _, _] =>
+ // Comet shuffle dependency, which comes from
`CometShuffleExchangeExec`.
+ cometShuffleDependency.shuffleType match {
+ case CometColumnarShuffle =>
+ // Comet columnar shuffle, which uses Arrow format to shuffle data.
+ if (shouldBypassMergeSort(conf, dependency) ||
+ !SortShuffleManager.canUseSerializedShuffle(dependency)) {
+ new CometBypassMergeSortShuffleHandle(
+ shuffleId,
+ dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ } else {
+ new CometSerializedShuffleHandle(
+ shuffleId,
+ dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ }
+ case CometNativeShuffle =>
+ new CometNativeShuffleHandle(
+ shuffleId,
+ dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ case _ =>
+ // Unsupported shuffle type.
+ throw new UnsupportedOperationException(
+ s"Unsupported shuffle type:
${cometShuffleDependency.shuffleType}")
+ }
+ case _ =>
+ // It is a Spark shuffle dependency, so we use Spark Sort Shuffle
Manager.
+ if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
+ // If there are fewer than spark.shuffle.sort.bypassMergeThreshold
partitions and we don't
+ // need map-side aggregation, then write numPartitions files
directly and just concatenate
+ // them at the end. This avoids doing serialization and
deserialization twice to merge
+ // together the spilled files, which would happen with the normal
code path. The downside
+ // is having multiple files open at a time and thus more memory
allocated to buffers.
+ new BypassMergeSortShuffleHandle[K, V](
+ shuffleId,
+ dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
+ // Otherwise, try to buffer map outputs in a serialized form, since
this is more
+ // efficient:
+ new SerializedShuffleHandle[K, V](
+ shuffleId,
+ dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+ } else {
+ // Otherwise, buffer map outputs in a deserialized form:
+ new BaseShuffleHandle(shuffleId, dependency)
+ }
}
}
@@ -150,7 +164,8 @@ class CometShuffleManager(conf: SparkConf) extends
ShuffleManager with Logging {
}
if (handle.isInstanceOf[CometBypassMergeSortShuffleHandle[_, _]] ||
- handle.isInstanceOf[CometSerializedShuffleHandle[_, _]]) {
+ handle.isInstanceOf[CometSerializedShuffleHandle[_, _]] ||
+ handle.isInstanceOf[CometNativeShuffleHandle[_, _]]) {
new CometBlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
blocksByAddress,
@@ -184,6 +199,17 @@ class CometShuffleManager(conf: SparkConf) extends
ShuffleManager with Logging {
}
val env = SparkEnv.get
handle match {
+ case cometShuffleHandle: CometNativeShuffleHandle[K @unchecked, V
@unchecked] =>
+ val dep =
cometShuffleHandle.dependency.asInstanceOf[CometShuffleDependency[_, _, _]]
+ new CometNativeShuffleWriter(
+ dep.outputPartitioning.get,
+ dep.outputAttributes,
+ dep.shuffleWriteMetrics,
+ dep.numParts,
+ dep.shuffleId,
+ mapId,
+ context,
+ metrics)
case bypassMergeSortHandle: CometBypassMergeSortShuffleHandle[K
@unchecked, V @unchecked] =>
new CometBypassMergeSortShuffleWriter(
env.blockManager,
@@ -295,3 +321,12 @@ private[spark] class CometSerializedShuffleHandle[K, V](
shuffleId: Int,
dependency: ShuffleDependency[K, V, V])
extends BaseShuffleHandle(shuffleId, dependency) {}
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to
use the native shuffle
+ * writer.
+ */
+private[spark] class CometNativeShuffleHandle[K, V](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, V])
+ extends BaseShuffleHandle(shuffleId, dependency) {}
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
index 9d40a0cdf..e1cbd7406 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
@@ -19,10 +19,13 @@
package org.apache.comet.exec
+import scala.concurrent.duration.DurationInt
+
import org.scalactic.source.Position
import org.scalatest.Tag
import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkEnv
import org.apache.spark.sql.{CometTestBase, DataFrame}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -201,6 +204,20 @@ class CometNativeShuffleSuite extends CometTestBase with
AdaptiveSparkPlanHelper
}
}
+ test("fix: Comet native shuffle deletes shuffle files after query") {
+ withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") {
+ var df = sql("SELECT count(_2), sum(_2) FROM tbl GROUP BY _1")
+ df.collect()
+ val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+ assert(diskBlockManager.getAllFiles().nonEmpty)
+ df = null
+ eventually(timeout(30.seconds), interval(1.seconds)) {
+ System.gc()
+ assert(diskBlockManager.getAllFiles().isEmpty)
+ }
+ }
+ }
+
/**
* Checks that `df` produces the same answer as Spark does, and has the
`expectedNum` Comet
* exchange operators. When `checkNativeOperators` is true, this also checks
that all operators
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]