Repository: spark
Updated Branches:
  refs/heads/master aa1e22b17 -> 7450a992b


[SPARK-4749] [mllib]: Allow initializing KMeans clusters using a seed

This implements the functionality for SPARK-4749 and provides units tests in 
Scala and PySpark

Author: nate.crosswhite <nate.crosswh...@stresearch.com>
Author: nxwhite-str <nxwhite-...@users.noreply.github.com>
Author: Xiangrui Meng <m...@databricks.com>

Closes #3610 from nxwhite-str/master and squashes the following commits:

a2ebbd3 [nxwhite-str] Merge pull request #1 from mengxr/SPARK-4749-kmeans-seed
7668124 [Xiangrui Meng] minor updates
f8d5928 [nate.crosswhite] Addressing PR issues
277d367 [nate.crosswhite] Merge remote-tracking branch 'upstream/master'
9156a57 [nate.crosswhite] Merge remote-tracking branch 'upstream/master'
5d087b4 [nate.crosswhite] Adding KMeans train with seed and Scala unit test
616d111 [nate.crosswhite] Merge remote-tracking branch 'upstream/master'
35c1884 [nate.crosswhite] Add kmeans initial seed to pyspark API


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7450a992
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7450a992
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7450a992

Branch: refs/heads/master
Commit: 7450a992b3b543a373c34fc4444a528954ac4b4a
Parents: aa1e22b
Author: nate.crosswhite <nate.crosswh...@stresearch.com>
Authored: Wed Jan 21 10:32:10 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Wed Jan 21 10:32:10 2015 -0800

----------------------------------------------------------------------
 .../spark/mllib/api/python/PythonMLLibAPI.scala |  6 ++-
 .../apache/spark/mllib/clustering/KMeans.scala  | 48 ++++++++++++++++----
 .../spark/mllib/clustering/KMeansSuite.scala    | 21 +++++++++
 python/pyspark/mllib/clustering.py              |  4 +-
 python/pyspark/mllib/tests.py                   | 17 ++++++-
 5 files changed, 84 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7450a992/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 555da8c..430d763 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -266,12 +266,16 @@ class PythonMLLibAPI extends Serializable {
       k: Int,
       maxIterations: Int,
       runs: Int,
-      initializationMode: String): KMeansModel = {
+      initializationMode: String,
+      seed: java.lang.Long): KMeansModel = {
     val kMeansAlg = new KMeans()
       .setK(k)
       .setMaxIterations(maxIterations)
       .setRuns(runs)
       .setInitializationMode(initializationMode)
+
+    if (seed != null) kMeansAlg.setSeed(seed)
+
     try {
       kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
     } finally {

http://git-wip-us.apache.org/repos/asf/spark/blob/7450a992/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 54c301d..6b5c934 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -19,14 +19,14 @@ package org.apache.spark.mllib.clustering
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.annotation.Experimental
 import org.apache.spark.Logging
-import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
 import org.apache.spark.util.random.XORShiftRandom
 
 /**
@@ -43,13 +43,14 @@ class KMeans private (
     private var runs: Int,
     private var initializationMode: String,
     private var initializationSteps: Int,
-    private var epsilon: Double) extends Serializable with Logging {
+    private var epsilon: Double,
+    private var seed: Long) extends Serializable with Logging {
 
   /**
    * Constructs a KMeans instance with default parameters: {k: 2, 
maxIterations: 20, runs: 1,
-   * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}.
+   * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, 
seed: random}.
    */
-  def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)
+  def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, 
Utils.random.nextLong())
 
   /** Set the number of clusters to create (k). Default: 2. */
   def setK(k: Int): this.type = {
@@ -112,6 +113,12 @@ class KMeans private (
     this
   }
 
+  /** Set the random seed for cluster initialization. */
+  def setSeed(seed: Long): this.type = {
+    this.seed = seed
+    this
+  }
+
   /**
    * Train a K-means model on the given set of points; `data` should be cached 
for high
    * performance, because this is an iterative algorithm.
@@ -255,7 +262,7 @@ class KMeans private (
   private def initRandom(data: RDD[VectorWithNorm])
   : Array[Array[VectorWithNorm]] = {
     // Sample all the cluster centers in one pass to avoid repeated scans
-    val sample = data.takeSample(true, runs * k, new 
XORShiftRandom().nextInt()).toSeq
+    val sample = data.takeSample(true, runs * k, new 
XORShiftRandom(this.seed).nextInt()).toSeq
     Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
       new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
     }.toArray)
@@ -273,7 +280,7 @@ class KMeans private (
   private def initKMeansParallel(data: RDD[VectorWithNorm])
   : Array[Array[VectorWithNorm]] = {
     // Initialize each run's center to a random point
-    val seed = new XORShiftRandom().nextInt()
+    val seed = new XORShiftRandom(this.seed).nextInt()
     val sample = data.takeSample(true, runs, seed).toSeq
     val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
 
@@ -333,7 +340,32 @@ object KMeans {
   /**
    * Trains a k-means model using the given set of parameters.
    *
-   * @param data training points stored as `RDD[Array[Double]]`
+   * @param data training points stored as `RDD[Vector]`
+   * @param k number of clusters
+   * @param maxIterations max number of iterations
+   * @param runs number of parallel runs, defaults to 1. The best model is 
returned.
+   * @param initializationMode initialization model, either "random" or 
"k-means||" (default).
+   * @param seed random seed value for cluster initialization
+   */
+  def train(
+      data: RDD[Vector],
+      k: Int,
+      maxIterations: Int,
+      runs: Int,
+      initializationMode: String,
+      seed: Long): KMeansModel = {
+    new KMeans().setK(k)
+      .setMaxIterations(maxIterations)
+      .setRuns(runs)
+      .setInitializationMode(initializationMode)
+      .setSeed(seed)
+      .run(data)
+  }
+
+  /**
+   * Trains a k-means model using the given set of parameters.
+   *
+   * @param data training points stored as `RDD[Vector]`
    * @param k number of clusters
    * @param maxIterations max number of iterations
    * @param runs number of parallel runs, defaults to 1. The best model is 
returned.

http://git-wip-us.apache.org/repos/asf/spark/blob/7450a992/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 9ebef84..caee591 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -90,6 +90,27 @@ class KMeansSuite extends FunSuite with 
MLlibTestSparkContext {
     assert(model.clusterCenters.size === 3)
   }
 
+  test("deterministic initialization") {
+    // Create a large-ish set of points for clustering
+    val points = List.tabulate(1000)(n => Vectors.dense(n, n))
+    val rdd = sc.parallelize(points, 3)
+
+    for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
+      // Create three deterministic models and compare cluster means
+      val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
+        initializationMode = initMode, seed = 42)
+      val centers1 = model1.clusterCenters
+
+      val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
+        initializationMode = initMode, seed = 42)
+      val centers2 = model2.clusterCenters
+
+      centers1.zip(centers2).foreach { case (c1, c2) =>
+        assert(c1 ~== c2 absTol 1E-14)
+      }
+    }
+  }
+
   test("single cluster with big dataset") {
     val smallData = Array(
       Vectors.dense(1.0, 2.0, 6.0),

http://git-wip-us.apache.org/repos/asf/spark/blob/7450a992/python/pyspark/mllib/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/clustering.py 
b/python/pyspark/mllib/clustering.py
index e2492ee..6b713aa 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -78,10 +78,10 @@ class KMeansModel(object):
 class KMeans(object):
 
     @classmethod
-    def train(cls, rdd, k, maxIterations=100, runs=1, 
initializationMode="k-means||"):
+    def train(cls, rdd, k, maxIterations=100, runs=1, 
initializationMode="k-means||", seed=None):
         """Train a k-means clustering model."""
         model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), 
k, maxIterations,
-                              runs, initializationMode)
+                              runs, initializationMode, seed)
         centers = callJavaFunc(rdd.context, model.clusterCenters)
         return KMeansModel([c.toArray() for c in centers])
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7450a992/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 140c22b..f48e3d6 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -140,7 +140,7 @@ class ListTests(PySparkTestCase):
     as NumPy arrays.
     """
 
-    def test_clustering(self):
+    def test_kmeans(self):
         from pyspark.mllib.clustering import KMeans
         data = [
             [0, 1.1],
@@ -152,6 +152,21 @@ class ListTests(PySparkTestCase):
         self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1]))
         self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3]))
 
+    def test_kmeans_deterministic(self):
+        from pyspark.mllib.clustering import KMeans
+        X = range(0, 100, 10)
+        Y = range(0, 100, 10)
+        data = [[x, y] for x, y in zip(X, Y)]
+        clusters1 = KMeans.train(self.sc.parallelize(data),
+                                 3, initializationMode="k-means||", seed=42)
+        clusters2 = KMeans.train(self.sc.parallelize(data),
+                                 3, initializationMode="k-means||", seed=42)
+        centers1 = clusters1.centers
+        centers2 = clusters2.centers
+        for c1, c2 in zip(centers1, centers2):
+            # TODO: Allow small numeric difference.
+            self.assertTrue(array_equal(c1, c2))
+
     def test_classification(self):
         from pyspark.mllib.classification import LogisticRegressionWithSGD, 
SVMWithSGD, NaiveBayes
         from pyspark.mllib.tree import DecisionTree


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to