Repository: spark
Updated Branches:
  refs/heads/master 1fa48d942 -> d679843a3


[SPARK-1327] GLM needs to check addIntercept for intercept and weights

GLM needs to check addIntercept for intercept and weights. The current 
implementation always uses the first weight as intercept. Added a test for 
training without adding intercept.

JIRA: https://spark-project.atlassian.net/browse/SPARK-1327

Author: Xiangrui Meng <m...@databricks.com>

Closes #236 from mengxr/glm and squashes the following commits:

bcac1ac [Xiangrui Meng] add two tests to ensure {Lasso, Ridge}.setIntercept 
will throw an exceptions
a104072 [Xiangrui Meng] remove protected to be compatible with 0.9
0e57aa4 [Xiangrui Meng] update Lasso and RidgeRegression to parse the weights 
correctly from GLM mark createModel protected mark predictPoint protected
d7f629f [Xiangrui Meng] fix a bug in GLM when intercept is not used


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d679843a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d679843a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d679843a

Branch: refs/heads/master
Commit: d679843a39bb4918a08a5aebdf113ac8886a5275
Parents: 1fa48d9
Author: Xiangrui Meng <m...@databricks.com>
Authored: Wed Mar 26 19:30:20 2014 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Wed Mar 26 19:30:20 2014 -0700

----------------------------------------------------------------------
 .../regression/GeneralizedLinearAlgorithm.scala | 21 +++++++++-------
 .../apache/spark/mllib/regression/Lasso.scala   | 20 ++++++++++-----
 .../mllib/regression/LinearRegression.scala     | 20 +++++++--------
 .../mllib/regression/RidgeRegression.scala      | 18 ++++++++++----
 .../spark/mllib/regression/LassoSuite.scala     |  9 ++++---
 .../regression/LinearRegressionSuite.scala      | 26 +++++++++++++++++++-
 .../mllib/regression/RidgeRegressionSuite.scala |  9 ++++---
 7 files changed, 86 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d679843a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index b962153..3e1ed91 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -136,25 +136,28 @@ abstract class GeneralizedLinearAlgorithm[M <: 
GeneralizedLinearModel]
 
     // Prepend an extra variable consisting of all 1.0's for the intercept.
     val data = if (addIntercept) {
-      input.map(labeledPoint => (labeledPoint.label, 
labeledPoint.features.+:(1.0)))
+      input.map(labeledPoint => (labeledPoint.label, 1.0 +: 
labeledPoint.features))
     } else {
       input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
     }
 
     val initialWeightsWithIntercept = if (addIntercept) {
-      initialWeights.+:(1.0)
+      0.0 +: initialWeights
     } else {
       initialWeights
     }
 
-    val weights = optimizer.optimize(data, initialWeightsWithIntercept)
-    val intercept = weights(0)
-    val weightsScaled = weights.tail
+    val weightsWithIntercept = optimizer.optimize(data, 
initialWeightsWithIntercept)
 
-    val model = createModel(weightsScaled, intercept)
+    val (intercept, weights) = if (addIntercept) {
+      (weightsWithIntercept(0), weightsWithIntercept.tail)
+    } else {
+      (0.0, weightsWithIntercept)
+    }
+
+    logInfo("Final weights " + weights.mkString(","))
+    logInfo("Final intercept " + intercept)
 
-    logInfo("Final model weights " + model.weights.mkString(","))
-    logInfo("Final model intercept " + model.intercept)
-    model
+    createModel(weights, intercept)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d679843a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index fb2bc9b..be63ce8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -36,8 +36,10 @@ class LassoModel(
   extends GeneralizedLinearModel(weights, intercept)
   with RegressionModel with Serializable {
 
-  override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: 
DoubleMatrix,
-      intercept: Double) = {
+  override def predictPoint(
+      dataMatrix: DoubleMatrix,
+      weightMatrix: DoubleMatrix,
+      intercept: Double): Double = {
     dataMatrix.dot(weightMatrix) + intercept
   }
 }
@@ -66,7 +68,7 @@ class LassoWithSGD private (
     .setMiniBatchFraction(miniBatchFraction)
 
   // We don't want to penalize the intercept, so set this to false.
-  setIntercept(false)
+  super.setIntercept(false)
 
   var yMean = 0.0
   var xColMean: DoubleMatrix = _
@@ -77,10 +79,16 @@ class LassoWithSGD private (
    */
   def this() = this(1.0, 100, 1.0, 1.0)
 
-  def createModel(weights: Array[Double], intercept: Double) = {
-    val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) 
++ weights):_*)
+  override def setIntercept(addIntercept: Boolean): this.type = {
+    // TODO: Support adding intercept.
+    if (addIntercept) throw new UnsupportedOperationException("Adding 
intercept is not supported.")
+    this
+  }
+
+  override def createModel(weights: Array[Double], intercept: Double) = {
+    val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
     val weightsScaled = weightsMat.div(xColSd)
-    val interceptScaled = yMean - 
(weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
+    val interceptScaled = yMean - 
weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
 
     new LassoModel(weightsScaled.data, interceptScaled)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d679843a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 8ee40ad..f5f15d1 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix
  * @param intercept Intercept computed for this model.
  */
 class LinearRegressionModel(
-                  override val weights: Array[Double],
-                  override val intercept: Double)
-  extends GeneralizedLinearModel(weights, intercept)
-  with RegressionModel with Serializable {
-
-  override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: 
DoubleMatrix,
-                            intercept: Double) = {
+    override val weights: Array[Double],
+    override val intercept: Double)
+  extends GeneralizedLinearModel(weights, intercept) with RegressionModel with 
Serializable {
+
+  override def predictPoint(
+      dataMatrix: DoubleMatrix,
+      weightMatrix: DoubleMatrix,
+      intercept: Double): Double = {
     dataMatrix.dot(weightMatrix) + intercept
   }
 }
@@ -55,8 +56,7 @@ class LinearRegressionWithSGD private (
     var stepSize: Double,
     var numIterations: Int,
     var miniBatchFraction: Double)
-  extends GeneralizedLinearAlgorithm[LinearRegressionModel]
-  with Serializable {
+  extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {
 
   val gradient = new LeastSquaresGradient()
   val updater = new SimpleUpdater()
@@ -69,7 +69,7 @@ class LinearRegressionWithSGD private (
    */
   def this() = this(1.0, 100, 1.0)
 
-  def createModel(weights: Array[Double], intercept: Double) = {
+  override def createModel(weights: Array[Double], intercept: Double) = {
     new LinearRegressionModel(weights, intercept)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d679843a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index c504d3d..feb100f 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -36,8 +36,10 @@ class RidgeRegressionModel(
   extends GeneralizedLinearModel(weights, intercept)
   with RegressionModel with Serializable {
 
-  override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: 
DoubleMatrix,
-                            intercept: Double) = {
+  override def predictPoint(
+      dataMatrix: DoubleMatrix,
+      weightMatrix: DoubleMatrix,
+      intercept: Double): Double = {
     dataMatrix.dot(weightMatrix) + intercept
   }
 }
@@ -67,7 +69,7 @@ class RidgeRegressionWithSGD private (
     .setMiniBatchFraction(miniBatchFraction)
 
   // We don't want to penalize the intercept in RidgeRegression, so set this 
to false.
-  setIntercept(false)
+  super.setIntercept(false)
 
   var yMean = 0.0
   var xColMean: DoubleMatrix = _
@@ -78,8 +80,14 @@ class RidgeRegressionWithSGD private (
    */
   def this() = this(1.0, 100, 1.0, 1.0)
 
-  def createModel(weights: Array[Double], intercept: Double) = {
-    val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) 
++ weights):_*)
+  override def setIntercept(addIntercept: Boolean): this.type = {
+    // TODO: Support adding intercept.
+    if (addIntercept) throw new UnsupportedOperationException("Adding 
intercept is not supported.")
+    this
+  }
+
+  override def createModel(weights: Array[Double], intercept: Double) = {
+    val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
     val weightsScaled = weightsMat.div(xColSd)
     val interceptScaled = yMean - 
weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d679843a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 64e4cbb..2cebac9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -17,11 +17,8 @@
 
 package org.apache.spark.mllib.regression
 
-
-import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
-import org.apache.spark.SparkContext
 import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
 
 class LassoSuite extends FunSuite with LocalSparkContext {
@@ -104,4 +101,10 @@ class LassoSuite extends FunSuite with LocalSparkContext {
     // Test prediction on Array.
     validatePrediction(validationData.map(row => model.predict(row.features)), 
validationData)
   }
+
+  test("do not support intercept") {
+    intercept[UnsupportedOperationException] {
+      new LassoWithSGD().setIntercept(true)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d679843a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 281f9df..5d251bc 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.mllib.regression
 
-import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
@@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with 
LocalSparkContext {
     // Test prediction on Array.
     validatePrediction(validationData.map(row => model.predict(row.features)), 
validationData)
   }
+
+  // Test if we can correctly learn Y = 10*X1 + 10*X2
+  test("linear regression without intercept") {
+    val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
+      0.0, Array(10.0, 10.0), 100, 42), 2).cache()
+    val linReg = new LinearRegressionWithSGD().setIntercept(false)
+    linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
+
+    val model = linReg.run(testRDD)
+
+    assert(model.intercept === 0.0)
+    assert(model.weights.length === 2)
+    assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
+    assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
+
+    val validationData = LinearDataGenerator.generateLinearInput(
+      0.0, Array(10.0, 10.0), 100, 17)
+    val validationRDD = sc.parallelize(validationData, 2).cache()
+
+    // Test prediction on RDD.
+    validatePrediction(model.predict(validationRDD.map(_.features)).collect(), 
validationData)
+
+    // Test prediction on Array.
+    validatePrediction(validationData.map(row => model.predict(row.features)), 
validationData)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d679843a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 67dd06c..b2044ed 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -17,14 +17,11 @@
 
 package org.apache.spark.mllib.regression
 
-
 import org.jblas.DoubleMatrix
-import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
 
-
 class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
 
   def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
@@ -74,4 +71,10 @@ class RidgeRegressionSuite extends FunSuite with 
LocalSparkContext {
     assert(ridgeErr < linearErr,
       "ridgeError (" + ridgeErr + ") was not less than linearError(" + 
linearErr + ")")
   }
+
+  test("do not support intercept") {
+    intercept[UnsupportedOperationException] {
+      new RidgeRegressionWithSGD().setIntercept(true)
+    }
+  }
 }

Reply via email to