Repository: spark
Updated Branches:
  refs/heads/master 21385d02a -> 393db655c


[SPARK-15858][ML] Fix calculating error by tree stack over flow prob…

## What changes were proposed in this pull request?

What changes were proposed in this pull request?

Improving evaluateEachIteration function in mllib as it fails when trying to 
calculate error by tree for a model that has more than 500 trees

## How was this patch tested?

the batch tested on productions data set (2K rows x 2K features) training a 
gradient boosted model without validation with 1000 maxIteration settings, then 
trying to produce the error by tree, the new patch was able to perform the 
calculation within 30 seconds, while previously it was take hours then fail.

**PS**: It would be better if this PR can be cherry picked into release 
branches 1.6.1 and 2.0

Author: Mahmoud Rawas <mhmo...@gmail.com>
Author: Mahmoud Rawas <mahmoud.ra...@quantium.com.au>

Closes #13624 from mhmoudr/SPARK-15858.master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/393db655
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/393db655
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/393db655

Branch: refs/heads/master
Commit: 393db655c3c43155305fbba1b2f8c48a95f18d93
Parents: 21385d0
Author: Mahmoud Rawas <mhmo...@gmail.com>
Authored: Wed Jun 29 13:12:17 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Wed Jun 29 13:12:17 2016 +0100

----------------------------------------------------------------------
 .../ml/tree/impl/GradientBoostedTrees.scala     | 40 ++++++++++----------
 .../mllib/tree/model/treeEnsembleModels.scala   | 37 ++++++++----------
 2 files changed, 34 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/393db655/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index a0faff2..7bef899 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -205,31 +205,29 @@ private[spark] object GradientBoostedTrees extends 
Logging {
       case _ => data
     }
 
-    val numIterations = trees.length
-    val evaluationArray = Array.fill(numIterations)(0.0)
-    val localTreeWeights = treeWeights
-
-    var predictionAndError = computeInitialPredictionAndError(
-      remappedData, localTreeWeights(0), trees(0), loss)
-
-    evaluationArray(0) = predictionAndError.values.mean()
-
     val broadcastTrees = sc.broadcast(trees)
-    (1 until numIterations).foreach { nTree =>
-      predictionAndError = remappedData.zip(predictionAndError).mapPartitions 
{ iter =>
-        val currentTree = broadcastTrees.value(nTree)
-        val currentTreeWeight = localTreeWeights(nTree)
-        iter.map { case (point, (pred, error)) =>
-          val newPred = updatePrediction(point.features, pred, currentTree, 
currentTreeWeight)
-          val newError = loss.computeError(newPred, point.label)
-          (newPred, newError)
-        }
+    val localTreeWeights = treeWeights
+    val treesIndices = trees.indices
+
+    val dataCount = remappedData.count()
+    val evaluation = remappedData.map { point =>
+      treesIndices.map { idx =>
+        val prediction = broadcastTrees.value(idx)
+          .rootNode
+          .predictImpl(point.features)
+          .prediction
+        prediction * localTreeWeights(idx)
       }
-      evaluationArray(nTree) = predictionAndError.values.mean()
+      .scanLeft(0.0)(_ + _).drop(1)
+      .map(prediction => loss.computeError(prediction, point.label))
     }
+    .aggregate(treesIndices.map(_ => 0.0))(
+      (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)),
+      (a, b) => treesIndices.map(idx => a(idx) + b(idx)))
+    .map(_ / dataCount)
 
-    broadcastTrees.unpersist()
-    evaluationArray
+    broadcastTrees.destroy()
+    evaluation.toArray
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/393db655/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index f7d9b22..657ed0a 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -151,31 +151,24 @@ class GradientBoostedTreesModel @Since("1.2.0") (
       case _ => data
     }
 
-    val numIterations = trees.length
-    val evaluationArray = Array.fill(numIterations)(0.0)
-    val localTreeWeights = treeWeights
-
-    var predictionAndError = 
GradientBoostedTreesModel.computeInitialPredictionAndError(
-      remappedData, localTreeWeights(0), trees(0), loss)
-
-    evaluationArray(0) = predictionAndError.values.mean()
-
     val broadcastTrees = sc.broadcast(trees)
-    (1 until numIterations).foreach { nTree =>
-      predictionAndError = remappedData.zip(predictionAndError).mapPartitions 
{ iter =>
-        val currentTree = broadcastTrees.value(nTree)
-        val currentTreeWeight = localTreeWeights(nTree)
-        iter.map { case (point, (pred, error)) =>
-          val newPred = pred + currentTree.predict(point.features) * 
currentTreeWeight
-          val newError = loss.computeError(newPred, point.label)
-          (newPred, newError)
-        }
-      }
-      evaluationArray(nTree) = predictionAndError.values.mean()
+    val localTreeWeights = treeWeights
+    val treesIndices = trees.indices
+
+    val dataCount = remappedData.count()
+    val evaluation = remappedData.map { point =>
+      treesIndices
+        .map(idx => broadcastTrees.value(idx).predict(point.features) * 
localTreeWeights(idx))
+        .scanLeft(0.0)(_ + _).drop(1)
+        .map(prediction => loss.computeError(prediction, point.label))
     }
+    .aggregate(treesIndices.map(_ => 0.0))(
+      (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)),
+      (a, b) => treesIndices.map(idx => a(idx) + b(idx)))
+    .map(_ / dataCount)
 
-    broadcastTrees.unpersist()
-    evaluationArray
+    broadcastTrees.destroy()
+    evaluation.toArray
   }
 
   override protected def formatVersion: String = 
GradientBoostedTreesModel.formatVersion


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to