Repository: spark Updated Branches: refs/heads/branch-1.5 547120287 -> 295266049
[SPARK-8481] [MLLIB] GaussianMixtureModel predict accepting single vector Resubmit of [https://github.com/apache/spark/pull/6906] for adding single-vec predict to GMMs CC: dkobylarz mengxr To be merged with master and branch-1.5 Primary author: dkobylarz Author: Dariusz Kobylarz <darek.kobyl...@gmail.com> Closes #8039 from jkbradley/gmm-predict-vec and squashes the following commits: bfbedc4 [Dariusz Kobylarz] [SPARK-8481] [MLlib] GaussianMixtureModel predict accepting single vector (cherry picked from commit e2fbbe73111d4624390f596a19a1799c86a05f6c) Signed-off-by: Joseph K. Bradley <jos...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/29526604 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/29526604 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/29526604 Branch: refs/heads/branch-1.5 Commit: 29526604916a5e1dff12fcbc395f1039b3a69dcd Parents: 5471202 Author: Dariusz Kobylarz <darek.kobyl...@gmail.com> Authored: Fri Aug 7 14:51:03 2015 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Fri Aug 7 14:51:13 2015 -0700 ---------------------------------------------------------------------- .../spark/mllib/clustering/GaussianMixtureModel.scala | 13 +++++++++++++ .../spark/mllib/clustering/GaussianMixtureSuite.scala | 10 ++++++++++ 2 files changed, 23 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/29526604/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index cb807c8..76aeebd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -66,6 +66,12 @@ class GaussianMixtureModel( responsibilityMatrix.map(r => r.indexOf(r.max)) } + /** Maps given point to its cluster index. */ + def predict(point: Vector): Int = { + val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + r.indexOf(r.max) + } + /** Java-friendly version of [[predict()]] */ def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] @@ -84,6 +90,13 @@ class GaussianMixtureModel( } /** + * Given the input vector, return the membership values to all mixture components. + */ + def predictSoft(point: Vector): Array[Double] = { + computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + } + + /** * Compute the partial assignments for each vector */ private def computeSoftAssignments( http://git-wip-us.apache.org/repos/asf/spark/blob/29526604/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index b218d72..b636d02 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -148,6 +148,16 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("model prediction, parallel and local") { + val data = sc.parallelize(GaussianTestData.data) + val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) + + val batchPredictions = gmm.predict(data) + batchPredictions.zip(data).collect().foreach { case (batchPred, datum) => + assert(batchPred === gmm.predict(datum)) + } + } + object GaussianTestData { val data = Array( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org