Repository: spark
Updated Branches:
  refs/heads/master 775cf17ea -> b64482f49


[SPARK-14306][ML][PYSPARK] PySpark ml.classification OneVsRest support 
export/import

## What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-14306

Add PySpark OneVsRest save/load supports.

## How was this patch tested?

Test with Python unit test.

Author: Xusen Yin <yinxu...@gmail.com>

Closes #12439 from yinxusen/SPARK-14306-0415.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b64482f4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b64482f4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b64482f4

Branch: refs/heads/master
Commit: b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a
Parents: 775cf17
Author: Xusen Yin <yinxu...@gmail.com>
Authored: Mon Apr 18 11:52:29 2016 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Mon Apr 18 11:52:29 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/OneVsRest.scala     |   7 +
 python/pyspark/ml/classification.py             | 142 ++++++++++++++++---
 python/pyspark/ml/tests.py                      |  25 +++-
 3 files changed, 151 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b64482f4/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 4de1b87..f10c60a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -17,8 +17,10 @@
 
 package org.apache.spark.ml.classification
 
+import java.util.{List => JList}
 import java.util.UUID
 
+import scala.collection.JavaConverters._
 import scala.language.existentials
 
 import org.apache.hadoop.fs.Path
@@ -135,6 +137,11 @@ final class OneVsRestModel private[ml] (
     @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
   extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
 
+  /** A Python-friendly auxiliary constructor. */
+  private[ml] def this(uid: String, models: JList[_ <: ClassificationModel[_, 
_]]) = {
+    this(uid, Metadata.empty, models.asScala.toArray)
+  }
+
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
     validateAndTransformSchema(schema, fitting = false, 
getClassifier.featuresDataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/b64482f4/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index 0893167..de1321b 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -23,7 +23,7 @@ from pyspark.ml.param.shared import *
 from pyspark.ml.regression import (
     RandomForestParams, TreeEnsembleParams, DecisionTreeModel, 
TreeEnsembleModels)
 from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
 from pyspark.ml.wrapper import JavaWrapper
 from pyspark.mllib.common import inherit_doc
 from pyspark.sql import DataFrame
@@ -1160,8 +1160,33 @@ class MultilayerPerceptronClassificationModel(JavaModel, 
JavaMLWritable, JavaMLR
         return self._call_java("weights")
 
 
+class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
+    """
+    Parameters for OneVsRest and OneVsRestModel.
+    """
+
+    classifier = Param(Params._dummy(), "classifier", "base binary classifier")
+
+    @since("2.0.0")
+    def setClassifier(self, value):
+        """
+        Sets the value of :py:attr:`classifier`.
+
+        .. note:: Only LogisticRegression and NaiveBayes are supported now.
+        """
+        self._set(classifier=value)
+        return self
+
+    @since("2.0.0")
+    def getClassifier(self):
+        """
+        Gets the value of classifier or its default value.
+        """
+        return self.getOrDefault(self.classifier)
+
+
 @inherit_doc
-class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
+class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
     """
     Reduction of Multiclass Classification to Binary Classification.
     Performs reduction using one against all strategy.
@@ -1195,8 +1220,6 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, 
HasPredictionCol):
     .. versionadded:: 2.0.0
     """
 
-    classifier = Param(Params._dummy(), "classifier", "base binary classifier")
-
     @keyword_only
     def __init__(self, featuresCol="features", labelCol="label", 
predictionCol="prediction",
                  classifier=None):
@@ -1218,23 +1241,6 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, 
HasPredictionCol):
         kwargs = self.setParams._input_kwargs
         return self._set(**kwargs)
 
-    @since("2.0.0")
-    def setClassifier(self, value):
-        """
-        Sets the value of :py:attr:`classifier`.
-
-        .. note:: Only LogisticRegression and NaiveBayes are supported now.
-        """
-        self._set(classifier=value)
-        return self
-
-    @since("2.0.0")
-    def getClassifier(self):
-        """
-        Gets the value of classifier or its default value.
-        """
-        return self.getOrDefault(self.classifier)
-
     def _fit(self, dataset):
         labelCol = self.getLabelCol()
         featuresCol = self.getFeaturesCol()
@@ -1288,8 +1294,53 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, 
HasPredictionCol):
             newOvr.setClassifier(self.getClassifier().copy(extra))
         return newOvr
 
+    @since("2.0.0")
+    def write(self):
+        """Returns an MLWriter instance for this ML instance."""
+        return JavaMLWriter(self)
+
+    @since("2.0.0")
+    def save(self, path):
+        """Save this ML instance to the given path, a shortcut of 
`write().save(path)`."""
+        self.write().save(path)
+
+    @classmethod
+    @since("2.0.0")
+    def read(cls):
+        """Returns an MLReader instance for this class."""
+        return JavaMLReader(cls)
+
+    @classmethod
+    def _from_java(cls, java_stage):
+        """
+        Given a Java OneVsRest, create and return a Python wrapper of it.
+        Used for ML persistence.
+        """
+        featuresCol = java_stage.getFeaturesCol()
+        labelCol = java_stage.getLabelCol()
+        predictionCol = java_stage.getPredictionCol()
+        classifier = JavaParams._from_java(java_stage.getClassifier())
+        py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, 
predictionCol=predictionCol,
+                       classifier=classifier)
+        py_stage._resetUid(java_stage.uid())
+        return py_stage
+
+    def _to_java(self):
+        """
+        Transfer this instance to a Java OneVsRest. Used for ML persistence.
+
+        :return: Java object equivalent to this instance.
+        """
+        _java_obj = 
JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
+                                             self.uid)
+        _java_obj.setClassifier(self.getClassifier()._to_java())
+        _java_obj.setFeaturesCol(self.getFeaturesCol())
+        _java_obj.setLabelCol(self.getLabelCol())
+        _java_obj.setPredictionCol(self.getPredictionCol())
+        return _java_obj
 
-class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol):
+
+class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
     """
     Model fitted by OneVsRest.
     This stores the models resulting from training k binary classifiers: one 
for each class.
@@ -1367,6 +1418,53 @@ class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, 
HasPredictionCol):
         newModel.models = [model.copy(extra) for model in self.models]
         return newModel
 
+    @since("2.0.0")
+    def write(self):
+        """Returns an MLWriter instance for this ML instance."""
+        return JavaMLWriter(self)
+
+    @since("2.0.0")
+    def save(self, path):
+        """Save this ML instance to the given path, a shortcut of 
`write().save(path)`."""
+        self.write().save(path)
+
+    @classmethod
+    @since("2.0.0")
+    def read(cls):
+        """Returns an MLReader instance for this class."""
+        return JavaMLReader(cls)
+
+    @classmethod
+    def _from_java(cls, java_stage):
+        """
+        Given a Java OneVsRestModel, create and return a Python wrapper of it.
+        Used for ML persistence.
+        """
+        featuresCol = java_stage.getFeaturesCol()
+        labelCol = java_stage.getLabelCol()
+        predictionCol = java_stage.getPredictionCol()
+        classifier = JavaParams._from_java(java_stage.getClassifier())
+        models = [JavaParams._from_java(model) for model in 
java_stage.models()]
+        py_stage = 
cls(models=models).setPredictionCol(predictionCol).setLabelCol(labelCol)\
+            .setFeaturesCol(featuresCol).setClassifier(classifier)
+        py_stage._resetUid(java_stage.uid())
+        return py_stage
+
+    def _to_java(self):
+        """
+        Transfer this instance to a Java OneVsRestModel. Used for ML 
persistence.
+
+        :return: Java object equivalent to this instance.
+        """
+        java_models = [model._to_java() for model in self.models]
+        _java_obj = 
JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
+                                             self.uid, java_models)
+        _java_obj.set("classifier", self.getClassifier()._to_java())
+        _java_obj.set("featuresCol", self.getFeaturesCol())
+        _java_obj.set("labelCol", self.getLabelCol())
+        _java_obj.set("predictionCol", self.getPredictionCol())
+        return _java_obj
+
 
 if __name__ == "__main__":
     import doctest

http://git-wip-us.apache.org/repos/asf/spark/blob/b64482f4/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index a7a9868..9d6ff47 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -43,7 +43,8 @@ import tempfile
 import numpy as np
 
 from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
-from pyspark.ml.classification import LogisticRegression, 
DecisionTreeClassifier, OneVsRest
+from pyspark.ml.classification import (
+    LogisticRegression, DecisionTreeClassifier, OneVsRest, OneVsRestModel)
 from pyspark.ml.clustering import KMeans
 from pyspark.ml.evaluation import BinaryClassificationEvaluator, 
RegressionEvaluator
 from pyspark.ml.feature import *
@@ -881,6 +882,28 @@ class OneVsRestTests(PySparkTestCase):
         output = model.transform(df)
         self.assertEqual(output.columns, ["label", "features", "prediction"])
 
+    def test_save_load(self):
+        temp_path = tempfile.mkdtemp()
+        sqlContext = SQLContext(self.sc)
+        df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+                                         (1.0, Vectors.sparse(2, [], [])),
+                                         (2.0, Vectors.dense(0.5, 0.5))],
+                                        ["label", "features"])
+        lr = LogisticRegression(maxIter=5, regParam=0.01)
+        ovr = OneVsRest(classifier=lr)
+        model = ovr.fit(df)
+        ovrPath = temp_path + "/ovr"
+        ovr.save(ovrPath)
+        loadedOvr = OneVsRest.load(ovrPath)
+        self.assertEqual(loadedOvr.getFeaturesCol(), ovr.getFeaturesCol())
+        self.assertEqual(loadedOvr.getLabelCol(), ovr.getLabelCol())
+        self.assertEqual(loadedOvr.getClassifier().uid, 
ovr.getClassifier().uid)
+        modelPath = temp_path + "/ovrModel"
+        model.save(modelPath)
+        loadedModel = OneVsRestModel.load(modelPath)
+        for m, n in zip(model.models, loadedModel.models):
+            self.assertEqual(m.uid, n.uid)
+
 
 class HashingTFTest(PySparkTestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to