Repository: spark Updated Branches: refs/heads/master 007da1a9d -> 95eb65163
[SPARK-11945][ML][PYSPARK] Add computeCost to KMeansModel for PySpark spark.ml Add ```computeCost``` to ```KMeansModel``` as evaluator for PySpark spark.ml. Author: Yanbo Liang <yblia...@gmail.com> Closes #9931 from yanboliang/SPARK-11945. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/95eb6516 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/95eb6516 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/95eb6516 Branch: refs/heads/master Commit: 95eb65163391b9e910277a948b72efccf6136e0c Parents: 007da1a Author: Yanbo Liang <yblia...@gmail.com> Authored: Wed Jan 6 10:50:02 2016 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Wed Jan 6 10:50:02 2016 -0800 ---------------------------------------------------------------------- python/pyspark/ml/clustering.py | 10 ++++++++++ 1 file changed, 10 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/95eb6516/python/pyspark/ml/clustering.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7bb8ab9..9189c02 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -36,6 +36,14 @@ class KMeansModel(JavaModel): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] + @since("2.0.0") + def computeCost(self, dataset): + """ + Return the K-means cost (sum of squared distances of points to their nearest center) + for this model on the given data. + """ + return self._call_java("computeCost", dataset) + @inherit_doc class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): @@ -53,6 +61,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol >>> centers = model.clusterCenters() >>> len(centers) 2 + >>> model.computeCost(df) + 2.000... >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org