Repository: spark
Updated Branches:
  refs/heads/master 007da1a9d -> 95eb65163


[SPARK-11945][ML][PYSPARK] Add computeCost to KMeansModel for PySpark spark.ml

Add ```computeCost``` to ```KMeansModel``` as evaluator for PySpark spark.ml.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #9931 from yanboliang/SPARK-11945.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/95eb6516
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/95eb6516
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/95eb6516

Branch: refs/heads/master
Commit: 95eb65163391b9e910277a948b72efccf6136e0c
Parents: 007da1a
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Wed Jan 6 10:50:02 2016 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Wed Jan 6 10:50:02 2016 -0800

----------------------------------------------------------------------
 python/pyspark/ml/clustering.py | 10 ++++++++++
 1 file changed, 10 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/95eb6516/python/pyspark/ml/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 7bb8ab9..9189c02 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -36,6 +36,14 @@ class KMeansModel(JavaModel):
         """Get the cluster centers, represented as a list of NumPy arrays."""
         return [c.toArray() for c in self._call_java("clusterCenters")]
 
+    @since("2.0.0")
+    def computeCost(self, dataset):
+        """
+        Return the K-means cost (sum of squared distances of points to their 
nearest center)
+        for this model on the given data.
+        """
+        return self._call_java("computeCost", dataset)
+
 
 @inherit_doc
 class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, 
HasTol, HasSeed):
@@ -53,6 +61,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, 
HasMaxIter, HasTol
     >>> centers = model.clusterCenters()
     >>> len(centers)
     2
+    >>> model.computeCost(df)
+    2.000...
     >>> transformed = model.transform(df).select("features", "prediction")
     >>> rows = transformed.collect()
     >>> rows[0].prediction == rows[1].prediction


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to