Repository: spark
Updated Branches:
  refs/heads/master 1c751fcf4 -> 643b4e225


[SPARK-14510][MLLIB] Add args-checking for LDA and StreamingKMeans

## What changes were proposed in this pull request?
add the checking for LDA and StreamingKMeans

## How was this patch tested?
manual tests

Author: Zheng RuiFeng <ruife...@foxmail.com>

Closes #12062 from zhengruifeng/initmodel.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/643b4e22
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/643b4e22
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/643b4e22

Branch: refs/heads/master
Commit: 643b4e2257c56338b192f8554e2fe5523bea4bdf
Parents: 1c751fc
Author: Zheng RuiFeng <ruife...@foxmail.com>
Authored: Mon Apr 11 09:33:52 2016 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Mon Apr 11 09:33:52 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/mllib/clustering/LDA.scala     | 10 +++++++---
 .../apache/spark/mllib/clustering/StreamingKMeans.scala   | 10 ++++++++++
 2 files changed, 17 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/643b4e22/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index 12813fd..d999b9b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -130,7 +130,8 @@ class LDA private (
    */
   @Since("1.5.0")
   def setDocConcentration(docConcentration: Vector): this.type = {
-    require(docConcentration.size > 0, "docConcentration must have > 0 
elements")
+    require(docConcentration.size == 1 || docConcentration.size == k,
+      s"Size of docConcentration must be 1 or ${k} but got 
${docConcentration.size}")
     this.docConcentration = docConcentration
     this
   }
@@ -260,15 +261,18 @@ class LDA private (
   def getCheckpointInterval: Int = checkpointInterval
 
   /**
-   * Period (in iterations) between checkpoints (default = 10). Checkpointing 
helps with recovery
+   * Parameter for set checkpoint interval (>= 1) or disable checkpoint (-1). 
E.g. 10 means that
+   * the cache will get checkpointed every 10 iterations. Checkpointing helps 
with recovery
    * (when nodes fail). It also helps with eliminating temporary shuffle files 
on disk, which can be
    * important when LDA is run for many iterations. If the checkpoint 
directory is not set in
-   * [[org.apache.spark.SparkContext]], this setting is ignored.
+   * [[org.apache.spark.SparkContext]], this setting is ignored. (default = 10)
    *
    * @see [[org.apache.spark.SparkContext#setCheckpointDir]]
    */
   @Since("1.3.0")
   def setCheckpointInterval(checkpointInterval: Int): this.type = {
+    require(checkpointInterval == -1 || checkpointInterval > 0,
+      s"Period between checkpoints must be -1 or positive but got 
${checkpointInterval}")
     this.checkpointInterval = checkpointInterval
     this
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/643b4e22/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 4eb8fc0..24e1cff 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -218,6 +218,12 @@ class StreamingKMeans @Since("1.2.0") (
    */
   @Since("1.2.0")
   def setInitialCenters(centers: Array[Vector], weights: Array[Double]): 
this.type = {
+    require(centers.size == weights.size,
+      "Number of initial centers must be equal to number of weights")
+    require(centers.size == k,
+      s"Number of initial centers must be ${k} but got ${centers.size}")
+    require(weights.forall(_ >= 0),
+      s"Weight for each inital center must be nonnegative but got 
[${weights.mkString(" ")}]")
     model = new StreamingKMeansModel(centers, weights)
     this
   }
@@ -231,6 +237,10 @@ class StreamingKMeans @Since("1.2.0") (
    */
   @Since("1.2.0")
   def setRandomCenters(dim: Int, weight: Double, seed: Long = 
Utils.random.nextLong): this.type = {
+    require(dim > 0,
+      s"Number of dimensions must be positive but got ${dim}")
+    require(weight >= 0,
+      s"Weight for each center must be nonnegative but got ${weight}")
     val random = new XORShiftRandom(seed)
     val centers = 
Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
     val weights = Array.fill(k)(weight)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to