Repository: spark Updated Branches: refs/heads/master 1c751fcf4 -> 643b4e225
[SPARK-14510][MLLIB] Add args-checking for LDA and StreamingKMeans ## What changes were proposed in this pull request? add the checking for LDA and StreamingKMeans ## How was this patch tested? manual tests Author: Zheng RuiFeng <ruife...@foxmail.com> Closes #12062 from zhengruifeng/initmodel. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/643b4e22 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/643b4e22 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/643b4e22 Branch: refs/heads/master Commit: 643b4e2257c56338b192f8554e2fe5523bea4bdf Parents: 1c751fc Author: Zheng RuiFeng <ruife...@foxmail.com> Authored: Mon Apr 11 09:33:52 2016 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Mon Apr 11 09:33:52 2016 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/mllib/clustering/LDA.scala | 10 +++++++--- .../apache/spark/mllib/clustering/StreamingKMeans.scala | 10 ++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/643b4e22/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 12813fd..d999b9b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -130,7 +130,8 @@ class LDA private ( */ @Since("1.5.0") def setDocConcentration(docConcentration: Vector): this.type = { - require(docConcentration.size > 0, "docConcentration must have > 0 elements") + require(docConcentration.size == 1 || docConcentration.size == k, + s"Size of docConcentration must be 1 or ${k} but got ${docConcentration.size}") this.docConcentration = docConcentration this } @@ -260,15 +261,18 @@ class LDA private ( def getCheckpointInterval: Int = checkpointInterval /** - * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery + * Parameter for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that + * the cache will get checkpointed every 10 iterations. Checkpointing helps with recovery * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be * important when LDA is run for many iterations. If the checkpoint directory is not set in - * [[org.apache.spark.SparkContext]], this setting is ignored. + * [[org.apache.spark.SparkContext]], this setting is ignored. (default = 10) * * @see [[org.apache.spark.SparkContext#setCheckpointDir]] */ @Since("1.3.0") def setCheckpointInterval(checkpointInterval: Int): this.type = { + require(checkpointInterval == -1 || checkpointInterval > 0, + s"Period between checkpoints must be -1 or positive but got ${checkpointInterval}") this.checkpointInterval = checkpointInterval this } http://git-wip-us.apache.org/repos/asf/spark/blob/643b4e22/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 4eb8fc0..24e1cff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -218,6 +218,12 @@ class StreamingKMeans @Since("1.2.0") ( */ @Since("1.2.0") def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { + require(centers.size == weights.size, + "Number of initial centers must be equal to number of weights") + require(centers.size == k, + s"Number of initial centers must be ${k} but got ${centers.size}") + require(weights.forall(_ >= 0), + s"Weight for each inital center must be nonnegative but got [${weights.mkString(" ")}]") model = new StreamingKMeansModel(centers, weights) this } @@ -231,6 +237,10 @@ class StreamingKMeans @Since("1.2.0") ( */ @Since("1.2.0") def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { + require(dim > 0, + s"Number of dimensions must be positive but got ${dim}") + require(weight >= 0, + s"Weight for each center must be nonnegative but got ${weight}") val random = new XORShiftRandom(seed) val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) val weights = Array.fill(k)(weight) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org