Github user squito commented on a diff in the pull request: https://github.com/apache/spark/pull/20244#discussion_r165763274 --- Diff: core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala --- @@ -2399,6 +2424,115 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } + /** + * In this test, we simply simulate the scene in concurrent jobs using the same + * rdd which is marked to do checkpoint: + * Job one has already finished the spark job, and start the process of doCheckpoint; + * Job two is submitted, and submitMissingTasks is called. + * In submitMissingTasks, if taskSerialization is called before doCheckpoint is done, + * while part calculates from stage.rdd.partitions is called after doCheckpoint is done, + * we may get a ClassCastException when execute the task because of some rdd will do + * Partition cast. + * + * With this test case, just want to indicate that we should do taskSerialization and + * part calculate in submitMissingTasks with the same rdd checkpoint status. + */ + test("SPARK-23053: avoid ClassCastException in concurrent execution with checkpoint") { + // set checkpointDir. + val tempDir = Utils.createTempDir() + val checkpointDir = File.createTempFile("temp", "", tempDir) + checkpointDir.delete() + sc.setCheckpointDir(checkpointDir.toString) + + // Semaphores to control the process sequence for the two threads below. + val semaphore1 = new Semaphore(0) + val semaphore2 = new Semaphore(0) + + val rdd = new WrappedRDD(sc.makeRDD(1 to 100, 4)) + rdd.checkpoint() + + val checkpointRunnable = new Runnable { + override def run() = { + // Simply simulate what RDD.doCheckpoint() do here. + rdd.doCheckpointCalled = true + val checkpointData = rdd.checkpointData.get + RDDCheckpointData.synchronized { + if (checkpointData.cpState == CheckpointState.Initialized) { + checkpointData.cpState = CheckpointState.CheckpointingInProgress + } + } + + val newRDD = checkpointData.doCheckpoint() + + // Release semaphore1 after job triggered in checkpoint finished, so that taskBinary + // serialization can start. + semaphore1.release() + // Wait until taskBinary serialization finished in submitMissingTasksThread. + semaphore2.acquire() + + // Update our state and truncate the RDD lineage. + RDDCheckpointData.synchronized { + checkpointData.cpRDD = Some(newRDD) + checkpointData.cpState = CheckpointState.Checkpointed + rdd.markCheckpointed() + } + semaphore1.release() --- End diff -- and then this would be another semaphore `checkpointStateUpdated`
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org