This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c44493db1dd7 [SPARK-47764][CORE][SQL] Cleanup shuffle dependencies 
based on ShuffleCleanupMode
c44493db1dd7 is described below

commit c44493db1dd7bd56dc41ade19138563b6b76529b
Author: Bo Zhang <bo.zh...@databricks.com>
AuthorDate: Wed Apr 24 16:14:06 2024 +0800

    [SPARK-47764][CORE][SQL] Cleanup shuffle dependencies based on 
ShuffleCleanupMode
    
    ### What changes were proposed in this pull request?
    This change adds a new trait, `ShuffleCleanupMode` under `QueryExecution`, 
and two new configs, `spark.sql.shuffleDependency.skipMigration.enabled` and 
`spark.sql.shuffleDependency.fileCleanup.enabled`.
    
    For Spark Connect query executions, `ShuffleCleanupMode` is controlled by 
the two new configs, and shuffle dependency cleanup are performed accordingly.
    
    When `spark.sql.shuffleDependency.fileCleanup.enabled` is `true`, shuffle 
dependency files will be cleaned up at the end of query executions.
    
    When `spark.sql.shuffleDependency.skipMigration.enabled` is `true`, shuffle 
dependencies will be skipped at the shuffle data migration for node 
decommissions.
    
    ### Why are the changes needed?
    This is to: 1. speed up shuffle data migration at decommissions and 2. 
possibly (when file cleanup mode is enabled) release disk space occupied by 
unused shuffle files.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. This change adds two new configs, 
`spark.sql.shuffleDependency.skipMigration.enabled` and 
`spark.sql.shuffleDependency.fileCleanup.enabled` to control the cleanup 
behaviors.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #45930 from bozhang2820/spark-47764.
    
    Authored-by: Bo Zhang <bo.zh...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../execution/SparkConnectPlanExecution.scala      | 15 +++++++-
 .../spark/shuffle/IndexShuffleBlockResolver.scala  | 12 +++++-
 .../apache/spark/shuffle/MigratableResolver.scala  |  5 +++
 .../org/apache/spark/storage/BlockManager.scala    |  2 +-
 project/MimaExcludes.scala                         |  7 +++-
 .../org/apache/spark/sql/internal/SQLConf.scala    | 16 ++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 20 +++++++++-
 .../spark/sql/execution/QueryExecution.scala       | 19 +++++++++-
 .../apache/spark/sql/execution/SQLExecution.scala  | 24 +++++++++++-
 .../execution/adaptive/AdaptiveSparkPlanExec.scala |  9 ++++-
 .../execution/exchange/ShuffleExchangeExec.scala   |  7 ++++
 .../spark/sql/SparkSessionExtensionSuite.scala     |  1 +
 .../spark/sql/execution/QueryExecutionSuite.scala  | 43 ++++++++++++++++++++++
 13 files changed, 169 insertions(+), 11 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 23390bf7aba8..32cdae7bae56 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -35,8 +35,9 @@ import 
org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
 import org.apache.spark.sql.connect.service.ExecuteHolder
 import org.apache.spark.sql.connect.utils.MetricGenerator
-import org.apache.spark.sql.execution.{LocalTableScanExec, SQLExecution}
+import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, 
RemoveShuffleFiles, SkipMigration, SQLExecution}
 import org.apache.spark.sql.execution.arrow.ArrowConverters
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.ThreadUtils
 
@@ -58,11 +59,21 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
     }
     val planner = new SparkConnectPlanner(executeHolder)
     val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
+    val conf = session.sessionState.conf
+    val shuffleCleanupMode =
+      if (conf.getConf(SQLConf.SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED)) {
+        RemoveShuffleFiles
+      } else if 
(conf.getConf(SQLConf.SHUFFLE_DEPENDENCY_SKIP_MIGRATION_ENABLED)) {
+        SkipMigration
+      } else {
+        DoNotCleanup
+      }
     val dataframe =
       Dataset.ofRows(
         sessionHolder.session,
         planner.transformRelation(request.getPlan.getRoot),
-        tracker)
+        tracker,
+        shuffleCleanupMode)
     responseObserver.onNext(createSchemaResponse(request.getSessionId, 
dataframe.schema))
     processAsArrowBatches(dataframe, responseObserver, executeHolder)
     
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, 
dataframe))
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala 
b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index b878c88c43b0..e7b9a7e2f0ee 100644
--- 
a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ 
b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -24,6 +24,8 @@ import java.nio.file.Files
 
 import scala.collection.mutable.ArrayBuffer
 
+import com.google.common.cache.CacheBuilder
+
 import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException}
 import org.apache.spark.errors.SparkCoreErrors
 import org.apache.spark.internal.{config, Logging, MDC}
@@ -76,13 +78,21 @@ private[spark] class IndexShuffleBlockResolver(
   override def getStoredShuffles(): Seq[ShuffleBlockInfo] = {
     val allBlocks = blockManager.diskBlockManager.getAllBlocks()
     allBlocks.flatMap {
-      case ShuffleIndexBlockId(shuffleId, mapId, _) =>
+      case ShuffleIndexBlockId(shuffleId, mapId, _)
+        if Option(shuffleIdsToSkip.getIfPresent(shuffleId)).isEmpty =>
         Some(ShuffleBlockInfo(shuffleId, mapId))
       case _ =>
         None
     }
   }
 
+  private val shuffleIdsToSkip =
+    CacheBuilder.newBuilder().maximumSize(1000).build[java.lang.Integer, 
java.lang.Boolean]()
+
+  override def addShuffleToSkip(shuffleId: ShuffleId): Unit = {
+    shuffleIdsToSkip.put(shuffleId, true)
+  }
+
   private def getShuffleBytesStored(): Long = {
     val shuffleFiles: Seq[File] = getStoredShuffles().map {
       si => getDataFile(si.shuffleId, si.mapId)
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/MigratableResolver.scala 
b/core/src/main/scala/org/apache/spark/shuffle/MigratableResolver.scala
index 9908281deed8..19835d515fec 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/MigratableResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/MigratableResolver.scala
@@ -35,6 +35,11 @@ trait MigratableResolver {
    */
   def getStoredShuffles(): Seq[ShuffleBlockInfo]
 
+  /**
+   * Mark a shuffle that should not be migrated.
+   */
+  def addShuffleToSkip(shuffleId: Int): Unit = {}
+
   /**
    * Write a provided shuffle block as a stream. Used for block migrations.
    * Up to the implementation to support STORAGE_REMOTE_SHUFFLE_MAX_DISK
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 31669e688a19..ebe782301044 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -305,7 +305,7 @@ private[spark] class BlockManager(
 
   // This is a lazy val so someone can migrating RDDs even if they don't have 
a MigratableResolver
   // for shuffles. Used in BlockManagerDecommissioner & block puts.
-  private[storage] lazy val migratableResolver: MigratableResolver = {
+  lazy val migratableResolver: MigratableResolver = {
     shuffleManager.shuffleBlockResolver.asInstanceOf[MigratableResolver]
   }
 
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 0783b6a611b8..c684e2e30f7f 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -15,7 +15,8 @@
  * limitations under the License.
  */
 
-import com.typesafe.tools.mima.core._
+import com.typesafe.tools.mima.core
+import com.typesafe.tools.mima.core.*
 
 /**
  * Additional excludes for checking of Spark's binary compatibility.
@@ -93,7 +94,9 @@ object MimaExcludes {
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.TestWritable"),
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.TestWritable$"),
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.WriteInputFormatTestDataGenerator"),
-    
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.WriteInputFormatTestDataGenerator$")
+    
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.python.WriteInputFormatTestDataGenerator$"),
+    // SPARK-47764: Cleanup shuffle dependencies based on ShuffleCleanupMode
+    
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.MigratableResolver.addShuffleToSkip")
   )
 
   // Default exclude rules
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 428ad052eba8..974810133859 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2874,6 +2874,22 @@ object SQLConf {
       .intConf
       
.createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get)
 
+  val SHUFFLE_DEPENDENCY_SKIP_MIGRATION_ENABLED =
+    buildConf("spark.sql.shuffleDependency.skipMigration.enabled")
+      .doc("When enabled, shuffle dependencies for a Spark Connect SQL 
execution are marked at " +
+        "the end of the execution, and they will not be migrated during 
decommissions.")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(Utils.isTesting)
+
+  val SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED =
+    buildConf("spark.sql.shuffleDependency.fileCleanup.enabled")
+      .doc("When enabled, shuffle files will be cleaned up at the end of Spark 
Connect " +
+        "SQL executions.")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(Utils.isTesting)
+
   val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
     buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold")
       .internal()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c29fd968fc19..18c9704afdf8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -95,10 +95,26 @@ private[sql] object Dataset {
       new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
   }
 
+  def ofRows(
+      sparkSession: SparkSession,
+      logicalPlan: LogicalPlan,
+      shuffleCleanupMode: ShuffleCleanupMode): DataFrame =
+    sparkSession.withActive {
+      val qe = new QueryExecution(
+        sparkSession, logicalPlan, shuffleCleanupMode = shuffleCleanupMode)
+      qe.assertAnalyzed()
+      new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
+    }
+
   /** A variant of ofRows that allows passing in a tracker so we can track 
query parsing time. */
-  def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan, tracker: 
QueryPlanningTracker)
+  def ofRows(
+      sparkSession: SparkSession,
+      logicalPlan: LogicalPlan,
+      tracker: QueryPlanningTracker,
+      shuffleCleanupMode: ShuffleCleanupMode = DoNotCleanup)
     : DataFrame = sparkSession.withActive {
-    val qe = new QueryExecution(sparkSession, logicalPlan, tracker)
+    val qe = new QueryExecution(
+      sparkSession, logicalPlan, tracker, shuffleCleanupMode = 
shuffleCleanupMode)
     qe.assertAnalyzed()
     new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 00666190c0cc..c3775b9d14d5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -60,7 +60,8 @@ class QueryExecution(
     val sparkSession: SparkSession,
     val logical: LogicalPlan,
     val tracker: QueryPlanningTracker = new QueryPlanningTracker,
-    val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL) extends 
Logging {
+    val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL,
+    val shuffleCleanupMode: ShuffleCleanupMode = DoNotCleanup) extends Logging 
{
 
   val id: Long = QueryExecution.nextExecutionId
 
@@ -459,6 +460,22 @@ object CommandExecutionMode extends Enumeration {
   val SKIP, NON_ROOT, ALL = Value
 }
 
+/**
+ * Modes for shuffle dependency cleanup.
+ *
+ * DoNotCleanup: Do not perform any cleanup.
+ * SkipMigration: Shuffle dependencies will not be migrated at node 
decommissions.
+ * RemoveShuffleFiles: Shuffle dependency files are removed at the end of SQL 
executions.
+ */
+sealed trait ShuffleCleanupMode
+
+case object DoNotCleanup extends ShuffleCleanupMode
+
+case object SkipMigration extends ShuffleCleanupMode
+
+case object RemoveShuffleFiles extends ShuffleCleanupMode
+
+
 object QueryExecution {
   private val _nextExecutionId = new AtomicLong(0)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 561deacfb72d..f4be03c90be7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -20,14 +20,16 @@ package org.apache.spark.sql.execution
 import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => 
JFuture}
 import java.util.concurrent.atomic.AtomicLong
 
+import scala.jdk.CollectionConverters._
 import scala.util.control.NonFatal
 
-import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, SparkException, 
SparkThrowable, SparkThrowableHelper}
+import org.apache.spark.{ErrorMessageFormat, JobArtifactSet, SparkEnv, 
SparkException, SparkThrowable, SparkThrowableHelper}
 import org.apache.spark.SparkContext.{SPARK_JOB_DESCRIPTION, 
SPARK_JOB_INTERRUPT_ON_CANCEL}
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.{SPARK_DRIVER_PREFIX, 
SPARK_EXECUTOR_PREFIX}
 import org.apache.spark.internal.config.Tests.IS_TESTING
 import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
 import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, 
SparkListenerSQLExecutionStart}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.StaticSQLConf.SQL_EVENT_TRUNCATE_LENGTH
@@ -115,6 +117,7 @@ object SQLExecution extends Logging {
 
       withSQLConfPropagated(sparkSession) {
         var ex: Option[Throwable] = None
+        var isExecutedPlanAvailable = false
         val startTime = System.nanoTime()
         val startEvent = SparkListenerSQLExecutionStart(
           executionId = executionId,
@@ -147,6 +150,7 @@ object SQLExecution extends Logging {
               }
               sc.listenerBus.post(
                 startEvent.copy(physicalPlanDescription = planDesc, 
sparkPlanInfo = planInfo))
+              isExecutedPlanAvailable = true
               f()
           }
         } catch {
@@ -161,6 +165,24 @@ object SQLExecution extends Logging {
             case e =>
               Utils.exceptionString(e)
           }
+          if (queryExecution.shuffleCleanupMode != DoNotCleanup
+            && isExecutedPlanAvailable) {
+            val shuffleIds = queryExecution.executedPlan match {
+              case ae: AdaptiveSparkPlanExec =>
+                ae.context.shuffleIds.asScala.keys
+              case _ =>
+                Iterable.empty
+            }
+            shuffleIds.foreach { shuffleId =>
+              queryExecution.shuffleCleanupMode match {
+                case RemoveShuffleFiles =>
+                  SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
+                case SkipMigration =>
+                  
SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId)
+                case _ => // this should not happen
+              }
+            }
+          }
           val event = SparkListenerSQLExecutionEnd(
             executionId,
             System.currentTimeMillis(),
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index a5e681535cb8..ca4400068250 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.adaptive
 
 import java.util
-import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}
 
 import scala.collection.concurrent.TrieMap
 import scala.collection.mutable
@@ -302,6 +302,11 @@ case class AdaptiveSparkPlanExec(
             try {
               stage.materialize().onComplete { res =>
                 if (res.isSuccess) {
+                  // record shuffle IDs for successful stages for cleanup
+                  stage.plan.collect {
+                    case s: ShuffleExchangeLike =>
+                      context.shuffleIds.put(s.shuffleId, true)
+                  }
                   events.offer(StageSuccess(stage, res.get))
                 } else {
                   events.offer(StageFailure(stage, res.failed.get))
@@ -869,6 +874,8 @@ case class AdaptiveExecutionContext(session: SparkSession, 
qe: QueryExecution) {
    */
   val stageCache: TrieMap[SparkPlan, ExchangeQueryStageExec] =
     new TrieMap[SparkPlan, ExchangeQueryStageExec]()
+
+  val shuffleIds: ConcurrentHashMap[Int, Boolean] = new ConcurrentHashMap[Int, 
Boolean]()
 }
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 69705afbb7c7..6f9402e1c9e4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -86,6 +86,11 @@ trait ShuffleExchangeLike extends Exchange {
    * Returns the runtime statistics after shuffle materialization.
    */
   def runtimeStatistics: Statistics
+
+  /**
+   * The shuffle ID.
+   */
+  def shuffleId: Int
 }
 
 // Describes where the shuffle operator comes from.
@@ -166,6 +171,8 @@ case class ShuffleExchangeExec(
     Statistics(dataSize, Some(rowCount))
   }
 
+  override def shuffleId: Int = shuffleDependency.shuffleId
+
   /**
    * A [[ShuffleDependency]] that will partition rows of its child based on
    * the partitioning scheme defined in `newPartitioning`. Those partitions of
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 18c1f4dcc4e0..1c44d0c3b4ea 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -1014,6 +1014,7 @@ case class MyShuffleExchangeExec(delegate: 
ShuffleExchangeExec) extends ShuffleE
     val attributeStats = AttributeMap(Seq((child.output.head, columnStats)))
     Statistics(stats.sizeInBytes, stats.rowCount, attributeStats)
   }
+  override def shuffleId: Int = delegate.shuffleId
   override def child: SparkPlan = delegate.child
   override protected def doExecute(): RDD[InternalRow] = delegate.execute()
   override def outputPartitioning: Partitioning = delegate.outputPartitioning
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
index 73e516582932..3608e7c92076 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
@@ -33,6 +33,7 @@ import 
org.apache.spark.sql.execution.datasources.v2.ShowTablesExec
 import org.apache.spark.sql.execution.joins.SortMergeJoinExec
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.storage.ShuffleIndexBlockId
 import org.apache.spark.util.Utils
 
 case class QueryExecutionTestRecord(
@@ -314,6 +315,48 @@ class QueryExecutionSuite extends SharedSparkSession {
     mockCallback.assertExecutedPlanPrepared()
   }
 
+  private def cleanupShuffles(): Unit = {
+    val blockManager = spark.sparkContext.env.blockManager
+    blockManager.diskBlockManager.getAllBlocks().foreach {
+      case ShuffleIndexBlockId(shuffleId, _, _) =>
+        spark.sparkContext.env.shuffleManager.unregisterShuffle(shuffleId)
+      case _ =>
+    }
+  }
+
+  test("SPARK-47764: Cleanup shuffle dependencies - DoNotCleanup mode") {
+    val plan = spark.range(100).repartition(10).logicalPlan
+    val df = Dataset.ofRows(spark, plan, DoNotCleanup)
+    df.collect()
+
+    val blockManager = spark.sparkContext.env.blockManager
+    assert(blockManager.migratableResolver.getStoredShuffles().nonEmpty)
+    assert(blockManager.diskBlockManager.getAllBlocks().nonEmpty)
+    cleanupShuffles()
+  }
+
+  test("SPARK-47764: Cleanup shuffle dependencies - SkipMigration mode") {
+    val plan = spark.range(100).repartition(10).logicalPlan
+    val df = Dataset.ofRows(spark, plan, SkipMigration)
+    df.collect()
+
+    val blockManager = spark.sparkContext.env.blockManager
+    assert(blockManager.migratableResolver.getStoredShuffles().isEmpty)
+    assert(blockManager.diskBlockManager.getAllBlocks().nonEmpty)
+    cleanupShuffles()
+  }
+
+  test("SPARK-47764: Cleanup shuffle dependencies - RemoveShuffleFiles mode") {
+    val plan = spark.range(100).repartition(10).logicalPlan
+    val df = Dataset.ofRows(spark, plan, RemoveShuffleFiles)
+    df.collect()
+
+    val blockManager = spark.sparkContext.env.blockManager
+    assert(blockManager.migratableResolver.getStoredShuffles().isEmpty)
+    assert(blockManager.diskBlockManager.getAllBlocks().isEmpty)
+    cleanupShuffles()
+  }
+
   test("SPARK-35378: Return UnsafeRow in CommandResultExecCheck execute 
methods") {
     val plan = spark.sql("SHOW FUNCTIONS").queryExecution.executedPlan
     assert(plan.isInstanceOf[CommandResultExec])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to