Repository: spark
Updated Branches:
  refs/heads/master 524827f06 -> a95a4af76


[SPARK-23120][PYSPARK][ML] Add basic PMML export support to PySpark

## What changes were proposed in this pull request?

Adds basic PMML export support for Spark ML stages to PySpark as was previously 
done in Scala. Includes LinearRegressionModel as the first stage to implement.

## How was this patch tested?

Doctest, the main testing work for this is on the Scala side. (TODO holden add 
the unittest once I finish locally).

Author: Holden Karau <hol...@pigscanfly.ca>

Closes #21172 from holdenk/SPARK-23120-add-pmml-export-support-to-pyspark.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a95a4af7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a95a4af7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a95a4af7

Branch: refs/heads/master
Commit: a95a4af76459016b0d52df90adab68a49904da99
Parents: 524827f
Author: Holden Karau <hol...@pigscanfly.ca>
Authored: Thu Jun 28 13:20:08 2018 -0700
Committer: Holden Karau <hol...@pigscanfly.ca>
Committed: Thu Jun 28 13:20:08 2018 -0700

----------------------------------------------------------------------
 python/pyspark/ml/regression.py |  3 ++-
 python/pyspark/ml/tests.py      | 17 +++++++++++++
 python/pyspark/ml/util.py       | 46 ++++++++++++++++++++++++++++++++++++
 3 files changed, 65 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a95a4af7/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index dba0e57..83f0edb 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -95,6 +95,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPrediction
     True
     >>> model.numFeatures
     1
+    >>> model.write().format("pmml").save(model_path + "_2")
 
     .. versionadded:: 1.4.0
     """
@@ -161,7 +162,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPrediction
         return self.getOrDefault(self.epsilon)
 
 
-class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, 
JavaMLReadable):
+class LinearRegressionModel(JavaModel, JavaPredictionModel, 
GeneralJavaMLWritable, JavaMLReadable):
     """
     Model fitted by :class:`LinearRegression`.
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a95a4af7/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index ebd36cb..bc78213 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1362,6 +1362,23 @@ class PersistenceTest(SparkSessionTestCase):
         except OSError:
             pass
 
+    def test_linear_regression_pmml_basic(self):
+        # Most of the validation is done in the Scala side, here we just check
+        # that we output text rather than parquet (e.g. that the format flag
+        # was respected).
+        df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+                                         (0.0, 2.0, Vectors.sparse(1, [], 
[]))],
+                                        ["label", "weight", "features"])
+        lr = LinearRegression(maxIter=1)
+        model = lr.fit(df)
+        path = tempfile.mkdtemp()
+        lr_path = path + "/lr-pmml"
+        model.write().format("pmml").save(lr_path)
+        pmml_text_list = self.sc.textFile(lr_path).collect()
+        pmml_text = "\n".join(pmml_text_list)
+        self.assertIn("Apache Spark", pmml_text)
+        self.assertIn("PMML", pmml_text)
+
     def test_logistic_regression(self):
         lr = LogisticRegression(maxIter=1)
         path = tempfile.mkdtemp()

http://git-wip-us.apache.org/repos/asf/spark/blob/a95a4af7/python/pyspark/ml/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 9fa8566..080cd299 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -149,6 +149,23 @@ class MLWriter(BaseReadWrite):
 
 
 @inherit_doc
+class GeneralMLWriter(MLWriter):
+    """
+    Utility class that can save ML instances in different formats.
+
+    .. versionadded:: 2.4.0
+    """
+
+    def format(self, source):
+        """
+        Specifies the format of ML export (e.g. "pmml", "internal", or the 
fully qualified class
+        name for export).
+        """
+        self.source = source
+        return self
+
+
+@inherit_doc
 class JavaMLWriter(MLWriter):
     """
     (Private) Specialization of :py:class:`MLWriter` for 
:py:class:`JavaParams` types
@@ -193,6 +210,24 @@ class JavaMLWriter(MLWriter):
 
 
 @inherit_doc
+class GeneralJavaMLWriter(JavaMLWriter):
+    """
+    (Private) Specialization of :py:class:`GeneralMLWriter` for 
:py:class:`JavaParams` types
+    """
+
+    def __init__(self, instance):
+        super(GeneralJavaMLWriter, self).__init__(instance)
+
+    def format(self, source):
+        """
+        Specifies the format of ML export (e.g. "pmml", "internal", or the 
fully qualified class
+        name for export).
+        """
+        self._jwrite.format(source)
+        return self
+
+
+@inherit_doc
 class MLWritable(object):
     """
     Mixin for ML instances that provide :py:class:`MLWriter`.
@@ -221,6 +256,17 @@ class JavaMLWritable(MLWritable):
 
 
 @inherit_doc
+class GeneralJavaMLWritable(JavaMLWritable):
+    """
+    (Private) Mixin for ML instances that provide 
:py:class:`GeneralJavaMLWriter`.
+    """
+
+    def write(self):
+        """Returns an GeneralMLWriter instance for this ML instance."""
+        return GeneralJavaMLWriter(self)
+
+
+@inherit_doc
 class MLReader(BaseReadWrite):
     """
     Utility class that can load ML instances.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to