Repository: spark Updated Branches: refs/heads/master 775cf17ea -> b64482f49
[SPARK-14306][ML][PYSPARK] PySpark ml.classification OneVsRest support export/import ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14306 Add PySpark OneVsRest save/load supports. ## How was this patch tested? Test with Python unit test. Author: Xusen Yin <yinxu...@gmail.com> Closes #12439 from yinxusen/SPARK-14306-0415. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b64482f4 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b64482f4 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b64482f4 Branch: refs/heads/master Commit: b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a Parents: 775cf17 Author: Xusen Yin <yinxu...@gmail.com> Authored: Mon Apr 18 11:52:29 2016 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Mon Apr 18 11:52:29 2016 -0700 ---------------------------------------------------------------------- .../spark/ml/classification/OneVsRest.scala | 7 + python/pyspark/ml/classification.py | 142 ++++++++++++++++--- python/pyspark/ml/tests.py | 25 +++- 3 files changed, 151 insertions(+), 23 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b64482f4/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 4de1b87..f10c60a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -17,8 +17,10 @@ package org.apache.spark.ml.classification +import java.util.{List => JList} import java.util.UUID +import scala.collection.JavaConverters._ import scala.language.existentials import org.apache.hadoop.fs.Path @@ -135,6 +137,11 @@ final class OneVsRestModel private[ml] ( @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, models: JList[_ <: ClassificationModel[_, _]]) = { + this(uid, Metadata.empty, models.asScala.toArray) + } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) http://git-wip-us.apache.org/repos/asf/spark/blob/b64482f4/python/pyspark/ml/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 0893167..de1321b 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -23,7 +23,7 @@ from pyspark.ml.param.shared import * from pyspark.ml.regression import ( RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.common import inherit_doc from pyspark.sql import DataFrame @@ -1160,8 +1160,33 @@ class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLR return self._call_java("weights") +class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol): + """ + Parameters for OneVsRest and OneVsRestModel. + """ + + classifier = Param(Params._dummy(), "classifier", "base binary classifier") + + @since("2.0.0") + def setClassifier(self, value): + """ + Sets the value of :py:attr:`classifier`. + + .. note:: Only LogisticRegression and NaiveBayes are supported now. + """ + self._set(classifier=value) + return self + + @since("2.0.0") + def getClassifier(self): + """ + Gets the value of classifier or its default value. + """ + return self.getOrDefault(self.classifier) + + @inherit_doc -class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): +class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): """ Reduction of Multiclass Classification to Binary Classification. Performs reduction using one against all strategy. @@ -1195,8 +1220,6 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): .. versionadded:: 2.0.0 """ - classifier = Param(Params._dummy(), "classifier", "base binary classifier") - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", classifier=None): @@ -1218,23 +1241,6 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): kwargs = self.setParams._input_kwargs return self._set(**kwargs) - @since("2.0.0") - def setClassifier(self, value): - """ - Sets the value of :py:attr:`classifier`. - - .. note:: Only LogisticRegression and NaiveBayes are supported now. - """ - self._set(classifier=value) - return self - - @since("2.0.0") - def getClassifier(self): - """ - Gets the value of classifier or its default value. - """ - return self.getOrDefault(self.classifier) - def _fit(self, dataset): labelCol = self.getLabelCol() featuresCol = self.getFeaturesCol() @@ -1288,8 +1294,53 @@ class OneVsRest(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): newOvr.setClassifier(self.getClassifier().copy(extra)) return newOvr + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java OneVsRest, create and return a Python wrapper of it. + Used for ML persistence. + """ + featuresCol = java_stage.getFeaturesCol() + labelCol = java_stage.getLabelCol() + predictionCol = java_stage.getPredictionCol() + classifier = JavaParams._from_java(java_stage.getClassifier()) + py_stage = cls(featuresCol=featuresCol, labelCol=labelCol, predictionCol=predictionCol, + classifier=classifier) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java OneVsRest. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest", + self.uid) + _java_obj.setClassifier(self.getClassifier()._to_java()) + _java_obj.setFeaturesCol(self.getFeaturesCol()) + _java_obj.setLabelCol(self.getLabelCol()) + _java_obj.setPredictionCol(self.getPredictionCol()) + return _java_obj -class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol): + +class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable): """ Model fitted by OneVsRest. This stores the models resulting from training k binary classifiers: one for each class. @@ -1367,6 +1418,53 @@ class OneVsRestModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol): newModel.models = [model.copy(extra) for model in self.models] return newModel + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java OneVsRestModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + featuresCol = java_stage.getFeaturesCol() + labelCol = java_stage.getLabelCol() + predictionCol = java_stage.getPredictionCol() + classifier = JavaParams._from_java(java_stage.getClassifier()) + models = [JavaParams._from_java(model) for model in java_stage.models()] + py_stage = cls(models=models).setPredictionCol(predictionCol).setLabelCol(labelCol)\ + .setFeaturesCol(featuresCol).setClassifier(classifier) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java OneVsRestModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + java_models = [model._to_java() for model in self.models] + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel", + self.uid, java_models) + _java_obj.set("classifier", self.getClassifier()._to_java()) + _java_obj.set("featuresCol", self.getFeaturesCol()) + _java_obj.set("labelCol", self.getLabelCol()) + _java_obj.set("predictionCol", self.getPredictionCol()) + return _java_obj + if __name__ == "__main__": import doctest http://git-wip-us.apache.org/repos/asf/spark/blob/b64482f4/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index a7a9868..9d6ff47 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -43,7 +43,8 @@ import tempfile import numpy as np from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer -from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, OneVsRest +from pyspark.ml.classification import ( + LogisticRegression, DecisionTreeClassifier, OneVsRest, OneVsRestModel) from pyspark.ml.clustering import KMeans from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * @@ -881,6 +882,28 @@ class OneVsRestTests(PySparkTestCase): output = model.transform(df) self.assertEqual(output.columns, ["label", "features", "prediction"]) + def test_save_load(self): + temp_path = tempfile.mkdtemp() + sqlContext = SQLContext(self.sc) + df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))], + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr) + model = ovr.fit(df) + ovrPath = temp_path + "/ovr" + ovr.save(ovrPath) + loadedOvr = OneVsRest.load(ovrPath) + self.assertEqual(loadedOvr.getFeaturesCol(), ovr.getFeaturesCol()) + self.assertEqual(loadedOvr.getLabelCol(), ovr.getLabelCol()) + self.assertEqual(loadedOvr.getClassifier().uid, ovr.getClassifier().uid) + modelPath = temp_path + "/ovrModel" + model.save(modelPath) + loadedModel = OneVsRestModel.load(modelPath) + for m, n in zip(model.models, loadedModel.models): + self.assertEqual(m.uid, n.uid) + class HashingTFTest(PySparkTestCase): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org