This is an automated email from the ASF dual-hosted git repository. wuyi pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new bc80c84 [SPARK-36575][CORE] Should ignore task finished event if its task set is gone in TaskSchedulerImpl.handleSuccessfulTask bc80c84 is described below commit bc80c844fcb37d8d699d46bb34edadb98ed0d9f7 Author: hujiahua <hujia...@youzan.com> AuthorDate: Wed Nov 10 11:20:35 2021 +0800 [SPARK-36575][CORE] Should ignore task finished event if its task set is gone in TaskSchedulerImpl.handleSuccessfulTask ### What changes were proposed in this pull request? When a executor finished a task of some stage, the driver will receive a `StatusUpdate` event to handle it. At the same time the driver found the executor heartbeat timed out, so the dirver also need handle ExecutorLost event simultaneously. There was a race condition issues here, which will make `TaskSetManager.successful` and `TaskSetManager.tasksSuccessful` wrong result. The problem is that `TaskResultGetter.enqueueSuccessfulTask` use asynchronous thread to handle successful task, that mean the synchronized lock of `TaskSchedulerImpl` was released prematurely during midway https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala#L61. So `TaskSchedulerImpl` may handle executorLost first, then the asynchronous thread will go on to handle successful task. It cause `TaskSetManager.successful` and `T [...] ### Why are the changes needed? It will cause `TaskSetManager.successful` and `TaskSetManager.tasksSuccessful` wrong result. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Add a new test. Closes #33872 from sleep1661/SPARK-36575. Lead-authored-by: hujiahua <hujia...@youzan.com> Co-authored-by: MattHu <hujia...@youzan.com> Signed-off-by: yi.wu <yi...@databricks.com> --- .../apache/spark/scheduler/TaskSchedulerImpl.scala | 8 +- .../spark/scheduler/TaskSchedulerImplSuite.scala | 86 +++++++++++++++++++++- 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 55db73a..282f12b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -871,7 +871,13 @@ private[spark] class TaskSchedulerImpl( taskSetManager: TaskSetManager, tid: Long, taskResult: DirectTaskResult[_]): Unit = synchronized { - taskSetManager.handleSuccessfulTask(tid, taskResult) + if (taskIdToTaskSetManager.contains(tid)) { + taskSetManager.handleSuccessfulTask(tid, taskResult) + } else { + logInfo(s"Ignoring update with state finished for task (TID $tid) because its task set " + + "is gone (this is likely the result of receiving duplicate task finished status updates)" + + " or its executor has been marked as failed.") + } } def handleFailedTask( diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 53dc14c..551d55d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -18,9 +18,12 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.Properties +import java.util.concurrent.{CountDownLatch, ExecutorService, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.duration._ +import scala.language.reflectiveCalls import org.mockito.ArgumentMatchers.{any, anyInt, anyString, eq => meq} import org.mockito.Mockito.{atLeast, atMost, never, spy, times, verify, when} @@ -34,7 +37,7 @@ import org.apache.spark.internal.config import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, TaskResourceRequests} import org.apache.spark.resource.ResourceUtils._ import org.apache.spark.resource.TestResourceIDs._ -import org.apache.spark.util.{Clock, ManualClock} +import org.apache.spark.util.{Clock, ManualClock, ThreadUtils} class FakeSchedulerBackend extends SchedulerBackend { def start(): Unit = {} @@ -1995,6 +1998,87 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(!normalTSM.runningTasksSet.contains(taskId)) } + test("SPARK-36575: Should ignore task finished event if its task set is gone " + + "in TaskSchedulerImpl.handleSuccessfulTask") { + val taskScheduler = setupScheduler() + + val latch = new CountDownLatch(2) + val resultGetter = new TaskResultGetter(sc.env, taskScheduler) { + override protected val getTaskResultExecutor: ExecutorService = + new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue[Runnable], + ThreadUtils.namedThreadFactory("task-result-getter")) { + override def execute(command: Runnable): Unit = { + super.execute(new Runnable { + override def run(): Unit = { + command.run() + latch.countDown() + } + }) + } + } + def taskResultExecutor() : ExecutorService = getTaskResultExecutor + } + taskScheduler.taskResultGetter = resultGetter + + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1), + new WorkerOffer("executor1", "host1", 1)) + val task1 = new ShuffleMapTask(1, 0, null, new Partition { + override def index: Int = 0 + }, Seq(TaskLocation("host0", "executor0")), new Properties, null) + + val task2 = new ShuffleMapTask(1, 0, null, new Partition { + override def index: Int = 1 + }, Seq(TaskLocation("host1", "executor1")), new Properties, null) + + val taskSet = new TaskSet(Array(task1, task2), 0, 0, 0, null, 0) + + taskScheduler.submitTasks(taskSet) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(2 === taskDescriptions.length) + + val ser = sc.env.serializer.newInstance() + val directResult = new DirectTaskResult[Int](ser.serialize(1), Seq(), Array.empty) + val resultBytes = ser.serialize(directResult) + + val busyTask = new Runnable { + val lock : Object = new Object + override def run(): Unit = { + lock.synchronized { + lock.wait() + } + } + def markTaskDone: Unit = { + lock.synchronized { + lock.notify() + } + } + } + // make getTaskResultExecutor busy + resultGetter.taskResultExecutor().submit(busyTask) + + // task1 finished + val tid = taskDescriptions(0).taskId + taskScheduler.statusUpdate( + tid = tid, + state = TaskState.FINISHED, + serializedData = resultBytes + ) + + // mark executor heartbeat timed out + taskScheduler.executorLost(taskDescriptions(0).executorId, ExecutorProcessLost("Executor " + + "heartbeat timed out")) + + busyTask.markTaskDone + + // Wait until all events are processed + latch.await() + + val taskSetManager = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions(1).taskId) + assert(taskSetManager != null) + assert(0 == taskSetManager.tasksSuccessful) + assert(!taskSetManager.successful(taskDescriptions(0).index)) + } + /** * Used by tests to simulate a task failure. This calls the failure handler explicitly, to ensure * that all the state is updated when this method returns. Otherwise, there's no way to know when --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org