This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0ea318f3e23f [SPARK-46947][CORE] Delay memory manager initialization 
until Driver plugin is loaded
0ea318f3e23f is described below

commit 0ea318f3e23fb2daf953c9129047e8716e4ebf9b
Author: Chao Sun <sunc...@apache.org>
AuthorDate: Mon Feb 26 10:17:38 2024 -0800

    [SPARK-46947][CORE] Delay memory manager initialization until Driver plugin 
is loaded
    
    ### What changes were proposed in this pull request?
    
    This changes the initialization of `SparkEnv.memoryManager` to after the 
`DriverPlugin` is loaded, to allow the plugin to customize memory related 
configurations.
    
    A minor fix has been made to `Task` to make sure that it uses the same 
`BlockManager` through out the task execution. Previous a different 
`BlockManager` could be used in some corner cases. Also added a test for the 
fix.
    
    ### Why are the changes needed?
    
    Today, there is no way for a custom `DriverPlugin` to override memory 
configurations such as `spark.executor.memory`, 
`spark.executor.memoryOverhead`, `spark.memory.offheap.size` etc This is 
because the memory manager is initialized before `DriverPlugin` is loaded.
    
    A similar change has been made to `shuffleManager` in #43627.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests. Also added new tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45052 from sunchao/SPARK-46947.
    
    Authored-by: Chao Sun <sunc...@apache.org>
    Signed-off-by: Chao Sun <sunc...@apache.org>
---
 .../main/scala/org/apache/spark/SparkContext.scala |  1 +
 .../src/main/scala/org/apache/spark/SparkEnv.scala | 20 ++++--
 .../scala/org/apache/spark/scheduler/Task.scala    | 15 +++--
 .../org/apache/spark/storage/BlockManager.scala    | 20 ++++--
 .../internal/plugin/PluginContainerSuite.scala     | 53 ++++++++++++++++
 .../apache/spark/scheduler/TaskContextSuite.scala  | 71 +++++++++++++++++++++-
 6 files changed, 163 insertions(+), 17 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala 
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 801b6dd85a2b..d519617c4095 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -577,6 +577,7 @@ class SparkContext(config: SparkConf) extends Logging {
     // Initialize any plugins before the task scheduler is initialized.
     _plugins = PluginContainer(this, _resources.asJava)
     _env.initializeShuffleManager()
+    _env.initializeMemoryManager(SparkContext.numDriverCores(master, conf))
 
     // Create and start the scheduler
     val (sched, ts) = SparkContext.createTaskScheduler(this, master)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala 
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index ca07c276fbff..84c0fa5840b7 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -67,7 +67,6 @@ class SparkEnv (
     val blockManager: BlockManager,
     val securityManager: SecurityManager,
     val metricsSystem: MetricsSystem,
-    val memoryManager: MemoryManager,
     val outputCommitCoordinator: OutputCommitCoordinator,
     val conf: SparkConf) extends Logging {
 
@@ -77,6 +76,12 @@ class SparkEnv (
 
   def shuffleManager: ShuffleManager = _shuffleManager
 
+  // We initialize the MemoryManager later in SparkContext after DriverPlugin 
is loaded
+  // to allow the plugin to overwrite executor memory configurations
+  private var _memoryManager: MemoryManager = _
+
+  def memoryManager: MemoryManager = _memoryManager
+
   @volatile private[spark] var isStopped = false
 
   /**
@@ -199,6 +204,12 @@ class SparkEnv (
       "Shuffle manager already initialized to %s", _shuffleManager)
     _shuffleManager = ShuffleManager.create(conf, executorId == 
SparkContext.DRIVER_IDENTIFIER)
   }
+
+  private[spark] def initializeMemoryManager(numUsableCores: Int): Unit = {
+    Preconditions.checkState(null == memoryManager,
+      "Memory manager already initialized to %s", _memoryManager)
+    _memoryManager = UnifiedMemoryManager(conf, numUsableCores)
+  }
 }
 
 object SparkEnv extends Logging {
@@ -276,6 +287,8 @@ object SparkEnv extends Logging {
       numCores,
       ioEncryptionKey
     )
+    // Set the memory manager since it needs to be initialized explicitly
+    env.initializeMemoryManager(numCores)
     SparkEnv.set(env)
     env
   }
@@ -358,8 +371,6 @@ object SparkEnv extends Logging {
       new MapOutputTrackerMasterEndpoint(
         rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
 
-    val memoryManager: MemoryManager = UnifiedMemoryManager(conf, 
numUsableCores)
-
     val blockManagerPort = if (isDriver) {
       conf.get(DRIVER_BLOCK_MANAGER_PORT)
     } else {
@@ -418,7 +429,7 @@ object SparkEnv extends Logging {
       blockManagerMaster,
       serializerManager,
       conf,
-      memoryManager,
+      _memoryManager = null,
       mapOutputTracker,
       _shuffleManager = null,
       blockTransferService,
@@ -463,7 +474,6 @@ object SparkEnv extends Logging {
       blockManager,
       securityManager,
       metricsSystem,
-      memoryManager,
       outputCommitCoordinator,
       conf)
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala 
b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 1ecd185de557..6e449e4dc111 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -93,7 +93,12 @@ private[spark] abstract class Task[T](
 
     require(cpus > 0, "CPUs per task should be > 0")
 
-    SparkEnv.get.blockManager.registerTask(taskAttemptId)
+    // Use the blockManager at start of the task through out the task - 
particularly in
+    // case of local mode, a SparkEnv can be initialized when spark context is 
restarted
+    // and we want to ensure the right env and block manager is used (given 
lazy initialization of
+    // block manager)
+    val blockManager = SparkEnv.get.blockManager
+    blockManager.registerTask(taskAttemptId)
     // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, 
instead of whether
     // the stage is barrier.
     val taskContext = new TaskContextImpl(
@@ -143,15 +148,15 @@ private[spark] abstract class Task[T](
       try {
         Utils.tryLogNonFatalError {
           // Release memory used by this thread for unrolling blocks
-          
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
-          SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
-            MemoryMode.OFF_HEAP)
+          
blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
+          
blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP)
           // Notify any tasks waiting for execution memory to be freed to wake 
up and try to
           // acquire memory again. This makes impossible the scenario where a 
task sleeps forever
           // because there are no other tasks left to notify it. Since this is 
safe to do but may
           // not be strictly necessary, we should revisit whether we can 
remove this in the
           // future.
-          val memoryManager = SparkEnv.get.memoryManager
+
+          val memoryManager = blockManager.memoryManager
           memoryManager.synchronized { memoryManager.notifyAll() }
         }
       } finally {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 42bbd025177b..228ec5752e1b 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -185,7 +185,7 @@ private[spark] class BlockManager(
     val master: BlockManagerMaster,
     val serializerManager: SerializerManager,
     val conf: SparkConf,
-    memoryManager: MemoryManager,
+    private val _memoryManager: MemoryManager,
     mapOutputTracker: MapOutputTracker,
     private val _shuffleManager: ShuffleManager,
     val blockTransferService: BlockTransferService,
@@ -198,6 +198,12 @@ private[spark] class BlockManager(
   // (except for tests) and we ask for the instance from the SparkEnv.
   private lazy val shuffleManager = 
Option(_shuffleManager).getOrElse(SparkEnv.get.shuffleManager)
 
+  // Similarly, we also initialize MemoryManager later after DriverPlugin is 
loaded, to
+  // allow the plugin to overwrite certain memory configurations. The 
`_memoryManager` will be
+  // null here and we ask for the instance from SparkEnv
+  private[spark] lazy val memoryManager =
+    Option(_memoryManager).getOrElse(SparkEnv.get.memoryManager)
+
   // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED)`
   private[spark] val externalShuffleServiceEnabled: Boolean = 
externalBlockStoreClient.isDefined
   private val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
@@ -224,17 +230,19 @@ private[spark] class BlockManager(
     ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128))
 
   // Actual storage of where blocks are kept
-  private[spark] val memoryStore =
-    new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, 
this)
+  private[spark] lazy val memoryStore = {
+    val store = new MemoryStore(conf, blockInfoManager, serializerManager, 
memoryManager, this)
+    memoryManager.setMemoryStore(store)
+    store
+  }
   private[spark] val diskStore = new DiskStore(conf, diskBlockManager, 
securityManager)
-  memoryManager.setMemoryStore(memoryStore)
 
   // Note: depending on the memory manager, `maxMemory` may actually vary over 
time.
   // However, since we use this only for reporting and logging, what we 
actually want here is
   // the absolute maximum value that `maxMemory` can ever possibly reach. We 
may need
   // to revisit whether reporting this value as the "max" is intuitive to the 
user.
-  private val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory
-  private val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory
+  private lazy val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory
+  private lazy val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory
 
   private[spark] val externalShuffleServicePort = 
StorageUtils.externalShuffleServicePort(conf)
 
diff --git 
a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala
 
b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala
index 197c2f13d807..cdbe5553bc95 100644
--- 
a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala
@@ -36,6 +36,7 @@ import org.apache.spark.TestUtils._
 import org.apache.spark.api.plugin._
 import org.apache.spark.internal.config._
 import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.memory.MemoryMode
 import org.apache.spark.resource.ResourceInformation
 import org.apache.spark.resource.ResourceUtils.GPU
 import org.apache.spark.resource.TestResourceIDs.{DRIVER_GPU_ID, 
EXECUTOR_GPU_ID, WORKER_GPU_ID}
@@ -228,6 +229,58 @@ class PluginContainerSuite extends SparkFunSuite with 
LocalSparkContext {
       assert(driverResources.get(GPU).name === GPU)
     }
   }
+
+  test("memory override in plugin") {
+    val conf = new SparkConf()
+      .setAppName(getClass().getName())
+      .set(SparkLauncher.SPARK_MASTER, "local-cluster[2,1,1024]")
+      .set(PLUGINS, Seq(classOf[MemoryOverridePlugin].getName()))
+
+    var sc: SparkContext = null
+    try {
+      sc = new SparkContext(conf)
+      val memoryManager = sc.env.memoryManager
+
+      assert(memoryManager.tungstenMemoryMode == MemoryMode.OFF_HEAP)
+      assert(memoryManager.maxOffHeapStorageMemory == 
MemoryOverridePlugin.offHeapMemory)
+
+      // Ensure all executors has started
+      TestUtils.waitUntilExecutorsUp(sc, 1, 60000)
+
+      // Check executor memory is also updated
+      val execInfo = sc.statusTracker.getExecutorInfos.head
+      assert(execInfo.totalOffHeapStorageMemory() == 
MemoryOverridePlugin.offHeapMemory)
+    } finally {
+      if (sc != null) {
+        sc.stop()
+      }
+    }
+  }
+}
+
+class MemoryOverridePlugin extends SparkPlugin {
+  override def driverPlugin(): DriverPlugin = {
+    new DriverPlugin {
+      override def init(sc: SparkContext, pluginContext: PluginContext): 
JMap[String, String] = {
+        // Take the original executor memory, and set 
`spark.memory.offHeap.size` to be the
+        // same value. Also set `spark.memory.offHeap.enabled` to true.
+        val originalExecutorMemBytes =
+          sc.conf.getSizeAsMb(EXECUTOR_MEMORY.key, 
EXECUTOR_MEMORY.defaultValueString)
+        sc.conf.set(MEMORY_OFFHEAP_ENABLED.key, "true")
+        sc.conf.set(MEMORY_OFFHEAP_SIZE.key, s"${originalExecutorMemBytes}M")
+        MemoryOverridePlugin.offHeapMemory = 
sc.conf.getSizeAsBytes(MEMORY_OFFHEAP_SIZE.key)
+        Map.empty[String, String].asJava
+      }
+    }
+  }
+
+  override def executorPlugin(): ExecutorPlugin = {
+    new ExecutorPlugin {}
+  }
+}
+
+object MemoryOverridePlugin {
+  var offHeapMemory: Long = _
 }
 
 class NonLocalModeSparkPlugin extends SparkPlugin {
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 9aba41cea215..d08e75733abf 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.scheduler
 
 import java.util.Properties
+import java.util.concurrent.Semaphore
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.mutable.ArrayBuffer
@@ -27,7 +28,9 @@ import org.mockito.Mockito._
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark._
+import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin}
 import org.apache.spark.executor.{Executor, TaskMetrics, TaskMetricsSuite}
+import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.METRICS_CONF
 import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.metrics.source.JvmSource
@@ -680,9 +683,49 @@ class TaskContextSuite extends SparkFunSuite with 
BeforeAndAfter with LocalSpark
     context.markTaskCompleted(None)
     assert(isFailed)
   }
+
+  test("SPARK-46947: ensure the correct block manager is used to unroll memory 
for task") {
+    import BlockManagerValidationPlugin._
+    BlockManagerValidationPlugin.resetState()
+
+    // run a task which ignores thread interruption when spark context is 
shutdown
+    sc = new SparkContext("local", "test")
+
+    val rdd = new RDD[String](sc, List()) {
+      override def getPartitions = Array[Partition](StubPartition(0))
+
+      override def compute(split: Partition, context: TaskContext): 
Iterator[String] = {
+        context.addTaskCompletionListener(new TaskCompletionListener {
+          override def onTaskCompletion(context: TaskContext): Unit = {
+            try {
+              releaseTaskSem.acquire()
+            } catch {
+              case _: InterruptedException =>
+                // ignore thread interruption
+            }
+          }
+        })
+        taskStartedSem.release()
+        Iterator.empty
+      }
+    }
+    // submit the job, but don't block this thread
+    rdd.collectAsync()
+    // wait for task to start
+    taskStartedSem.acquire()
+
+    sc.stop()
+    assert(sc.isStopped)
+
+    // create a new SparkContext which will be blocked for certain amount of 
time
+    // during initializing the driver plugin below
+    val conf = new SparkConf()
+    conf.set("spark.plugins", classOf[BlockManagerValidationPlugin].getName)
+    sc = new SparkContext("local", "test", conf)
+  }
 }
 
-private object TaskContextSuite {
+private object TaskContextSuite extends Logging {
   @volatile var completed = false
 
   @volatile var lastError: Throwable = _
@@ -690,4 +733,30 @@ private object TaskContextSuite {
   class FakeTaskFailureException extends Exception("Fake task failure")
 }
 
+class BlockManagerValidationPlugin extends SparkPlugin {
+  override def driverPlugin(): DriverPlugin = {
+    new DriverPlugin() {
+      // does nothing but block the current thread for certain time for the 
task thread
+      // to progress and reproduce the issue.
+      BlockManagerValidationPlugin.releaseTaskSem.release()
+      Thread.sleep(2500)
+    }
+  }
+  override def executorPlugin(): ExecutorPlugin = {
+    new ExecutorPlugin() {
+      // do nothing
+    }
+  }
+}
+
+object BlockManagerValidationPlugin {
+  val releaseTaskSem = new Semaphore(0)
+  val taskStartedSem = new Semaphore(0)
+
+  def resetState(): Unit = {
+    releaseTaskSem.drainPermits()
+    taskStartedSem.drainPermits()
+  }
+}
+
 private case class StubPartition(index: Int) extends Partition


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to