Repository: spark
Updated Branches:
  refs/heads/master 46de6c05e -> b5bd75d90


[SPARK-6255] [MLLIB] Support multiclass classification in Python API

Python API parity check for classification and multiclass classification 
support, major disparities need to be added for Python:
```scala
LogisticRegressionWithLBFGS
    setNumClasses
    setValidateData
LogisticRegressionModel
    getThreshold
    numClasses
    numFeatures
SVMWithSGD
    setValidateData
SVMModel
    getThreshold
```
For users the greatest benefit in this PR is multiclass classification was 
supported by Python API.
Users can train multiclass classification model and use it to predict in 
pyspark.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #5137 from yanboliang/spark-6255 and squashes the following commits:

0bd531e [Yanbo Liang] address comments
444d5e2 [Yanbo Liang] LogisticRegressionModel.predict() optimization
fc7990b [Yanbo Liang] address comments
b0d9c63 [Yanbo Liang] Support Mulinomial LR model predict in Python API
ded847c [Yanbo Liang] Python API parity check for classification (support 
multiclass classification)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b5bd75d9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b5bd75d9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b5bd75d9

Branch: refs/heads/master
Commit: b5bd75d90a761199c3f9cb583c1fe48c8fda7780
Parents: 46de6c0
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Tue Mar 31 11:32:14 2015 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Mar 31 11:32:14 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/api/python/PythonMLLibAPI.scala |  22 ++-
 python/pyspark/mllib/classification.py          | 134 +++++++++++++++----
 python/pyspark/mllib/regression.py              |  10 +-
 3 files changed, 134 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b5bd75d9/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 22fa684..662ec5f 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -77,7 +77,13 @@ private[python] class PythonMLLibAPI extends Serializable {
       initialWeights: Vector): JList[Object] = {
     try {
       val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), 
initialWeights)
-      List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+      if (model.isInstanceOf[LogisticRegressionModel]) {
+        val lrModel = model.asInstanceOf[LogisticRegressionModel]
+        List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, 
lrModel.numClasses)
+          .map(_.asInstanceOf[Object]).asJava
+      } else {
+        List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+      }
     } finally {
       data.rdd.unpersist(blocking = false)
     }
@@ -190,9 +196,11 @@ private[python] class PythonMLLibAPI extends Serializable {
       miniBatchFraction: Double,
       initialWeights: Vector,
       regType: String,
-      intercept: Boolean): JList[Object] = {
+      intercept: Boolean,
+      validateData: Boolean): JList[Object] = {
     val SVMAlg = new SVMWithSGD()
     SVMAlg.setIntercept(intercept)
+      .setValidateData(validateData)
     SVMAlg.optimizer
       .setNumIterations(numIterations)
       .setRegParam(regParam)
@@ -216,9 +224,11 @@ private[python] class PythonMLLibAPI extends Serializable {
       initialWeights: Vector,
       regParam: Double,
       regType: String,
-      intercept: Boolean): JList[Object] = {
+      intercept: Boolean,
+      validateData: Boolean): JList[Object] = {
     val LogRegAlg = new LogisticRegressionWithSGD()
     LogRegAlg.setIntercept(intercept)
+      .setValidateData(validateData)
     LogRegAlg.optimizer
       .setNumIterations(numIterations)
       .setRegParam(regParam)
@@ -242,9 +252,13 @@ private[python] class PythonMLLibAPI extends Serializable {
       regType: String,
       intercept: Boolean,
       corrections: Int,
-      tolerance: Double): JList[Object] = {
+      tolerance: Double,
+      validateData: Boolean,
+      numClasses: Int): JList[Object] = {
     val LogRegAlg = new LogisticRegressionWithLBFGS()
     LogRegAlg.setIntercept(intercept)
+      .setValidateData(validateData)
+      .setNumClasses(numClasses)
     LogRegAlg.optimizer
       .setNumIterations(numIterations)
       .setRegParam(regParam)

http://git-wip-us.apache.org/repos/asf/spark/blob/b5bd75d9/python/pyspark/mllib/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/classification.py 
b/python/pyspark/mllib/classification.py
index 6766f3e..2466e8a 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -22,7 +22,7 @@ from numpy import array
 
 from pyspark import RDD
 from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector
+from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
 from pyspark.mllib.regression import LabeledPoint, LinearModel, 
_regression_train_wrapper
 from pyspark.mllib.util import Saveable, Loader, inherit_doc
 
@@ -31,13 +31,13 @@ __all__ = ['LogisticRegressionModel', 
'LogisticRegressionWithSGD', 'LogisticRegr
            'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
 
 
-class LinearBinaryClassificationModel(LinearModel):
+class LinearClassificationModel(LinearModel):
     """
-    Represents a linear binary classification model that predicts to whether an
-    example is positive (1.0) or negative (0.0).
+    A private abstract class representing a multiclass classification model.
+    The categories are represented by int values: 0, 1, 2, etc.
     """
     def __init__(self, weights, intercept):
-        super(LinearBinaryClassificationModel, self).__init__(weights, 
intercept)
+        super(LinearClassificationModel, self).__init__(weights, intercept)
         self._threshold = None
 
     def setThreshold(self, value):
@@ -47,14 +47,26 @@ class LinearBinaryClassificationModel(LinearModel):
         Sets the threshold that separates positive predictions from negative
         predictions. An example with prediction score greater than or equal
         to this threshold is identified as an positive, and negative otherwise.
+        It is used for binary classification only.
         """
         self._threshold = value
 
+    @property
+    def threshold(self):
+        """
+        .. note:: Experimental
+
+        Returns the threshold (if any) used for converting raw prediction 
scores
+        into 0/1 predictions. It is used for binary classification only.
+        """
+        return self._threshold
+
     def clearThreshold(self):
         """
         .. note:: Experimental
 
         Clears the threshold so that `predict` will output raw prediction 
scores.
+        It is used for binary classification only.
         """
         self._threshold = None
 
@@ -66,7 +78,7 @@ class LinearBinaryClassificationModel(LinearModel):
         raise NotImplementedError
 
 
-class LogisticRegressionModel(LinearBinaryClassificationModel):
+class LogisticRegressionModel(LinearClassificationModel):
 
     """A linear binary classification model derived from logistic regression.
 
@@ -112,10 +124,39 @@ class 
LogisticRegressionModel(LinearBinaryClassificationModel):
     ...    os.removedirs(path)
     ... except:
     ...    pass
+    >>> multi_class_data = [
+    ...     LabeledPoint(0.0, [0.0, 1.0, 0.0]),
+    ...     LabeledPoint(1.0, [1.0, 0.0, 0.0]),
+    ...     LabeledPoint(2.0, [0.0, 0.0, 1.0])
+    ... ]
+    >>> mcm = 
LogisticRegressionWithLBFGS.train(data=sc.parallelize(multi_class_data), 
numClasses=3)
+    >>> mcm.predict([0.0, 0.5, 0.0])
+    0
+    >>> mcm.predict([0.8, 0.0, 0.0])
+    1
+    >>> mcm.predict([0.0, 0.0, 0.3])
+    2
     """
-    def __init__(self, weights, intercept):
+    def __init__(self, weights, intercept, numFeatures, numClasses):
         super(LogisticRegressionModel, self).__init__(weights, intercept)
+        self._numFeatures = int(numFeatures)
+        self._numClasses = int(numClasses)
         self._threshold = 0.5
+        if self._numClasses == 2:
+            self._dataWithBiasSize = None
+            self._weightsMatrix = None
+        else:
+            self._dataWithBiasSize = self._coeff.size / (self._numClasses - 1)
+            self._weightsMatrix = 
self._coeff.toArray().reshape(self._numClasses - 1,
+                                                                
self._dataWithBiasSize)
+
+    @property
+    def numFeatures(self):
+        return self._numFeatures
+
+    @property
+    def numClasses(self):
+        return self._numClasses
 
     def predict(self, x):
         """
@@ -126,20 +167,38 @@ class 
LogisticRegressionModel(LinearBinaryClassificationModel):
             return x.map(lambda v: self.predict(v))
 
         x = _convert_to_vector(x)
-        margin = self.weights.dot(x) + self._intercept
-        if margin > 0:
-            prob = 1 / (1 + exp(-margin))
+        if self.numClasses == 2:
+            margin = self.weights.dot(x) + self._intercept
+            if margin > 0:
+                prob = 1 / (1 + exp(-margin))
+            else:
+                exp_margin = exp(margin)
+                prob = exp_margin / (1 + exp_margin)
+            if self._threshold is None:
+                return prob
+            else:
+                return 1 if prob > self._threshold else 0
         else:
-            exp_margin = exp(margin)
-            prob = exp_margin / (1 + exp_margin)
-        if self._threshold is None:
-            return prob
-        else:
-            return 1 if prob > self._threshold else 0
+            best_class = 0
+            max_margin = 0.0
+            if x.size + 1 == self._dataWithBiasSize:
+                for i in range(0, self._numClasses - 1):
+                    margin = x.dot(self._weightsMatrix[i][0:x.size]) + \
+                        self._weightsMatrix[i][x.size]
+                    if margin > max_margin:
+                        max_margin = margin
+                        best_class = i + 1
+            else:
+                for i in range(0, self._numClasses - 1):
+                    margin = x.dot(self._weightsMatrix[i])
+                    if margin > max_margin:
+                        max_margin = margin
+                        best_class = i + 1
+            return best_class
 
     def save(self, sc, path):
         java_model = 
sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel(
-            _py2java(sc, self._coeff), self.intercept)
+            _py2java(sc, self._coeff), self.intercept, self.numFeatures, 
self.numClasses)
         java_model.save(sc._jsc.sc(), path)
 
     @classmethod
@@ -148,8 +207,10 @@ class 
LogisticRegressionModel(LinearBinaryClassificationModel):
             sc._jsc.sc(), path)
         weights = _java2py(sc, java_model.weights())
         intercept = java_model.intercept()
+        numFeatures = java_model.numFeatures()
+        numClasses = java_model.numClasses()
         threshold = java_model.getThreshold().get()
-        model = LogisticRegressionModel(weights, intercept)
+        model = LogisticRegressionModel(weights, intercept, numFeatures, 
numClasses)
         model.setThreshold(threshold)
         return model
 
@@ -158,7 +219,8 @@ class LogisticRegressionWithSGD(object):
 
     @classmethod
     def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
-              initialWeights=None, regParam=0.01, regType="l2", 
intercept=False):
+              initialWeights=None, regParam=0.01, regType="l2", 
intercept=False,
+              validateData=True):
         """
         Train a logistic regression model on the given data.
 
@@ -184,11 +246,14 @@ class LogisticRegressionWithSGD(object):
                                   or not of the augmented representation for
                                   training data (i.e. whether bias features
                                   are activated or not).
+        :param validateData:      Boolean parameter which indicates if the
+                                  algorithm should validate data before 
training.
+                                  (default: True)
         """
         def train(rdd, i):
             return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, 
int(iterations),
                                  float(step), float(miniBatchFraction), i, 
float(regParam), regType,
-                                 bool(intercept))
+                                 bool(intercept), bool(validateData))
 
         return _regression_train_wrapper(train, LogisticRegressionModel, data, 
initialWeights)
 
@@ -197,7 +262,7 @@ class LogisticRegressionWithLBFGS(object):
 
     @classmethod
     def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, 
regType="l2",
-              intercept=False, corrections=10, tolerance=1e-4):
+              intercept=False, corrections=10, tolerance=1e-4, 
validateData=True, numClasses=2):
         """
         Train a logistic regression model on the given data.
 
@@ -223,6 +288,11 @@ class LogisticRegressionWithLBFGS(object):
                                update (default: 10).
         :param tolerance:      The convergence tolerance of iterations for
                                L-BFGS (default: 1e-4).
+        :param validateData:   Boolean parameter which indicates if the
+                               algorithm should validate data before training.
+                               (default: True)
+        :param numClasses:     The number of classes (i.e., outcomes) a label 
can take
+                               in Multinomial Logistic Regression (default: 2).
 
         >>> data = [
         ...     LabeledPoint(0.0, [0.0, 1.0]),
@@ -237,12 +307,20 @@ class LogisticRegressionWithLBFGS(object):
         def train(rdd, i):
             return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, 
int(iterations), i,
                                  float(regParam), regType, bool(intercept), 
int(corrections),
-                                 float(tolerance))
-
+                                 float(tolerance), bool(validateData), 
int(numClasses))
+
+        if initialWeights is None:
+            if numClasses == 2:
+                initialWeights = [0.0] * len(data.first().features)
+            else:
+                if intercept:
+                    initialWeights = [0.0] * (len(data.first().features) + 1) 
* (numClasses - 1)
+                else:
+                    initialWeights = [0.0] * len(data.first().features) * 
(numClasses - 1)
         return _regression_train_wrapper(train, LogisticRegressionModel, data, 
initialWeights)
 
 
-class SVMModel(LinearBinaryClassificationModel):
+class SVMModel(LinearClassificationModel):
 
     """A support vector machine.
 
@@ -325,7 +403,8 @@ class SVMWithSGD(object):
 
     @classmethod
     def train(cls, data, iterations=100, step=1.0, regParam=0.01,
-              miniBatchFraction=1.0, initialWeights=None, regType="l2", 
intercept=False):
+              miniBatchFraction=1.0, initialWeights=None, regType="l2",
+              intercept=False, validateData=True):
         """
         Train a support vector machine on the given data.
 
@@ -351,11 +430,14 @@ class SVMWithSGD(object):
                                   or not of the augmented representation for
                                   training data (i.e. whether bias features
                                   are activated or not).
+        :param validateData:      Boolean parameter which indicates if the
+                                  algorithm should validate data before 
training.
+                                  (default: True)
         """
         def train(rdd, i):
             return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), 
float(step),
                                  float(regParam), float(miniBatchFraction), i, 
regType,
-                                 bool(intercept))
+                                 bool(intercept), bool(validateData))
 
         return _regression_train_wrapper(train, SVMModel, data, initialWeights)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b5bd75d9/python/pyspark/mllib/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/regression.py 
b/python/pyspark/mllib/regression.py
index 209f1ee..cd7310a 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -167,13 +167,19 @@ class LinearRegressionModel(LinearRegressionModelBase):
 # return the result of a call to the appropriate JVM stub.
 # _regression_train_wrapper is responsible for setup and error checking.
 def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
+    from pyspark.mllib.classification import LogisticRegressionModel
     first = data.first()
     if not isinstance(first, LabeledPoint):
         raise ValueError("data should be an RDD of LabeledPoint, but got %s" % 
first)
     if initial_weights is None:
         initial_weights = [0.0] * len(data.first().features)
-    weights, intercept = train_func(data, _convert_to_vector(initial_weights))
-    return modelClass(weights, intercept)
+    if (modelClass == LogisticRegressionModel):
+        weights, intercept, numFeatures, numClasses = train_func(
+            data, _convert_to_vector(initial_weights))
+        return modelClass(weights, intercept, numFeatures, numClasses)
+    else:
+        weights, intercept = train_func(data, 
_convert_to_vector(initial_weights))
+        return modelClass(weights, intercept)
 
 
 class LinearRegressionWithSGD(object):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to