Repository: spark Updated Branches: refs/heads/branch-2.2 eed3a5aa1 -> 7123ec8e1
[SPARK-20702][CORE] TaskContextImpl.markTaskCompleted should not hide the original error ## What changes were proposed in this pull request? This PR adds an `error` parameter to `TaskContextImpl.markTaskCompleted` to propagate the original error. It also fixes an issue that `TaskCompletionListenerException.getMessage` doesn't include `previousError`. ## How was this patch tested? New unit tests. Author: Shixiong Zhu <shixi...@databricks.com> Closes #17942 from zsxwing/SPARK-20702. (cherry picked from commit 7d6ff39106938fa4bbb68b3d5114b93a4d332c5c) Signed-off-by: Shixiong Zhu <shixi...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7123ec8e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7123ec8e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7123ec8e Branch: refs/heads/branch-2.2 Commit: 7123ec8e144746d9333d4fd0f59fccda2d59c508 Parents: eed3a5a Author: Shixiong Zhu <shixi...@databricks.com> Authored: Fri May 12 10:46:44 2017 -0700 Committer: Shixiong Zhu <shixi...@databricks.com> Committed: Fri May 12 10:46:54 2017 -0700 ---------------------------------------------------------------------- .../org/apache/spark/TaskContextImpl.scala | 4 +- .../scala/org/apache/spark/scheduler/Task.scala | 39 ++++++++++++-------- .../org/apache/spark/util/taskListeners.scala | 14 ++++--- .../spark/scheduler/TaskContextSuite.scala | 36 ++++++++++++++++-- .../storage/PartiallySerializedBlockSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 2 +- 6 files changed, 68 insertions(+), 29 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7123ec8e/core/src/main/scala/org/apache/spark/TaskContextImpl.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 8cd1d1c..01d8973 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -110,10 +110,10 @@ private[spark] class TaskContextImpl( /** Marks the task as completed and triggers the completion listeners. */ @GuardedBy("this") - private[spark] def markTaskCompleted(): Unit = synchronized { + private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { if (completed) return completed = true - invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) { + invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) { _.onTaskCompletion(this) } } http://git-wip-us.apache.org/repos/asf/spark/blob/7123ec8e/core/src/main/scala/org/apache/spark/scheduler/Task.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5c337b9..7767ef1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -115,26 +115,33 @@ private[spark] abstract class Task[T]( case t: Throwable => e.addSuppressed(t) } + context.markTaskCompleted(Some(e)) throw e } finally { - // Call the task completion callbacks. - context.markTaskCompleted() try { - Utils.tryLogNonFatalError { - // Release memory used by this thread for unrolling blocks - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP) - // Notify any tasks waiting for execution memory to be freed to wake up and try to - // acquire memory again. This makes impossible the scenario where a task sleeps forever - // because there are no other tasks left to notify it. Since this is safe to do but may - // not be strictly necessary, we should revisit whether we can remove this in the future. - val memoryManager = SparkEnv.get.memoryManager - memoryManager.synchronized { memoryManager.notifyAll() } - } + // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second + // one is no-op. + context.markTaskCompleted(None) } finally { - // Though we unset the ThreadLocal here, the context member variable itself is still queried - // directly in the TaskRunner to check for FetchFailedExceptions. - TaskContext.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask( + MemoryMode.OFF_HEAP) + // Notify any tasks waiting for execution memory to be freed to wake up and try to + // acquire memory again. This makes impossible the scenario where a task sleeps forever + // because there are no other tasks left to notify it. Since this is safe to do but may + // not be strictly necessary, we should revisit whether we can remove this in the + // future. + val memoryManager = SparkEnv.get.memoryManager + memoryManager.synchronized { memoryManager.notifyAll() } + } + } finally { + // Though we unset the ThreadLocal here, the context member variable itself is still + // queried directly in the TaskRunner to check for FetchFailedExceptions. + TaskContext.unset() + } } } } http://git-wip-us.apache.org/repos/asf/spark/blob/7123ec8e/core/src/main/scala/org/apache/spark/util/taskListeners.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/taskListeners.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala index 1be31e8..51feccf 100644 --- a/core/src/main/scala/org/apache/spark/util/taskListeners.scala +++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala @@ -55,14 +55,16 @@ class TaskCompletionListenerException( extends RuntimeException { override def getMessage: String = { - if (errorMessages.size == 1) { - errorMessages.head - } else { - errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") - } + - previousError.map { e => + val listenerErrorMessage = + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + val previousErrorMessage = previousError.map { e => "\n\nPrevious exception in task: " + e.getMessage + "\n" + e.getStackTrace.mkString("\t", "\n\t", "") }.getOrElse("") + listenerErrorMessage + previousErrorMessage } } http://git-wip-us.apache.org/repos/asf/spark/blob/7123ec8e/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index b22da56..992d339 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -100,7 +100,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark context.addTaskCompletionListener(_ => throw new Exception("blah")) intercept[TaskCompletionListenerException] { - context.markTaskCompleted() + context.markTaskCompleted(None) } verify(listener, times(1)).onTaskCompletion(any()) @@ -231,10 +231,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark test("immediately call a completion listener if the context is completed") { var invocations = 0 val context = TaskContext.empty() - context.markTaskCompleted() + context.markTaskCompleted(None) context.addTaskCompletionListener(_ => invocations += 1) assert(invocations == 1) - context.markTaskCompleted() + context.markTaskCompleted(None) assert(invocations == 1) } @@ -254,6 +254,36 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(lastError == error) assert(invocations == 1) } + + test("TaskCompletionListenerException.getMessage should include previousError") { + val listenerErrorMessage = "exception in listener" + val taskErrorMessage = "exception in task" + val e = new TaskCompletionListenerException( + Seq(listenerErrorMessage), + Some(new RuntimeException(taskErrorMessage))) + assert(e.getMessage.contains(listenerErrorMessage) && e.getMessage.contains(taskErrorMessage)) + } + + test("all TaskCompletionListeners should be called even if some fail or a task") { + val context = TaskContext.empty() + val listener = mock(classOf[TaskCompletionListener]) + context.addTaskCompletionListener(_ => throw new Exception("exception in listener1")) + context.addTaskCompletionListener(listener) + context.addTaskCompletionListener(_ => throw new Exception("exception in listener3")) + + val e = intercept[TaskCompletionListenerException] { + context.markTaskCompleted(Some(new Exception("exception in task"))) + } + + // Make sure listener 2 was called. + verify(listener, times(1)).onTaskCompletion(any()) + + // also need to check failure in TaskCompletionListener does not mask earlier exception + assert(e.getMessage.contains("exception in listener1")) + assert(e.getMessage.contains("exception in listener3")) + assert(e.getMessage.contains("exception in task")) + } + } private object TaskContextSuite { http://git-wip-us.apache.org/repos/asf/spark/blob/7123ec8e/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala index 3050f9a..5351053 100644 --- a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala @@ -145,7 +145,7 @@ class PartiallySerializedBlockSuite try { TaskContext.setTaskContext(TaskContext.empty()) val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted() + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted(None) Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose() Mockito.verifyNoMoreInteractions(memoryStore) } finally { http://git-wip-us.apache.org/repos/asf/spark/blob/7123ec8e/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index e56e440..9900d1e 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -192,7 +192,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() - taskContext.markTaskCompleted() + taskContext.markTaskCompleted(None) verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release() // The 3rd block should not be retained because the iterator is already in zombie state --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org