Repository: spark
Updated Branches:
  refs/heads/master 952e4d1c8 -> 82fb5bfa7


[SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason

## What changes were proposed in this pull request?
The ultimate goal is for listeners to onTaskEnd to receive metrics when a task 
is killed intentionally, since the data is currently just thrown away. This is 
already done for ExceptionFailure, so this just copies the same approach.

## How was this patch tested?
Updated existing tests.

This is a rework of https://github.com/apache/spark/pull/17422, all credits 
should go to noodle-fb

Author: Xianjin YE <advance...@gmail.com>
Author: Charles Lewis <noo...@fb.com>

Closes #21165 from advancedxy/SPARK-20087.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/82fb5bfa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/82fb5bfa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/82fb5bfa

Branch: refs/heads/master
Commit: 82fb5bfa770b0325d4f377dd38d89869007c6111
Parents: 952e4d1
Author: Xianjin YE <advance...@gmail.com>
Authored: Tue May 22 21:02:17 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue May 22 21:02:17 2018 +0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/TaskEndReason.scala  |  8 ++-
 .../org/apache/spark/executor/Executor.scala    | 55 +++++++++++++-------
 .../apache/spark/scheduler/DAGScheduler.scala   |  6 +--
 .../apache/spark/scheduler/TaskSetManager.scala |  8 ++-
 .../org/apache/spark/util/JsonProtocol.scala    |  9 +++-
 .../spark/scheduler/DAGSchedulerSuite.scala     | 18 +++++--
 project/MimaExcludes.scala                      |  5 ++
 7 files changed, 78 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/82fb5bfa/core/src/main/scala/org/apache/spark/TaskEndReason.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala 
b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index a76283e..33901bc 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -212,9 +212,15 @@ case object TaskResultLost extends TaskFailedReason {
  * Task was killed intentionally and needs to be rescheduled.
  */
 @DeveloperApi
-case class TaskKilled(reason: String) extends TaskFailedReason {
+case class TaskKilled(
+    reason: String,
+    accumUpdates: Seq[AccumulableInfo] = Seq.empty,
+    private[spark] val accums: Seq[AccumulatorV2[_, _]] = Nil)
+  extends TaskFailedReason {
+
   override def toErrorString: String = s"TaskKilled ($reason)"
   override def countTowardsTaskFailures: Boolean = false
+
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/82fb5bfa/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index c325222..b1856ff 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -287,6 +287,28 @@ private[spark] class Executor(
       notifyAll()
     }
 
+    /**
+     *  Utility function to:
+     *    1. Report executor runtime and JVM gc time if possible
+     *    2. Collect accumulator updates
+     *    3. Set the finished flag to true and clear current thread's 
interrupt status
+     */
+    private def collectAccumulatorsAndResetStatusOnFailure(taskStartTime: 
Long) = {
+      // Report executor runtime and JVM gc time
+      Option(task).foreach(t => {
+        t.metrics.setExecutorRunTime(System.currentTimeMillis() - 
taskStartTime)
+        t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
+      })
+
+      // Collect latest accumulator values to report back to the driver
+      val accums: Seq[AccumulatorV2[_, _]] =
+        Option(task).map(_.collectAccumulatorUpdates(taskFailed = 
true)).getOrElse(Seq.empty)
+      val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))
+
+      setTaskFinishedAndClearInterruptStatus()
+      (accums, accUpdates)
+    }
+
     override def run(): Unit = {
       threadId = Thread.currentThread.getId
       Thread.currentThread.setName(threadName)
@@ -300,7 +322,7 @@ private[spark] class Executor(
       val ser = env.closureSerializer.newInstance()
       logInfo(s"Running $taskName (TID $taskId)")
       execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
-      var taskStart: Long = 0
+      var taskStartTime: Long = 0
       var taskStartCpu: Long = 0
       startGCTime = computeTotalGcTime()
 
@@ -336,7 +358,7 @@ private[spark] class Executor(
         }
 
         // Run the actual task and measure its runtime.
-        taskStart = System.currentTimeMillis()
+        taskStartTime = System.currentTimeMillis()
         taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
           threadMXBean.getCurrentThreadCpuTime
         } else 0L
@@ -396,11 +418,11 @@ private[spark] class Executor(
         // Deserialization happens in two parts: first, we deserialize a Task 
object, which
         // includes the Partition. Second, Task.run() deserializes the RDD and 
function to be run.
         task.metrics.setExecutorDeserializeTime(
-          (taskStart - deserializeStartTime) + task.executorDeserializeTime)
+          (taskStartTime - deserializeStartTime) + 
task.executorDeserializeTime)
         task.metrics.setExecutorDeserializeCpuTime(
           (taskStartCpu - deserializeStartCpuTime) + 
task.executorDeserializeCpuTime)
         // We need to subtract Task.run()'s deserialization time to avoid 
double-counting
-        task.metrics.setExecutorRunTime((taskFinish - taskStart) - 
task.executorDeserializeTime)
+        task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - 
task.executorDeserializeTime)
         task.metrics.setExecutorCpuTime(
           (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
         task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
@@ -482,16 +504,19 @@ private[spark] class Executor(
       } catch {
         case t: TaskKilledException =>
           logInfo(s"Executor killed $taskName (TID $taskId), reason: 
${t.reason}")
-          setTaskFinishedAndClearInterruptStatus()
-          execBackend.statusUpdate(taskId, TaskState.KILLED, 
ser.serialize(TaskKilled(t.reason)))
+
+          val (accums, accUpdates) = 
collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
+          val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, 
accums))
+          execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
 
         case _: InterruptedException | NonFatal(_) if
             task != null && task.reasonIfKilled.isDefined =>
           val killReason = task.reasonIfKilled.getOrElse("unknown reason")
           logInfo(s"Executor interrupted and killed $taskName (TID $taskId), 
reason: $killReason")
-          setTaskFinishedAndClearInterruptStatus()
-          execBackend.statusUpdate(
-            taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
+
+          val (accums, accUpdates) = 
collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
+          val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, 
accums))
+          execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
 
         case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
           val reason = task.context.fetchFailed.get.toTaskFailedReason
@@ -524,17 +549,7 @@ private[spark] class Executor(
           // the task failure would not be ignored if the shutdown happened 
because of premption,
           // instead of an app issue).
           if (!ShutdownHookManager.inShutdown()) {
-            // Collect latest accumulator values to report back to the driver
-            val accums: Seq[AccumulatorV2[_, _]] =
-              if (task != null) {
-                task.metrics.setExecutorRunTime(System.currentTimeMillis() - 
taskStart)
-                task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
-                task.collectAccumulatorUpdates(taskFailed = true)
-              } else {
-                Seq.empty
-              }
-
-            val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), 
None))
+            val (accums, accUpdates) = 
collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
 
             val serializedTaskEndReason = {
               try {

http://git-wip-us.apache.org/repos/asf/spark/blob/82fb5bfa/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5f2d16d..ea7bfd7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1210,7 +1210,7 @@ class DAGScheduler(
           case _ =>
             updateAccumulators(event)
         }
-      case _: ExceptionFailure => updateAccumulators(event)
+      case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event)
       case _ =>
     }
     postTaskEnd(event)
@@ -1414,13 +1414,13 @@ class DAGScheduler(
       case commitDenied: TaskCommitDenied =>
         // Do nothing here, left up to the TaskScheduler to decide how to 
handle denied commits
 
-      case exceptionFailure: ExceptionFailure =>
+      case _: ExceptionFailure | _: TaskKilled =>
         // Nothing left to do, already handled above for accumulator updates.
 
       case TaskResultLost =>
         // Do nothing here; the TaskScheduler handles these failures and 
resubmits the task.
 
-      case _: ExecutorLostFailure | _: TaskKilled | UnknownReason =>
+      case _: ExecutorLostFailure | UnknownReason =>
         // Unrecognized failure - also do nothing. If the task fails 
repeatedly, the TaskScheduler
         // will abort the job.
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/82fb5bfa/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 195fc80..a18c665 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -851,13 +851,19 @@ private[spark] class TaskSetManager(
         }
         ef.exception
 
+      case tk: TaskKilled =>
+        // TaskKilled might have accumulator updates
+        accumUpdates = tk.accums
+        logWarning(failureReason)
+        None
+
       case e: ExecutorLostFailure if !e.exitCausedByApp =>
         logInfo(s"Task $tid failed because while it was being computed, its 
executor " +
           "exited for a reason unrelated to the task. Not counting this 
failure towards the " +
           "maximum number of failures for the task.")
         None
 
-      case e: TaskFailedReason =>  // TaskResultLost, TaskKilled, and others
+      case e: TaskFailedReason =>  // TaskResultLost and others
         logWarning(failureReason)
         None
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/82fb5bfa/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala 
b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 40383fe..50c6461 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -407,7 +407,9 @@ private[spark] object JsonProtocol {
         ("Exit Caused By App" -> exitCausedByApp) ~
         ("Loss Reason" -> reason.map(_.toString))
       case taskKilled: TaskKilled =>
-        ("Kill Reason" -> taskKilled.reason)
+        val accumUpdates = 
JArray(taskKilled.accumUpdates.map(accumulableInfoToJson).toList)
+        ("Kill Reason" -> taskKilled.reason) ~
+        ("Accumulator Updates" -> accumUpdates)
       case _ => emptyJson
     }
     ("Reason" -> reason) ~ json
@@ -917,7 +919,10 @@ private[spark] object JsonProtocol {
       case `taskKilled` =>
         val killReason = jsonOption(json \ "Kill Reason")
           .map(_.extract[String]).getOrElse("unknown reason")
-        TaskKilled(killReason)
+        val accumUpdates = jsonOption(json \ "Accumulator Updates")
+          .map(_.extract[List[JValue]].map(accumulableInfoFromJson))
+          .getOrElse(Seq[AccumulableInfo]())
+        TaskKilled(killReason, accumUpdates)
       case `taskCommitDenied` =>
         // Unfortunately, the `TaskCommitDenied` message was introduced in 
1.3.0 but the JSON
         // de/serialization logic was not added until 1.5.1. To provide 
backward compatibility

http://git-wip-us.apache.org/repos/asf/spark/blob/82fb5bfa/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 8b6ec37..2987170 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -1852,7 +1852,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
LocalSparkContext with TimeLi
     assertDataStructuresEmpty()
   }
 
-  test("accumulators are updated on exception failures") {
+  test("accumulators are updated on exception failures and task killed") {
     val acc1 = AccumulatorSuite.createLongAccum("ingenieur")
     val acc2 = AccumulatorSuite.createLongAccum("boulanger")
     val acc3 = AccumulatorSuite.createLongAccum("agriculteur")
@@ -1868,15 +1868,24 @@ class DAGSchedulerSuite extends SparkFunSuite with 
LocalSparkContext with TimeLi
     val accUpdate3 = new LongAccumulator
     accUpdate3.metadata = acc3.metadata
     accUpdate3.setValue(18)
-    val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3)
-    val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo)
+
+    val accumUpdates1 = Seq(accUpdate1, accUpdate2)
+    val accumInfo1 = accumUpdates1.map(AccumulatorSuite.makeInfo)
     val exceptionFailure = new ExceptionFailure(
       new SparkException("fondue?"),
-      accumInfo).copy(accums = accumUpdates)
+      accumInfo1).copy(accums = accumUpdates1)
     submit(new MyRDD(sc, 1, Nil), Array(0))
     runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, 
"result"))
+
     assert(AccumulatorContext.get(acc1.id).get.value === 15L)
     assert(AccumulatorContext.get(acc2.id).get.value === 13L)
+
+    val accumUpdates2 = Seq(accUpdate3)
+    val accumInfo2 = accumUpdates2.map(AccumulatorSuite.makeInfo)
+
+    val taskKilled = new TaskKilled( "test", accumInfo2, accums = 
accumUpdates2)
+    runEvent(makeCompletionEvent(taskSets.head.tasks.head, taskKilled, 
"result"))
+
     assert(AccumulatorContext.get(acc3.id).get.value === 18L)
   }
 
@@ -2497,6 +2506,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
LocalSparkContext with TimeLi
     val accumUpdates = reason match {
       case Success => task.metrics.accumulators()
       case ef: ExceptionFailure => ef.accums
+      case tk: TaskKilled => tk.accums
       case _ => Seq.empty
     }
     CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, 
taskInfo)

http://git-wip-us.apache.org/repos/asf/spark/blob/82fb5bfa/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 6bae4d1..4f6d5ff 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,11 @@ object MimaExcludes {
 
   // Exclude rules for 2.4.x
   lazy val v24excludes = v23excludes ++ Seq(
+    // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end 
reason
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"),
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"),
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.this"),
+
     // [SPARK-22941][core] Do not exit JVM when submit fails with in-process 
launcher.
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"),
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to