Repository: spark
Updated Branches:
  refs/heads/master a95043b17 -> 25e271d9f


[SPARK-6025] [MLlib] Add helper method evaluateEachIteration to extract 
learning curve

Added evaluateEachIteration to allow the user to manually extract the error for 
each iteration of GradientBoosting. The internal optimisation can be dealt with 
later.

Author: MechCoder <manojkumarsivaraj...@gmail.com>

Closes #4906 from MechCoder/spark-6025 and squashes the following commits:

67146ab [MechCoder] Minor
352001f [MechCoder] Minor
6e8aa10 [MechCoder] Made the following changes Used mapPartition instead of map 
Refactored computeError and unpersisted broadcast variables
bc99ac6 [MechCoder] Refactor the method and stuff
dbda033 [MechCoder] [SPARK-6025] Add helper method evaluateEachIteration to 
extract learning curve


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/25e271d9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/25e271d9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/25e271d9

Branch: refs/heads/master
Commit: 25e271d9fbb3394931d23822a1b2020e9d9b46b3
Parents: a95043b
Author: MechCoder <manojkumarsivaraj...@gmail.com>
Authored: Fri Mar 20 17:14:09 2015 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Fri Mar 20 17:14:09 2015 -0700

----------------------------------------------------------------------
 docs/mllib-ensembles.md                         |  4 +-
 .../spark/mllib/tree/loss/AbsoluteError.scala   | 17 ++----
 .../apache/spark/mllib/tree/loss/LogLoss.scala  | 20 ++------
 .../org/apache/spark/mllib/tree/loss/Loss.scala | 14 ++++-
 .../spark/mllib/tree/loss/SquaredError.scala    | 17 ++----
 .../mllib/tree/model/treeEnsembleModels.scala   | 54 ++++++++++++++++++++
 .../mllib/tree/GradientBoostedTreesSuite.scala  | 16 +++++-
 7 files changed, 96 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/25e271d9/docs/mllib-ensembles.md
----------------------------------------------------------------------
diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md
index cbfb682..7521fb1 100644
--- a/docs/mllib-ensembles.md
+++ b/docs/mllib-ensembles.md
@@ -464,8 +464,8 @@ first one being the training dataset and the second being 
the validation dataset
 The training is stopped when the improvement in the validation error is not 
more than a certain tolerance
 (supplied by the `validationTol` argument in `BoostingStrategy`). In practice, 
the validation error
 decreases initially and later increases. There might be cases in which the 
validation error does not change monotonically,
-and the user is advised to set a large enough negative tolerance and examine 
the validation curve to to tune the number of
-iterations.
+and the user is advised to set a large enough negative tolerance and examine 
the validation curve using `evaluateEachIteration`
+(which gives the error or loss per iteration) to tune the number of iterations.
 
 ### Examples
 

http://git-wip-us.apache.org/repos/asf/spark/blob/25e271d9/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index d1bde15..793dd66 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -47,18 +47,9 @@ object AbsoluteError extends Loss {
     if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
   }
 
-  /**
-   * Method to calculate loss of the base learner for the gradient boosting 
calculation.
-   * Note: This method is not used by the gradient boosting algorithm but is 
useful for debugging
-   * purposes.
-   * @param model Ensemble model
-   * @param data Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
-   * @return  Mean absolute error of model on data
-   */
-  override def computeError(model: TreeEnsembleModel, data: 
RDD[LabeledPoint]): Double = {
-    data.map { y =>
-      val err = model.predict(y.features) - y.label
-      math.abs(err)
-    }.mean()
+  override def computeError(prediction: Double, label: Double): Double = {
+    val err = label - prediction
+    math.abs(err)
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/25e271d9/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 55213e6..51b1aed 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -50,20 +50,10 @@ object LogLoss extends Loss {
     - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
   }
 
-  /**
-   * Method to calculate loss of the base learner for the gradient boosting 
calculation.
-   * Note: This method is not used by the gradient boosting algorithm but is 
useful for debugging
-   * purposes.
-   * @param model Ensemble model
-   * @param data Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
-   * @return Mean log loss of model on data
-   */
-  override def computeError(model: TreeEnsembleModel, data: 
RDD[LabeledPoint]): Double = {
-    data.map { case point =>
-      val prediction = model.predict(point.features)
-      val margin = 2.0 * point.label * prediction
-      // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more 
numerically stable.
-      2.0 * MLUtils.log1pExp(-margin)
-    }.mean()
+  override def computeError(prediction: Double, label: Double): Double = {
+    val margin = 2.0 * label * prediction
+    // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more 
numerically stable.
+    2.0 * MLUtils.log1pExp(-margin)
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/25e271d9/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index e1169d9..357869f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -47,6 +47,18 @@ trait Loss extends Serializable {
    * @param data Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
    * @return Measure of model error on data
    */
-  def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double
+  def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double 
= {
+    data.map(point => computeError(model.predict(point.features), 
point.label)).mean()
+  }
+
+  /**
+   * Method to calculate loss when the predictions are already known.
+   * Note: This method is used in the method evaluateEachIteration to avoid 
recomputing the
+   * predicted values from previously fit trees.
+   * @param prediction Predicted label.
+   * @param label True label.
+   * @return Measure of model error on datapoint.
+   */
+  def computeError(prediction: Double, label: Double): Double
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/25e271d9/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index 50ecaa2..b990707 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -47,18 +47,9 @@ object SquaredError extends Loss {
     2.0 * (model.predict(point.features) - point.label)
   }
 
-  /**
-   * Method to calculate loss of the base learner for the gradient boosting 
calculation.
-   * Note: This method is not used by the gradient boosting algorithm but is 
useful for debugging
-   * purposes.
-   * @param model Ensemble model
-   * @param data Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
-   * @return  Mean squared error of model on data
-   */
-  override def computeError(model: TreeEnsembleModel, data: 
RDD[LabeledPoint]): Double = {
-    data.map { y =>
-      val err = model.predict(y.features) - y.label
-      err * err
-    }.mean()
+  override def computeError(prediction: Double, label: Double): Double = {
+    val err = prediction - label
+    err * err
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/25e271d9/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index f160852..1950254 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -28,9 +28,11 @@ import org.apache.spark.{Logging, SparkContext}
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.mllib.tree.loss.Loss
 import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLContext
@@ -108,6 +110,58 @@ class GradientBoostedTreesModel(
   }
 
   override protected def formatVersion: String = 
TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+
+  /**
+   * Method to compute error or loss for every iteration of gradient boosting.
+   * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+   * @param loss evaluation metric.
+   * @return an array with index i having the losses or errors for the ensemble
+   *         containing the first i+1 trees
+   */
+  def evaluateEachIteration(
+      data: RDD[LabeledPoint],
+      loss: Loss): Array[Double] = {
+
+    val sc = data.sparkContext
+    val remappedData = algo match {
+      case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, 
x.features))
+      case _ => data
+    }
+
+    val numIterations = trees.length
+    val evaluationArray = Array.fill(numIterations)(0.0)
+
+    var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
+      val pred = treeWeights(0) * trees(0).predict(i.features)
+      val error = loss.computeError(pred, i.label)
+      (pred, error)
+    }
+    evaluationArray(0) = predictionAndError.values.mean()
+
+    // Avoid the model being copied across numIterations.
+    val broadcastTrees = sc.broadcast(trees)
+    val broadcastWeights = sc.broadcast(treeWeights)
+
+    (1 until numIterations).map { nTree =>
+      predictionAndError = remappedData.zip(predictionAndError).mapPartitions 
{ iter =>
+        val currentTree = broadcastTrees.value(nTree)
+        val currentTreeWeight = broadcastWeights.value(nTree)
+        iter.map {
+          case (point, (pred, error)) => {
+            val newPred = pred + currentTree.predict(point.features) * 
currentTreeWeight
+            val newError = loss.computeError(newPred, point.label)
+            (newPred, newError)
+          }
+        }
+      }
+      evaluationArray(nTree) = predictionAndError.values.mean()
+    }
+
+    broadcastTrees.unpersist()
+    broadcastWeights.unpersist()
+    evaluationArray
+  }
+
 }
 
 object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {

http://git-wip-us.apache.org/repos/asf/spark/blob/25e271d9/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index b437aea..55b0bac 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -175,10 +175,11 @@ class GradientBoostedTreesSuite extends FunSuite with 
MLlibTestSparkContext {
           new BoostingStrategy(treeStrategy, loss, numIterations, 
validationTol = 0.0)
         val gbtValidate = new GradientBoostedTrees(boostingStrategy)
           .runWithValidation(trainRdd, validateRdd)
-        assert(gbtValidate.numTrees !== numIterations)
+        val numTrees = gbtValidate.numTrees
+        assert(numTrees !== numIterations)
 
         // Test that it performs better on the validation dataset.
-        val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
+        val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
         val (errorWithoutValidation, errorWithValidation) = {
           if (algo == Classification) {
             val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * 
x.label - 1, x.features))
@@ -188,6 +189,17 @@ class GradientBoostedTreesSuite extends FunSuite with 
MLlibTestSparkContext {
           }
         }
         assert(errorWithValidation <= errorWithoutValidation)
+
+        // Test that results from evaluateEachIteration comply with 
runWithValidation.
+        // Note that convergenceTol is set to 0.0
+        val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
+        assert(evaluationArray.length === numIterations)
+        assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
+        var i = 1
+        while (i < numTrees) {
+          assert(evaluationArray(i) <= evaluationArray(i - 1))
+          i += 1
+        }
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to