This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 1b221f3  [SPARK-31472][CORE][3.0] Make sure Barrier Task always return 
messages or exception with abortableRpcFuture check
1b221f3 is described below

commit 1b221f35abd1657a3ecd49335118bfd5dcb811ee
Author: yi.wu <yi...@databricks.com>
AuthorDate: Thu Apr 23 14:43:27 2020 +0000

    [SPARK-31472][CORE][3.0] Make sure Barrier Task always return messages or 
exception with abortableRpcFuture check
    
    ### What changes were proposed in this pull request?
    
    Rewrite the periodically check logic of  `abortableRpcFuture` to make sure 
that barrier task would always return either desired messages or expected 
exception.
    
    This PR also simplify a bit around `AbortableRpcFuture`.
    
    ### Why are the changes needed?
    
    Currently, the periodically check logic of  `abortableRpcFuture` is done by 
following:
    
    ```scala
    ...
    var messages: Array[String] = null
    
    while (!abortableRpcFuture.toFuture.isCompleted) {
       messages = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
       ...
    }
    return messages
    ```
    It's possible that `abortableRpcFuture` complete before next invocation on 
`messages = ...`. In this case, the task may return null messages or execute 
successfully while it should throw exception(e.g. `SparkException` from 
`BarrierCoordinator`).
    
    And here's a flaky test which caused by this bug:
    
    ```
    [info] BarrierTaskContextSuite:
    [info] - share messages with allGather() call *** FAILED *** (18 seconds, 
705 milliseconds)
    [info]   org.apache.spark.SparkException: Job aborted due to stage failure: 
Could not recover from a failed barrier ResultStage. Most recent failure 
reason: Stage failed because barrier task ResultTask(0, 2) finished 
unsuccessfully.
    [info] java.lang.NullPointerException
    [info]      at 
scala.collection.mutable.ArrayOps$ofRef$.length$extension(ArrayOps.scala:204)
    [info]      at 
scala.collection.mutable.ArrayOps$ofRef.length(ArrayOps.scala:204)
    [info]      at 
scala.collection.IndexedSeqOptimized.toList(IndexedSeqOptimized.scala:285)
    [info]      at 
scala.collection.IndexedSeqOptimized.toList$(IndexedSeqOptimized.scala:284)
    [info]      at 
scala.collection.mutable.ArrayOps$ofRef.toList(ArrayOps.scala:198)
    [info]      at 
org.apache.spark.scheduler.BarrierTaskContextSuite.$anonfun$new$4(BarrierTaskContextSuite.scala:68)
    ...
    ```
    
    The test exception can be reproduced by changing the line `messages = ...` 
to the following:
    
    ```scala
    messages = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 10.micros)
    Thread.sleep(5000)
    ```
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Manually test and update some unit tests.
    
    Closes #28312 from Ngone51/cherry-pick-31472.
    
    Authored-by: yi.wu <yi...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/BarrierTaskContext.scala      | 30 ++++++++++------------
 .../org/apache/spark/rpc/RpcEndpointRef.scala      | 10 +++-----
 .../org/apache/spark/rpc/netty/NettyRpcEnv.scala   | 12 +++++----
 .../scala/org/apache/spark/util/ThreadUtils.scala  |  5 ++--
 .../scala/org/apache/spark/rpc/RpcEnvSuite.scala   | 12 ++++-----
 5 files changed, 32 insertions(+), 37 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala 
b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
index 06f8024..4d76548 100644
--- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
@@ -20,9 +20,9 @@ package org.apache.spark
 import java.util.{Properties, Timer, TimerTask}
 
 import scala.collection.JavaConverters._
-import scala.concurrent.TimeoutException
 import scala.concurrent.duration._
 import scala.language.postfixOps
+import scala.util.{Failure, Success => ScalaSuccess, Try}
 
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.executor.TaskMetrics
@@ -85,28 +85,26 @@ class BarrierTaskContext private[spark] (
         // BarrierCoordinator on timeout, instead of RPCTimeoutException from 
the RPC framework.
         timeout = new RpcTimeout(365.days, "barrierTimeout"))
 
-      // messages which consist of all barrier tasks' messages
-      var messages: Array[String] = null
       // Wait the RPC future to be completed, but every 1 second it will jump 
out waiting
       // and check whether current spark task is killed. If killed, then throw
       // a `TaskKilledException`, otherwise continue wait RPC until it 
completes.
-      try {
-        while (!abortableRpcFuture.toFuture.isCompleted) {
+
+      while (!abortableRpcFuture.future.isCompleted) {
+        try {
           // wait RPC future for at most 1 second
-          try {
-            messages = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 
1.second)
-          } catch {
-            case _: TimeoutException | _: InterruptedException =>
-              // If `TimeoutException` thrown, waiting RPC future reach 1 
second.
-              // If `InterruptedException` thrown, it is possible this task is 
killed.
-              // So in this two cases, we should check whether task is killed 
and then
-              // throw `TaskKilledException`
-              taskContext.killTaskIfInterrupted()
+          Thread.sleep(1000)
+        } catch {
+          case _: InterruptedException => // task is killed by driver
+        } finally {
+          Try(taskContext.killTaskIfInterrupted()) match {
+            case ScalaSuccess(_) => // task is still running healthily
+            case Failure(e) => abortableRpcFuture.abort(e)
           }
         }
-      } finally {
-        
abortableRpcFuture.abort(taskContext.getKillReason().getOrElse("Unknown 
reason."))
       }
+      // messages which consist of all barrier tasks' messages. The future 
will return the
+      // desired messages if it is completed successfully. Otherwise, 
exception could be thrown.
+      val messages = abortableRpcFuture.future.value.get.get
 
       barrierEpoch += 1
       logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt 
$stageAttemptNumber) finished " +
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala 
b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
index 56f3d37..a3d27b0 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
@@ -114,11 +114,7 @@ private[spark] class RpcAbortException(message: String) 
extends Exception(messag
  * A wrapper for [[Future]] but add abort method.
  * This is used in long run RPC and provide an approach to abort the RPC.
  */
-private[spark] class AbortableRpcFuture[T: ClassTag](
-    future: Future[T],
-    onAbort: String => Unit) {
-
-  def abort(reason: String): Unit = onAbort(reason)
-
-  def toFuture: Future[T] = future
+private[spark]
+class AbortableRpcFuture[T: ClassTag](val future: Future[T], onAbort: 
Throwable => Unit) {
+  def abort(t: Throwable): Unit = onAbort(t)
 }
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala 
b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 265e158..9259ec7 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -208,6 +208,7 @@ private[netty] class NettyRpcEnv(
       message: RequestMessage, timeout: RpcTimeout): AbortableRpcFuture[T] = {
     val promise = Promise[Any]()
     val remoteAddr = message.receiver.address
+    var rpcMsg: Option[RpcOutboxMessage] = None
 
     def onFailure(e: Throwable): Unit = {
       if (!promise.tryFailure(e)) {
@@ -226,8 +227,9 @@ private[netty] class NettyRpcEnv(
         }
     }
 
-    def onAbort(reason: String): Unit = {
-      onFailure(new RpcAbortException(reason))
+    def onAbort(t: Throwable): Unit = {
+      onFailure(t)
+      rpcMsg.foreach(_.onAbort())
     }
 
     try {
@@ -242,10 +244,10 @@ private[netty] class NettyRpcEnv(
         val rpcMessage = RpcOutboxMessage(message.serialize(this),
           onFailure,
           (client, response) => onSuccess(deserialize[Any](client, response)))
+        rpcMsg = Option(rpcMessage)
         postToOutbox(message.receiver, rpcMessage)
         promise.future.failed.foreach {
           case _: TimeoutException => rpcMessage.onTimeout()
-          case _: RpcAbortException => rpcMessage.onAbort()
           case _ =>
         }(ThreadUtils.sameThread)
       }
@@ -270,7 +272,7 @@ private[netty] class NettyRpcEnv(
   }
 
   private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: 
RpcTimeout): Future[T] = {
-    askAbortable(message, timeout).toFuture
+    askAbortable(message, timeout).future
   }
 
   private[netty] def serialize(content: Any): ByteBuffer = {
@@ -547,7 +549,7 @@ private[netty] class NettyRpcEndpointRef(
   }
 
   override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] 
= {
-    askAbortable(message, timeout).toFuture
+    askAbortable(message, timeout).future
   }
 
   override def send(message: Any): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala 
b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index e7872bb..78206c5 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -29,7 +29,6 @@ import scala.util.control.NonFatal
 import com.google.common.util.concurrent.ThreadFactoryBuilder
 
 import org.apache.spark.SparkException
-import org.apache.spark.rpc.RpcAbortException
 
 private[spark] object ThreadUtils {
 
@@ -299,7 +298,7 @@ private[spark] object ThreadUtils {
       // TimeoutException and RpcAbortException is thrown in the current 
thread, so not need to warp
       // the exception.
       case NonFatal(t)
-          if !t.isInstanceOf[TimeoutException] && 
!t.isInstanceOf[RpcAbortException] =>
+          if !t.isInstanceOf[TimeoutException] =>
         throw new SparkException("Exception thrown in awaitResult: ", t)
     }
   }
@@ -316,7 +315,7 @@ private[spark] object ThreadUtils {
       case e: SparkFatalException =>
         throw e.throwable
       case NonFatal(t)
-        if !t.isInstanceOf[TimeoutException] && 
!t.isInstanceOf[RpcAbortException] =>
+        if !t.isInstanceOf[TimeoutException] =>
         throw new SparkException("Exception thrown in awaitResult: ", t)
     }
   }
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala 
b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index c10f2c2..01c67b3 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -209,7 +209,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with 
BeforeAndAfterAll {
     // Use anotherEnv to find out the RpcEndpointRef
     val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-abort")
     try {
-      val e = intercept[RpcAbortException] {
+      val e = intercept[SparkException] {
         val timeout = new RpcTimeout(10.seconds, shortProp)
         val abortableRpcFuture = rpcEndpointRef.askAbortable[String](
           "hello", timeout)
@@ -217,15 +217,15 @@ abstract class RpcEnvSuite extends SparkFunSuite with 
BeforeAndAfterAll {
         new Thread {
           override def run: Unit = {
             Thread.sleep(100)
-            abortableRpcFuture.abort("TestAbort")
+            abortableRpcFuture.abort(new RuntimeException("TestAbort"))
           }
         }.start()
 
-        timeout.awaitResult(abortableRpcFuture.toFuture)
+        timeout.awaitResult(abortableRpcFuture.future)
       }
-      // The SparkException cause should be a RpcAbortException with 
"TestAbort" message
-      assert(e.isInstanceOf[RpcAbortException])
-      assert(e.getMessage.contains("TestAbort"))
+      // The SparkException cause should be a RuntimeException with 
"TestAbort" message
+      assert(e.getCause.isInstanceOf[RuntimeException])
+      assert(e.getCause.getMessage.contains("TestAbort"))
     } finally {
       anotherEnv.shutdown()
       anotherEnv.awaitTermination()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to