Repository: spark
Updated Branches:
  refs/heads/master be88383e1 -> a4851ed05


[SPARK-15963][CORE] Catch `TaskKilledException` correctly in Executor.TaskRunner

## The problem

Before this change, if either of the following cases happened to a task , the 
task would be marked as `FAILED` instead of `KILLED`:
- the task was killed before it was deserialized
- `executor.kill()` marked `taskRunner.killed`, but before calling 
`task.killed()` the worker thread threw the `TaskKilledException`

The reason is, in the `catch` block of the current 
[Executor.TaskRunner](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/executor/Executor.scala#L362)'s
 implementation, we are mistakenly catching:
```scala
case _: TaskKilledException | _: InterruptedException if task.killed => ...
```
the semantics of which is:
- **(**`TaskKilledException` **OR** `InterruptedException`**)** **AND** 
`task.killed`

Then when `TaskKilledException` is thrown but `task.killed` is not marked, we 
would mark the task as `FAILED` (which should really be `KILLED`).

## What changes were proposed in this pull request?

This patch alters the catch condition's semantics from:
- **(**`TaskKilledException` **OR** `InterruptedException`**)** **AND** 
`task.killed`

to

- `TaskKilledException` **OR** **(**`InterruptedException` **AND** 
`task.killed`**)**

so that we can catch `TaskKilledException` correctly and mark the task as 
`KILLED` correctly.

## How was this patch tested?

Added unit test which failed before the change, ran new test 1000 times manually

Author: Liwei Lin <lwl...@gmail.com>

Closes #13685 from lw-lin/fix-task-killed.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a4851ed0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a4851ed0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a4851ed0

Branch: refs/heads/master
Commit: a4851ed05053a9b7545a258c9159fd529225c455
Parents: be88383
Author: Liwei Lin <lwl...@gmail.com>
Authored: Fri Jun 24 10:09:04 2016 -0500
Committer: Imran Rashid <iras...@cloudera.com>
Committed: Fri Jun 24 10:09:04 2016 -0500

----------------------------------------------------------------------
 .../org/apache/spark/executor/Executor.scala    |   7 +-
 .../apache/spark/executor/ExecutorSuite.scala   | 139 +++++++++++++++++++
 2 files changed, 145 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a4851ed0/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 9a017f2..fbf2b86 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -359,11 +359,16 @@ private[spark] class Executor(
           setTaskFinishedAndClearInterruptStatus()
           execBackend.statusUpdate(taskId, TaskState.FAILED, 
ser.serialize(reason))
 
-        case _: TaskKilledException | _: InterruptedException if task.killed =>
+        case _: TaskKilledException =>
           logInfo(s"Executor killed $taskName (TID $taskId)")
           setTaskFinishedAndClearInterruptStatus()
           execBackend.statusUpdate(taskId, TaskState.KILLED, 
ser.serialize(TaskKilled))
 
+        case _: InterruptedException if task.killed =>
+          logInfo(s"Executor interrupted and killed $taskName (TID $taskId)")
+          setTaskFinishedAndClearInterruptStatus()
+          execBackend.statusUpdate(taskId, TaskState.KILLED, 
ser.serialize(TaskKilled))
+
         case CausedBy(cDE: CommitDeniedException) =>
           val reason = cDE.toTaskEndReason
           setTaskFinishedAndClearInterruptStatus()

http://git-wip-us.apache.org/repos/asf/spark/blob/a4851ed0/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala 
b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
new file mode 100644
index 0000000..3e69894
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import java.nio.ByteBuffer
+import java.util.concurrent.CountDownLatch
+
+import scala.collection.mutable.HashMap
+
+import org.mockito.Matchers._
+import org.mockito.Mockito.{mock, when}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+
+import org.apache.spark._
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.memory.MemoryManager
+import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.rpc.RpcEnv
+import org.apache.spark.scheduler.{FakeTask, Task}
+import org.apache.spark.serializer.JavaSerializer
+
+class ExecutorSuite extends SparkFunSuite {
+
+  test("SPARK-15963: Catch `TaskKilledException` correctly in 
Executor.TaskRunner") {
+    // mock some objects to make Executor.launchTask() happy
+    val conf = new SparkConf
+    val serializer = new JavaSerializer(conf)
+    val mockEnv = mock(classOf[SparkEnv])
+    val mockRpcEnv = mock(classOf[RpcEnv])
+    val mockMetricsSystem = mock(classOf[MetricsSystem])
+    val mockMemoryManager = mock(classOf[MemoryManager])
+    when(mockEnv.conf).thenReturn(conf)
+    when(mockEnv.serializer).thenReturn(serializer)
+    when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
+    when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
+    when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
+    when(mockEnv.closureSerializer).thenReturn(serializer)
+    val serializedTask =
+      Task.serializeWithDependencies(
+        new FakeTask(0),
+        HashMap[String, Long](),
+        HashMap[String, Long](),
+        serializer.newInstance())
+
+    // we use latches to force the program to run in this order:
+    // +-----------------------------+---------------------------------------+
+    // |      main test thread       |      worker thread                    |
+    // +-----------------------------+---------------------------------------+
+    // |    executor.launchTask()    |                                       |
+    // |                             | TaskRunner.run() begins               |
+    // |                             |          ...                          |
+    // |                             | execBackend.statusUpdate  // 1st time |
+    // | executor.killAllTasks(true) |                                       |
+    // |                             |          ...                          |
+    // |                             |  task = ser.deserialize               |
+    // |                             |          ...                          |
+    // |                             | execBackend.statusUpdate  // 2nd time |
+    // |                             |          ...                          |
+    // |                             |   TaskRunner.run() ends               |
+    // |       check results         |                                       |
+    // +-----------------------------+---------------------------------------+
+
+    val executorSuiteHelper = new ExecutorSuiteHelper
+
+    val mockExecutorBackend = mock(classOf[ExecutorBackend])
+    when(mockExecutorBackend.statusUpdate(any(), any(), any()))
+      .thenAnswer(new Answer[Unit] {
+        var firstTime = true
+        override def answer(invocationOnMock: InvocationOnMock): Unit = {
+          if (firstTime) {
+            executorSuiteHelper.latch1.countDown()
+            // here between latch1 and latch2, executor.killAllTasks() is 
called
+            executorSuiteHelper.latch2.await()
+            firstTime = false
+          }
+          else {
+            // save the returned `taskState` and `testFailedReason` into 
`executorSuiteHelper`
+            val taskState = 
invocationOnMock.getArguments()(1).asInstanceOf[TaskState]
+            executorSuiteHelper.taskState = taskState
+            val taskEndReason = 
invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer]
+            executorSuiteHelper.testFailedReason
+              = serializer.newInstance().deserialize(taskEndReason)
+            // let the main test thread check `taskState` and 
`testFailedReason`
+            executorSuiteHelper.latch3.countDown()
+          }
+        }
+      })
+
+    var executor: Executor = null
+    try {
+      executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, 
isLocal = true)
+      // the task will be launched in a dedicated worker thread
+      executor.launchTask(mockExecutorBackend, 0, 0, "", serializedTask)
+
+      executorSuiteHelper.latch1.await()
+      // we know the task will be started, but not yet deserialized, because 
of the latches we
+      // use in mockExecutorBackend.
+      executor.killAllTasks(true)
+      executorSuiteHelper.latch2.countDown()
+      executorSuiteHelper.latch3.await()
+
+      // `testFailedReason` should be `TaskKilled`; `taskState` should be 
`KILLED`
+      assert(executorSuiteHelper.testFailedReason === TaskKilled)
+      assert(executorSuiteHelper.taskState === TaskState.KILLED)
+    }
+    finally {
+      if (executor != null) {
+        executor.stop()
+      }
+    }
+  }
+}
+
+// Helps to test("SPARK-15963")
+private class ExecutorSuiteHelper {
+
+  val latch1 = new CountDownLatch(1)
+  val latch2 = new CountDownLatch(1)
+  val latch3 = new CountDownLatch(1)
+
+  @volatile var taskState: TaskState = _
+  @volatile var testFailedReason: TaskFailedReason = _
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to