Repository: spark
Updated Branches:
  refs/heads/branch-1.5 6d46e9b7c -> b3117d312


[SPARK-5133] [ML] Added featureImportance to RandomForestClassifier and 
Regressor

Added featureImportance to RandomForestClassifier and Regressor.

This follows the scikit-learn implementation here: 
[https://github.com/scikit-learn/scikit-learn/blob/a95203b249c1cf392f86d001ad999e29b2392739/sklearn/tree/_tree.pyx#L3341]

CC: yanboliang  Would you mind taking a look?  Thanks!

Author: Joseph K. Bradley <jos...@databricks.com>
Author: Feynman Liang <fli...@databricks.com>

Closes #7838 from jkbradley/dt-feature-importance and squashes the following 
commits:

72a167a [Joseph K. Bradley] fixed unit test
86cea5f [Joseph K. Bradley] Modified RF featuresImportances to return Vector 
instead of Map
5aa74f0 [Joseph K. Bradley] finally fixed unit test for real
33df5db [Joseph K. Bradley] fix unit test
42a2d3b [Joseph K. Bradley] fix unit test
fe94e72 [Joseph K. Bradley] modified feature importance unit tests
cc693ee [Feynman Liang] Add classifier tests
79a6f87 [Feynman Liang] Compare dense vectors in test
21d01fc [Feynman Liang] Added failing SKLearn test
ac0b254 [Joseph K. Bradley] Added featureImportance to 
RandomForestClassifier/Regressor.  Need to add unit tests

(cherry picked from commit ff9169a002f1b75231fd25b7d04157a912503038)
Signed-off-by: Joseph K. Bradley <jos...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b3117d31
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b3117d31
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b3117d31

Branch: refs/heads/branch-1.5
Commit: b3117d312332af3b4bd416857f632cacb5230feb
Parents: 6d46e9b
Author: Joseph K. Bradley <jos...@databricks.com>
Authored: Mon Aug 3 12:17:46 2015 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Mon Aug 3 12:17:56 2015 -0700

----------------------------------------------------------------------
 .../classification/RandomForestClassifier.scala |  30 +++++-
 .../ml/regression/RandomForestRegressor.scala   |  33 ++++--
 .../scala/org/apache/spark/ml/tree/Node.scala   |  19 +++-
 .../spark/ml/tree/impl/RandomForest.scala       |  92 ++++++++++++++++
 .../org/apache/spark/ml/tree/treeModels.scala   |   6 ++
 .../JavaRandomForestClassifierSuite.java        |   2 +
 .../JavaRandomForestRegressorSuite.java         |   2 +
 .../RandomForestClassifierSuite.scala           |  31 +++++-
 .../org/apache/spark/ml/impl/TreeTests.scala    |  18 ++++
 .../regression/RandomForestRegressorSuite.scala |  27 ++++-
 .../spark/ml/tree/impl/RandomForestSuite.scala  | 107 +++++++++++++++++++
 11 files changed, 351 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b3117d31/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 56e80cc..b59826a 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -95,7 +95,8 @@ final class RandomForestClassifier(override val uid: String)
     val trees =
       RandomForest.run(oldDataset, strategy, getNumTrees, 
getFeatureSubsetStrategy, getSeed)
         .map(_.asInstanceOf[DecisionTreeClassificationModel])
-    new RandomForestClassificationModel(trees, numClasses)
+    val numFeatures = oldDataset.first().features.size
+    new RandomForestClassificationModel(trees, numFeatures, numClasses)
   }
 
   override def copy(extra: ParamMap): RandomForestClassifier = 
defaultCopy(extra)
@@ -118,11 +119,13 @@ object RandomForestClassifier {
  * features.
  * @param _trees  Decision trees in the ensemble.
  *               Warning: These have null parents.
+ * @param numFeatures  Number of features used by this model
  */
 @Experimental
 final class RandomForestClassificationModel private[ml] (
     override val uid: String,
     private val _trees: Array[DecisionTreeClassificationModel],
+    val numFeatures: Int,
     override val numClasses: Int)
   extends ProbabilisticClassificationModel[Vector, 
RandomForestClassificationModel]
   with TreeEnsembleModel with Serializable {
@@ -133,8 +136,8 @@ final class RandomForestClassificationModel private[ml] (
    * Construct a random forest classification model, with all trees weighted 
equally.
    * @param trees  Component trees
    */
-  def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) =
-    this(Identifiable.randomUID("rfc"), trees, numClasses)
+  def this(trees: Array[DecisionTreeClassificationModel], numFeatures: Int, 
numClasses: Int) =
+    this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
 
   override def trees: Array[DecisionTreeModel] = 
_trees.asInstanceOf[Array[DecisionTreeModel]]
 
@@ -182,13 +185,30 @@ final class RandomForestClassificationModel private[ml] (
   }
 
   override def copy(extra: ParamMap): RandomForestClassificationModel = {
-    copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), 
extra)
+    copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, 
numClasses), extra)
   }
 
   override def toString: String = {
     s"RandomForestClassificationModel with $numTrees trees"
   }
 
+  /**
+   * Estimate of the importance of each feature.
+   *
+   * This generalizes the idea of "Gini" importance to other losses,
+   * following the explanation of Gini importance from "Random Forests" 
documentation
+   * by Leo Breiman and Adele Cutler, and following the implementation from 
scikit-learn.
+   *
+   * This feature importance is calculated as follows:
+   *  - Average over trees:
+   *     - importance(feature j) = sum (over nodes which split on feature j) 
of the gain,
+   *       where gain is scaled by the number of instances passing through node
+   *     - Normalize importances for tree based on total number of training 
instances used
+   *       to build tree.
+   *  - Normalize feature importance vector to sum to 1.
+   */
+  lazy val featureImportances: Vector = RandomForest.featureImportances(trees, 
numFeatures)
+
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldRandomForestModel = {
     new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
@@ -210,6 +230,6 @@ private[ml] object RandomForestClassificationModel {
       DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
     }
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
-    new RandomForestClassificationModel(uid, newTrees, numClasses)
+    new RandomForestClassificationModel(uid, newTrees, -1, numClasses)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b3117d31/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 17fb1ad..1ee43c8 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel 
=> OldRandomForestMo
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
+
 
 /**
  * :: Experimental ::
@@ -87,7 +87,8 @@ final class RandomForestRegressor(override val uid: String)
     val trees =
       RandomForest.run(oldDataset, strategy, getNumTrees, 
getFeatureSubsetStrategy, getSeed)
         .map(_.asInstanceOf[DecisionTreeRegressionModel])
-    new RandomForestRegressionModel(trees)
+    val numFeatures = oldDataset.first().features.size
+    new RandomForestRegressionModel(trees, numFeatures)
   }
 
   override def copy(extra: ParamMap): RandomForestRegressor = 
defaultCopy(extra)
@@ -108,11 +109,13 @@ object RandomForestRegressor {
  * [[http://en.wikipedia.org/wiki/Random_forest  Random Forest]] model for 
regression.
  * It supports both continuous and categorical features.
  * @param _trees  Decision trees in the ensemble.
+ * @param numFeatures  Number of features used by this model
  */
 @Experimental
 final class RandomForestRegressionModel private[ml] (
     override val uid: String,
-    private val _trees: Array[DecisionTreeRegressionModel])
+    private val _trees: Array[DecisionTreeRegressionModel],
+    val numFeatures: Int)
   extends PredictionModel[Vector, RandomForestRegressionModel]
   with TreeEnsembleModel with Serializable {
 
@@ -122,7 +125,8 @@ final class RandomForestRegressionModel private[ml] (
    * Construct a random forest regression model, with all trees weighted 
equally.
    * @param trees  Component trees
    */
-  def this(trees: Array[DecisionTreeRegressionModel]) = 
this(Identifiable.randomUID("rfr"), trees)
+  def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
+    this(Identifiable.randomUID("rfr"), trees, numFeatures)
 
   override def trees: Array[DecisionTreeModel] = 
_trees.asInstanceOf[Array[DecisionTreeModel]]
 
@@ -147,13 +151,30 @@ final class RandomForestRegressionModel private[ml] (
   }
 
   override def copy(extra: ParamMap): RandomForestRegressionModel = {
-    copyValues(new RandomForestRegressionModel(uid, _trees), extra)
+    copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), 
extra)
   }
 
   override def toString: String = {
     s"RandomForestRegressionModel with $numTrees trees"
   }
 
+  /**
+   * Estimate of the importance of each feature.
+   *
+   * This generalizes the idea of "Gini" importance to other losses,
+   * following the explanation of Gini importance from "Random Forests" 
documentation
+   * by Leo Breiman and Adele Cutler, and following the implementation from 
scikit-learn.
+   *
+   * This feature importance is calculated as follows:
+   *  - Average over trees:
+   *     - importance(feature j) = sum (over nodes which split on feature j) 
of the gain,
+   *       where gain is scaled by the number of instances passing through node
+   *     - Normalize importances for tree based on total number of training 
instances used
+   *       to build tree.
+   *  - Normalize feature importance vector to sum to 1.
+   */
+  lazy val featureImportances: Vector = RandomForest.featureImportances(trees, 
numFeatures)
+
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldRandomForestModel = {
     new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
@@ -173,6 +194,6 @@ private[ml] object RandomForestRegressionModel {
       // parent for each tree is null since there is no good way to set this.
       DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
     }
-    new RandomForestRegressionModel(parent.uid, newTrees)
+    new RandomForestRegressionModel(parent.uid, newTrees, -1)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b3117d31/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index 8879352..cd24931 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -44,7 +44,7 @@ sealed abstract class Node extends Serializable {
    * and probabilities.
    * For classification, the array of class counts must be normalized to a 
probability distribution.
    */
-  private[tree] def impurityStats: ImpurityCalculator
+  private[ml] def impurityStats: ImpurityCalculator
 
   /** Recursive prediction helper method */
   private[ml] def predictImpl(features: Vector): LeafNode
@@ -72,6 +72,12 @@ sealed abstract class Node extends Serializable {
    * @param id  Node ID using old format IDs
    */
   private[ml] def toOld(id: Int): OldNode
+
+  /**
+   * Trace down the tree, and return the largest feature index used in any 
split.
+   * @return  Max feature index used in a split, or -1 if there are no splits 
(single leaf node).
+   */
+  private[ml] def maxSplitFeatureIndex(): Int
 }
 
 private[ml] object Node {
@@ -109,7 +115,7 @@ private[ml] object Node {
 final class LeafNode private[ml] (
     override val prediction: Double,
     override val impurity: Double,
-    override val impurityStats: ImpurityCalculator) extends Node {
+    override private[ml] val impurityStats: ImpurityCalculator) extends Node {
 
   override def toString: String =
     s"LeafNode(prediction = $prediction, impurity = $impurity)"
@@ -129,6 +135,8 @@ final class LeafNode private[ml] (
     new OldNode(id, new OldPredict(prediction, prob = 
impurityStats.prob(prediction)),
       impurity, isLeaf = true, None, None, None, None)
   }
+
+  override private[ml] def maxSplitFeatureIndex(): Int = -1
 }
 
 /**
@@ -150,7 +158,7 @@ final class InternalNode private[ml] (
     val leftChild: Node,
     val rightChild: Node,
     val split: Split,
-    override val impurityStats: ImpurityCalculator) extends Node {
+    override private[ml] val impurityStats: ImpurityCalculator) extends Node {
 
   override def toString: String = {
     s"InternalNode(prediction = $prediction, impurity = $impurity, split = 
$split)"
@@ -190,6 +198,11 @@ final class InternalNode private[ml] (
         new OldPredict(leftChild.prediction, prob = 0.0),
         new OldPredict(rightChild.prediction, prob = 0.0))))
   }
+
+  override private[ml] def maxSplitFeatureIndex(): Int = {
+    math.max(split.featureIndex,
+      math.max(leftChild.maxSplitFeatureIndex(), 
rightChild.maxSplitFeatureIndex()))
+  }
 }
 
 private object InternalNode {

http://git-wip-us.apache.org/repos/asf/spark/blob/b3117d31/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index a8b90d9..4ac51a4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -26,6 +26,7 @@ import org.apache.spark.Logging
 import org.apache.spark.ml.classification.DecisionTreeClassificationModel
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
 import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => 
OldStrategy}
 import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, 
DecisionTreeMetadata,
@@ -34,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
 import org.apache.spark.mllib.tree.model.ImpurityStats
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.collection.OpenHashMap
 import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
 
 
@@ -1113,4 +1115,94 @@ private[ml] object RandomForest extends Logging {
     }
   }
 
+  /**
+   * Given a Random Forest model, compute the importance of each feature.
+   * This generalizes the idea of "Gini" importance to other losses,
+   * following the explanation of Gini importance from "Random Forests" 
documentation
+   * by Leo Breiman and Adele Cutler, and following the implementation from 
scikit-learn.
+   *
+   * This feature importance is calculated as follows:
+   *  - Average over trees:
+   *     - importance(feature j) = sum (over nodes which split on feature j) 
of the gain,
+   *       where gain is scaled by the number of instances passing through node
+   *     - Normalize importances for tree based on total number of training 
instances used
+   *       to build tree.
+   *  - Normalize feature importance vector to sum to 1.
+   *
+   * Note: This should not be used with Gradient-Boosted Trees.  It only makes 
sense for
+   *       independently trained trees.
+   * @param trees  Unweighted forest of trees
+   * @param numFeatures  Number of features in model (even if not all are 
explicitly used by
+   *                     the model).
+   *                     If -1, then numFeatures is set based on the max 
feature index in all trees.
+   * @return  Feature importance values, of length numFeatures.
+   */
+  private[ml] def featureImportances(trees: Array[DecisionTreeModel], 
numFeatures: Int): Vector = {
+    val totalImportances = new OpenHashMap[Int, Double]()
+    trees.foreach { tree =>
+      // Aggregate feature importance vector for this tree
+      val importances = new OpenHashMap[Int, Double]()
+      computeFeatureImportance(tree.rootNode, importances)
+      // Normalize importance vector for this tree, and add it to total.
+      // TODO: In the future, also support normalizing by 
tree.rootNode.impurityStats.count?
+      val treeNorm = importances.map(_._2).sum
+      if (treeNorm != 0) {
+        importances.foreach { case (idx, impt) =>
+          val normImpt = impt / treeNorm
+          totalImportances.changeValue(idx, normImpt, _ + normImpt)
+        }
+      }
+    }
+    // Normalize importances
+    normalizeMapValues(totalImportances)
+    // Construct vector
+    val d = if (numFeatures != -1) {
+      numFeatures
+    } else {
+      // Find max feature index used in trees
+      val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
+      maxFeatureIndex + 1
+    }
+    if (d == 0) {
+      assert(totalImportances.size == 0, s"Unknown error in computing 
RandomForest feature" +
+        s" importance: No splits in forest, but some non-zero importances.")
+    }
+    val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
+    Vectors.sparse(d, indices.toArray, values.toArray)
+  }
+
+  /**
+   * Recursive method for computing feature importances for one tree.
+   * This walks down the tree, adding to the importance of 1 feature at each 
node.
+   * @param node  Current node in recursion
+   * @param importances  Aggregate feature importances, modified by this method
+   */
+  private[impl] def computeFeatureImportance(
+      node: Node,
+      importances: OpenHashMap[Int, Double]): Unit = {
+    node match {
+      case n: InternalNode =>
+        val feature = n.split.featureIndex
+        val scaledGain = n.gain * n.impurityStats.count
+        importances.changeValue(feature, scaledGain, _ + scaledGain)
+        computeFeatureImportance(n.leftChild, importances)
+        computeFeatureImportance(n.rightChild, importances)
+      case n: LeafNode =>
+        // do nothing
+    }
+  }
+
+  /**
+   * Normalize the values of this map to sum to 1, in place.
+   * If all values are 0, this method does nothing.
+   * @param map  Map with non-negative values.
+   */
+  private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
+    val total = map.map(_._2).sum
+    if (total != 0) {
+      val keys = map.iterator.map(_._1).toArray
+      keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b3117d31/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 2287390..b771911 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -53,6 +53,12 @@ private[ml] trait DecisionTreeModel {
     val header = toString + "\n"
     header + rootNode.subtreeToString(2)
   }
+
+  /**
+   * Trace down the tree, and return the largest feature index used in any 
split.
+   * @return  Max feature index used in a split, or -1 if there are no splits 
(single leaf node).
+   */
+  private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex()
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/b3117d31/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 32d0b38..a66a1e1 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.impl.TreeTests;
 import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.DataFrame;
 
@@ -85,6 +86,7 @@ public class JavaRandomForestClassifierSuite implements 
Serializable {
     model.toDebugString();
     model.trees();
     model.treeWeights();
+    Vector importances = model.featureImportances();
 
     /*
     // TODO: Add test once save/load are implemented.  SPARK-6725

http://git-wip-us.apache.org/repos/asf/spark/blob/b3117d31/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index e306eba..a00ce5e 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.classification.LogisticRegressionSuite;
 import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.DataFrame;
 
@@ -85,6 +86,7 @@ public class JavaRandomForestRegressorSuite implements 
Serializable {
     model.toDebugString();
     model.trees();
     model.treeWeights();
+    Vector importances = model.featureImportances();
 
     /*
     // TODO: Add test once save/load are implemented.   SPARK-6725

http://git-wip-us.apache.org/repos/asf/spark/blob/b3117d31/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index edf848b..6ca4b5a 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -67,7 +67,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with 
MLlibTestSparkConte
   test("params") {
     ParamsSuite.checkParams(new RandomForestClassifier)
     val model = new RandomForestClassificationModel("rfc",
-      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, 
null), 2)), 2)
+      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, 
null), 2)), 2, 2)
     ParamsSuite.checkParams(model)
   }
 
@@ -150,6 +150,35 @@ class RandomForestClassifierSuite extends SparkFunSuite 
with MLlibTestSparkConte
   }
 
   /////////////////////////////////////////////////////////////////////////////
+  // Tests of feature importance
+  /////////////////////////////////////////////////////////////////////////////
+  test("Feature importance with toy data") {
+    val numClasses = 2
+    val rf = new RandomForestClassifier()
+      .setImpurity("Gini")
+      .setMaxDepth(3)
+      .setNumTrees(3)
+      .setFeatureSubsetStrategy("all")
+      .setSubsamplingRate(1.0)
+      .setSeed(123)
+
+    // In this data, feature 1 is very important.
+    val data: RDD[LabeledPoint] = sc.parallelize(Seq(
+      new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
+      new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
+      new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
+      new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
+      new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
+    ))
+    val categoricalFeatures = Map.empty[Int, Int]
+    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses)
+
+    val importances = rf.fit(df).featureImportances
+    val mostImportantFeature = importances.argmax
+    assert(mostImportantFeature === 1)
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b3117d31/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala 
b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index 778abcb..460849c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -124,4 +124,22 @@ private[ml] object TreeTests extends SparkFunSuite {
         "checkEqual failed since the two tree ensembles were not identical")
     }
   }
+
+  /**
+   * Helper method for constructing a tree for testing.
+   * Given left, right children, construct a parent node.
+   * @param split  Split for parent node
+   * @return  Parent node with children attached
+   */
+  def buildParentNode(left: Node, right: Node, split: Split): Node = {
+    val leftImp = left.impurityStats
+    val rightImp = right.impurityStats
+    val parentImp = leftImp.copy.add(rightImp)
+    val leftWeight = leftImp.count / parentImp.count.toDouble
+    val rightWeight = rightImp.count / parentImp.count.toDouble
+    val gain = parentImp.calculate() -
+      (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
+    val pred = parentImp.predict
+    new InternalNode(pred, parentImp.calculate(), gain, left, right, split, 
parentImp)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b3117d31/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index b24ecaa..992ce95 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => 
OldRandomForest}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -26,7 +27,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 
-
 /**
  * Test suite for [[RandomForestRegressor]].
  */
@@ -71,6 +71,31 @@ class RandomForestRegressorSuite extends SparkFunSuite with 
MLlibTestSparkContex
     regressionTestWithContinuousFeatures(rf)
   }
 
+  test("Feature importance with toy data") {
+    val rf = new RandomForestRegressor()
+      .setImpurity("variance")
+      .setMaxDepth(3)
+      .setNumTrees(3)
+      .setFeatureSubsetStrategy("all")
+      .setSubsamplingRate(1.0)
+      .setSeed(123)
+
+    // In this data, feature 1 is very important.
+    val data: RDD[LabeledPoint] = sc.parallelize(Seq(
+      new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
+      new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
+      new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
+      new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
+      new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
+    ))
+    val categoricalFeatures = Map.empty[Int, Int]
+    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+    val importances = rf.fit(df).featureImportances
+    val mostImportantFeature = importances.argmax
+    assert(mostImportantFeature === 1)
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/b3117d31/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
new file mode 100644
index 0000000..dc85279
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, 
Node}
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.tree.impurity.GiniCalculator
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.collection.OpenHashMap
+
+/**
+ * Test suite for [[RandomForest]].
+ */
+class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  import RandomForestSuite.mapToVec
+
+  test("computeFeatureImportance, featureImportances") {
+    /* Build tree for testing, with this structure:
+          grandParent
+      left2       parent
+                left  right
+     */
+    val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
+    val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
+
+    val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
+    val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
+
+    val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 
0.5))
+    val parentImp = parent.impurityStats
+
+    val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
+    val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
+
+    val grandParent = TreeTests.buildParentNode(left2, parent, new 
ContinuousSplit(1, 1.0))
+    val grandImp = grandParent.impurityStats
+
+    // Test feature importance computed at different subtrees.
+    def testNode(node: Node, expected: Map[Int, Double]): Unit = {
+      val map = new OpenHashMap[Int, Double]()
+      RandomForest.computeFeatureImportance(node, map)
+      assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
+    }
+
+    // Leaf node
+    testNode(left, Map.empty[Int, Double])
+
+    // Internal node with 2 leaf children
+    val feature0importance = parentImp.calculate() * parentImp.count -
+      (leftImp.calculate() * leftImp.count + rightImp.calculate() * 
rightImp.count)
+    testNode(parent, Map(0 -> feature0importance))
+
+    // Full tree
+    val feature1importance = grandImp.calculate() * grandImp.count -
+      (left2Imp.calculate() * left2Imp.count + parentImp.calculate() * 
parentImp.count)
+    testNode(grandParent, Map(0 -> feature0importance, 1 -> 
feature1importance))
+
+    // Forest consisting of (full tree) + (internal node with 2 leafs)
+    val trees = Array(parent, grandParent).map { root =>
+      new DecisionTreeClassificationModel(root, numClasses = 
3).asInstanceOf[DecisionTreeModel]
+    }
+    val importances: Vector = RandomForest.featureImportances(trees, 2)
+    val tree2norm = feature0importance + feature1importance
+    val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
+      (feature1importance / tree2norm) / 2.0)
+    assert(importances ~== expected relTol 0.01)
+  }
+
+  test("normalizeMapValues") {
+    val map = new OpenHashMap[Int, Double]()
+    map(0) = 1.0
+    map(2) = 2.0
+    RandomForest.normalizeMapValues(map)
+    val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
+    assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
+  }
+
+}
+
+private object RandomForestSuite {
+
+  def mapToVec(map: Map[Int, Double]): Vector = {
+    val size = (map.keys.toSeq :+ 0).max + 1
+    val (indices, values) = map.toSeq.sortBy(_._1).unzip
+    Vectors.sparse(size, indices.toArray, values.toArray)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to