Repository: spark Updated Branches: refs/heads/master 88c9c467a -> 9893dc975
[SPARK-15610][ML] update error message for k in pca ## What changes were proposed in this pull request? Fix the wrong bound of `k` in `PCA` `require(k <= sources.first().size, ...` -> `require(k < sources.first().size` BTW, remove unused import in `ml.ElementwiseProduct` ## How was this patch tested? manual tests Author: Zheng RuiFeng <ruife...@foxmail.com> Closes #13356 from zhengruifeng/fix_pca. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9893dc97 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9893dc97 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9893dc97 Branch: refs/heads/master Commit: 9893dc975784551a62f65bbd709f8972e0204b2a Parents: 88c9c46 Author: Zheng RuiFeng <ruife...@foxmail.com> Authored: Fri May 27 21:57:41 2016 -0500 Committer: Sean Owen <so...@cloudera.com> Committed: Fri May 27 21:57:41 2016 -0500 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/feature/ElementwiseProduct.scala | 1 - mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9893dc97/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 91989c3..9d2e60f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -23,7 +23,6 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.Param import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.sql.types.DataType http://git-wip-us.apache.org/repos/asf/spark/blob/9893dc97/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 30c403e..15b7220 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -40,8 +40,9 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { */ @Since("1.4.0") def fit(sources: RDD[Vector]): PCAModel = { - require(k <= sources.first().size, - s"source vector size is ${sources.first().size} must be greater than k=$k") + val numFeatures = sources.first().size + require(k <= numFeatures, + s"source vector size $numFeatures must be no less than k=$k") val mat = new RowMatrix(sources) val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) @@ -58,7 +59,6 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { case m => throw new IllegalArgumentException("Unsupported matrix format. Expected " + s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}") - } val denseExplainedVariance = explainedVariance match { case dv: DenseVector => --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org