This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 1b221f3 [SPARK-31472][CORE][3.0] Make sure Barrier Task always return messages or exception with abortableRpcFuture check 1b221f3 is described below commit 1b221f35abd1657a3ecd49335118bfd5dcb811ee Author: yi.wu <yi...@databricks.com> AuthorDate: Thu Apr 23 14:43:27 2020 +0000 [SPARK-31472][CORE][3.0] Make sure Barrier Task always return messages or exception with abortableRpcFuture check ### What changes were proposed in this pull request? Rewrite the periodically check logic of `abortableRpcFuture` to make sure that barrier task would always return either desired messages or expected exception. This PR also simplify a bit around `AbortableRpcFuture`. ### Why are the changes needed? Currently, the periodically check logic of `abortableRpcFuture` is done by following: ```scala ... var messages: Array[String] = null while (!abortableRpcFuture.toFuture.isCompleted) { messages = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) ... } return messages ``` It's possible that `abortableRpcFuture` complete before next invocation on `messages = ...`. In this case, the task may return null messages or execute successfully while it should throw exception(e.g. `SparkException` from `BarrierCoordinator`). And here's a flaky test which caused by this bug: ``` [info] BarrierTaskContextSuite: [info] - share messages with allGather() call *** FAILED *** (18 seconds, 705 milliseconds) [info] org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(0, 2) finished unsuccessfully. [info] java.lang.NullPointerException [info] at scala.collection.mutable.ArrayOps$ofRef$.length$extension(ArrayOps.scala:204) [info] at scala.collection.mutable.ArrayOps$ofRef.length(ArrayOps.scala:204) [info] at scala.collection.IndexedSeqOptimized.toList(IndexedSeqOptimized.scala:285) [info] at scala.collection.IndexedSeqOptimized.toList$(IndexedSeqOptimized.scala:284) [info] at scala.collection.mutable.ArrayOps$ofRef.toList(ArrayOps.scala:198) [info] at org.apache.spark.scheduler.BarrierTaskContextSuite.$anonfun$new$4(BarrierTaskContextSuite.scala:68) ... ``` The test exception can be reproduced by changing the line `messages = ...` to the following: ```scala messages = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 10.micros) Thread.sleep(5000) ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Manually test and update some unit tests. Closes #28312 from Ngone51/cherry-pick-31472. Authored-by: yi.wu <yi...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../org/apache/spark/BarrierTaskContext.scala | 30 ++++++++++------------ .../org/apache/spark/rpc/RpcEndpointRef.scala | 10 +++----- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 12 +++++---- .../scala/org/apache/spark/util/ThreadUtils.scala | 5 ++-- .../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 12 ++++----- 5 files changed, 32 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 06f8024..4d76548 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -20,9 +20,9 @@ package org.apache.spark import java.util.{Properties, Timer, TimerTask} import scala.collection.JavaConverters._ -import scala.concurrent.TimeoutException import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.{Failure, Success => ScalaSuccess, Try} import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics @@ -85,28 +85,26 @@ class BarrierTaskContext private[spark] ( // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. timeout = new RpcTimeout(365.days, "barrierTimeout")) - // messages which consist of all barrier tasks' messages - var messages: Array[String] = null // Wait the RPC future to be completed, but every 1 second it will jump out waiting // and check whether current spark task is killed. If killed, then throw // a `TaskKilledException`, otherwise continue wait RPC until it completes. - try { - while (!abortableRpcFuture.toFuture.isCompleted) { + + while (!abortableRpcFuture.future.isCompleted) { + try { // wait RPC future for at most 1 second - try { - messages = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) - } catch { - case _: TimeoutException | _: InterruptedException => - // If `TimeoutException` thrown, waiting RPC future reach 1 second. - // If `InterruptedException` thrown, it is possible this task is killed. - // So in this two cases, we should check whether task is killed and then - // throw `TaskKilledException` - taskContext.killTaskIfInterrupted() + Thread.sleep(1000) + } catch { + case _: InterruptedException => // task is killed by driver + } finally { + Try(taskContext.killTaskIfInterrupted()) match { + case ScalaSuccess(_) => // task is still running healthily + case Failure(e) => abortableRpcFuture.abort(e) } } - } finally { - abortableRpcFuture.abort(taskContext.getKillReason().getOrElse("Unknown reason.")) } + // messages which consist of all barrier tasks' messages. The future will return the + // desired messages if it is completed successfully. Otherwise, exception could be thrown. + val messages = abortableRpcFuture.future.value.get.get barrierEpoch += 1 logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 56f3d37..a3d27b0 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -114,11 +114,7 @@ private[spark] class RpcAbortException(message: String) extends Exception(messag * A wrapper for [[Future]] but add abort method. * This is used in long run RPC and provide an approach to abort the RPC. */ -private[spark] class AbortableRpcFuture[T: ClassTag]( - future: Future[T], - onAbort: String => Unit) { - - def abort(reason: String): Unit = onAbort(reason) - - def toFuture: Future[T] = future +private[spark] +class AbortableRpcFuture[T: ClassTag](val future: Future[T], onAbort: Throwable => Unit) { + def abort(t: Throwable): Unit = onAbort(t) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 265e158..9259ec7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -208,6 +208,7 @@ private[netty] class NettyRpcEnv( message: RequestMessage, timeout: RpcTimeout): AbortableRpcFuture[T] = { val promise = Promise[Any]() val remoteAddr = message.receiver.address + var rpcMsg: Option[RpcOutboxMessage] = None def onFailure(e: Throwable): Unit = { if (!promise.tryFailure(e)) { @@ -226,8 +227,9 @@ private[netty] class NettyRpcEnv( } } - def onAbort(reason: String): Unit = { - onFailure(new RpcAbortException(reason)) + def onAbort(t: Throwable): Unit = { + onFailure(t) + rpcMsg.foreach(_.onAbort()) } try { @@ -242,10 +244,10 @@ private[netty] class NettyRpcEnv( val rpcMessage = RpcOutboxMessage(message.serialize(this), onFailure, (client, response) => onSuccess(deserialize[Any](client, response))) + rpcMsg = Option(rpcMessage) postToOutbox(message.receiver, rpcMessage) promise.future.failed.foreach { case _: TimeoutException => rpcMessage.onTimeout() - case _: RpcAbortException => rpcMessage.onAbort() case _ => }(ThreadUtils.sameThread) } @@ -270,7 +272,7 @@ private[netty] class NettyRpcEnv( } private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = { - askAbortable(message, timeout).toFuture + askAbortable(message, timeout).future } private[netty] def serialize(content: Any): ByteBuffer = { @@ -547,7 +549,7 @@ private[netty] class NettyRpcEndpointRef( } override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - askAbortable(message, timeout).toFuture + askAbortable(message, timeout).future } override def send(message: Any): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index e7872bb..78206c5 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -29,7 +29,6 @@ import scala.util.control.NonFatal import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.spark.SparkException -import org.apache.spark.rpc.RpcAbortException private[spark] object ThreadUtils { @@ -299,7 +298,7 @@ private[spark] object ThreadUtils { // TimeoutException and RpcAbortException is thrown in the current thread, so not need to warp // the exception. case NonFatal(t) - if !t.isInstanceOf[TimeoutException] && !t.isInstanceOf[RpcAbortException] => + if !t.isInstanceOf[TimeoutException] => throw new SparkException("Exception thrown in awaitResult: ", t) } } @@ -316,7 +315,7 @@ private[spark] object ThreadUtils { case e: SparkFatalException => throw e.throwable case NonFatal(t) - if !t.isInstanceOf[TimeoutException] && !t.isInstanceOf[RpcAbortException] => + if !t.isInstanceOf[TimeoutException] => throw new SparkException("Exception thrown in awaitResult: ", t) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index c10f2c2..01c67b3 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -209,7 +209,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-abort") try { - val e = intercept[RpcAbortException] { + val e = intercept[SparkException] { val timeout = new RpcTimeout(10.seconds, shortProp) val abortableRpcFuture = rpcEndpointRef.askAbortable[String]( "hello", timeout) @@ -217,15 +217,15 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { new Thread { override def run: Unit = { Thread.sleep(100) - abortableRpcFuture.abort("TestAbort") + abortableRpcFuture.abort(new RuntimeException("TestAbort")) } }.start() - timeout.awaitResult(abortableRpcFuture.toFuture) + timeout.awaitResult(abortableRpcFuture.future) } - // The SparkException cause should be a RpcAbortException with "TestAbort" message - assert(e.isInstanceOf[RpcAbortException]) - assert(e.getMessage.contains("TestAbort")) + // The SparkException cause should be a RuntimeException with "TestAbort" message + assert(e.getCause.isInstanceOf[RuntimeException]) + assert(e.getCause.getMessage.contains("TestAbort")) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org