Repository: spark
Updated Branches:
  refs/heads/master 1a2655a9e -> ca8243f30


[MINOR][ML] Minor correction in the powerIterationSuite

## What changes were proposed in this pull request?

Currently the power iteration clustering test in  spark ml, maps the results to 
the labels 0 and 1 for assertion. Since the clustering outputs need not be the 
same as the mapped labels, it may cause failure in the test case. Even if it 
correctly maps, theoretically we cannot guarantee which set belongs to which 
cluster label. KMeans can assign label 0 to either of the set.

PowerIterationClusteringSuite in the MLLib checks the clustering results 
without mapping to the particular cluster label, as shown below.
``  val predictions = Array.fill(2)(mutable.Set.empty[Long])
    model.assignments.collect().foreach { a =>
      predictions(a.cluster) += a.id
    }
    assert(predictions.toSet == Set((0 until n1).toSet, (n1 until n).toSet))
``

## How was this patch tested?
Existing tests

Author: Shahid <shahidk...@gmail.com>

Closes #21689 from shahidki31/picTestSuiteMinorCorrection.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca8243f3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca8243f3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca8243f3

Branch: refs/heads/master
Commit: ca8243f30fc6939ee099a9534e3b811d5c64d2cf
Parents: 1a2655a
Author: Shahid <shahidk...@gmail.com>
Authored: Wed Jul 4 09:56:24 2018 -0500
Committer: Sean Owen <sro...@gmail.com>
Committed: Wed Jul 4 09:56:24 2018 -0500

----------------------------------------------------------------------
 .../PowerIterationClusteringSuite.scala         | 30 +++++++++++++-------
 1 file changed, 20 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ca8243f3/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
index b707272..55b460f 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.ml.clustering
 
+import scala.collection.mutable
+
 import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -76,12 +78,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite
       .setMaxIter(40)
       .setWeightCol("weight")
       .assignClusters(data)
-    val localAssignments = assignments
-      .select('id, 'cluster)
-      .as[(Long, Int)].collect().toSet
-    val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++
-      (n1 until n).map(x => (x, 0)).toSet
-    assert(localAssignments === expectedResult)
+      .select("id", "cluster")
+      .as[(Long, Int)]
+      .collect()
+
+    val predictions = Array.fill(2)(mutable.Set.empty[Long])
+    assignments.foreach {
+      case (id, cluster) => predictions(cluster) += id
+    }
+    assert(predictions.toSet === Set((0 until n1).toSet, (n1 until n).toSet))
 
     val assignments2 = new PowerIterationClustering()
       .setK(2)
@@ -89,10 +94,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite
       .setInitMode("degree")
       .setWeightCol("weight")
       .assignClusters(data)
-    val localAssignments2 = assignments2
-      .select('id, 'cluster)
-      .as[(Long, Int)].collect().toSet
-    assert(localAssignments2 === expectedResult)
+      .select("id", "cluster")
+      .as[(Long, Int)]
+      .collect()
+
+    val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
+    assignments2.foreach {
+      case (id, cluster) => predictions2(cluster) += id
+    }
+    assert(predictions2.toSet === Set((0 until n1).toSet, (n1 until n).toSet))
   }
 
   test("supported input types") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to