Repository: spark Updated Branches: refs/heads/master 524827f06 -> a95a4af76
[SPARK-23120][PYSPARK][ML] Add basic PMML export support to PySpark ## What changes were proposed in this pull request? Adds basic PMML export support for Spark ML stages to PySpark as was previously done in Scala. Includes LinearRegressionModel as the first stage to implement. ## How was this patch tested? Doctest, the main testing work for this is on the Scala side. (TODO holden add the unittest once I finish locally). Author: Holden Karau <hol...@pigscanfly.ca> Closes #21172 from holdenk/SPARK-23120-add-pmml-export-support-to-pyspark. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a95a4af7 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a95a4af7 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a95a4af7 Branch: refs/heads/master Commit: a95a4af76459016b0d52df90adab68a49904da99 Parents: 524827f Author: Holden Karau <hol...@pigscanfly.ca> Authored: Thu Jun 28 13:20:08 2018 -0700 Committer: Holden Karau <hol...@pigscanfly.ca> Committed: Thu Jun 28 13:20:08 2018 -0700 ---------------------------------------------------------------------- python/pyspark/ml/regression.py | 3 ++- python/pyspark/ml/tests.py | 17 +++++++++++++ python/pyspark/ml/util.py | 46 ++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/a95a4af7/python/pyspark/ml/regression.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index dba0e57..83f0edb 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -95,6 +95,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction True >>> model.numFeatures 1 + >>> model.write().format("pmml").save(model_path + "_2") .. versionadded:: 1.4.0 """ @@ -161,7 +162,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction return self.getOrDefault(self.epsilon) -class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): +class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable): """ Model fitted by :class:`LinearRegression`. http://git-wip-us.apache.org/repos/asf/spark/blob/a95a4af7/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index ebd36cb..bc78213 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1362,6 +1362,23 @@ class PersistenceTest(SparkSessionTestCase): except OSError: pass + def test_linear_regression_pmml_basic(self): + # Most of the validation is done in the Scala side, here we just check + # that we output text rather than parquet (e.g. that the format flag + # was respected). + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1) + model = lr.fit(df) + path = tempfile.mkdtemp() + lr_path = path + "/lr-pmml" + model.write().format("pmml").save(lr_path) + pmml_text_list = self.sc.textFile(lr_path).collect() + pmml_text = "\n".join(pmml_text_list) + self.assertIn("Apache Spark", pmml_text) + self.assertIn("PMML", pmml_text) + def test_logistic_regression(self): lr = LogisticRegression(maxIter=1) path = tempfile.mkdtemp() http://git-wip-us.apache.org/repos/asf/spark/blob/a95a4af7/python/pyspark/ml/util.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 9fa8566..080cd299 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -149,6 +149,23 @@ class MLWriter(BaseReadWrite): @inherit_doc +class GeneralMLWriter(MLWriter): + """ + Utility class that can save ML instances in different formats. + + .. versionadded:: 2.4.0 + """ + + def format(self, source): + """ + Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class + name for export). + """ + self.source = source + return self + + +@inherit_doc class JavaMLWriter(MLWriter): """ (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types @@ -193,6 +210,24 @@ class JavaMLWriter(MLWriter): @inherit_doc +class GeneralJavaMLWriter(JavaMLWriter): + """ + (Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types + """ + + def __init__(self, instance): + super(GeneralJavaMLWriter, self).__init__(instance) + + def format(self, source): + """ + Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class + name for export). + """ + self._jwrite.format(source) + return self + + +@inherit_doc class MLWritable(object): """ Mixin for ML instances that provide :py:class:`MLWriter`. @@ -221,6 +256,17 @@ class JavaMLWritable(MLWritable): @inherit_doc +class GeneralJavaMLWritable(JavaMLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`. + """ + + def write(self): + """Returns an GeneralMLWriter instance for this ML instance.""" + return GeneralJavaMLWriter(self) + + +@inherit_doc class MLReader(BaseReadWrite): """ Utility class that can load ML instances. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org