Github user kayousterhout commented on a diff in the pull request:

    https://github.com/apache/spark/pull/16639#discussion_r100413863
  
    --- Diff: core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala 
---
    @@ -133,6 +123,153 @@ class ExecutorSuite extends SparkFunSuite {
           }
         }
       }
    +
    +  test("SPARK-19276: Handle Fetch Failed for all intervening user code") {
    +    val conf = new SparkConf().setMaster("local").setAppName("executor 
suite test")
    +    sc = new SparkContext(conf)
    +
    +    val serializer = SparkEnv.get.closureSerializer.newInstance()
    +    val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
    +    val inputRDD = new FakeShuffleRDD(sc)
    +    val secondRDD = new FetchFailureHidingRDD(sc, inputRDD)
    +    val taskBinary = sc.broadcast(serializer.serialize((secondRDD, 
resultFunc)).array())
    +    val serializedTaskMetrics = 
serializer.serialize(TaskMetrics.registered).array()
    +    val task = new ResultTask(
    +      stageId = 1,
    +      stageAttemptId = 0,
    +      taskBinary = taskBinary,
    +      partition = secondRDD.partitions(0),
    +      locs = Seq(),
    +      outputId = 0,
    +      localProperties = new Properties(),
    +      serializedTaskMetrics = serializedTaskMetrics
    +    )
    +
    +    val serTask = serializer.serialize(task)
    +    val taskDescription = fakeTaskDescription(serTask)
    +
    +
    +    val failReason = runTaskAndGetFailReason(taskDescription)
    +    assert(failReason.isInstanceOf[FetchFailed])
    +  }
    +
    +  test("Gracefully handle error in task deserialization") {
    +    val conf = new SparkConf
    +    val serializer = new JavaSerializer(conf)
    +    val env = mockEnv(conf, serializer)
    +    val serializedTask = serializer.newInstance().serialize(new 
NonDeserializableTask)
    +    val taskDescription = fakeTaskDescription(serializedTask)
    +
    +    val failReason = runTaskAndGetFailReason(taskDescription)
    +    failReason match {
    +      case ef: ExceptionFailure =>
    +        assert(ef.exception.isDefined)
    +        assert(ef.exception.get.getMessage() === "failure in 
deserialization")
    +      case _ =>
    +        fail("unexpected failure type: $failReason")
    +    }
    +  }
    +
    +  private def mockEnv(conf: SparkConf, serializer: JavaSerializer): 
SparkEnv = {
    +    val mockEnv = mock[SparkEnv]
    +    val mockRpcEnv = mock[RpcEnv]
    +    val mockMetricsSystem = mock[MetricsSystem]
    +    val mockMemoryManager = mock[MemoryManager]
    +    when(mockEnv.conf).thenReturn(conf)
    +    when(mockEnv.serializer).thenReturn(serializer)
    +    when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
    +    when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
    +    when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
    +    when(mockEnv.closureSerializer).thenReturn(serializer)
    +    SparkEnv.set(mockEnv)
    +    mockEnv
    +  }
    +
    +  private def fakeTaskDescription(serializedTask: ByteBuffer): 
TaskDescription = {
    +    new TaskDescription(
    +      taskId = 0,
    +      attemptNumber = 0,
    +      executorId = "",
    +      name = "",
    +      index = 0,
    +      addedFiles = Map[String, Long](),
    +      addedJars = Map[String, Long](),
    +      properties = new Properties,
    +      serializedTask)
    +  }
    +
    +  private def runTaskAndGetFailReason(taskDescription: TaskDescription): 
TaskFailedReason = {
    +    val mockBackend = mock[ExecutorBackend]
    +    var executor: Executor = null
    +    try {
    +      executor = new Executor("id", "localhost", SparkEnv.get, 
userClassPath = Nil, isLocal = true)
    +      // the task will be launched in a dedicated worker thread
    +      executor.launchTask(mockBackend, taskDescription)
    +      val startTime = System.currentTimeMillis()
    +      val maxTime = startTime + 5000
    +      while (executor.numRunningTasks > 0 && System.currentTimeMillis() < 
maxTime) {
    +        Thread.sleep(10)
    +      }
    +      assert(executor.numRunningTasks === 0)
    +    } finally {
    +      if (executor != null) {
    +        executor.stop()
    +      }
    +    }
    +    val orderedMock = inOrder(mockBackend)
    +    val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
    +    orderedMock.verify(mockBackend)
    +      .statusUpdate(meq(0L), meq(TaskState.RUNNING), 
statusCaptor.capture())
    +    orderedMock.verify(mockBackend)
    +      .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
    +    // first statusUpdate for RUNNING has empty data
    +    assert(statusCaptor.getAllValues().get(0).remaining() === 0)
    +    // second update is more interesting
    +    val failureData = statusCaptor.getAllValues.get(1)
    +    
SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
    +  }
    +}
    +
    +class FakeShuffleRDD(sc: SparkContext) extends RDD[Int](sc, Nil) {
    --- End diff --
    
    about about FetchFailureThrowingShuffleRDD? (to make it obvious what the 
point of this is?)


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to