Repository: spark
Updated Branches:
  refs/heads/master b526f70c1 -> 7d6ff3910


[SPARK-20702][CORE] TaskContextImpl.markTaskCompleted should not hide the 
original error

## What changes were proposed in this pull request?

This PR adds an `error` parameter to `TaskContextImpl.markTaskCompleted` to 
propagate the original error.

It also fixes an issue that `TaskCompletionListenerException.getMessage` 
doesn't include `previousError`.

## How was this patch tested?

New unit tests.

Author: Shixiong Zhu <shixi...@databricks.com>

Closes #17942 from zsxwing/SPARK-20702.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7d6ff391
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7d6ff391
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7d6ff391

Branch: refs/heads/master
Commit: 7d6ff39106938fa4bbb68b3d5114b93a4d332c5c
Parents: b526f70
Author: Shixiong Zhu <shixi...@databricks.com>
Authored: Fri May 12 10:46:44 2017 -0700
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Fri May 12 10:46:44 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/TaskContextImpl.scala      |  4 +-
 .../scala/org/apache/spark/scheduler/Task.scala | 39 ++++++++++++--------
 .../org/apache/spark/util/taskListeners.scala   | 14 ++++---
 .../spark/scheduler/TaskContextSuite.scala      | 36 ++++++++++++++++--
 .../storage/PartiallySerializedBlockSuite.scala |  2 +-
 .../ShuffleBlockFetcherIteratorSuite.scala      |  2 +-
 6 files changed, 68 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7d6ff391/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala 
b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 8cd1d1c..01d8973 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -110,10 +110,10 @@ private[spark] class TaskContextImpl(
 
   /** Marks the task as completed and triggers the completion listeners. */
   @GuardedBy("this")
-  private[spark] def markTaskCompleted(): Unit = synchronized {
+  private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = 
synchronized {
     if (completed) return
     completed = true
-    invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) {
+    invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) {
       _.onTaskCompletion(this)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/7d6ff391/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala 
b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 5c337b9..7767ef1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -115,26 +115,33 @@ private[spark] abstract class Task[T](
           case t: Throwable =>
             e.addSuppressed(t)
         }
+        context.markTaskCompleted(Some(e))
         throw e
     } finally {
-      // Call the task completion callbacks.
-      context.markTaskCompleted()
       try {
-        Utils.tryLogNonFatalError {
-          // Release memory used by this thread for unrolling blocks
-          
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
-          
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP)
-          // Notify any tasks waiting for execution memory to be freed to wake 
up and try to
-          // acquire memory again. This makes impossible the scenario where a 
task sleeps forever
-          // because there are no other tasks left to notify it. Since this is 
safe to do but may
-          // not be strictly necessary, we should revisit whether we can 
remove this in the future.
-          val memoryManager = SparkEnv.get.memoryManager
-          memoryManager.synchronized { memoryManager.notifyAll() }
-        }
+        // Call the task completion callbacks. If "markTaskCompleted" is 
called twice, the second
+        // one is no-op.
+        context.markTaskCompleted(None)
       } finally {
-        // Though we unset the ThreadLocal here, the context member variable 
itself is still queried
-        // directly in the TaskRunner to check for FetchFailedExceptions.
-        TaskContext.unset()
+        try {
+          Utils.tryLogNonFatalError {
+            // Release memory used by this thread for unrolling blocks
+            
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
+            
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
+              MemoryMode.OFF_HEAP)
+            // Notify any tasks waiting for execution memory to be freed to 
wake up and try to
+            // acquire memory again. This makes impossible the scenario where 
a task sleeps forever
+            // because there are no other tasks left to notify it. Since this 
is safe to do but may
+            // not be strictly necessary, we should revisit whether we can 
remove this in the
+            // future.
+            val memoryManager = SparkEnv.get.memoryManager
+            memoryManager.synchronized { memoryManager.notifyAll() }
+          }
+        } finally {
+          // Though we unset the ThreadLocal here, the context member variable 
itself is still
+          // queried directly in the TaskRunner to check for 
FetchFailedExceptions.
+          TaskContext.unset()
+        }
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/7d6ff391/core/src/main/scala/org/apache/spark/util/taskListeners.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/taskListeners.scala 
b/core/src/main/scala/org/apache/spark/util/taskListeners.scala
index 1be31e8..51feccf 100644
--- a/core/src/main/scala/org/apache/spark/util/taskListeners.scala
+++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala
@@ -55,14 +55,16 @@ class TaskCompletionListenerException(
   extends RuntimeException {
 
   override def getMessage: String = {
-    if (errorMessages.size == 1) {
-      errorMessages.head
-    } else {
-      errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" 
}.mkString("\n")
-    } +
-    previousError.map { e =>
+    val listenerErrorMessage =
+      if (errorMessages.size == 1) {
+        errorMessages.head
+      } else {
+        errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: 
$msg" }.mkString("\n")
+      }
+    val previousErrorMessage = previousError.map { e =>
       "\n\nPrevious exception in task: " + e.getMessage + "\n" +
         e.getStackTrace.mkString("\t", "\n\t", "")
     }.getOrElse("")
+    listenerErrorMessage + previousErrorMessage
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7d6ff391/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index b22da56..992d339 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -100,7 +100,7 @@ class TaskContextSuite extends SparkFunSuite with 
BeforeAndAfter with LocalSpark
     context.addTaskCompletionListener(_ => throw new Exception("blah"))
 
     intercept[TaskCompletionListenerException] {
-      context.markTaskCompleted()
+      context.markTaskCompleted(None)
     }
 
     verify(listener, times(1)).onTaskCompletion(any())
@@ -231,10 +231,10 @@ class TaskContextSuite extends SparkFunSuite with 
BeforeAndAfter with LocalSpark
   test("immediately call a completion listener if the context is completed") {
     var invocations = 0
     val context = TaskContext.empty()
-    context.markTaskCompleted()
+    context.markTaskCompleted(None)
     context.addTaskCompletionListener(_ => invocations += 1)
     assert(invocations == 1)
-    context.markTaskCompleted()
+    context.markTaskCompleted(None)
     assert(invocations == 1)
   }
 
@@ -254,6 +254,36 @@ class TaskContextSuite extends SparkFunSuite with 
BeforeAndAfter with LocalSpark
     assert(lastError == error)
     assert(invocations == 1)
   }
+
+  test("TaskCompletionListenerException.getMessage should include 
previousError") {
+    val listenerErrorMessage = "exception in listener"
+    val taskErrorMessage = "exception in task"
+    val e = new TaskCompletionListenerException(
+      Seq(listenerErrorMessage),
+      Some(new RuntimeException(taskErrorMessage)))
+    assert(e.getMessage.contains(listenerErrorMessage) && 
e.getMessage.contains(taskErrorMessage))
+  }
+
+  test("all TaskCompletionListeners should be called even if some fail or a 
task") {
+    val context = TaskContext.empty()
+    val listener = mock(classOf[TaskCompletionListener])
+    context.addTaskCompletionListener(_ => throw new Exception("exception in 
listener1"))
+    context.addTaskCompletionListener(listener)
+    context.addTaskCompletionListener(_ => throw new Exception("exception in 
listener3"))
+
+    val e = intercept[TaskCompletionListenerException] {
+      context.markTaskCompleted(Some(new Exception("exception in task")))
+    }
+
+    // Make sure listener 2 was called.
+    verify(listener, times(1)).onTaskCompletion(any())
+
+    // also need to check failure in TaskCompletionListener does not mask 
earlier exception
+    assert(e.getMessage.contains("exception in listener1"))
+    assert(e.getMessage.contains("exception in listener3"))
+    assert(e.getMessage.contains("exception in task"))
+  }
+
 }
 
 private object TaskContextSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/7d6ff391/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
 
b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
index 3050f9a..5351053 100644
--- 
a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
@@ -145,7 +145,7 @@ class PartiallySerializedBlockSuite
     try {
       TaskContext.setTaskContext(TaskContext.empty())
       val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
-      TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted()
+      TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted(None)
       
Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose()
       Mockito.verifyNoMoreInteractions(memoryStore)
     } finally {

http://git-wip-us.apache.org/repos/asf/spark/blob/7d6ff391/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
 
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index e56e440..9900d1e 100644
--- 
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -192,7 +192,7 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
 
     // Complete the task; then the 2nd block buffer should be exhausted
     verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release()
-    taskContext.markTaskCompleted()
+    taskContext.markTaskCompleted(None)
     verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release()
 
     // The 3rd block should not be retained because the iterator is already in 
zombie state


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to