Repository: spark
Updated Branches:
  refs/heads/branch-1.0 692dd6936 -> 5812472c1


[SPARK-1157][MLlib] L-BFGS Optimizer based on Breeze's implementation.

This PR uses Breeze's L-BFGS implement, and Breeze dependency has already been 
introduced by Xiangrui's sparse input format work in SPARK-1212. Nice work, 
@mengxr !

When use with regularized updater, we need compute the regVal and regGradient 
(the gradient of regularized part in the cost function), and in the currently 
updater design, we can compute those two values by the following way.

Let's review how updater works when returning newWeights given the input 
parameters.

w' = w - thisIterStepSize * (gradient + regGradient(w))  Note that regGradient 
is function of w!
If we set gradient = 0, thisIterStepSize = 1, then
regGradient(w) = w - w'

As a result, for regVal, it can be computed by

    val regVal = updater.compute(
      weights,
      new DoubleMatrix(initialWeights.length, 1), 0, 1, regParam)._2
and for regGradient, it can be obtained by

      val regGradient = weights.sub(
        updater.compute(weights, new DoubleMatrix(initialWeights.length, 1), 1, 
1, regParam)._1)

The PR includes the tests which compare the result with SGD with/without 
regularization.

We did a comparison between LBFGS and SGD, and often we saw 10x less
steps in LBFGS while the cost of per step is the same (just computing
the gradient).

The following is the paper by Prof. Ng at Stanford comparing different
optimizers including LBFGS and SGD. They use them in the context of
deep learning, but worth as reference.
http://cs.stanford.edu/~jngiam/papers/LeNgiamCoatesLahiriProchnowNg2011.pdf

Author: DB Tsai <dbt...@alpinenow.com>

Closes #353 from dbtsai/dbtsai-LBFGS and squashes the following commits:

984b18e [DB Tsai] L-BFGS Optimizer based on Breeze's implementation. Also fixed 
indentation issue in GradientDescent optimizer.
(cherry picked from commit 6843d637e72e5262d05cfa2b1935152743f2bd5a)

Signed-off-by: Patrick Wendell <pwend...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5812472c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5812472c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5812472c

Branch: refs/heads/branch-1.0
Commit: 5812472c14f60b968bbb44e6cb27ffd9f9ab66e6
Parents: 692dd69
Author: DB Tsai <dbt...@alpinenow.com>
Authored: Tue Apr 15 11:12:47 2014 -0700
Committer: Patrick Wendell <pwend...@gmail.com>
Committed: Tue Apr 15 11:12:59 2014 -0700

----------------------------------------------------------------------
 .../mllib/optimization/GradientDescent.scala    |  28 +-
 .../apache/spark/mllib/optimization/LBFGS.scala | 263 +++++++++++++++++++
 .../spark/mllib/optimization/LBFGSSuite.scala   | 203 ++++++++++++++
 3 files changed, 480 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5812472c/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index f60417f..c75909b 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -34,8 +34,8 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector}
  */
 @DeveloperApi
 class GradientDescent(private var gradient: Gradient, private var updater: 
Updater)
-  extends Optimizer with Logging
-{
+  extends Optimizer with Logging {
+
   private var stepSize: Double = 1.0
   private var numIterations: Int = 100
   private var regParam: Double = 0.0
@@ -139,26 +139,26 @@ object GradientDescent extends Logging {
    *         stochastic loss computed for every iteration.
    */
   def runMiniBatchSGD(
-    data: RDD[(Double, Vector)],
-    gradient: Gradient,
-    updater: Updater,
-    stepSize: Double,
-    numIterations: Int,
-    regParam: Double,
-    miniBatchFraction: Double,
-    initialWeights: Vector): (Vector, Array[Double]) = {
+      data: RDD[(Double, Vector)],
+      gradient: Gradient,
+      updater: Updater,
+      stepSize: Double,
+      numIterations: Int,
+      regParam: Double,
+      miniBatchFraction: Double,
+      initialWeights: Vector): (Vector, Array[Double]) = {
 
     val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
 
-    val nexamples: Long = data.count()
-    val miniBatchSize = nexamples * miniBatchFraction
+    val numExamples = data.count()
+    val miniBatchSize = numExamples * miniBatchFraction
 
     // Initialize weights as a column vector
     var weights = Vectors.dense(initialWeights.toArray)
 
     /**
-     * For the first iteration, the regVal will be initialized as sum of sqrt 
of
-     * weights if it's L2 update; for L1 update; the same logic is followed.
+     * For the first iteration, the regVal will be initialized as sum of 
weight squares
+     * if it's L2 updater; for L1 updater, the same logic is followed.
      */
     var regVal = updater.compute(
       weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, 
regParam)._2

http://git-wip-us.apache.org/repos/asf/spark/blob/5812472c/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
new file mode 100644
index 0000000..969a0c5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.optimization
+
+import scala.collection.mutable.ArrayBuffer
+
+import breeze.linalg.{DenseVector => BDV, axpy}
+import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS}
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
+
+/**
+ * :: DeveloperApi ::
+ * Class used to solve an optimization problem using Limited-memory BFGS.
+ * Reference: [[http://en.wikipedia.org/wiki/Limited-memory_BFGS]]
+ * @param gradient Gradient function to be used.
+ * @param updater Updater to be used to update weights after every iteration.
+ */
+@DeveloperApi
+class LBFGS(private var gradient: Gradient, private var updater: Updater)
+  extends Optimizer with Logging {
+
+  private var numCorrections = 10
+  private var convergenceTol = 1E-4
+  private var maxNumIterations = 100
+  private var regParam = 0.0
+  private var miniBatchFraction = 1.0
+
+  /**
+   * Set the number of corrections used in the LBFGS update. Default 10.
+   * Values of numCorrections less than 3 are not recommended; large values
+   * of numCorrections will result in excessive computing time.
+   * 3 < numCorrections < 10 is recommended.
+   * Restriction: numCorrections > 0
+   */
+  def setNumCorrections(corrections: Int): this.type = {
+    assert(corrections > 0)
+    this.numCorrections = corrections
+    this
+  }
+
+  /**
+   * Set fraction of data to be used for each L-BFGS iteration. Default 1.0.
+   */
+  def setMiniBatchFraction(fraction: Double): this.type = {
+    this.miniBatchFraction = fraction
+    this
+  }
+
+  /**
+   * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
+   * Smaller value will lead to higher accuracy with the cost of more 
iterations.
+   */
+  def setConvergenceTol(tolerance: Int): this.type = {
+    this.convergenceTol = tolerance
+    this
+  }
+
+  /**
+   * Set the maximal number of iterations for L-BFGS. Default 100.
+   */
+  def setMaxNumIterations(iters: Int): this.type = {
+    this.maxNumIterations = iters
+    this
+  }
+
+  /**
+   * Set the regularization parameter. Default 0.0.
+   */
+  def setRegParam(regParam: Double): this.type = {
+    this.regParam = regParam
+    this
+  }
+
+  /**
+   * Set the gradient function (of the loss function of one single data 
example)
+   * to be used for L-BFGS.
+   */
+  def setGradient(gradient: Gradient): this.type = {
+    this.gradient = gradient
+    this
+  }
+
+  /**
+   * Set the updater function to actually perform a gradient step in a given 
direction.
+   * The updater is responsible to perform the update from the regularization 
term as well,
+   * and therefore determines what kind or regularization is used, if any.
+   */
+  def setUpdater(updater: Updater): this.type = {
+    this.updater = updater
+    this
+  }
+
+  override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): 
Vector = {
+    val (weights, _) = LBFGS.runMiniBatchLBFGS(
+      data,
+      gradient,
+      updater,
+      numCorrections,
+      convergenceTol,
+      maxNumIterations,
+      regParam,
+      miniBatchFraction,
+      initialWeights)
+    weights
+  }
+
+}
+
+/**
+ * :: DeveloperApi ::
+ * Top-level method to run L-BFGS.
+ */
+@DeveloperApi
+object LBFGS extends Logging {
+  /**
+   * Run Limited-memory BFGS (L-BFGS) in parallel using mini batches.
+   * In each iteration, we sample a subset (fraction miniBatchFraction) of the 
total data
+   * in order to compute a gradient estimate.
+   * Sampling, and averaging the subgradients over this subset is performed 
using one standard
+   * spark map-reduce in each iteration.
+   *
+   * @param data - Input data for L-BFGS. RDD of the set of data examples, 
each of
+   *               the form (label, [feature values]).
+   * @param gradient - Gradient object (used to compute the gradient of the 
loss function of
+   *                   one single data example)
+   * @param updater - Updater function to actually perform a gradient step in 
a given direction.
+   * @param numCorrections - The number of corrections used in the L-BFGS 
update.
+   * @param convergenceTol - The convergence tolerance of iterations for L-BFGS
+   * @param maxNumIterations - Maximal number of iterations that L-BFGS can be 
run.
+   * @param regParam - Regularization parameter
+   * @param miniBatchFraction - Fraction of the input data set that should be 
used for
+   *                          one iteration of L-BFGS. Default value 1.0.
+   *
+   * @return A tuple containing two elements. The first element is a column 
matrix containing
+   *         weights for every feature, and the second element is an array 
containing the loss
+   *         computed for every iteration.
+   */
+  def runMiniBatchLBFGS(
+      data: RDD[(Double, Vector)],
+      gradient: Gradient,
+      updater: Updater,
+      numCorrections: Int,
+      convergenceTol: Double,
+      maxNumIterations: Int,
+      regParam: Double,
+      miniBatchFraction: Double,
+      initialWeights: Vector): (Vector, Array[Double]) = {
+
+    val lossHistory = new ArrayBuffer[Double](maxNumIterations)
+
+    val numExamples = data.count()
+    val miniBatchSize = numExamples * miniBatchFraction
+
+    val costFun =
+      new CostFun(data, gradient, updater, regParam, miniBatchFraction, 
lossHistory, miniBatchSize)
+
+    val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, 
convergenceTol)
+
+    val weights = Vectors.fromBreeze(
+      lbfgs.minimize(new CachedDiffFunction(costFun), 
initialWeights.toBreeze.toDenseVector))
+
+    logInfo("LBFGS.runMiniBatchSGD finished. Last 10 losses %s".format(
+      lossHistory.takeRight(10).mkString(", ")))
+
+    (weights, lossHistory.toArray)
+  }
+
+  /**
+   * CostFun implements Breeze's DiffFunction[T], which returns the loss and 
gradient
+   * at a particular point (weights). It's used in Breeze's convex 
optimization routines.
+   */
+  private class CostFun(
+    data: RDD[(Double, Vector)],
+    gradient: Gradient,
+    updater: Updater,
+    regParam: Double,
+    miniBatchFraction: Double,
+    lossHistory: ArrayBuffer[Double],
+    miniBatchSize: Double) extends DiffFunction[BDV[Double]] {
+
+    private var i = 0
+
+    override def calculate(weights: BDV[Double]) = {
+      // Have a local copy to avoid the serialization of CostFun object which 
is not serializable.
+      val localData = data
+      val localGradient = gradient
+
+      val (gradientSum, lossSum) = localData.sample(false, miniBatchFraction, 
42 + i)
+        .aggregate((BDV.zeros[Double](weights.size), 0.0))(
+          seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, 
features)) =>
+            val l = localGradient.compute(
+              features, label, Vectors.fromBreeze(weights), 
Vectors.fromBreeze(grad))
+            (grad, loss + l)
+          },
+          combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, 
loss2)) =>
+            (grad1 += grad2, loss1 + loss2)
+          })
+
+      /**
+       * regVal is sum of weight squares if it's L2 updater;
+       * for other updater, the same logic is followed.
+       */
+      val regVal = updater.compute(
+        Vectors.fromBreeze(weights),
+        Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
+
+      val loss = lossSum / miniBatchSize + regVal
+      /**
+       * It will return the gradient part of regularization using updater.
+       *
+       * Given the input parameters, the updater basically does the following,
+       *
+       * w' = w - thisIterStepSize * (gradient + regGradient(w))
+       * Note that regGradient is function of w
+       *
+       * If we set gradient = 0, thisIterStepSize = 1, then
+       *
+       * regGradient(w) = w - w'
+       *
+       * TODO: We need to clean it up by separating the logic of 
regularization out
+       *       from updater to regularizer.
+       */
+      // The following gradientTotal is actually the regularization part of 
gradient.
+      // Will add the gradientSum computed from the data with weights in the 
next step.
+      val gradientTotal = weights - updater.compute(
+        Vectors.fromBreeze(weights),
+        Vectors.dense(new Array[Double](weights.size)), 1, 1, 
regParam)._1.toBreeze
+
+      // gradientTotal = gradientSum / miniBatchSize + gradientTotal
+      axpy(1.0 / miniBatchSize, gradientSum, gradientTotal)
+
+      /**
+       * NOTE: lossSum and loss is computed using the weights from the 
previous iteration
+       * and regVal is the regularization value computed in the previous 
iteration as well.
+       */
+      lossHistory.append(loss)
+
+      i += 1
+
+      (loss, gradientTotal)
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5812472c/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
new file mode 100644
index 0000000..f33770a
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.optimization
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.LocalSparkContext
+
+class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+
+  val nPoints = 10000
+  val A = 2.0
+  val B = -1.5
+
+  val initialB = -1.0
+  val initialWeights = Array(initialB)
+
+  val gradient = new LogisticGradient()
+  val numCorrections = 10
+  val miniBatchFrac = 1.0
+
+  val simpleUpdater = new SimpleUpdater()
+  val squaredL2Updater = new SquaredL2Updater()
+
+  // Add an extra variable consisting of all 1.0's for the intercept.
+  val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
+  val data = testData.map { case LabeledPoint(label, features) =>
+    label -> Vectors.dense(1.0, features.toArray: _*)
+  }
+
+  lazy val dataRDD = sc.parallelize(data, 2).cache()
+
+  def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = {
+    math.abs(x - y) / (math.abs(y) + 1e-15) < tol
+  }
+
+  test("LBFGS loss should be decreasing and match the result of Gradient 
Descent.") {
+    val regParam = 0
+
+    val initialWeightsWithIntercept = Vectors.dense(1.0, initialWeights: _*)
+    val convergenceTol = 1e-12
+    val maxNumIterations = 10
+
+    val (_, loss) = LBFGS.runMiniBatchLBFGS(
+      dataRDD,
+      gradient,
+      simpleUpdater,
+      numCorrections,
+      convergenceTol,
+      maxNumIterations,
+      regParam,
+      miniBatchFrac,
+      initialWeightsWithIntercept)
+
+    // Since the cost function is convex, the loss is guaranteed to be 
monotonically decreasing
+    // with L-BFGS optimizer.
+    // (SGD doesn't guarantee this, and the loss will be fluctuating in the 
optimization process.)
+    assert((loss, loss.tail).zipped.forall(_ > _), "loss should be 
monotonically decreasing.")
+
+    val stepSize = 1.0
+    // Well, GD converges slower, so it requires more iterations!
+    val numGDIterations = 50
+    val (_, lossGD) = GradientDescent.runMiniBatchSGD(
+      dataRDD,
+      gradient,
+      simpleUpdater,
+      stepSize,
+      numGDIterations,
+      regParam,
+      miniBatchFrac,
+      initialWeightsWithIntercept)
+
+    // GD converges a way slower than L-BFGS. To achieve 1% difference,
+    // it requires 90 iterations in GD. No matter how hard we increase
+    // the number of iterations in GD here, the lossGD will be always
+    // larger than lossLBFGS. This is based on observation, no theoretically 
guaranteed
+    assert(Math.abs((lossGD.last - loss.last) / loss.last) < 0.02,
+      "LBFGS should match GD result within 2% difference.")
+  }
+
+  test("LBFGS and Gradient Descent with L2 regularization should get the same 
result.") {
+    val regParam = 0.2
+
+    // Prepare another non-zero weights to compare the loss in the first 
iteration.
+    val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12)
+    val convergenceTol = 1e-12
+    val maxNumIterations = 10
+
+    val (weightLBFGS, lossLBFGS) = LBFGS.runMiniBatchLBFGS(
+      dataRDD,
+      gradient,
+      squaredL2Updater,
+      numCorrections,
+      convergenceTol,
+      maxNumIterations,
+      regParam,
+      miniBatchFrac,
+      initialWeightsWithIntercept)
+
+    val numGDIterations = 50
+    val stepSize = 1.0
+    val (weightGD, lossGD) = GradientDescent.runMiniBatchSGD(
+      dataRDD,
+      gradient,
+      squaredL2Updater,
+      stepSize,
+      numGDIterations,
+      regParam,
+      miniBatchFrac,
+      initialWeightsWithIntercept)
+
+    assert(compareDouble(lossGD(0), lossLBFGS(0)),
+      "The first losses of LBFGS and GD should be the same.")
+
+    // The 2% difference here is based on observation, but is not 
theoretically guaranteed.
+    assert(compareDouble(lossGD.last, lossLBFGS.last, 0.02),
+      "The last losses of LBFGS and GD should be within 2% difference.")
+
+    assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) &&
+      compareDouble(weightLBFGS(1), weightGD(1), 0.02),
+      "The weight differences between LBFGS and GD should be within 2%.")
+  }
+
+  test("The convergence criteria should work as we expect.") {
+    val regParam = 0.0
+
+    /**
+     * For the first run, we set the convergenceTol to 0.0, so that the 
algorithm will
+     * run up to the maxNumIterations which is 8 here.
+     */
+    val initialWeightsWithIntercept = Vectors.dense(0.0, 0.0)
+    val maxNumIterations = 8
+    var convergenceTol = 0.0
+
+    val (_, lossLBFGS1) = LBFGS.runMiniBatchLBFGS(
+      dataRDD,
+      gradient,
+      squaredL2Updater,
+      numCorrections,
+      convergenceTol,
+      maxNumIterations,
+      regParam,
+      miniBatchFrac,
+      initialWeightsWithIntercept)
+
+    // Note that the first loss is computed with initial weights,
+    // so the total numbers of loss will be numbers of iterations + 1
+    assert(lossLBFGS1.length == 9)
+
+    convergenceTol = 0.1
+    val (_, lossLBFGS2) = LBFGS.runMiniBatchLBFGS(
+      dataRDD,
+      gradient,
+      squaredL2Updater,
+      numCorrections,
+      convergenceTol,
+      maxNumIterations,
+      regParam,
+      miniBatchFrac,
+      initialWeightsWithIntercept)
+
+    // Based on observation, lossLBFGS2 runs 3 iterations, no theoretically 
guaranteed.
+    assert(lossLBFGS2.length == 4)
+    assert((lossLBFGS2(2) - lossLBFGS2(3)) / lossLBFGS2(2) < convergenceTol)
+
+    convergenceTol = 0.01
+    val (_, lossLBFGS3) = LBFGS.runMiniBatchLBFGS(
+      dataRDD,
+      gradient,
+      squaredL2Updater,
+      numCorrections,
+      convergenceTol,
+      maxNumIterations,
+      regParam,
+      miniBatchFrac,
+      initialWeightsWithIntercept)
+
+    // With smaller convergenceTol, it takes more steps.
+    assert(lossLBFGS3.length > lossLBFGS2.length)
+
+    // Based on observation, lossLBFGS2 runs 5 iterations, no theoretically 
guaranteed.
+    assert(lossLBFGS3.length == 6)
+    assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol)
+  }
+}

Reply via email to