Repository: spark Updated Branches: refs/heads/branch-2.2 55834a898 -> f971ce5dd
[SPARK-5484][GRAPHX] Periodically do checkpoint in Pregel ## What changes were proposed in this pull request? Pregel-based iterative algorithms with more than ~50 iterations begin to slow down and eventually fail with a StackOverflowError due to Spark's lack of support for long lineage chains. This PR causes Pregel to checkpoint the graph periodically if the checkpoint directory is set. This PR moves PeriodicGraphCheckpointer.scala from mllib to graphx, moves PeriodicRDDCheckpointer.scala, PeriodicCheckpointer.scala from mllib to core ## How was this patch tested? unit tests, manual tests (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: ding <ding@localhost.localdomain> Author: dding3 <ding.d...@intel.com> Author: Michael Allman <mich...@videoamp.com> Closes #15125 from dding3/cp2_pregel. (cherry picked from commit 0a7f5f2798b6e8b2ba15e8b3aa07d5953ad1c695) Signed-off-by: Felix Cheung <felixche...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f971ce5d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f971ce5d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f971ce5d Branch: refs/heads/branch-2.2 Commit: f971ce5dd0788fe7f5d2ca820b9ea3db72033ddc Parents: 55834a8 Author: ding <ding@localhost.localdomain> Authored: Tue Apr 25 11:20:32 2017 -0700 Committer: Felix Cheung <felixche...@apache.org> Committed: Tue Apr 25 11:20:52 2017 -0700 ---------------------------------------------------------------------- .../main/scala/org/apache/spark/rdd/RDD.scala | 4 +- .../rdd/util/PeriodicRDDCheckpointer.scala | 98 ++++++++++ .../spark/util/PeriodicCheckpointer.scala | 193 ++++++++++++++++++ .../org/apache/spark/rdd/SortingSuite.scala | 2 +- .../util/PeriodicRDDCheckpointerSuite.scala | 175 +++++++++++++++++ docs/configuration.md | 14 ++ docs/graphx-programming-guide.md | 9 +- .../scala/org/apache/spark/graphx/Pregel.scala | 25 ++- .../graphx/util/PeriodicGraphCheckpointer.scala | 105 ++++++++++ .../util/PeriodicGraphCheckpointerSuite.scala | 194 +++++++++++++++++++ .../org/apache/spark/ml/clustering/LDA.scala | 3 +- .../ml/tree/impl/GradientBoostedTrees.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 2 +- .../spark/mllib/impl/PeriodicCheckpointer.scala | 183 ----------------- .../mllib/impl/PeriodicGraphCheckpointer.scala | 102 ---------- .../mllib/impl/PeriodicRDDCheckpointer.scala | 97 ---------- .../impl/PeriodicGraphCheckpointerSuite.scala | 189 ------------------ .../impl/PeriodicRDDCheckpointerSuite.scala | 175 ----------------- 18 files changed, 812 insertions(+), 760 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/core/src/main/scala/org/apache/spark/rdd/RDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e524675..63a87e7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} -import org.apache.spark.util.collection.OpenHashMap +import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils} import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1420,7 +1420,7 @@ abstract class RDD[T: ClassTag]( val mapRDDs = mapPartitions { items => // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[T](num)(ord.reverse) - queue ++= util.collection.Utils.takeOrdered(items, num)(ord) + queue ++= collectionUtils.takeOrdered(items, num)(ord) Iterator.single(queue) } if (mapRDDs.partitions.length == 0) { http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala new file mode 100644 index 0000000..ab72add --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd.util + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer + + +/** + * This class helps with persisting and checkpointing RDDs. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call update() when a new RDD has been created, + * before the RDD has been materialized. After updating [[PeriodicRDDCheckpointer]], users are + * responsible for materializing the RDD to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs. + * - Unpersist RDDs from queue until there are at most 3 persisted RDDs. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which RDDs should be + * checkpointed). + * - This class removes checkpoint files once later RDDs have been checkpointed. + * However, references to the older RDDs will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (rdd1, rdd2, rdd3, ...) = ... + * val cp = new PeriodicRDDCheckpointer(2, sc) + * rdd1.count(); + * // persisted: rdd1 + * cp.update(rdd2) + * rdd2.count(); + * // persisted: rdd1, rdd2 + * // checkpointed: rdd2 + * cp.update(rdd3) + * rdd3.count(); + * // persisted: rdd1, rdd2, rdd3 + * // checkpointed: rdd2 + * cp.update(rdd4) + * rdd4.count(); + * // persisted: rdd2, rdd3, rdd4 + * // checkpointed: rdd4 + * cp.update(rdd5) + * rdd5.count(); + * // persisted: rdd3, rdd4, rdd5 + * // checkpointed: rdd4 + * }}} + * + * @param checkpointInterval RDDs will be checkpointed at this interval + * @tparam T RDD element type + * + * TODO: Move this out of MLlib? + */ +private[spark] class PeriodicRDDCheckpointer[T]( + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { + + override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint() + + override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed + + override protected def persist(data: RDD[T]): Unit = { + if (data.getStorageLevel == StorageLevel.NONE) { + data.persist() + } + } + + override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) + + override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = { + data.getCheckpointFile.map(x => x) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala new file mode 100644 index 0000000..ce06e18 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.mutable + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.storage.StorageLevel + + +/** + * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs + * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to + * the distributed data type (RDD, Graph, etc.). + * + * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing, + * as well as unpersisting and removing checkpoint files. + * + * Users should call update() when a new Dataset has been created, + * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are + * responsible for materializing the Dataset to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets. + * - Unpersist Datasets from queue until there are at most 3 persisted Datasets. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Datasets should be + * checkpointed). + * - This class removes checkpoint files once later Datasets have been checkpointed. + * However, references to the older Datasets will still return isCheckpointed = true. + * + * @param checkpointInterval Datasets will be checkpointed at this interval. + * If this interval was set as -1, then checkpointing will be disabled. + * @param sc SparkContext for the Datasets given to this checkpointer + * @tparam T Dataset type, such as RDD[Double] + */ +private[spark] abstract class PeriodicCheckpointer[T]( + val checkpointInterval: Int, + val sc: SparkContext) extends Logging { + + /** FIFO queue of past checkpointed Datasets */ + private val checkpointQueue = mutable.Queue[T]() + + /** FIFO queue of past persisted Datasets */ + private val persistedQueue = mutable.Queue[T]() + + /** Number of times [[update()]] has been called */ + private var updateCount = 0 + + /** + * Update with a new Dataset. Handle persistence and checkpointing as needed. + * Since this handles persistence and checkpointing, this should be called before the Dataset + * has been materialized. + * + * @param newData New Dataset created from previous Datasets in the lineage. + */ + def update(newData: T): Unit = { + persist(newData) + persistedQueue.enqueue(newData) + // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class: + // Users should call [[update()]] when a new Dataset has been created, + // before the Dataset has been materialized. + while (persistedQueue.size > 3) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + updateCount += 1 + + // Handle checkpointing (after persisting) + if (checkpointInterval != -1 && (updateCount % checkpointInterval) == 0 + && sc.getCheckpointDir.nonEmpty) { + // Add new checkpoint before removing old checkpoints. + checkpoint(newData) + checkpointQueue.enqueue(newData) + // Remove checkpoints before the latest one. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // Delete the oldest checkpoint only if the next checkpoint exists. + if (isCheckpointed(checkpointQueue.head)) { + removeCheckpointFile() + } else { + canDelete = false + } + } + } + } + + /** Checkpoint the Dataset */ + protected def checkpoint(data: T): Unit + + /** Return true iff the Dataset is checkpointed */ + protected def isCheckpointed(data: T): Boolean + + /** + * Persist the Dataset. + * Note: This should handle checking the current [[StorageLevel]] of the Dataset. + */ + protected def persist(data: T): Unit + + /** Unpersist the Dataset */ + protected def unpersist(data: T): Unit + + /** Get list of checkpoint files for this given Dataset */ + protected def getCheckpointFiles(data: T): Iterable[String] + + /** + * Call this to unpersist the Dataset. + */ + def unpersistDataSet(): Unit = { + while (persistedQueue.nonEmpty) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + } + + /** + * Call this at the end to delete any remaining checkpoint files. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.nonEmpty) { + removeCheckpointFile() + } + } + + /** + * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint. + * Note that there may not be any checkpoints at all. + */ + def deleteAllCheckpointsButLast(): Unit = { + while (checkpointQueue.size > 1) { + removeCheckpointFile() + } + } + + /** + * Get all current checkpoint files. + * This is useful in combination with [[deleteAllCheckpointsButLast()]]. + */ + def getAllCheckpointFiles: Array[String] = { + checkpointQueue.flatMap(getCheckpointFiles).toArray + } + + /** + * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. + * This prints a warning but does not fail if the files cannot be removed. + */ + private def removeCheckpointFile(): Unit = { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we manually delete it. + getCheckpointFiles(old).foreach( + PeriodicCheckpointer.removeCheckpointFile(_, sc.hadoopConfiguration)) + } +} + +private[spark] object PeriodicCheckpointer extends Logging { + + /** Delete a checkpoint file, and log a warning if deletion fails. */ + def removeCheckpointFile(checkpointFile: String, conf: Configuration): Unit = { + try { + val path = new Path(checkpointFile) + val fs = path.getFileSystem(conf) + fs.delete(path, true) + } catch { + case e: Exception => + logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + + checkpointFile) + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index f9a7f15..7f20206 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w } test("get a range of elements in an array not partitioned by a range partitioner") { - val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) + val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) val pairs = sc.parallelize(pairArr, 10) val range = pairs.filterByRange(200, 800).collect() assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala new file mode 100644 index 0000000..f9e1b79 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.utils + +import org.apache.hadoop.fs.Path + +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + + +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext { + + import PeriodicRDDCheckpointerSuite._ + + test("Persisting") { + var rddsToCheck = Seq.empty[RDDToCheck] + + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext) + checkpointer.update(rdd1) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkPersistence(rddsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkPersistence(rddsToCheck, iteration) + iteration += 1 + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var rddsToCheck = Seq.empty[RDDToCheck] + sc.setCheckpointDir(path) + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) + checkpointer.update(rdd1) + rdd1.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkCheckpoint(rddsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rdd.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkCheckpoint(rddsToCheck, iteration, checkpointInterval) + iteration += 1 + } + + checkpointer.deleteAllCheckpoints() + rddsToCheck.foreach { rdd => + confirmCheckpointRemoved(rdd.rdd) + } + + Utils.deleteRecursively(tempDir) + } +} + +private object PeriodicRDDCheckpointerSuite { + + case class RDDToCheck(rdd: RDD[Double], gIndex: Int) + + def createRDD(sc: SparkContext): RDD[Double] = { + sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0)) + } + + def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = { + rdds.foreach { g => + checkPersistence(g.rdd, g.gIndex, iteration) + } + } + + /** + * Check storage level of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = { + try { + if (gIndex + 2 < iteration) { + assert(rdd.getStorageLevel == StorageLevel.NONE) + } else { + assert(rdd.getStorageLevel != StorageLevel.NONE) + } + } catch { + case _: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n") + } + } + + def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = { + rdds.reverse.foreach { g => + checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval) + } + } + + def confirmCheckpointRemoved(rdd: RDD[_]): Unit = { + // Note: We cannot check rdd.isCheckpointed since that value is never updated. + // Instead, we check for the presence of the checkpoint files. + // This test should continue to work even after this rdd.isCheckpointed issue + // is fixed (though it can then be simplified and not look for the files). + val hadoopConf = rdd.sparkContext.hadoopConfiguration + rdd.getCheckpointFile.foreach { checkpointFile => + val path = new Path(checkpointFile) + val fs = path.getFileSystem(hadoopConf) + assert(!fs.exists(path), "RDD checkpoint file should have been removed") + } + } + + /** + * Check checkpointed status of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkCheckpoint( + rdd: RDD[_], + gIndex: Int, + iteration: Int, + checkpointInterval: Int): Unit = { + try { + if (gIndex % checkpointInterval == 0) { + // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd) + // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint. + if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { + assert(rdd.isCheckpointed, "RDD should be checkpointed") + assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files") + } else { + confirmCheckpointRemoved(rdd) + } + } else { + // RDD should never be checkpointed + assert(!rdd.isCheckpointed, "RDD should never have been checkpointed") + assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files") + } + } catch { + case e: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t checkpointInterval = $checkpointInterval\n" + + s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" + + s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" + + s" AssertionError message: ${e.getMessage}") + } + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/docs/configuration.md ---------------------------------------------------------------------- diff --git a/docs/configuration.md b/docs/configuration.md index 6b65d2b..87b7632 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2149,6 +2149,20 @@ showDF(properties, numRows = 200, truncate = FALSE) </table> +### GraphX + +<table class="table"> +<tr><th>Property Name</th><th>Default</th><th>Meaning</th></tr> +<tr> + <td><code>spark.graphx.pregel.checkpointInterval</code></td> + <td>-1</td> + <td> + Checkpoint interval for graph and message in Pregel. It used to avoid stackOverflowError due to long lineage chains + after lots of iterations. The checkpoint is disabled by default. + </td> +</tr> +</table> + ### Deploy <table class="table"> http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/docs/graphx-programming-guide.md ---------------------------------------------------------------------- diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index e271b28..76aa7b4 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -708,7 +708,9 @@ messages remaining. > messaging function. These constraints allow additional optimization within > GraphX. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* -of its implementation (note calls to graph.cache have been removed): +of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally +checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number, +say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)): {% highlight scala %} class GraphOps[VD, ED] { @@ -722,6 +724,7 @@ class GraphOps[VD, ED] { : Graph[VD, ED] = { // Receive the initial message at each vertex var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache() + // compute the messages var messages = g.mapReduceTriplets(sendMsg, mergeMsg) var activeMessages = messages.count() @@ -734,8 +737,8 @@ class GraphOps[VD, ED] { // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. - messages = g.mapReduceTriplets( - sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + messages = GraphXUtils.mapReduceTriplets( + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() activeMessages = messages.count() i += 1 } http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala ---------------------------------------------------------------------- diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 646462b..755c6fe 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -19,7 +19,10 @@ package org.apache.spark.graphx import scala.reflect.ClassTag +import org.apache.spark.graphx.util.PeriodicGraphCheckpointer import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer /** * Implements a Pregel-like bulk-synchronous message-passing API. @@ -122,27 +125,39 @@ object Pregel extends Logging { require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," + s" but got ${maxIterations}") - var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() + val checkpointInterval = graph.vertices.sparkContext.getConf + .getInt("spark.graphx.pregel.checkpointInterval", -1) + var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)) + val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED]( + checkpointInterval, graph.vertices.sparkContext) + graphCheckpointer.update(g) + // compute the messages var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg) + val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)]( + checkpointInterval, graph.vertices.sparkContext) + messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) var activeMessages = messages.count() + // Loop var prevG: Graph[VD, ED] = null var i = 0 while (activeMessages > 0 && i < maxIterations) { // Receive the messages and update the vertices. prevG = g - g = g.joinVertices(messages)(vprog).cache() + g = g.joinVertices(messages)(vprog) + graphCheckpointer.update(g) val oldMessages = messages // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. messages = GraphXUtils.mapReduceTriplets( - g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))) // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages // and the vertices of g). + messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) activeMessages = messages.count() logInfo("Pregel finished iteration " + i) @@ -154,7 +169,9 @@ object Pregel extends Logging { // count the iteration i += 1 } - messages.unpersist(blocking = false) + messageCheckpointer.unpersistDataSet() + graphCheckpointer.deleteAllCheckpoints() + messageCheckpointer.deleteAllCheckpoints() g } // end of apply http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala ---------------------------------------------------------------------- diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala new file mode 100644 index 0000000..fda501a --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.util + +import org.apache.spark.SparkContext +import org.apache.spark.graphx.Graph +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer + + +/** + * This class helps with persisting and checkpointing Graphs. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call update() when a new graph has been created, + * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are + * responsible for materializing the graph to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new graph (if not yet persisted), and put in queue of persisted graphs. + * - Unpersist graphs from queue until there are at most 3 persisted graphs. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new graph, and put in a queue of checkpointed graphs. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Graphs should be + * checkpointed). + * - This class removes checkpoint files once later graphs have been checkpointed. + * However, references to the older graphs will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (graph1, graph2, graph3, ...) = ... + * val cp = new PeriodicGraphCheckpointer(2, sc) + * graph1.vertices.count(); graph1.edges.count() + * // persisted: graph1 + * cp.updateGraph(graph2) + * graph2.vertices.count(); graph2.edges.count() + * // persisted: graph1, graph2 + * // checkpointed: graph2 + * cp.updateGraph(graph3) + * graph3.vertices.count(); graph3.edges.count() + * // persisted: graph1, graph2, graph3 + * // checkpointed: graph2 + * cp.updateGraph(graph4) + * graph4.vertices.count(); graph4.edges.count() + * // persisted: graph2, graph3, graph4 + * // checkpointed: graph4 + * cp.updateGraph(graph5) + * graph5.vertices.count(); graph5.edges.count() + * // persisted: graph3, graph4, graph5 + * // checkpointed: graph4 + * }}} + * + * @param checkpointInterval Graphs will be checkpointed at this interval. + * If this interval was set as -1, then checkpointing will be disabled. + * @tparam VD Vertex descriptor type + * @tparam ED Edge descriptor type + * + */ +private[spark] class PeriodicGraphCheckpointer[VD, ED]( + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { + + override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() + + override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed + + override protected def persist(data: Graph[VD, ED]): Unit = { + if (data.vertices.getStorageLevel == StorageLevel.NONE) { + /* We need to use cache because persist does not honor the default storage level requested + * when constructing the graph. Only cache does that. + */ + data.vertices.cache() + } + if (data.edges.getStorageLevel == StorageLevel.NONE) { + data.edges.cache() + } + } + + override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) + + override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = { + data.getCheckpointFiles + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala ---------------------------------------------------------------------- diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala new file mode 100644 index 0000000..e0c65e6 --- /dev/null +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.util + +import org.apache.hadoop.fs.Path + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + + +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContext { + + import PeriodicGraphCheckpointerSuite._ + + test("Persisting") { + var graphsToCheck = Seq.empty[GraphToCheck] + + withSpark { sc => + val graph1 = createGraph(sc) + val checkpointer = + new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkPersistence(graphsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.update(graph) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkPersistence(graphsToCheck, iteration) + iteration += 1 + } + } + } + + test("Checkpointing") { + withSpark { sc => + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var graphsToCheck = Seq.empty[GraphToCheck] + sc.setCheckpointDir(path) + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( + checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) + graph1.edges.count() + graph1.vertices.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkCheckpoint(graphsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.update(graph) + graph.vertices.count() + graph.edges.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkCheckpoint(graphsToCheck, iteration, checkpointInterval) + iteration += 1 + } + + checkpointer.deleteAllCheckpoints() + graphsToCheck.foreach { graph => + confirmCheckpointRemoved(graph.graph) + } + + Utils.deleteRecursively(tempDir) + } + } +} + +private object PeriodicGraphCheckpointerSuite { + private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER + + case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) + + val edges = Seq( + Edge[Double](0, 1, 0), + Edge[Double](1, 2, 0), + Edge[Double](2, 3, 0), + Edge[Double](3, 4, 0)) + + def createGraph(sc: SparkContext): Graph[Double, Double] = { + Graph.fromEdges[Double, Double]( + sc.parallelize(edges), 0, defaultStorageLevel, defaultStorageLevel) + } + + def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { + graphs.foreach { g => + checkPersistence(g.graph, g.gIndex, iteration) + } + } + + /** + * Check storage level of graph. + * @param gIndex Index of graph in order inserted into checkpointer (from 1). + * @param iteration Total number of graphs inserted into checkpointer. + */ + def checkPersistence(graph: Graph[_, _], gIndex: Int, iteration: Int): Unit = { + try { + if (gIndex + 2 < iteration) { + assert(graph.vertices.getStorageLevel == StorageLevel.NONE) + assert(graph.edges.getStorageLevel == StorageLevel.NONE) + } else { + assert(graph.vertices.getStorageLevel == defaultStorageLevel) + assert(graph.edges.getStorageLevel == defaultStorageLevel) + } + } catch { + case _: AssertionError => + throw new Exception(s"PeriodicGraphCheckpointerSuite.checkPersistence failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t graph.vertices.getStorageLevel = ${graph.vertices.getStorageLevel}\n" + + s"\t graph.edges.getStorageLevel = ${graph.edges.getStorageLevel}\n") + } + } + + def checkCheckpoint(graphs: Seq[GraphToCheck], iteration: Int, checkpointInterval: Int): Unit = { + graphs.reverse.foreach { g => + checkCheckpoint(g.graph, g.gIndex, iteration, checkpointInterval) + } + } + + def confirmCheckpointRemoved(graph: Graph[_, _]): Unit = { + // Note: We cannot check graph.isCheckpointed since that value is never updated. + // Instead, we check for the presence of the checkpoint files. + // This test should continue to work even after this graph.isCheckpointed issue + // is fixed (though it can then be simplified and not look for the files). + val hadoopConf = graph.vertices.sparkContext.hadoopConfiguration + graph.getCheckpointFiles.foreach { checkpointFile => + val path = new Path(checkpointFile) + val fs = path.getFileSystem(hadoopConf) + assert(!fs.exists(path), + "Graph checkpoint file should have been removed") + } + } + + /** + * Check checkpointed status of graph. + * @param gIndex Index of graph in order inserted into checkpointer (from 1). + * @param iteration Total number of graphs inserted into checkpointer. + */ + def checkCheckpoint( + graph: Graph[_, _], + gIndex: Int, + iteration: Int, + checkpointInterval: Int): Unit = { + try { + if (gIndex % checkpointInterval == 0) { + // We allow 2 checkpoint intervals since we perform an action (checkpointing a second graph) + // only AFTER PeriodicGraphCheckpointer decides whether to remove the previous checkpoint. + if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { + assert(graph.isCheckpointed, "Graph should be checkpointed") + assert(graph.getCheckpointFiles.length == 2, "Graph should have 2 checkpoint files") + } else { + confirmCheckpointRemoved(graph) + } + } else { + // Graph should never be checkpointed + assert(!graph.isCheckpointed, "Graph should never have been checkpointed") + assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files") + } + } catch { + case e: AssertionError => + throw new Exception(s"PeriodicGraphCheckpointerSuite.checkCheckpoint failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t checkpointInterval = $checkpointInterval\n" + + s"\t graph.isCheckpointed = ${graph.isCheckpointed}\n" + + s"\t graph.getCheckpointFiles = ${graph.getCheckpointFiles.mkString(", ")}\n" + + s" AssertionError message: ${e.getMessage}") + } + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 2f50dc7..e3026c8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -36,7 +36,6 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} -import org.apache.spark.mllib.impl.PeriodicCheckpointer import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.MatrixImplicits._ import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -45,9 +44,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils - private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter with HasSeed with HasCheckpointInterval { http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 4c525c0..ce2bd7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -21,12 +21,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} -import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 48bae42..3697a9b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -25,7 +25,7 @@ import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ -import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.graphx.util.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala deleted file mode 100644 index 4dd498c..0000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.impl - -import scala.collection.mutable - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path - -import org.apache.spark.SparkContext -import org.apache.spark.internal.Logging -import org.apache.spark.storage.StorageLevel - - -/** - * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs - * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to - * the distributed data type (RDD, Graph, etc.). - * - * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing, - * as well as unpersisting and removing checkpoint files. - * - * Users should call update() when a new Dataset has been created, - * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are - * responsible for materializing the Dataset to ensure that persisting and checkpointing actually - * occur. - * - * When update() is called, this does the following: - * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets. - * - Unpersist Datasets from queue until there are at most 3 persisted Datasets. - * - If using checkpointing and the checkpoint interval has been reached, - * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets. - * - Remove older checkpoints. - * - * WARNINGS: - * - This class should NOT be copied (since copies may conflict on which Datasets should be - * checkpointed). - * - This class removes checkpoint files once later Datasets have been checkpointed. - * However, references to the older Datasets will still return isCheckpointed = true. - * - * @param checkpointInterval Datasets will be checkpointed at this interval. - * If this interval was set as -1, then checkpointing will be disabled. - * @param sc SparkContext for the Datasets given to this checkpointer - * @tparam T Dataset type, such as RDD[Double] - */ -private[mllib] abstract class PeriodicCheckpointer[T]( - val checkpointInterval: Int, - val sc: SparkContext) extends Logging { - - /** FIFO queue of past checkpointed Datasets */ - private val checkpointQueue = mutable.Queue[T]() - - /** FIFO queue of past persisted Datasets */ - private val persistedQueue = mutable.Queue[T]() - - /** Number of times [[update()]] has been called */ - private var updateCount = 0 - - /** - * Update with a new Dataset. Handle persistence and checkpointing as needed. - * Since this handles persistence and checkpointing, this should be called before the Dataset - * has been materialized. - * - * @param newData New Dataset created from previous Datasets in the lineage. - */ - def update(newData: T): Unit = { - persist(newData) - persistedQueue.enqueue(newData) - // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class: - // Users should call [[update()]] when a new Dataset has been created, - // before the Dataset has been materialized. - while (persistedQueue.size > 3) { - val dataToUnpersist = persistedQueue.dequeue() - unpersist(dataToUnpersist) - } - updateCount += 1 - - // Handle checkpointing (after persisting) - if (checkpointInterval != -1 && (updateCount % checkpointInterval) == 0 - && sc.getCheckpointDir.nonEmpty) { - // Add new checkpoint before removing old checkpoints. - checkpoint(newData) - checkpointQueue.enqueue(newData) - // Remove checkpoints before the latest one. - var canDelete = true - while (checkpointQueue.size > 1 && canDelete) { - // Delete the oldest checkpoint only if the next checkpoint exists. - if (isCheckpointed(checkpointQueue.head)) { - removeCheckpointFile() - } else { - canDelete = false - } - } - } - } - - /** Checkpoint the Dataset */ - protected def checkpoint(data: T): Unit - - /** Return true iff the Dataset is checkpointed */ - protected def isCheckpointed(data: T): Boolean - - /** - * Persist the Dataset. - * Note: This should handle checking the current [[StorageLevel]] of the Dataset. - */ - protected def persist(data: T): Unit - - /** Unpersist the Dataset */ - protected def unpersist(data: T): Unit - - /** Get list of checkpoint files for this given Dataset */ - protected def getCheckpointFiles(data: T): Iterable[String] - - /** - * Call this at the end to delete any remaining checkpoint files. - */ - def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.nonEmpty) { - removeCheckpointFile() - } - } - - /** - * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint. - * Note that there may not be any checkpoints at all. - */ - def deleteAllCheckpointsButLast(): Unit = { - while (checkpointQueue.size > 1) { - removeCheckpointFile() - } - } - - /** - * Get all current checkpoint files. - * This is useful in combination with [[deleteAllCheckpointsButLast()]]. - */ - def getAllCheckpointFiles: Array[String] = { - checkpointQueue.flatMap(getCheckpointFiles).toArray - } - - /** - * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. - * This prints a warning but does not fail if the files cannot be removed. - */ - private def removeCheckpointFile(): Unit = { - val old = checkpointQueue.dequeue() - // Since the old checkpoint is not deleted by Spark, we manually delete it. - getCheckpointFiles(old).foreach( - PeriodicCheckpointer.removeCheckpointFile(_, sc.hadoopConfiguration)) - } -} - -private[spark] object PeriodicCheckpointer extends Logging { - - /** Delete a checkpoint file, and log a warning if deletion fails. */ - def removeCheckpointFile(checkpointFile: String, conf: Configuration): Unit = { - try { - val path = new Path(checkpointFile) - val fs = path.getFileSystem(conf) - fs.delete(path, true) - } catch { - case e: Exception => - logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + - checkpointFile) - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala deleted file mode 100644 index 8007489..0000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.impl - -import org.apache.spark.SparkContext -import org.apache.spark.graphx.Graph -import org.apache.spark.storage.StorageLevel - - -/** - * This class helps with persisting and checkpointing Graphs. - * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as - * unpersisting and removing checkpoint files. - * - * Users should call update() when a new graph has been created, - * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are - * responsible for materializing the graph to ensure that persisting and checkpointing actually - * occur. - * - * When update() is called, this does the following: - * - Persist new graph (if not yet persisted), and put in queue of persisted graphs. - * - Unpersist graphs from queue until there are at most 3 persisted graphs. - * - If using checkpointing and the checkpoint interval has been reached, - * - Checkpoint the new graph, and put in a queue of checkpointed graphs. - * - Remove older checkpoints. - * - * WARNINGS: - * - This class should NOT be copied (since copies may conflict on which Graphs should be - * checkpointed). - * - This class removes checkpoint files once later graphs have been checkpointed. - * However, references to the older graphs will still return isCheckpointed = true. - * - * Example usage: - * {{{ - * val (graph1, graph2, graph3, ...) = ... - * val cp = new PeriodicGraphCheckpointer(2, sc) - * graph1.vertices.count(); graph1.edges.count() - * // persisted: graph1 - * cp.updateGraph(graph2) - * graph2.vertices.count(); graph2.edges.count() - * // persisted: graph1, graph2 - * // checkpointed: graph2 - * cp.updateGraph(graph3) - * graph3.vertices.count(); graph3.edges.count() - * // persisted: graph1, graph2, graph3 - * // checkpointed: graph2 - * cp.updateGraph(graph4) - * graph4.vertices.count(); graph4.edges.count() - * // persisted: graph2, graph3, graph4 - * // checkpointed: graph4 - * cp.updateGraph(graph5) - * graph5.vertices.count(); graph5.edges.count() - * // persisted: graph3, graph4, graph5 - * // checkpointed: graph4 - * }}} - * - * @param checkpointInterval Graphs will be checkpointed at this interval. - * If this interval was set as -1, then checkpointing will be disabled. - * @tparam VD Vertex descriptor type - * @tparam ED Edge descriptor type - * - * TODO: Move this out of MLlib? - */ -private[mllib] class PeriodicGraphCheckpointer[VD, ED]( - checkpointInterval: Int, - sc: SparkContext) - extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { - - override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() - - override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed - - override protected def persist(data: Graph[VD, ED]): Unit = { - if (data.vertices.getStorageLevel == StorageLevel.NONE) { - data.vertices.persist() - } - if (data.edges.getStorageLevel == StorageLevel.NONE) { - data.edges.persist() - } - } - - override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) - - override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = { - data.getCheckpointFiles - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala deleted file mode 100644 index 145dc22..0000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.impl - -import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - - -/** - * This class helps with persisting and checkpointing RDDs. - * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as - * unpersisting and removing checkpoint files. - * - * Users should call update() when a new RDD has been created, - * before the RDD has been materialized. After updating [[PeriodicRDDCheckpointer]], users are - * responsible for materializing the RDD to ensure that persisting and checkpointing actually - * occur. - * - * When update() is called, this does the following: - * - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs. - * - Unpersist RDDs from queue until there are at most 3 persisted RDDs. - * - If using checkpointing and the checkpoint interval has been reached, - * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs. - * - Remove older checkpoints. - * - * WARNINGS: - * - This class should NOT be copied (since copies may conflict on which RDDs should be - * checkpointed). - * - This class removes checkpoint files once later RDDs have been checkpointed. - * However, references to the older RDDs will still return isCheckpointed = true. - * - * Example usage: - * {{{ - * val (rdd1, rdd2, rdd3, ...) = ... - * val cp = new PeriodicRDDCheckpointer(2, sc) - * rdd1.count(); - * // persisted: rdd1 - * cp.update(rdd2) - * rdd2.count(); - * // persisted: rdd1, rdd2 - * // checkpointed: rdd2 - * cp.update(rdd3) - * rdd3.count(); - * // persisted: rdd1, rdd2, rdd3 - * // checkpointed: rdd2 - * cp.update(rdd4) - * rdd4.count(); - * // persisted: rdd2, rdd3, rdd4 - * // checkpointed: rdd4 - * cp.update(rdd5) - * rdd5.count(); - * // persisted: rdd3, rdd4, rdd5 - * // checkpointed: rdd4 - * }}} - * - * @param checkpointInterval RDDs will be checkpointed at this interval - * @tparam T RDD element type - * - * TODO: Move this out of MLlib? - */ -private[spark] class PeriodicRDDCheckpointer[T]( - checkpointInterval: Int, - sc: SparkContext) - extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { - - override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint() - - override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed - - override protected def persist(data: RDD[T]): Unit = { - if (data.getStorageLevel == StorageLevel.NONE) { - data.persist() - } - } - - override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) - - override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = { - data.getCheckpointFile.map(x => x) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala deleted file mode 100644 index a13e7f6..0000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.impl - -import org.apache.hadoop.fs.Path - -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.graphx.{Edge, Graph} -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils - - -class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { - - import PeriodicGraphCheckpointerSuite._ - - test("Persisting") { - var graphsToCheck = Seq.empty[GraphToCheck] - - val graph1 = createGraph(sc) - val checkpointer = - new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) - checkpointer.update(graph1) - graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) - checkPersistence(graphsToCheck, 1) - - var iteration = 2 - while (iteration < 9) { - val graph = createGraph(sc) - checkpointer.update(graph) - graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) - checkPersistence(graphsToCheck, iteration) - iteration += 1 - } - } - - test("Checkpointing") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - val checkpointInterval = 2 - var graphsToCheck = Seq.empty[GraphToCheck] - sc.setCheckpointDir(path) - val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( - checkpointInterval, graph1.vertices.sparkContext) - checkpointer.update(graph1) - graph1.edges.count() - graph1.vertices.count() - graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) - checkCheckpoint(graphsToCheck, 1, checkpointInterval) - - var iteration = 2 - while (iteration < 9) { - val graph = createGraph(sc) - checkpointer.update(graph) - graph.vertices.count() - graph.edges.count() - graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) - checkCheckpoint(graphsToCheck, iteration, checkpointInterval) - iteration += 1 - } - - checkpointer.deleteAllCheckpoints() - graphsToCheck.foreach { graph => - confirmCheckpointRemoved(graph.graph) - } - - Utils.deleteRecursively(tempDir) - } -} - -private object PeriodicGraphCheckpointerSuite { - - case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) - - val edges = Seq( - Edge[Double](0, 1, 0), - Edge[Double](1, 2, 0), - Edge[Double](2, 3, 0), - Edge[Double](3, 4, 0)) - - def createGraph(sc: SparkContext): Graph[Double, Double] = { - Graph.fromEdges[Double, Double](sc.parallelize(edges), 0) - } - - def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { - graphs.foreach { g => - checkPersistence(g.graph, g.gIndex, iteration) - } - } - - /** - * Check storage level of graph. - * @param gIndex Index of graph in order inserted into checkpointer (from 1). - * @param iteration Total number of graphs inserted into checkpointer. - */ - def checkPersistence(graph: Graph[_, _], gIndex: Int, iteration: Int): Unit = { - try { - if (gIndex + 2 < iteration) { - assert(graph.vertices.getStorageLevel == StorageLevel.NONE) - assert(graph.edges.getStorageLevel == StorageLevel.NONE) - } else { - assert(graph.vertices.getStorageLevel != StorageLevel.NONE) - assert(graph.edges.getStorageLevel != StorageLevel.NONE) - } - } catch { - case _: AssertionError => - throw new Exception(s"PeriodicGraphCheckpointerSuite.checkPersistence failed with:\n" + - s"\t gIndex = $gIndex\n" + - s"\t iteration = $iteration\n" + - s"\t graph.vertices.getStorageLevel = ${graph.vertices.getStorageLevel}\n" + - s"\t graph.edges.getStorageLevel = ${graph.edges.getStorageLevel}\n") - } - } - - def checkCheckpoint(graphs: Seq[GraphToCheck], iteration: Int, checkpointInterval: Int): Unit = { - graphs.reverse.foreach { g => - checkCheckpoint(g.graph, g.gIndex, iteration, checkpointInterval) - } - } - - def confirmCheckpointRemoved(graph: Graph[_, _]): Unit = { - // Note: We cannot check graph.isCheckpointed since that value is never updated. - // Instead, we check for the presence of the checkpoint files. - // This test should continue to work even after this graph.isCheckpointed issue - // is fixed (though it can then be simplified and not look for the files). - val hadoopConf = graph.vertices.sparkContext.hadoopConfiguration - graph.getCheckpointFiles.foreach { checkpointFile => - val path = new Path(checkpointFile) - val fs = path.getFileSystem(hadoopConf) - assert(!fs.exists(path), - "Graph checkpoint file should have been removed") - } - } - - /** - * Check checkpointed status of graph. - * @param gIndex Index of graph in order inserted into checkpointer (from 1). - * @param iteration Total number of graphs inserted into checkpointer. - */ - def checkCheckpoint( - graph: Graph[_, _], - gIndex: Int, - iteration: Int, - checkpointInterval: Int): Unit = { - try { - if (gIndex % checkpointInterval == 0) { - // We allow 2 checkpoint intervals since we perform an action (checkpointing a second graph) - // only AFTER PeriodicGraphCheckpointer decides whether to remove the previous checkpoint. - if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { - assert(graph.isCheckpointed, "Graph should be checkpointed") - assert(graph.getCheckpointFiles.length == 2, "Graph should have 2 checkpoint files") - } else { - confirmCheckpointRemoved(graph) - } - } else { - // Graph should never be checkpointed - assert(!graph.isCheckpointed, "Graph should never have been checkpointed") - assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files") - } - } catch { - case e: AssertionError => - throw new Exception(s"PeriodicGraphCheckpointerSuite.checkCheckpoint failed with:\n" + - s"\t gIndex = $gIndex\n" + - s"\t iteration = $iteration\n" + - s"\t checkpointInterval = $checkpointInterval\n" + - s"\t graph.isCheckpointed = ${graph.isCheckpointed}\n" + - s"\t graph.getCheckpointFiles = ${graph.getCheckpointFiles.mkString(", ")}\n" + - s" AssertionError message: ${e.getMessage}") - } - } - -} http://git-wip-us.apache.org/repos/asf/spark/blob/f971ce5d/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala deleted file mode 100644 index 14adf8c..0000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.impl - -import org.apache.hadoop.fs.Path - -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils - - -class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { - - import PeriodicRDDCheckpointerSuite._ - - test("Persisting") { - var rddsToCheck = Seq.empty[RDDToCheck] - - val rdd1 = createRDD(sc) - val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext) - checkpointer.update(rdd1) - rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) - checkPersistence(rddsToCheck, 1) - - var iteration = 2 - while (iteration < 9) { - val rdd = createRDD(sc) - checkpointer.update(rdd) - rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) - checkPersistence(rddsToCheck, iteration) - iteration += 1 - } - } - - test("Checkpointing") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - val checkpointInterval = 2 - var rddsToCheck = Seq.empty[RDDToCheck] - sc.setCheckpointDir(path) - val rdd1 = createRDD(sc) - val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) - checkpointer.update(rdd1) - rdd1.count() - rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) - checkCheckpoint(rddsToCheck, 1, checkpointInterval) - - var iteration = 2 - while (iteration < 9) { - val rdd = createRDD(sc) - checkpointer.update(rdd) - rdd.count() - rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) - checkCheckpoint(rddsToCheck, iteration, checkpointInterval) - iteration += 1 - } - - checkpointer.deleteAllCheckpoints() - rddsToCheck.foreach { rdd => - confirmCheckpointRemoved(rdd.rdd) - } - - Utils.deleteRecursively(tempDir) - } -} - -private object PeriodicRDDCheckpointerSuite { - - case class RDDToCheck(rdd: RDD[Double], gIndex: Int) - - def createRDD(sc: SparkContext): RDD[Double] = { - sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0)) - } - - def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = { - rdds.foreach { g => - checkPersistence(g.rdd, g.gIndex, iteration) - } - } - - /** - * Check storage level of rdd. - * @param gIndex Index of rdd in order inserted into checkpointer (from 1). - * @param iteration Total number of rdds inserted into checkpointer. - */ - def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = { - try { - if (gIndex + 2 < iteration) { - assert(rdd.getStorageLevel == StorageLevel.NONE) - } else { - assert(rdd.getStorageLevel != StorageLevel.NONE) - } - } catch { - case _: AssertionError => - throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" + - s"\t gIndex = $gIndex\n" + - s"\t iteration = $iteration\n" + - s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n") - } - } - - def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = { - rdds.reverse.foreach { g => - checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval) - } - } - - def confirmCheckpointRemoved(rdd: RDD[_]): Unit = { - // Note: We cannot check rdd.isCheckpointed since that value is never updated. - // Instead, we check for the presence of the checkpoint files. - // This test should continue to work even after this rdd.isCheckpointed issue - // is fixed (though it can then be simplified and not look for the files). - val hadoopConf = rdd.sparkContext.hadoopConfiguration - rdd.getCheckpointFile.foreach { checkpointFile => - val path = new Path(checkpointFile) - val fs = path.getFileSystem(hadoopConf) - assert(!fs.exists(path), "RDD checkpoint file should have been removed") - } - } - - /** - * Check checkpointed status of rdd. - * @param gIndex Index of rdd in order inserted into checkpointer (from 1). - * @param iteration Total number of rdds inserted into checkpointer. - */ - def checkCheckpoint( - rdd: RDD[_], - gIndex: Int, - iteration: Int, - checkpointInterval: Int): Unit = { - try { - if (gIndex % checkpointInterval == 0) { - // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd) - // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint. - if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { - assert(rdd.isCheckpointed, "RDD should be checkpointed") - assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files") - } else { - confirmCheckpointRemoved(rdd) - } - } else { - // RDD should never be checkpointed - assert(!rdd.isCheckpointed, "RDD should never have been checkpointed") - assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files") - } - } catch { - case e: AssertionError => - throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" + - s"\t gIndex = $gIndex\n" + - s"\t iteration = $iteration\n" + - s"\t checkpointInterval = $checkpointInterval\n" + - s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" + - s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" + - s" AssertionError message: ${e.getMessage}") - } - } - -} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org