Github user kayousterhout commented on a diff in the pull request: https://github.com/apache/spark/pull/16639#discussion_r100413863 --- Diff: core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala --- @@ -133,6 +123,153 @@ class ExecutorSuite extends SparkFunSuite { } } } + + test("SPARK-19276: Handle Fetch Failed for all intervening user code") { + val conf = new SparkConf().setMaster("local").setAppName("executor suite test") + sc = new SparkContext(conf) + + val serializer = SparkEnv.get.closureSerializer.newInstance() + val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size + val inputRDD = new FakeShuffleRDD(sc) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD) + val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) + val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() + val task = new ResultTask( + stageId = 1, + stageAttemptId = 0, + taskBinary = taskBinary, + partition = secondRDD.partitions(0), + locs = Seq(), + outputId = 0, + localProperties = new Properties(), + serializedTaskMetrics = serializedTaskMetrics + ) + + val serTask = serializer.serialize(task) + val taskDescription = fakeTaskDescription(serTask) + + + val failReason = runTaskAndGetFailReason(taskDescription) + assert(failReason.isInstanceOf[FetchFailed]) + } + + test("Gracefully handle error in task deserialization") { + val conf = new SparkConf + val serializer = new JavaSerializer(conf) + val env = mockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask) + val taskDescription = fakeTaskDescription(serializedTask) + + val failReason = runTaskAndGetFailReason(taskDescription) + failReason match { + case ef: ExceptionFailure => + assert(ef.exception.isDefined) + assert(ef.exception.get.getMessage() === "failure in deserialization") + case _ => + fail("unexpected failure type: $failReason") + } + } + + private def mockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { + val mockEnv = mock[SparkEnv] + val mockRpcEnv = mock[RpcEnv] + val mockMetricsSystem = mock[MetricsSystem] + val mockMemoryManager = mock[MemoryManager] + when(mockEnv.conf).thenReturn(conf) + when(mockEnv.serializer).thenReturn(serializer) + when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) + when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) + when(mockEnv.memoryManager).thenReturn(mockMemoryManager) + when(mockEnv.closureSerializer).thenReturn(serializer) + SparkEnv.set(mockEnv) + mockEnv + } + + private def fakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = { + new TaskDescription( + taskId = 0, + attemptNumber = 0, + executorId = "", + name = "", + index = 0, + addedFiles = Map[String, Long](), + addedJars = Map[String, Long](), + properties = new Properties, + serializedTask) + } + + private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { + val mockBackend = mock[ExecutorBackend] + var executor: Executor = null + try { + executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) + // the task will be launched in a dedicated worker thread + executor.launchTask(mockBackend, taskDescription) + val startTime = System.currentTimeMillis() + val maxTime = startTime + 5000 + while (executor.numRunningTasks > 0 && System.currentTimeMillis() < maxTime) { + Thread.sleep(10) + } + assert(executor.numRunningTasks === 0) + } finally { + if (executor != null) { + executor.stop() + } + } + val orderedMock = inOrder(mockBackend) + val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + // first statusUpdate for RUNNING has empty data + assert(statusCaptor.getAllValues().get(0).remaining() === 0) + // second update is more interesting + val failureData = statusCaptor.getAllValues.get(1) + SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData) + } +} + +class FakeShuffleRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { --- End diff -- about about FetchFailureThrowingShuffleRDD? (to make it obvious what the point of this is?)
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org