This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 06792af [SPARK-16838][PYTHON] Add PMML export for ML KMeans in PySpark 06792af is described below commit 06792afd4c9c719df4af34b0768a999271383330 Author: Huaxin Gao <huax...@us.ibm.com> AuthorDate: Tue Jan 22 09:34:59 2019 -0600 [SPARK-16838][PYTHON] Add PMML export for ML KMeans in PySpark ## What changes were proposed in this pull request? Add PMML export support for ML KMeans to PySpark. ## How was this patch tested? Add tests in ml.tests.PersistenceTest. Closes #23592 from huaxingao/spark-16838. Authored-by: Huaxin Gao <huax...@us.ibm.com> Signed-off-by: Sean Owen <sean.o...@databricks.com> --- python/pyspark/ml/clustering.py | 2 +- python/pyspark/ml/tests/test_persistence.py | 37 +++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 5a776ae..b9c6bdf 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -323,7 +323,7 @@ class KMeansSummary(ClusteringSummary): return self._call_java("trainingCost") -class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): +class KMeansModel(JavaModel, GeneralJavaMLWritable, JavaMLReadable): """ Model fitted by KMeans. diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py index 34d6870..63b0594 100644 --- a/python/pyspark/ml/tests/test_persistence.py +++ b/python/pyspark/ml/tests/test_persistence.py @@ -23,6 +23,7 @@ import unittest from pyspark.ml import Transformer from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \ OneVsRestModel +from pyspark.ml.clustering import KMeans from pyspark.ml.feature import Binarizer, HashingTF, PCA from pyspark.ml.linalg import Vectors from pyspark.ml.param import Params @@ -89,6 +90,42 @@ class PersistenceTest(SparkSessionTestCase): except OSError: pass + def test_kmeans(self): + kmeans = KMeans(k=2, seed=1) + path = tempfile.mkdtemp() + km_path = path + "/km" + kmeans.save(km_path) + kmeans2 = KMeans.load(km_path) + self.assertEqual(kmeans.uid, kmeans2.uid) + self.assertEqual(type(kmeans.uid), type(kmeans2.uid)) + self.assertEqual(kmeans2.uid, kmeans2.k.parent, + "Loaded KMeans instance uid (%s) did not match Param's uid (%s)" + % (kmeans2.uid, kmeans2.k.parent)) + self.assertEqual(kmeans._defaultParamMap[kmeans.k], kmeans2._defaultParamMap[kmeans2.k], + "Loaded KMeans instance default params did not match " + + "original defaults") + try: + rmtree(path) + except OSError: + pass + + def test_kmean_pmml_basic(self): + # Most of the validation is done in the Scala side, here we just check + # that we output text rather than parquet (e.g. that the format flag + # was respected). + data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), + (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=2, seed=1) + model = kmeans.fit(df) + path = tempfile.mkdtemp() + km_path = path + "/km-pmml" + model.write().format("pmml").save(km_path) + pmml_text_list = self.sc.textFile(km_path).collect() + pmml_text = "\n".join(pmml_text_list) + self.assertIn("Apache Spark", pmml_text) + self.assertIn("PMML", pmml_text) + def _compare_params(self, m1, m2, param): """ Compare 2 ML Params instances for the given param, and assert both have the same param value --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org