This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.2 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push: new 93289a5dc92 Revert "[SPARK-38916][CORE] Tasks not killed caused by race conditions between killTask() and launchTask()" 93289a5dc92 is described below commit 93289a5dc929c97d10df03853161d5b931538ba5 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Mon Apr 25 11:57:51 2022 +0800 Revert "[SPARK-38916][CORE] Tasks not killed caused by race conditions between killTask() and launchTask()" This reverts commit 9dd64d40c91253c275fef2313c6a326ef72112cb. --- .../scala/org/apache/spark/executor/Executor.scala | 51 +----- .../CoarseGrainedExecutorBackendSuite.scala | 185 +-------------------- .../org/apache/spark/executor/ExecutorSuite.scala | 10 +- 3 files changed, 16 insertions(+), 230 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 4c84224dd05..3f1023e3491 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -83,7 +83,7 @@ private[spark] class Executor( private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) - private[executor] val conf = env.conf + private val conf = env.conf // No ip or host:port - just hostname Utils.checkHost(executorHostname) @@ -104,7 +104,7 @@ private[spark] class Executor( // Use UninterruptibleThread to run tasks so that we can allow running codes without being // interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622, // will hang forever if some methods are interrupted. - private[executor] val threadPool = { + private val threadPool = { val threadFactory = new ThreadFactoryBuilder() .setDaemon(true) .setNameFormat("Executor task launch worker-%d") @@ -174,33 +174,7 @@ private[spark] class Executor( private val maxResultSize = conf.get(MAX_RESULT_SIZE) // Maintains the list of running tasks. - private[executor] val runningTasks = new ConcurrentHashMap[Long, TaskRunner] - - // Kill mark TTL in milliseconds - 10 seconds. - private val KILL_MARK_TTL_MS = 10000L - - // Kill marks with interruptThread flag, kill reason and timestamp. - // This is to avoid dropping the kill event when killTask() is called before launchTask(). - private[executor] val killMarks = new ConcurrentHashMap[Long, (Boolean, String, Long)] - - private val killMarkCleanupTask = new Runnable { - override def run(): Unit = { - val oldest = System.currentTimeMillis() - KILL_MARK_TTL_MS - val iter = killMarks.entrySet().iterator() - while (iter.hasNext) { - if (iter.next().getValue._3 < oldest) { - iter.remove() - } - } - } - } - - // Kill mark cleanup thread executor. - private val killMarkCleanupService = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("executor-kill-mark-cleanup") - - killMarkCleanupService.scheduleAtFixedRate( - killMarkCleanupTask, KILL_MARK_TTL_MS, KILL_MARK_TTL_MS, TimeUnit.MILLISECONDS) + private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] /** * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES` @@ -290,18 +264,9 @@ private[spark] class Executor( decommissioned = true } - private[executor] def createTaskRunner(context: ExecutorBackend, - taskDescription: TaskDescription) = new TaskRunner(context, taskDescription, plugins) - def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { - val taskId = taskDescription.taskId - val tr = createTaskRunner(context, taskDescription) - runningTasks.put(taskId, tr) - val killMark = killMarks.get(taskId) - if (killMark != null) { - tr.kill(killMark._1, killMark._2) - killMarks.remove(taskId) - } + val tr = new TaskRunner(context, taskDescription, plugins) + runningTasks.put(taskDescription.taskId, tr) threadPool.execute(tr) if (decommissioned) { log.error(s"Launching a task while in decommissioned state.") @@ -309,7 +274,6 @@ private[spark] class Executor( } def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = { - killMarks.put(taskId, (interruptThread, reason, System.currentTimeMillis())) val taskRunner = runningTasks.get(taskId) if (taskRunner != null) { if (taskReaperEnabled) { @@ -332,8 +296,6 @@ private[spark] class Executor( } else { taskRunner.kill(interruptThread = interruptThread, reason = reason) } - // Safe to remove kill mark as we got a chance with the TaskRunner. - killMarks.remove(taskId) } } @@ -372,9 +334,6 @@ private[spark] class Executor( if (threadPool != null) { threadPool.shutdown() } - if (killMarkCleanupService != null) { - killMarkCleanupService.shutdown() - } if (replClassLoader != null && plugins != null) { // Notify plugins that executor is shutting down so they can terminate cleanly Utils.withContextClassLoader(replClassLoader) { diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala index 5210990f3b9..4909a586d31 100644 --- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala @@ -21,17 +21,14 @@ import java.io.File import java.net.URL import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.ConcurrentHashMap -import scala.collection.concurrent.TrieMap import scala.collection.mutable import scala.concurrent.duration._ import org.json4s.{DefaultFormats, Extraction} import org.json4s.JsonAST.{JArray, JObject} import org.json4s.JsonDSL._ -import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito._ +import org.mockito.Mockito.when import org.scalatest.concurrent.Eventually.{eventually, timeout} import org.scalatestplus.mockito.MockitoSugar @@ -42,9 +39,9 @@ import org.apache.spark.resource.ResourceUtils._ import org.apache.spark.resource.TestResourceIDs._ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.TaskDescription -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{KillTask, LaunchTask} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.LaunchTask import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.util.{SerializableBuffer, ThreadUtils, Utils} +import org.apache.spark.util.{SerializableBuffer, Utils} class CoarseGrainedExecutorBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { @@ -360,182 +357,6 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite assert(arg.bindAddress == "bindaddress1") } - /** - * This testcase is to verify that [[Executor.killTask()]] will always cancel a task that is - * being executed in [[Executor.TaskRunner]]. - */ - test(s"Tasks launched should always be cancelled.") { - val conf = new SparkConf - val securityMgr = new SecurityManager(conf) - val serializer = new JavaSerializer(conf) - val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor") - var backend: CoarseGrainedExecutorBackend = null - - try { - val rpcEnv = RpcEnv.create("1", "localhost", 0, conf, securityMgr) - val env = createMockEnv(conf, serializer, Some(rpcEnv)) - backend = new CoarseGrainedExecutorBackend(env.rpcEnv, rpcEnv.address.hostPort, "1", - "host1", "host1", 4, env, None, - resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf)) - - backend.rpcEnv.setupEndpoint("Executor 1", backend) - backend.executor = mock[Executor](CALLS_REAL_METHODS) - val executor = backend.executor - // Mock the executor. - when(executor.threadPool).thenReturn(threadPool) - val runningTasks = spy(new ConcurrentHashMap[Long, Executor#TaskRunner]) - when(executor.runningTasks).thenAnswer(_ => runningTasks) - when(executor.conf).thenReturn(conf) - - // We don't really verify the data, just pass it around. - val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) - - val numTasks = 1000 - val tasksKilled = new TrieMap[Long, Boolean]() - val tasksExecuted = new TrieMap[Long, Boolean]() - - // Fake tasks with different taskIds. - val taskDescriptions = (1 to numTasks).map { - taskId => new TaskDescription(taskId, 2, "1", "TASK ${taskId}", 19, - 1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new Properties, 1, - Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data) - } - assert(taskDescriptions.length == numTasks) - - def getFakeTaskRunner(taskDescription: TaskDescription): Executor#TaskRunner = { - new executor.TaskRunner(backend, taskDescription, None) { - override def run(): Unit = { - tasksExecuted.put(taskDescription.taskId, true) - logInfo(s"task ${taskDescription.taskId} runs.") - } - - override def kill(interruptThread: Boolean, reason: String): Unit = { - logInfo(s"task ${taskDescription.taskId} killed.") - tasksKilled.put(taskDescription.taskId, true) - } - } - } - - // Feed the fake task-runners to be executed by the executor. - val firstLaunchTask = getFakeTaskRunner(taskDescriptions(1)) - val otherTasks = taskDescriptions.slice(1, numTasks).map(getFakeTaskRunner(_)).toArray - assert (otherTasks.length == numTasks - 1) - // Workaround for compilation issue around Mockito.doReturn - doReturn(firstLaunchTask, otherTasks: _*).when(executor). - createTaskRunner(any(), any()) - - // Launch tasks and quickly kill them so that TaskRunner.killTask will be triggered. - taskDescriptions.foreach { taskDescription => - val buffer = new SerializableBuffer(TaskDescription.encode(taskDescription)) - backend.self.send(LaunchTask(buffer)) - Thread.sleep(1) - backend.self.send(KillTask(taskDescription.taskId, "exec1", false, "test")) - } - - eventually(timeout(10.seconds)) { - verify(runningTasks, times(numTasks)).put(any(), any()) - } - - assert(tasksExecuted.size == tasksKilled.size, - s"Tasks killed ${tasksKilled.size} != tasks executed ${tasksExecuted.size}") - assert(tasksExecuted.keySet == tasksKilled.keySet) - logInfo(s"Task executed ${tasksExecuted.size}, task killed ${tasksKilled.size}") - } finally { - if (backend != null) { - backend.rpcEnv.shutdown() - } - threadPool.shutdownNow() - } - } - - /** - * This testcase is to verify that [[Executor.killTask()]] will always cancel a task even if - * it has not been launched yet. - */ - test(s"Tasks not launched should always be cancelled.") { - val conf = new SparkConf - val securityMgr = new SecurityManager(conf) - val serializer = new JavaSerializer(conf) - val threadPool = ThreadUtils.newDaemonFixedThreadPool(32, "test-executor") - var backend: CoarseGrainedExecutorBackend = null - - try { - val rpcEnv = RpcEnv.create("1", "localhost", 0, conf, securityMgr) - val env = createMockEnv(conf, serializer, Some(rpcEnv)) - backend = new CoarseGrainedExecutorBackend(env.rpcEnv, rpcEnv.address.hostPort, "1", - "host1", "host1", 4, env, None, - resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf)) - - backend.rpcEnv.setupEndpoint("Executor 1", backend) - backend.executor = mock[Executor](CALLS_REAL_METHODS) - val executor = backend.executor - // Mock the executor. - when(executor.threadPool).thenReturn(threadPool) - val runningTasks = spy(new ConcurrentHashMap[Long, Executor#TaskRunner]) - when(executor.runningTasks).thenAnswer(_ => runningTasks) - when(executor.conf).thenReturn(conf) - - // We don't really verify the data, just pass it around. - val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) - - val numTasks = 1000 - val tasksKilled = new TrieMap[Long, Boolean]() - val tasksExecuted = new TrieMap[Long, Boolean]() - - // Fake tasks with different taskIds. - val taskDescriptions = (1 to numTasks).map { - taskId => new TaskDescription(taskId, 2, "1", "TASK ${taskId}", 19, - 1, mutable.Map.empty, mutable.Map.empty, mutable.Map.empty, new Properties, 1, - Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data) - } - assert(taskDescriptions.length == numTasks) - - def getFakeTaskRunner(taskDescription: TaskDescription): Executor#TaskRunner = { - new executor.TaskRunner(backend, taskDescription, None) { - override def run(): Unit = { - tasksExecuted.put(taskDescription.taskId, true) - logInfo(s"task ${taskDescription.taskId} runs.") - } - - override def kill(interruptThread: Boolean, reason: String): Unit = { - logInfo(s"task ${taskDescription.taskId} killed.") - tasksKilled.put(taskDescription.taskId, true) - } - } - } - - // Feed the fake task-runners to be executed by the executor. - val firstLaunchTask = getFakeTaskRunner(taskDescriptions(1)) - val otherTasks = taskDescriptions.slice(1, numTasks).map(getFakeTaskRunner(_)).toArray - assert (otherTasks.length == numTasks - 1) - // Workaround for compilation issue around Mockito.doReturn - doReturn(firstLaunchTask, otherTasks: _*).when(executor). - createTaskRunner(any(), any()) - - // The reverse order of events can happen when the scheduler tries to cancel a task right - // after launching it. - taskDescriptions.foreach { taskDescription => - val buffer = new SerializableBuffer(TaskDescription.encode(taskDescription)) - backend.self.send(KillTask(taskDescription.taskId, "exec1", false, "test")) - backend.self.send(LaunchTask(buffer)) - } - - eventually(timeout(10.seconds)) { - verify(runningTasks, times(numTasks)).put(any(), any()) - } - - assert(tasksExecuted.size == tasksKilled.size, - s"Tasks killed ${tasksKilled.size} != tasks executed ${tasksExecuted.size}") - assert(tasksExecuted.keySet == tasksKilled.keySet) - logInfo(s"Task executed ${tasksExecuted.size}, task killed ${tasksKilled.size}") - } finally { - if (backend != null) { - backend.rpcEnv.shutdown() - } - threadPool.shutdownNow() - } - } - private def createMockEnv(conf: SparkConf, serializer: JavaSerializer, rpcEnv: Option[RpcEnv] = None): SparkEnv = { val mockEnv = mock[SparkEnv] diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 7f7b10c8c33..a237447b0fa 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -22,7 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler import java.net.URL import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean import scala.collection.immutable @@ -321,7 +321,13 @@ class ExecutorSuite extends SparkFunSuite nonZeroAccumulator.add(1) metrics.registerAccumulator(nonZeroAccumulator) - val tasksMap = executor.runningTasks + val executorClass = classOf[Executor] + val tasksMap = { + val field = + executorClass.getDeclaredField("org$apache$spark$executor$Executor$$runningTasks") + field.setAccessible(true) + field.get(executor).asInstanceOf[ConcurrentHashMap[Long, executor.TaskRunner]] + } val mockTaskRunner = mock[executor.TaskRunner] val mockTask = mock[Task[Any]] when(mockTask.metrics).thenReturn(metrics) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org