Repository: spark
Updated Branches:
  refs/heads/branch-1.5 547120287 -> 295266049


[SPARK-8481] [MLLIB] GaussianMixtureModel predict accepting single vector

Resubmit of [https://github.com/apache/spark/pull/6906] for adding single-vec 
predict to GMMs

CC: dkobylarz  mengxr

To be merged with master and branch-1.5
Primary author: dkobylarz

Author: Dariusz Kobylarz <darek.kobyl...@gmail.com>

Closes #8039 from jkbradley/gmm-predict-vec and squashes the following commits:

bfbedc4 [Dariusz Kobylarz] [SPARK-8481] [MLlib] GaussianMixtureModel predict 
accepting single vector

(cherry picked from commit e2fbbe73111d4624390f596a19a1799c86a05f6c)
Signed-off-by: Joseph K. Bradley <jos...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/29526604
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/29526604
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/29526604

Branch: refs/heads/branch-1.5
Commit: 29526604916a5e1dff12fcbc395f1039b3a69dcd
Parents: 5471202
Author: Dariusz Kobylarz <darek.kobyl...@gmail.com>
Authored: Fri Aug 7 14:51:03 2015 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Fri Aug 7 14:51:13 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/clustering/GaussianMixtureModel.scala  | 13 +++++++++++++
 .../spark/mllib/clustering/GaussianMixtureSuite.scala  | 10 ++++++++++
 2 files changed, 23 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/29526604/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index cb807c8..76aeebd 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -66,6 +66,12 @@ class GaussianMixtureModel(
     responsibilityMatrix.map(r => r.indexOf(r.max))
   }
 
+  /** Maps given point to its cluster index. */
+  def predict(point: Vector): Int = {
+    val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, 
weights, k)
+    r.indexOf(r.max)
+  }
+
   /** Java-friendly version of [[predict()]] */
   def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
     predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
@@ -84,6 +90,13 @@ class GaussianMixtureModel(
   }
 
   /**
+   * Given the input vector, return the membership values to all mixture 
components.
+   */
+  def predictSoft(point: Vector): Array[Double] = {
+    computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
+  }
+
+  /**
    * Compute the partial assignments for each vector
    */
   private def computeSoftAssignments(

http://git-wip-us.apache.org/repos/asf/spark/blob/29526604/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index b218d72..b636d02 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -148,6 +148,16 @@ class GaussianMixtureSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     }
   }
 
+  test("model prediction, parallel and local") {
+    val data = sc.parallelize(GaussianTestData.data)
+    val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
+
+    val batchPredictions = gmm.predict(data)
+    batchPredictions.zip(data).collect().foreach { case (batchPred, datum) =>
+      assert(batchPred === gmm.predict(datum))
+    }
+  }
+
   object GaussianTestData {
 
     val data = Array(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to