Repository: spark
Updated Branches:
  refs/heads/master 46d50f151 -> b1aa8fe98


[SPARK-2309][MLlib] Multinomial Logistic Regression

#1379 is automatically closed by asfgit, and github can not reopen it once it's 
closed, so this will be the new PR.

Binary Logistic Regression can be extended to Multinomial Logistic Regression 
by running K-1 independent Binary Logistic Regression models. The following 
formula is implemented.
http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297/25

Author: DB Tsai <dbt...@alpinenow.com>

Closes #3833 from dbtsai/mlor and squashes the following commits:

4e2f354 [DB Tsai] triger jenkins
697b7c9 [DB Tsai] address some feedback
4ce4d33 [DB Tsai] refactoring
ff843b3 [DB Tsai] rebase
f114135 [DB Tsai] refactoring
4348426 [DB Tsai] Addressed feedback from Sean Owen
a252197 [DB Tsai] first commit


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b1aa8fe9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b1aa8fe9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b1aa8fe9

Branch: refs/heads/master
Commit: b1aa8fe988301b924048039529234278aeb0298a
Parents: 46d50f1
Author: DB Tsai <dbt...@alpinenow.com>
Authored: Mon Feb 2 15:59:15 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Mon Feb 2 15:59:15 2015 -0800

----------------------------------------------------------------------
 .../classification/LogisticRegression.scala     | 128 ++++++++++--
 .../spark/mllib/optimization/Gradient.scala     | 200 ++++++++++++++++---
 .../regression/GeneralizedLinearAlgorithm.scala | 101 ++++++++--
 .../spark/mllib/util/DataValidators.scala       |  18 +-
 .../LogisticRegressionSuite.scala               | 179 ++++++++++++++++-
 5 files changed, 565 insertions(+), 61 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b1aa8fe9/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 94d757b..282fb3f 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -18,30 +18,41 @@
 package org.apache.spark.mllib.classification
 
 import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.BLAS.dot
+import org.apache.spark.mllib.linalg.{DenseVector, Vector}
 import org.apache.spark.mllib.optimization._
 import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.DataValidators
+import org.apache.spark.mllib.util.{DataValidators, MLUtils}
 import org.apache.spark.rdd.RDD
 
 /**
- * Classification model trained using Logistic Regression.
+ * Classification model trained using Multinomial/Binary Logistic Regression.
  *
  * @param weights Weights computed for every feature.
- * @param intercept Intercept computed for this model.
+ * @param intercept Intercept computed for this model. (Only used in Binary 
Logistic Regression.
+ *                  In Multinomial Logistic Regression, the intercepts will 
not be a single values,
+ *                  so the intercepts will be part of the weights.)
+ * @param numFeatures the dimension of the features.
+ * @param numClasses the number of possible outcomes for k classes 
classification problem in
+ *                   Multinomial Logistic Regression. By default, it is binary 
logistic regression
+ *                   so numClasses will be set to 2.
  */
 class LogisticRegressionModel (
     override val weights: Vector,
-    override val intercept: Double)
+    override val intercept: Double,
+    val numFeatures: Int,
+    val numClasses: Int)
   extends GeneralizedLinearModel(weights, intercept) with ClassificationModel 
with Serializable {
 
+  def this(weights: Vector, intercept: Double) = this(weights, intercept, 
weights.size, 2)
+
   private var threshold: Option[Double] = Some(0.5)
 
   /**
    * :: Experimental ::
-   * Sets the threshold that separates positive predictions from negative 
predictions. An example
-   * with prediction score greater than or equal to this threshold is 
identified as an positive,
-   * and negative otherwise. The default value is 0.5.
+   * Sets the threshold that separates positive predictions from negative 
predictions
+   * in Binary Logistic Regression. An example with prediction score greater 
than or equal to
+   * this threshold is identified as an positive, and negative otherwise. The 
default value is 0.5.
    */
   @Experimental
   def setThreshold(threshold: Double): this.type = {
@@ -61,20 +72,68 @@ class LogisticRegressionModel (
 
   override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
       intercept: Double) = {
-    val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
-    val score = 1.0 / (1.0 + math.exp(-margin))
-    threshold match {
-      case Some(t) => if (score > t) 1.0 else 0.0
-      case None => score
+    require(dataMatrix.size == numFeatures)
+
+    // If dataMatrix and weightMatrix have the same dimension, it's binary 
logistic regression.
+    if (numClasses == 2) {
+      require(numFeatures == weightMatrix.size)
+      val margin = dot(weights, dataMatrix) + intercept
+      val score = 1.0 / (1.0 + math.exp(-margin))
+      threshold match {
+        case Some(t) => if (score > t) 1.0 else 0.0
+        case None => score
+      }
+    } else {
+      val dataWithBiasSize = weightMatrix.size / (numClasses - 1)
+
+      val weightsArray = weights match {
+        case dv: DenseVector => dv.values
+        case _ =>
+          throw new IllegalArgumentException(
+            s"weights only supports dense vector but got type 
${weights.getClass}.")
+      }
+
+      val margins = (0 until numClasses - 1).map { i =>
+        var margin = 0.0
+        dataMatrix.foreachActive { (index, value) =>
+          if (value != 0.0) margin += value * weightsArray((i * 
dataWithBiasSize) + index)
+        }
+        // Intercept is required to be added into margin.
+        if (dataMatrix.size + 1 == dataWithBiasSize) {
+          margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size)
+        }
+        margin
+      }
+
+      /**
+       * Find the one with maximum margins. If the maxMargin is negative, then 
the prediction
+       * result will be the first class.
+       *
+       * PS, if you want to compute the probabilities for each outcome instead 
of the outcome
+       * with maximum probability, remember to subtract the maxMargin from 
margins if maxMargin
+       * is positive to prevent overflow.
+       */
+      var bestClass = 0
+      var maxMargin = 0.0
+      var i = 0
+      while(i < margins.size) {
+        if (margins(i) > maxMargin) {
+          maxMargin = margins(i)
+          bestClass = i + 1
+        }
+        i += 1
+      }
+      bestClass.toDouble
     }
   }
 }
 
 /**
- * Train a classification model for Logistic Regression using Stochastic 
Gradient Descent. By
- * default L2 regularization is used, which can be changed via
- * [[LogisticRegressionWithSGD.optimizer]].
- * NOTE: Labels used in Logistic Regression should be {0, 1}.
+ * Train a classification model for Binary Logistic Regression
+ * using Stochastic Gradient Descent. By default L2 regularization is used,
+ * which can be changed via [[LogisticRegressionWithSGD.optimizer]].
+ * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1}
+ * for k classes multi-label classification problem.
  * Using [[LogisticRegressionWithLBFGS]] is recommended over this.
  */
 class LogisticRegressionWithSGD private (
@@ -194,9 +253,10 @@ object LogisticRegressionWithSGD {
 }
 
 /**
- * Train a classification model for Logistic Regression using Limited-memory 
BFGS.
- * Standard feature scaling and L2 regularization are used by default.
- * NOTE: Labels used in Logistic Regression should be {0, 1}
+ * Train a classification model for Multinomial/Binary Logistic Regression 
using
+ * Limited-memory BFGS. Standard feature scaling and L2 regularization are 
used by default.
+ * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1}
+ * for k classes multi-label classification problem.
  */
 class LogisticRegressionWithLBFGS
   extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with 
Serializable {
@@ -205,9 +265,33 @@ class LogisticRegressionWithLBFGS
 
   override val optimizer = new LBFGS(new LogisticGradient, new 
SquaredL2Updater)
 
-  override protected val validators = List(DataValidators.binaryLabelValidator)
+  override protected val validators = List(multiLabelValidator)
+
+  private def multiLabelValidator: RDD[LabeledPoint] => Boolean = { data =>
+    if (numOfLinearPredictor > 1) {
+      DataValidators.multiLabelValidator(numOfLinearPredictor + 1)(data)
+    } else {
+      DataValidators.binaryLabelValidator(data)
+    }
+  }
+
+  /**
+   * :: Experimental ::
+   * Set the number of possible outcomes for k classes classification problem 
in
+   * Multinomial Logistic Regression.
+   * By default, it is binary logistic regression so k will be set to 2.
+   */
+  @Experimental
+  def setNumClasses(numClasses: Int): this.type = {
+    require(numClasses > 1)
+    numOfLinearPredictor = numClasses - 1
+    if (numClasses > 2) {
+      optimizer.setGradient(new LogisticGradient(numClasses))
+    }
+    this
+  }
 
   override protected def createModel(weights: Vector, intercept: Double) = {
-    new LogisticRegressionModel(weights, intercept)
+    new LogisticRegressionModel(weights, intercept, numFeatures, 
numOfLinearPredictor + 1)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b1aa8fe9/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index 1ca0f36..0acdab7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.mllib.optimization
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
 import org.apache.spark.mllib.util.MLUtils
 
@@ -55,24 +55,86 @@ abstract class Gradient extends Serializable {
 
 /**
  * :: DeveloperApi ::
- * Compute gradient and loss for a logistic loss function, as used in binary 
classification.
- * See also the documentation for the precise formulation.
+ * Compute gradient and loss for a multinomial logistic loss function, as used
+ * in multi-class classification (it is also used in binary logistic 
regression).
+ *
+ * In `The Elements of Statistical Learning: Data Mining, Inference, and 
Prediction, 2nd Edition`
+ * by Trevor Hastie, Robert Tibshirani, and Jerome Friedman, which can be 
downloaded from
+ * http://statweb.stanford.edu/~tibs/ElemStatLearn/ , Eq. (4.17) on page 119 
gives the formula of
+ * multinomial logistic regression model. A simple calculation shows that
+ *
+ * P(y=0|x, w) = 1 / (1 + \sum_i^{K-1} \exp(x w_i))
+ * P(y=1|x, w) = exp(x w_1) / (1 + \sum_i^{K-1} \exp(x w_i))
+ *   ...
+ * P(y=K-1|x, w) = exp(x w_{K-1}) / (1 + \sum_i^{K-1} \exp(x w_i))
+ *
+ * for K classes multiclass classification problem.
+ *
+ * The model weights w = (w_1, w_2, ..., w_{K-1})^T becomes a matrix which has 
dimension of
+ * (K-1) * (N+1) if the intercepts are added. If the intercepts are not added, 
the dimension
+ * will be (K-1) * N.
+ *
+ * As a result, the loss of objective function for a single instance of data 
can be written as
+ * l(w, x) = -log P(y|x, w) = -\alpha(y) log P(y=0|x, w) - (1-\alpha(y)) log 
P(y|x, w)
+ *         = log(1 + \sum_i^{K-1}\exp(x w_i)) - (1-\alpha(y)) x w_{y-1}
+ *         = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1}
+ *
+ * where \alpha(i) = 1 if i != 0, and
+ *       \alpha(i) = 0 if i == 0,
+ *       margins_i = x w_i.
+ *
+ * For optimization, we have to calculate the first derivative of the loss 
function, and
+ * a simple calculation shows that
+ *
+ * \frac{\partial l(w, x)}{\partial w_{ij}}
+ *   = (\exp(x w_i) / (1 + \sum_k^{K-1} \exp(x w_k)) - (1-\alpha(y)\delta_{y, 
i+1})) * x_j
+ *   = multiplier_i * x_j
+ *
+ * where \delta_{i, j} = 1 if i == j,
+ *       \delta_{i, j} = 0 if i != j, and
+ *       multiplier
+ *         = \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - 
(1-\alpha(y)\delta_{y, i+1})
+ *
+ * If any of margins is larger than 709.78, the numerical computation of 
multiplier and loss
+ * function will be suffered from arithmetic overflow. This issue occurs when 
there are outliers
+ * in data which are far away from hyperplane, and this will cause the failing 
of training once
+ * infinity / infinity is introduced. Note that this is only a concern when 
max(margins) > 0.
+ *
+ * Fortunately, when max(margins) = maxMargin > 0, the loss function and the 
multiplier can be
+ * easily rewritten into the following equivalent numerically stable formula.
+ *
+ * l(w, x) = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1}
+ *         = log(\exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin)) + 
maxMargin
+ *           - (1-\alpha(y)) margins_{y-1}
+ *         = log(1 + sum) + maxMargin - (1-\alpha(y)) margins_{y-1}
+ *
+ * where sum = \exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin) - 1.
+ *
+ * Note that each term, (margins_i - maxMargin) in \exp is smaller than zero; 
as a result,
+ * overflow will not happen with this formula.
+ *
+ * For multiplier, similar trick can be applied as the following,
+ *
+ * multiplier = \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - 
(1-\alpha(y)\delta_{y, i+1})
+ *            = \exp(margins_i - maxMargin) / (1 + sum) - 
(1-\alpha(y)\delta_{y, i+1})
+ *
+ * where each term in \exp is also smaller than zero, so overflow is not a 
concern.
+ *
+ * For the detailed mathematical derivation, see the reference at
+ * http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297
+ *
+ * @param numClasses the number of possible outcomes for k classes 
classification problem in
+ *                   Multinomial Logistic Regression. By default, it is binary 
logistic regression
+ *                   so numClasses will be set to 2.
  */
 @DeveloperApi
-class LogisticGradient extends Gradient {
-  override def compute(data: Vector, label: Double, weights: Vector): (Vector, 
Double) = {
-    val margin = -1.0 * dot(data, weights)
-    val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
-    val gradient = data.copy
-    scal(gradientMultiplier, gradient)
-    val loss =
-      if (label > 0) {
-        // The following is equivalent to log(1 + exp(margin)) but more 
numerically stable.
-        MLUtils.log1pExp(margin)
-      } else {
-        MLUtils.log1pExp(margin) - margin
-      }
+class LogisticGradient(numClasses: Int) extends Gradient {
 
+  def this() = this(2)
+
+  override def compute(data: Vector, label: Double, weights: Vector): (Vector, 
Double) = {
+    val gradient = Vectors.zeros(weights.size)
+    val loss = compute(data, label, weights, gradient)
     (gradient, loss)
   }
 
@@ -81,14 +143,104 @@ class LogisticGradient extends Gradient {
       label: Double,
       weights: Vector,
       cumGradient: Vector): Double = {
-    val margin = -1.0 * dot(data, weights)
-    val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
-    axpy(gradientMultiplier, data, cumGradient)
-    if (label > 0) {
-      // The following is equivalent to log(1 + exp(margin)) but more 
numerically stable.
-      MLUtils.log1pExp(margin)
-    } else {
-      MLUtils.log1pExp(margin) - margin
+    val dataSize = data.size
+
+    // (weights.size / dataSize + 1) is number of classes
+    require(weights.size % dataSize == 0 && numClasses == weights.size / 
dataSize + 1)
+    numClasses match {
+      case 2 =>
+        /**
+         * For Binary Logistic Regression.
+         *
+         * Although the loss and gradient calculation for multinomial one is 
more generalized,
+         * and multinomial one can also be used in binary case, we still 
implement a specialized
+         * binary version for performance reason.
+         */
+        val margin = -1.0 * dot(data, weights)
+        val multiplier = (1.0 / (1.0 + math.exp(margin))) - label
+        axpy(multiplier, data, cumGradient)
+        if (label > 0) {
+          // The following is equivalent to log(1 + exp(margin)) but more 
numerically stable.
+          MLUtils.log1pExp(margin)
+        } else {
+          MLUtils.log1pExp(margin) - margin
+        }
+      case _ =>
+        /**
+         * For Multinomial Logistic Regression.
+         */
+        val weightsArray = weights match {
+          case dv: DenseVector => dv.values
+          case _ =>
+            throw new IllegalArgumentException(
+              s"weights only supports dense vector but got type 
${weights.getClass}.")
+        }
+        val cumGradientArray = cumGradient match {
+          case dv: DenseVector => dv.values
+          case _ =>
+            throw new IllegalArgumentException(
+              s"cumGradient only supports dense vector but got type 
${cumGradient.getClass}.")
+        }
+
+        // marginY is margins(label - 1) in the formula.
+        var marginY = 0.0
+        var maxMargin = Double.NegativeInfinity
+        var maxMarginIndex = 0
+
+        val margins = Array.tabulate(numClasses - 1) { i =>
+          var margin = 0.0
+          data.foreachActive { (index, value) =>
+            if (value != 0.0) margin += value * weightsArray((i * dataSize) + 
index)
+          }
+          if (i == label.toInt - 1) marginY = margin
+          if (margin > maxMargin) {
+            maxMargin = margin
+            maxMarginIndex = i
+          }
+          margin
+        }
+
+        /**
+         * When maxMargin > 0, the original formula will cause overflow as we 
discuss
+         * in the previous comment.
+         * We address this by subtracting maxMargin from all the margins, so 
it's guaranteed
+         * that all of the new margins will be smaller than zero to prevent 
arithmetic overflow.
+         */
+        val sum = {
+          var temp = 0.0
+          if (maxMargin > 0) {
+            for (i <- 0 until numClasses - 1) {
+              margins(i) -= maxMargin
+              if (i == maxMarginIndex) {
+                temp += math.exp(-maxMargin)
+              } else {
+                temp += math.exp(margins(i))
+              }
+            }
+          } else {
+            for (i <- 0 until numClasses - 1) {
+              temp += math.exp(margins(i))
+            }
+          }
+          temp
+        }
+
+        for (i <- 0 until numClasses - 1) {
+          val multiplier = math.exp(margins(i)) / (sum + 1.0) - {
+            if (label != 0.0 && label == i + 1) 1.0 else 0.0
+          }
+          data.foreachActive { (index, value) =>
+            if (value != 0.0) cumGradientArray(i * dataSize + index) += 
multiplier * value
+          }
+        }
+
+        val loss = if (label > 0.0) math.log1p(sum) - marginY else 
math.log1p(sum)
+
+        if (maxMargin > 0) {
+          loss + maxMargin
+        } else {
+          loss
+        }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b1aa8fe9/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 0287f04..17de215 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -99,6 +99,23 @@ abstract class GeneralizedLinearAlgorithm[M <: 
GeneralizedLinearModel]
   protected var validateData: Boolean = true
 
   /**
+   * In `GeneralizedLinearModel`, only single linear predictor is allowed for 
both weights
+   * and intercept. However, for multinomial logistic regression, with K 
possible outcomes,
+   * we are training K-1 independent binary logistic regression models which 
requires K-1 sets
+   * of linear predictor.
+   *
+   * As a result, the workaround here is if more than two sets of linear 
predictors are needed,
+   * we construct bigger `weights` vector which can hold both weights and 
intercepts.
+   * If the intercepts are added, the dimension of `weights` will be
+   * (numOfLinearPredictor) * (numFeatures + 1) . If the intercepts are not 
added,
+   * the dimension of `weights` will be (numOfLinearPredictor) * numFeatures.
+   *
+   * Thus, the intercepts will be encapsulated into weights, and we leave the 
value of intercept
+   * in GeneralizedLinearModel as zero.
+   */
+  protected var numOfLinearPredictor: Int = 1
+
+  /**
    * Whether to perform feature scaling before model training to reduce the 
condition numbers
    * which can significantly help the optimizer converging faster. The scaling 
correction will be
    * translated back to resulting model weights, so it's transparent to users.
@@ -107,6 +124,11 @@ abstract class GeneralizedLinearAlgorithm[M <: 
GeneralizedLinearModel]
   private var useFeatureScaling = false
 
   /**
+   * The dimension of training features.
+   */
+  protected var numFeatures: Int = 0
+
+  /**
    * Set if the algorithm should use feature scaling to improve the 
convergence during optimization.
    */
   private[mllib] def setFeatureScaling(useFeatureScaling: Boolean): this.type 
= {
@@ -141,8 +163,28 @@ abstract class GeneralizedLinearAlgorithm[M <: 
GeneralizedLinearModel]
    * RDD of LabeledPoint entries.
    */
   def run(input: RDD[LabeledPoint]): M = {
-    val numFeatures: Int = input.first().features.size
-    val initialWeights = Vectors.dense(new Array[Double](numFeatures))
+    numFeatures = input.first().features.size
+
+    /**
+     * When `numOfLinearPredictor > 1`, the intercepts are encapsulated into 
weights,
+     * so the `weights` will include the intercepts. When 
`numOfLinearPredictor == 1`,
+     * the intercept will be stored as separated value in 
`GeneralizedLinearModel`.
+     * This will result in different behaviors since when 
`numOfLinearPredictor == 1`,
+     * users have no way to set the initial intercept, while in the other 
case, users
+     * can set the intercepts as part of weights.
+     *
+     * TODO: See if we can deprecate `intercept` in `GeneralizedLinearModel`, 
and always
+     * have the intercept as part of weights to have consistent design.
+     */
+    val initialWeights = {
+      if (numOfLinearPredictor == 1) {
+        Vectors.dense(new Array[Double](numFeatures))
+      } else if (addIntercept) {
+        Vectors.dense(new Array[Double]((numFeatures + 1) * 
numOfLinearPredictor))
+      } else {
+        Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor))
+      }
+    }
     run(input, initialWeights)
   }
 
@@ -151,6 +193,7 @@ abstract class GeneralizedLinearAlgorithm[M <: 
GeneralizedLinearModel]
    * of LabeledPoint entries starting from the initial weights provided.
    */
   def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
+    numFeatures = input.first().features.size
 
     if (input.getStorageLevel == StorageLevel.NONE) {
       logWarning("The input data is not directly cached, which may hurt 
performance if its"
@@ -182,14 +225,14 @@ abstract class GeneralizedLinearAlgorithm[M <: 
GeneralizedLinearModel]
      * Currently, it's only enabled in LogisticRegressionWithLBFGS
      */
     val scaler = if (useFeatureScaling) {
-      (new StandardScaler).fit(input.map(x => x.features))
+      (new StandardScaler(withStd = true, withMean = false)).fit(input.map(x 
=> x.features))
     } else {
       null
     }
 
     // Prepend an extra variable consisting of all 1.0's for the intercept.
     val data = if (addIntercept) {
-      if(useFeatureScaling) {
+      if (useFeatureScaling) {
         input.map(labeledPoint =>
           (labeledPoint.label, 
appendBias(scaler.transform(labeledPoint.features))))
       } else {
@@ -203,21 +246,31 @@ abstract class GeneralizedLinearAlgorithm[M <: 
GeneralizedLinearModel]
       }
     }
 
-    val initialWeightsWithIntercept = if (addIntercept) {
+    /**
+     * TODO: For better convergence, in logistic regression, the intercepts 
should be computed
+     * from the prior probability distribution of the outcomes; for linear 
regression,
+     * the intercept should be set as the average of response.
+     */
+    val initialWeightsWithIntercept = if (addIntercept && numOfLinearPredictor 
== 1) {
       appendBias(initialWeights)
     } else {
+      /** If `numOfLinearPredictor > 1`, initialWeights already contains 
intercepts. */
       initialWeights
     }
 
     val weightsWithIntercept = optimizer.optimize(data, 
initialWeightsWithIntercept)
 
-    val intercept = if (addIntercept) 
weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0
-    var weights =
-      if (addIntercept) {
-        Vectors.dense(weightsWithIntercept.toArray.slice(0, 
weightsWithIntercept.size - 1))
-      } else {
-        weightsWithIntercept
-      }
+    val intercept = if (addIntercept && numOfLinearPredictor == 1) {
+      weightsWithIntercept(weightsWithIntercept.size - 1)
+    } else {
+      0.0
+    }
+
+    var weights = if (addIntercept && numOfLinearPredictor == 1) {
+      Vectors.dense(weightsWithIntercept.toArray.slice(0, 
weightsWithIntercept.size - 1))
+    } else {
+      weightsWithIntercept
+    }
 
     /**
      * The weights and intercept are trained in the scaled space; we're 
converting them back to
@@ -228,7 +281,29 @@ abstract class GeneralizedLinearAlgorithm[M <: 
GeneralizedLinearModel]
      * is the coefficient in the original space, and v_i is the variance of 
the column i.
      */
     if (useFeatureScaling) {
-      weights = scaler.transform(weights)
+      if (numOfLinearPredictor == 1) {
+        weights = scaler.transform(weights)
+      } else {
+        /**
+         * For `numOfLinearPredictor > 1`, we have to transform the weights 
back to the original
+         * scale for each set of linear predictor. Note that the intercepts 
have to be explicitly
+         * excluded when `addIntercept == true` since the intercepts are part 
of weights now.
+         */
+        var i = 0
+        val n = weights.size / numOfLinearPredictor
+        val weightsArray = weights.toArray
+        while (i < numOfLinearPredictor) {
+          val start = i * n
+          val end = (i + 1) * n - { if (addIntercept) 1 else 0 }
+
+          val partialWeightsArray = scaler.transform(
+            Vectors.dense(weightsArray.slice(start, end))).toArray
+
+          System.arraycopy(partialWeightsArray, 0, weightsArray, start, 
partialWeightsArray.size)
+          i += 1
+        }
+        weights = Vectors.dense(weightsArray)
+      }
     }
 
     // Warn at the end of the run as well, for increased visibility.

http://git-wip-us.apache.org/repos/asf/spark/blob/b1aa8fe9/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
index 45f9548..be335a1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
@@ -34,11 +34,27 @@ object DataValidators extends Logging {
    *
    * @return True if labels are all zero or one, false otherwise.
    */
-   val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data =>
+  val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data =>
     val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count()
     if (numInvalid != 0) {
       logError("Classification labels should be 0 or 1. Found " + numInvalid + 
" invalid labels")
     }
     numInvalid == 0
   }
+
+  /**
+   * Function to check if labels used for k class multi-label classification 
are
+   * in the range of {0, 1, ..., k - 1}.
+   *
+   * @return True if labels are all in the range of {0, 1, ..., k-1}, false 
otherwise.
+   */
+  def multiLabelValidator(k: Int): RDD[LabeledPoint] => Boolean = { data =>
+    val numInvalid = data.filter(x =>
+      x.label - x.label.toInt != 0.0 || x.label < 0 || x.label > k - 1).count()
+    if (numInvalid != 0) {
+      logError("Classification labels should be in {0 to " + (k - 1) + "}. " +
+        "Found " + numInvalid + " invalid labels")
+    }
+    numInvalid == 0
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b1aa8fe9/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 94b0e00..3fb4593 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -17,13 +17,14 @@
 
 package org.apache.spark.mllib.classification
 
+import scala.util.control.Breaks._
 import scala.util.Random
 import scala.collection.JavaConversions._
 
 import org.scalatest.FunSuite
 import org.scalatest.Matchers
 
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression._
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, 
MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
@@ -55,6 +56,97 @@ object LogisticRegressionSuite {
     val testData = (0 until nPoints).map(i => LabeledPoint(y(i), 
Vectors.dense(Array(x1(i)))))
     testData
   }
+
+  /**
+   * Generates `k` classes multinomial synthetic logistic input in `n` 
dimensional space given the
+   * model weights and mean/variance of the features. The synthetic data will 
be drawn from
+   * the probability distribution constructed by weights using the following 
formula.
+   *
+   * P(y = 0 | x) = 1 / norm
+   * P(y = 1 | x) = exp(x * w_1) / norm
+   * P(y = 2 | x) = exp(x * w_2) / norm
+   * ...
+   * P(y = k-1 | x) = exp(x * w_{k-1}) / norm
+   * where norm = 1 + exp(x * w_1) + exp(x * w_2) + ... + exp(x * w_{k-1})
+   *
+   * @param weights matrix is flatten into a vector; as a result, the 
dimension of weights vector
+   *                will be (k - 1) * (n + 1) if `addIntercept == true`, and
+   *                if `addIntercept != true`, the dimension will be (k - 1) * 
n.
+   * @param xMean the mean of the generated features. Lots of time, if the 
features are not properly
+   *              standardized, the algorithm with poor implementation will 
have difficulty
+   *              to converge.
+   * @param xVariance the variance of the generated features.
+   * @param addIntercept whether to add intercept.
+   * @param nPoints the number of instance of generated data.
+   * @param seed the seed for random generator. For consistent testing result, 
it will be fixed.
+   */
+  def generateMultinomialLogisticInput(
+      weights: Array[Double],
+      xMean: Array[Double],
+      xVariance: Array[Double],
+      addIntercept: Boolean,
+      nPoints: Int,
+      seed: Int): Seq[LabeledPoint] = {
+    val rnd = new Random(seed)
+
+    val xDim = xMean.size
+    val xWithInterceptsDim = if (addIntercept) xDim + 1 else xDim
+    val nClasses = weights.size / xWithInterceptsDim + 1
+
+    val x = 
Array.fill[Vector](nPoints)(Vectors.dense(Array.fill[Double](xDim)(rnd.nextGaussian())))
+
+    x.map(vector => {
+      // This doesn't work if `vector` is a sparse vector.
+      val vectorArray = vector.toArray
+      var i = 0
+      while (i < vectorArray.size) {
+        vectorArray(i) = vectorArray(i) * math.sqrt(xVariance(i)) + xMean(i)
+        i += 1
+      }
+    })
+
+    val y = (0 until nPoints).map { idx =>
+      val xArray = x(idx).toArray
+      val margins = Array.ofDim[Double](nClasses)
+      val probs = Array.ofDim[Double](nClasses)
+
+      for (i <- 0 until nClasses - 1) {
+        for (j <- 0 until xDim) margins(i + 1) += weights(i * 
xWithInterceptsDim + j) * xArray(j)
+        if (addIntercept) margins(i + 1) += weights((i + 1) * 
xWithInterceptsDim - 1)
+      }
+      // Preventing the overflow when we compute the probability
+      val maxMargin = margins.max
+      if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin
+
+      // Computing the probabilities for each class from the margins.
+      val norm = {
+        var temp = 0.0
+        for (i <- 0 until nClasses) {
+          probs(i) = math.exp(margins(i))
+          temp += probs(i)
+        }
+        temp
+      }
+      for (i <-0 until nClasses) probs(i) /= norm
+
+      // Compute the cumulative probability so we can generate a random number 
and assign a label.
+      for (i <- 1 until nClasses) probs(i) += probs(i - 1)
+      val p = rnd.nextDouble()
+      var y = 0
+      breakable {
+        for (i <- 0 until nClasses) {
+          if (p < probs(i)) {
+            y = i
+            break
+          }
+        }
+      }
+      y
+    }
+
+    val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i)))
+    testData
+  }
 }
 
 class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with 
Matchers {
@@ -285,6 +377,91 @@ class LogisticRegressionSuite extends FunSuite with 
MLlibTestSparkContext with M
     assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1)
   }
 
+  test("multinomial logistic regression with LBFGS") {
+    val nPoints = 10000
+
+    /**
+     * The following weights and xMean/xVariance are computed from iris 
dataset with lambda = 0.2.
+     * As a result, we are actually drawing samples from probability 
distribution of built model.
+     */
+    val weights = Array(
+      -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+      -0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
+
+    val xMean = Array(5.843, 3.057, 3.758, 1.199)
+    val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+
+    val testData = LogisticRegressionSuite.generateMultinomialLogisticInput(
+      weights, xMean, xVariance, true, nPoints, 42)
+
+    val testRDD = sc.parallelize(testData, 2)
+    testRDD.cache()
+
+    val lr = new 
LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(3)
+    lr.optimizer.setConvergenceTol(1E-15).setNumIterations(200)
+
+    val model = lr.run(testRDD)
+
+    /**
+     * The following is the instruction to reproduce the model using R's 
glmnet package.
+     *
+     * First of all, using the following scala code to save the data into 
`path`.
+     *
+     *    testRDD.map(x => x.label+ ", " + x.features(0) + ", " + 
x.features(1) + ", " +
+     *      x.features(2) + ", " + x.features(3)).saveAsTextFile("path")
+     *
+     * Using the following R code to load the data and train the model using 
glmnet package.
+     *
+     *    library("glmnet")
+     *    data <- read.csv("path", header=FALSE)
+     *    label = factor(data$V1)
+     *    features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+     *    weights = coef(glmnet(features,label, family="multinomial", alpha = 
0, lambda = 0))
+     *
+     * The model weights of mutinomial logstic regression in R have `K` set of 
linear predictors
+     * for `K` classes classification problem; however, only `K-1` set is 
required if the first
+     * outcome is chosen as a "pivot", and the other `K-1` outcomes are 
separately regressed against
+     * the pivot outcome. This can be done by subtracting the first weights 
from those `K-1` set
+     * weights. The mathematical discussion and proof can be found here:
+     * http://en.wikipedia.org/wiki/Multinomial_logistic_regression
+     *
+     *    weights1 = weights$`1` - weights$`0`
+     *    weights2 = weights$`2` - weights$`0`
+     *
+     *    > weights1
+     *    5 x 1 sparse Matrix of class "dgCMatrix"
+     *                    s0
+     *             2.6228269
+     *    data.V2 -0.5837166
+     *    data.V3  0.9285260
+     *    data.V4 -0.3783612
+     *    data.V5 -0.8123411
+     *    > weights2
+     *    5 x 1 sparse Matrix of class "dgCMatrix"
+     *                     s0
+     *             4.11197445
+     *    data.V2 -0.16918650
+     *    data.V3 -0.81104784
+     *    data.V4 -0.06463799
+     *    data.V5 -0.29198337
+     */
+
+    val weightsR = Vectors.dense(Array(
+      -0.5837166, 0.9285260, -0.3783612, -0.8123411, 2.6228269,
+      -0.1691865, -0.811048, -0.0646380, -0.2919834, 4.1119745))
+
+    assert(model.weights ~== weightsR relTol 0.05)
+
+    val validationData = 
LogisticRegressionSuite.generateMultinomialLogisticInput(
+      weights, xMean, xVariance, true, nPoints, 17)
+    val validationRDD = sc.parallelize(validationData, 2)
+    // The validation accuracy is not good since this model (even the original 
weights) doesn't have
+    // very steep curve in logistic function so that when we draw samples from 
distribution, it's
+    // very easy to assign to another labels. However, this prediction result 
is consistent to R.
+    validatePrediction(model.predict(validationRDD.map(_.features)).collect(), 
validationData, 0.47)
+
+  }
+
 }
 
 class LogisticRegressionClusterSuite extends FunSuite with 
LocalClusterSparkContext {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to