Repository: spark
Updated Branches:
  refs/heads/master cc18a7199 -> df78a934a


[SPARK-9835][ML] Implement IterativelyReweightedLeastSquares solver

Implement ```IterativelyReweightedLeastSquares``` solver for GLM. I consider it 
as a solver rather than estimator, it only used internal so I keep it 
```private[ml]```.
There are two limitations in the current implementation compared with R:
* It can not support ```Tuple``` as response for ```Binomial``` family, such as 
the following code:
```
glm( cbind(using, notUsing) ~  age + education + wantsMore , family = binomial)
```
* It does not support ```offset```.

Because I considered that ```RFormula``` did not support ```Tuple``` as label 
and ```offset``` keyword, so I simplified the implementation. But to add 
support for these two functions is not very hard, I can do it in follow-up PR 
if it is necessary. Meanwhile, we can also add R-like statistic summary for 
IRLS.
The implementation refers R, 
[statsmodels](https://github.com/statsmodels/statsmodels) and 
[sparkGLM](https://github.com/AlteryxLabs/sparkGLM).
Please focus on the main structure and overpass minor issues/docs that I will 
update later. Any comments and opinions will be appreciated.

cc mengxr jkbradley

Author: Yanbo Liang <yblia...@gmail.com>

Closes #10639 from yanboliang/spark-9835.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/df78a934
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/df78a934
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/df78a934

Branch: refs/heads/master
Commit: df78a934a07a4ce5af43243be9ba5fe60b91eee6
Parents: cc18a71
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Thu Jan 28 14:29:47 2016 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Thu Jan 28 14:29:47 2016 -0800

----------------------------------------------------------------------
 .../IterativelyReweightedLeastSquares.scala     | 108 ++++++++++
 .../spark/ml/optim/WeightedLeastSquares.scala   |   7 +-
 ...IterativelyReweightedLeastSquaresSuite.scala | 200 +++++++++++++++++++
 3 files changed, 314 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/df78a934/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala
new file mode 100644
index 0000000..6aa44e6
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.optim
+
+import org.apache.spark.Logging
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.rdd.RDD
+
+/**
+ * Model fitted by [[IterativelyReweightedLeastSquares]].
+ * @param coefficients model coefficients
+ * @param intercept model intercept
+ */
+private[ml] class IterativelyReweightedLeastSquaresModel(
+    val coefficients: DenseVector,
+    val intercept: Double) extends Serializable
+
+/**
+ * Implements the method of iteratively reweighted least squares (IRLS) which 
is used to solve
+ * certain optimization problems by an iterative method. In each step of the 
iterations, it
+ * involves solving a weighted lease squares (WLS) problem by 
[[WeightedLeastSquares]].
+ * It can be used to find maximum likelihood estimates of a generalized linear 
model (GLM),
+ * find M-estimator in robust regression and other optimization problems.
+ *
+ * @param initialModel the initial guess model.
+ * @param reweightFunc the reweight function which is used to update offsets 
and weights
+ *                     at each iteration.
+ * @param fitIntercept whether to fit intercept.
+ * @param regParam L2 regularization parameter used by WLS.
+ * @param maxIter maximum number of iterations.
+ * @param tol the convergence tolerance.
+ *
+ * @see [[http://www.jstor.org/stable/2345503 P. J. Green, Iteratively 
Reweighted Least Squares
+ *     for Maximum Likelihood Estimation, and some Robust and Resistant 
Alternatives,
+ *     Journal of the Royal Statistical Society. Series B, 1984.]]
+ */
+private[ml] class IterativelyReweightedLeastSquares(
+    val initialModel: WeightedLeastSquaresModel,
+    val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, 
Double),
+    val fitIntercept: Boolean,
+    val regParam: Double,
+    val maxIter: Int,
+    val tol: Double) extends Logging with Serializable {
+
+  def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = {
+
+    var converged = false
+    var iter = 0
+
+    var model: WeightedLeastSquaresModel = initialModel
+    var oldModel: WeightedLeastSquaresModel = null
+
+    while (iter < maxIter && !converged) {
+
+      oldModel = model
+
+      // Update offsets and weights using reweightFunc
+      val newInstances = instances.map { instance =>
+        val (newOffset, newWeight) = reweightFunc(instance, oldModel)
+        Instance(newOffset, newWeight, instance.features)
+      }
+
+      // Estimate new model
+      model = new WeightedLeastSquares(fitIntercept, regParam, 
standardizeFeatures = false,
+        standardizeLabel = false).fit(newInstances)
+
+      // Check convergence
+      val oldCoefficients = oldModel.coefficients
+      val coefficients = model.coefficients
+      BLAS.axpy(-1.0, coefficients, oldCoefficients)
+      val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) =>
+        math.max(math.abs(x), math.abs(y))
+      }
+      val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept 
- model.intercept))
+
+      if (maxTol < tol) {
+        converged = true
+        logInfo(s"IRLS converged in $iter iterations.")
+      }
+
+      logInfo(s"Iteration $iter : relative tolerance = $maxTol")
+      iter = iter + 1
+
+      if (iter == maxIter) {
+        logInfo(s"IRLS reached the max number of iterations: $maxIter.")
+      }
+
+    }
+
+    new IterativelyReweightedLeastSquaresModel(model.coefficients, 
model.intercept)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/df78a934/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala 
b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 797870e..61b3642 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -31,7 +31,12 @@ import org.apache.spark.rdd.RDD
 private[ml] class WeightedLeastSquaresModel(
     val coefficients: DenseVector,
     val intercept: Double,
-    val diagInvAtWA: DenseVector) extends Serializable
+    val diagInvAtWA: DenseVector) extends Serializable {
+
+  def predict(features: Vector): Double = {
+    BLAS.dot(coefficients, features) + intercept
+  }
+}
 
 /**
  * Weighted least squares solver via normal equation.

http://git-wip-us.apache.org/repos/asf/spark/blob/df78a934/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala
new file mode 100644
index 0000000..6040212
--- /dev/null
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.optim
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.rdd.RDD
+
+class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with 
MLlibTestSparkContext {
+
+  private var instances1: RDD[Instance] = _
+  private var instances2: RDD[Instance] = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    /*
+       R code:
+
+       A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2)
+       b <- c(1, 0, 1, 0)
+       w <- c(1, 2, 3, 4)
+     */
+    instances1 = sc.parallelize(Seq(
+      Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+      Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)),
+      Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)),
+      Instance(0.0, 4.0, Vectors.dense(3.0, 3.0))
+    ), 2)
+    /*
+       R code:
+
+       A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+       b <- c(2, 8, 3, 9)
+       w <- c(1, 2, 3, 4)
+     */
+    instances2 = sc.parallelize(Seq(
+      Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+      Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)),
+      Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)),
+      Instance(9.0, 4.0, Vectors.dense(3.0, 13.0))
+    ), 2)
+  }
+
+  test("IRLS against GLM with Binomial errors") {
+    /*
+       R code:
+
+       df <- as.data.frame(cbind(A, b))
+       for (formula in c(b ~ . -1, b ~ .)) {
+         model <- glm(formula, family="binomial", data=df, weights=w)
+         print(as.vector(coef(model)))
+       }
+
+       [1] -0.30216651 -0.04452045
+       [1]  3.5651651 -1.2334085 -0.7348971
+     */
+    val expected = Seq(
+      Vectors.dense(0.0, -0.30216651, -0.04452045),
+      Vectors.dense(3.5651651, -1.2334085, -0.7348971))
+
+    import IterativelyReweightedLeastSquaresSuite._
+
+    var idx = 0
+    for (fitIntercept <- Seq(false, true)) {
+      val newInstances = instances1.map { instance =>
+        val mu = (instance.label + 0.5) / 2.0
+        val eta = math.log(mu / (1.0 - mu))
+        Instance(eta, instance.weight, instance.features)
+      }
+      val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0,
+        standardizeFeatures = false, standardizeLabel = 
false).fit(newInstances)
+      val irls = new IterativelyReweightedLeastSquares(initial, 
BinomialReweightFunc,
+        fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances1)
+      val actual = Vectors.dense(irls.intercept, irls.coefficients(0), 
irls.coefficients(1))
+      assert(actual ~== expected(idx) absTol 1e-4)
+      idx += 1
+    }
+  }
+
+  test("IRLS against GLM with Poisson errors") {
+    /*
+       R code:
+
+       df <- as.data.frame(cbind(A, b))
+       for (formula in c(b ~ . -1, b ~ .)) {
+         model <- glm(formula, family="poisson", data=df, weights=w)
+         print(as.vector(coef(model)))
+       }
+
+       [1] -0.09607792  0.18375613
+       [1]  6.299947  3.324107 -1.081766
+     */
+    val expected = Seq(
+      Vectors.dense(0.0, -0.09607792, 0.18375613),
+      Vectors.dense(6.299947, 3.324107, -1.081766))
+
+    import IterativelyReweightedLeastSquaresSuite._
+
+    var idx = 0
+    for (fitIntercept <- Seq(false, true)) {
+      val yMean = instances2.map(_.label).mean
+      val newInstances = instances2.map { instance =>
+        val mu = (instance.label + yMean) / 2.0
+        val eta = math.log(mu)
+        Instance(eta, instance.weight, instance.features)
+      }
+      val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0,
+        standardizeFeatures = false, standardizeLabel = 
false).fit(newInstances)
+      val irls = new IterativelyReweightedLeastSquares(initial, 
PoissonReweightFunc,
+        fitIntercept, regParam = 0.0, maxIter = 25, tol = 1e-8).fit(instances2)
+      val actual = Vectors.dense(irls.intercept, irls.coefficients(0), 
irls.coefficients(1))
+      assert(actual ~== expected(idx) absTol 1e-4)
+      idx += 1
+    }
+  }
+
+  test("IRLS against L1Regression") {
+    /*
+       R code:
+
+       library(quantreg)
+
+       df <- as.data.frame(cbind(A, b))
+       for (formula in c(b ~ . -1, b ~ .)) {
+         model <- rq(formula, data=df, weights=w)
+         print(as.vector(coef(model)))
+       }
+
+       [1] 1.266667 0.400000
+       [1] 29.5 17.0 -5.5
+     */
+    val expected = Seq(
+      Vectors.dense(0.0, 1.266667, 0.400000),
+      Vectors.dense(29.5, 17.0, -5.5))
+
+    import IterativelyReweightedLeastSquaresSuite._
+
+    var idx = 0
+    for (fitIntercept <- Seq(false, true)) {
+      val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0,
+        standardizeFeatures = false, standardizeLabel = false).fit(instances2)
+      val irls = new IterativelyReweightedLeastSquares(initial, 
L1RegressionReweightFunc,
+        fitIntercept, regParam = 0.0, maxIter = 200, tol = 
1e-7).fit(instances2)
+      val actual = Vectors.dense(irls.intercept, irls.coefficients(0), 
irls.coefficients(1))
+      assert(actual ~== expected(idx) absTol 1e-4)
+      idx += 1
+    }
+  }
+}
+
+object IterativelyReweightedLeastSquaresSuite {
+
+  def BinomialReweightFunc(
+      instance: Instance,
+      model: WeightedLeastSquaresModel): (Double, Double) = {
+    val eta = model.predict(instance.features)
+    val mu = 1.0 / (1.0 + math.exp(-1.0 * eta))
+    val z = eta + (instance.label - mu) / (mu * (1.0 - mu))
+    val w = mu * (1 - mu) * instance.weight
+    (z, w)
+  }
+
+  def PoissonReweightFunc(
+      instance: Instance,
+      model: WeightedLeastSquaresModel): (Double, Double) = {
+    val eta = model.predict(instance.features)
+    val mu = math.exp(eta)
+    val z = eta + (instance.label - mu) / mu
+    val w = mu * instance.weight
+    (z, w)
+  }
+
+  def L1RegressionReweightFunc(
+      instance: Instance,
+      model: WeightedLeastSquaresModel): (Double, Double) = {
+    val eta = model.predict(instance.features)
+    val e = math.max(math.abs(eta - instance.label), 1e-7)
+    val w = 1 / e
+    val y = instance.label
+    (y, w)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to