Repository: spark
Updated Branches:
  refs/heads/branch-1.3 44768f582 -> 59798cb44


[SPARK-5604[MLLIB] remove checkpointDir from LDA

`checkpointDir` is a Spark global configuration. Users should set it outside 
LDA. This PR also hides some methods under `private[clustering] object LDA`, so 
they don't show up in the generated Java doc (SPARK-5610).

jkbradley

Author: Xiangrui Meng <m...@databricks.com>

Closes #4390 from mengxr/SPARK-5604 and squashes the following commits:

a34bb39 [Xiangrui Meng] remove checkpointDir from LDA

(cherry picked from commit c19152cd2a5d407ecf526a90e3bb059f09905b3a)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/59798cb4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/59798cb4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/59798cb4

Branch: refs/heads/branch-1.3
Commit: 59798cb4442f913e2bcb97c69ad931dbeb572349
Parents: 44768f5
Author: Xiangrui Meng <m...@databricks.com>
Authored: Thu Feb 5 15:07:33 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Thu Feb 5 15:07:39 2015 -0800

----------------------------------------------------------------------
 .../spark/examples/mllib/LDAExample.scala       |  2 +-
 .../org/apache/spark/mllib/clustering/LDA.scala | 73 ++++++--------------
 .../mllib/impl/PeriodicGraphCheckpointer.scala  |  8 ---
 .../impl/PeriodicGraphCheckpointerSuite.scala   |  6 +-
 4 files changed, 24 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/59798cb4/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
----------------------------------------------------------------------
diff --git 
a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala 
b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index f4c545a..0e1b27a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -134,7 +134,7 @@ object LDAExample {
       .setTopicConcentration(params.topicConcentration)
       .setCheckpointInterval(params.checkpointInterval)
     if (params.checkpointDir.nonEmpty) {
-      lda.setCheckpointDir(params.checkpointDir.get)
+      sc.setCheckpointDir(params.checkpointDir.get)
     }
     val startTime = System.nanoTime()
     val ldaModel = lda.run(corpus)

http://git-wip-us.apache.org/repos/asf/spark/blob/59798cb4/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index d8f8286..a1d3df0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -52,6 +52,9 @@ import org.apache.spark.util.Utils
  *  - Paper which clearly explains several algorithms, including EM:
  *    Asuncion, Welling, Smyth, and Teh.
  *    "On Smoothing and Inference for Topic Models."  UAI, 2009.
+ *
+ * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent 
Dirichlet allocation
+ *       (Wikipedia)]]
  */
 @Experimental
 class LDA private (
@@ -60,11 +63,10 @@ class LDA private (
     private var docConcentration: Double,
     private var topicConcentration: Double,
     private var seed: Long,
-    private var checkpointDir: Option[String],
     private var checkpointInterval: Int) extends Logging {
 
   def this() = this(k = 10, maxIterations = 20, docConcentration = -1, 
topicConcentration = -1,
-    seed = Utils.random.nextLong(), checkpointDir = None, checkpointInterval = 
10)
+    seed = Utils.random.nextLong(), checkpointInterval = 10)
 
   /**
    * Number of topics to infer.  I.e., the number of soft cluster centers.
@@ -201,49 +203,17 @@ class LDA private (
   }
 
   /**
-   * Directory for storing checkpoint files during learning.
-   * This is not necessary, but checkpointing helps with recovery (when nodes 
fail).
-   * It also helps with eliminating temporary shuffle files on disk, which can 
be important when
-   * LDA is run for many iterations.
-   */
-  def getCheckpointDir: Option[String] = checkpointDir
-
-  /**
-   * Directory for storing checkpoint files during learning.
-   * This is not necessary, but checkpointing helps with recovery (when nodes 
fail).
-   * It also helps with eliminating temporary shuffle files on disk, which can 
be important when
-   * LDA is run for many iterations.
-   *
-   * NOTE: If the [[org.apache.spark.SparkContext.checkpointDir]] is already 
set, then the value
-   *       given to LDA is ignored, and the existing directory is kept.
-   *
-   * (default = None)
-   */
-  def setCheckpointDir(checkpointDir: String): this.type = {
-    this.checkpointDir = Some(checkpointDir)
-    this
-  }
-
-  /**
-   * Clear the directory for storing checkpoint files during learning.
-   * If one is already set in the [[org.apache.spark.SparkContext]], then 
checkpointing will still
-   * occur; otherwise, no checkpointing will be used.
-   */
-  def clearCheckpointDir(): this.type = {
-    this.checkpointDir = None
-    this
-  }
-
-  /**
    * Period (in iterations) between checkpoints.
-   * @see [[getCheckpointDir]]
    */
   def getCheckpointInterval: Int = checkpointInterval
 
   /**
-   * Period (in iterations) between checkpoints.
-   * (default = 10)
-   * @see [[getCheckpointDir]]
+   * Period (in iterations) between checkpoints (default = 10). Checkpointing 
helps with recovery
+   * (when nodes fail). It also helps with eliminating temporary shuffle files 
on disk, which can be
+   * important when LDA is run for many iterations. If the checkpoint 
directory is not set in
+   * [[org.apache.spark.SparkContext]], this setting is ignored.
+   *
+   * @see [[org.apache.spark.SparkContext#setCheckpointDir]]
    */
   def setCheckpointInterval(checkpointInterval: Int): this.type = {
     this.checkpointInterval = checkpointInterval
@@ -261,7 +231,7 @@ class LDA private (
    */
   def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
     val state = LDA.initialState(documents, k, getDocConcentration, 
getTopicConcentration, seed,
-      checkpointDir, checkpointInterval)
+      checkpointInterval)
     var iter = 0
     val iterationTimes = Array.fill[Double](maxIterations)(0)
     while (iter < maxIterations) {
@@ -337,18 +307,18 @@ private[clustering] object LDA {
    * Vector over topics (length k) of token counts.
    * The meaning of these counts can vary, and it may or may not be normalized 
to be a distribution.
    */
-  type TopicCounts = BDV[Double]
+  private[clustering] type TopicCounts = BDV[Double]
 
-  type TokenCount = Double
+  private[clustering] type TokenCount = Double
 
   /** Term vertex IDs are {-1, -2, ..., -vocabSize} */
-  def term2index(term: Int): Long = -(1 + term.toLong)
+  private[clustering] def term2index(term: Int): Long = -(1 + term.toLong)
 
-  def index2term(termIndex: Long): Int = -(1 + termIndex).toInt
+  private[clustering] def index2term(termIndex: Long): Int = -(1 + 
termIndex).toInt
 
-  def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0
+  private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 
>= 0
 
-  def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0
+  private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0
 
   /**
    * Optimizer for EM algorithm which stores data + parameter graph, plus 
algorithm parameters.
@@ -360,17 +330,16 @@ private[clustering] object LDA {
    * @param docConcentration  "alpha"
    * @param topicConcentration  "beta" or "eta"
    */
-  class EMOptimizer(
+  private[clustering] class EMOptimizer(
       var graph: Graph[TopicCounts, TokenCount],
       val k: Int,
       val vocabSize: Int,
       val docConcentration: Double,
       val topicConcentration: Double,
-      checkpointDir: Option[String],
       checkpointInterval: Int) {
 
     private[LDA] val graphCheckpointer = new 
PeriodicGraphCheckpointer[TopicCounts, TokenCount](
-      graph, checkpointDir, checkpointInterval)
+      graph, checkpointInterval)
 
     def next(): EMOptimizer = {
       val eta = topicConcentration
@@ -468,7 +437,6 @@ private[clustering] object LDA {
       docConcentration: Double,
       topicConcentration: Double,
       randomSeed: Long,
-      checkpointDir: Option[String],
       checkpointInterval: Int): EMOptimizer = {
     // For each document, create an edge (Document -> Term) for each unique 
term in the document.
     val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, 
termCounts: Vector) =>
@@ -512,8 +480,7 @@ private[clustering] object LDA {
     val graph = Graph(docVertices ++ termVertices, edges)
       .partitionBy(PartitionStrategy.EdgePartition1D)
 
-    new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, 
checkpointDir,
-      checkpointInterval)
+    new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, 
checkpointInterval)
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/59798cb4/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
index 76672fe..6e5dd11 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
@@ -74,7 +74,6 @@ import org.apache.spark.storage.StorageLevel
  * }}}
  *
  * @param currentGraph  Initial graph
- * @param checkpointDir The directory for storing checkpoint files
  * @param checkpointInterval Graphs will be checkpointed at this interval
  * @tparam VD  Vertex descriptor type
  * @tparam ED  Edge descriptor type
@@ -83,7 +82,6 @@ import org.apache.spark.storage.StorageLevel
  */
 private[mllib] class PeriodicGraphCheckpointer[VD, ED](
     var currentGraph: Graph[VD, ED],
-    val checkpointDir: Option[String],
     val checkpointInterval: Int) extends Logging {
 
   /** FIFO queue of past checkpointed RDDs */
@@ -101,12 +99,6 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED](
    */
   private val sc = currentGraph.vertices.sparkContext
 
-  // If a checkpoint directory is given, and there's no prior checkpoint 
directory,
-  // then set the checkpoint directory with the given one.
-  if (checkpointDir.nonEmpty && sc.getCheckpointDir.isEmpty) {
-    sc.setCheckpointDir(checkpointDir.get)
-  }
-
   updateGraph(currentGraph)
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/59798cb4/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
index dac28a3..699f009 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
@@ -38,7 +38,7 @@ class PeriodicGraphCheckpointerSuite extends FunSuite with 
MLlibTestSparkContext
     var graphsToCheck = Seq.empty[GraphToCheck]
 
     val graph1 = createGraph(sc)
-    val checkpointer = new PeriodicGraphCheckpointer(graph1, None, 10)
+    val checkpointer = new PeriodicGraphCheckpointer(graph1, 10)
     graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
     checkPersistence(graphsToCheck, 1)
 
@@ -57,9 +57,9 @@ class PeriodicGraphCheckpointerSuite extends FunSuite with 
MLlibTestSparkContext
     val path = tempDir.toURI.toString
     val checkpointInterval = 2
     var graphsToCheck = Seq.empty[GraphToCheck]
-
+    sc.setCheckpointDir(path)
     val graph1 = createGraph(sc)
-    val checkpointer = new PeriodicGraphCheckpointer(graph1, Some(path), 
checkpointInterval)
+    val checkpointer = new PeriodicGraphCheckpointer(graph1, 
checkpointInterval)
     graph1.edges.count()
     graph1.vertices.count()
     graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to