This is an automated email from the ASF dual-hosted git repository.

wuyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bc80c84  [SPARK-36575][CORE] Should ignore task finished event if its 
task set is gone in TaskSchedulerImpl.handleSuccessfulTask
bc80c84 is described below

commit bc80c844fcb37d8d699d46bb34edadb98ed0d9f7
Author: hujiahua <hujia...@youzan.com>
AuthorDate: Wed Nov 10 11:20:35 2021 +0800

    [SPARK-36575][CORE] Should ignore task finished event if its task set is 
gone in TaskSchedulerImpl.handleSuccessfulTask
    
    ### What changes were proposed in this pull request?
    
    When a executor finished a task of some stage, the driver will receive a 
`StatusUpdate` event to handle it. At the same time the driver found the 
executor heartbeat timed out, so the dirver also need handle ExecutorLost event 
simultaneously. There was a race condition issues here, which will make 
`TaskSetManager.successful` and `TaskSetManager.tasksSuccessful` wrong result.
    
    The problem is that `TaskResultGetter.enqueueSuccessfulTask` use 
asynchronous thread to handle successful task, that mean the synchronized lock 
of `TaskSchedulerImpl` was released prematurely during midway 
https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala#L61.
 So `TaskSchedulerImpl` may handle executorLost first, then the asynchronous 
thread will go on to handle successful task. It cause 
`TaskSetManager.successful` and `T [...]
    
    ### Why are the changes needed?
    
     It will cause `TaskSetManager.successful` and 
`TaskSetManager.tasksSuccessful` wrong result.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Add a new test.
    
    Closes #33872 from sleep1661/SPARK-36575.
    
    Lead-authored-by: hujiahua <hujia...@youzan.com>
    Co-authored-by: MattHu <hujia...@youzan.com>
    Signed-off-by: yi.wu <yi...@databricks.com>
---
 .../apache/spark/scheduler/TaskSchedulerImpl.scala |  8 +-
 .../spark/scheduler/TaskSchedulerImplSuite.scala   | 86 +++++++++++++++++++++-
 2 files changed, 92 insertions(+), 2 deletions(-)

diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 55db73a..282f12b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -871,7 +871,13 @@ private[spark] class TaskSchedulerImpl(
       taskSetManager: TaskSetManager,
       tid: Long,
       taskResult: DirectTaskResult[_]): Unit = synchronized {
-    taskSetManager.handleSuccessfulTask(tid, taskResult)
+    if (taskIdToTaskSetManager.contains(tid)) {
+      taskSetManager.handleSuccessfulTask(tid, taskResult)
+    } else {
+      logInfo(s"Ignoring update with state finished for task (TID $tid) 
because its task set " +
+        "is gone (this is likely the result of receiving duplicate task 
finished status updates)" +
+        " or its executor has been marked as failed.")
+    }
   }
 
   def handleFailedTask(
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 53dc14c..551d55d 100644
--- 
a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -18,9 +18,12 @@
 package org.apache.spark.scheduler
 
 import java.nio.ByteBuffer
+import java.util.Properties
+import java.util.concurrent.{CountDownLatch, ExecutorService, 
LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
 
 import scala.collection.mutable.{ArrayBuffer, HashMap}
 import scala.concurrent.duration._
+import scala.language.reflectiveCalls
 
 import org.mockito.ArgumentMatchers.{any, anyInt, anyString, eq => meq}
 import org.mockito.Mockito.{atLeast, atMost, never, spy, times, verify, when}
@@ -34,7 +37,7 @@ import org.apache.spark.internal.config
 import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, 
TaskResourceRequests}
 import org.apache.spark.resource.ResourceUtils._
 import org.apache.spark.resource.TestResourceIDs._
-import org.apache.spark.util.{Clock, ManualClock}
+import org.apache.spark.util.{Clock, ManualClock, ThreadUtils}
 
 class FakeSchedulerBackend extends SchedulerBackend {
   def start(): Unit = {}
@@ -1995,6 +1998,87 @@ class TaskSchedulerImplSuite extends SparkFunSuite with 
LocalSparkContext with B
     assert(!normalTSM.runningTasksSet.contains(taskId))
   }
 
+  test("SPARK-36575: Should ignore task finished event if its task set is gone 
" +
+    "in TaskSchedulerImpl.handleSuccessfulTask") {
+    val taskScheduler = setupScheduler()
+
+    val latch = new CountDownLatch(2)
+    val resultGetter = new TaskResultGetter(sc.env, taskScheduler) {
+      override protected val getTaskResultExecutor: ExecutorService =
+        new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, new 
LinkedBlockingQueue[Runnable],
+          ThreadUtils.namedThreadFactory("task-result-getter")) {
+          override def execute(command: Runnable): Unit = {
+            super.execute(new Runnable {
+              override def run(): Unit = {
+                command.run()
+                latch.countDown()
+              }
+            })
+          }
+        }
+      def taskResultExecutor() : ExecutorService = getTaskResultExecutor
+    }
+    taskScheduler.taskResultGetter = resultGetter
+
+    val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1),
+      new WorkerOffer("executor1", "host1", 1))
+    val task1 = new ShuffleMapTask(1, 0, null, new Partition {
+      override def index: Int = 0
+    }, Seq(TaskLocation("host0", "executor0")), new Properties, null)
+
+    val task2 = new ShuffleMapTask(1, 0, null, new Partition {
+      override def index: Int = 1
+    }, Seq(TaskLocation("host1", "executor1")), new Properties, null)
+
+    val taskSet = new TaskSet(Array(task1, task2), 0, 0, 0, null, 0)
+
+    taskScheduler.submitTasks(taskSet)
+    val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+    assert(2 === taskDescriptions.length)
+
+    val ser = sc.env.serializer.newInstance()
+    val directResult = new DirectTaskResult[Int](ser.serialize(1), Seq(), 
Array.empty)
+    val resultBytes = ser.serialize(directResult)
+
+    val busyTask = new Runnable {
+      val lock : Object = new Object
+      override def run(): Unit = {
+        lock.synchronized {
+          lock.wait()
+        }
+      }
+      def markTaskDone: Unit = {
+        lock.synchronized {
+          lock.notify()
+        }
+      }
+    }
+    // make getTaskResultExecutor busy
+    resultGetter.taskResultExecutor().submit(busyTask)
+
+    // task1 finished
+    val tid = taskDescriptions(0).taskId
+    taskScheduler.statusUpdate(
+      tid = tid,
+      state = TaskState.FINISHED,
+      serializedData = resultBytes
+    )
+
+    // mark executor heartbeat timed out
+    taskScheduler.executorLost(taskDescriptions(0).executorId, 
ExecutorProcessLost("Executor " +
+      "heartbeat timed out"))
+
+    busyTask.markTaskDone
+
+    // Wait until all events are processed
+    latch.await()
+
+    val taskSetManager = 
taskScheduler.taskIdToTaskSetManager.get(taskDescriptions(1).taskId)
+    assert(taskSetManager != null)
+    assert(0 == taskSetManager.tasksSuccessful)
+    assert(!taskSetManager.successful(taskDescriptions(0).index))
+  }
+
   /**
    * Used by tests to simulate a task failure. This calls the failure handler 
explicitly, to ensure
    * that all the state is updated when this method returns. Otherwise, 
there's no way to know when

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to