Repository: spark
Updated Branches:
  refs/heads/master c6aa356cd -> e1772d3f1


[SPARK-11861][ML] Add feature importances for decision trees

This patch adds an API entry point for single decision tree feature importances.

Author: sethah <seth.hendrickso...@gmail.com>

Closes #9912 from sethah/SPARK-11861.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e1772d3f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e1772d3f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e1772d3f

Branch: refs/heads/master
Commit: e1772d3f19bed7e69a80de7900ed22d3eeb05300
Parents: c6aa356
Author: sethah <seth.hendrickso...@gmail.com>
Authored: Wed Mar 9 14:44:51 2016 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Wed Mar 9 14:44:51 2016 -0800

----------------------------------------------------------------------
 .../classification/DecisionTreeClassifier.scala | 19 +++++++++++++
 .../classification/RandomForestClassifier.scala |  4 +--
 .../ml/regression/DecisionTreeRegressor.scala   | 19 +++++++++++++
 .../ml/regression/RandomForestRegressor.scala   |  4 +--
 .../spark/ml/tree/impl/RandomForest.scala       | 30 ++++++++++++++++----
 .../DecisionTreeClassifierSuite.scala           | 21 ++++++++++++++
 .../RandomForestClassifierSuite.scala           | 10 ++-----
 .../org/apache/spark/ml/impl/TreeTests.scala    | 13 +++++++++
 .../regression/DecisionTreeRegressorSuite.scala | 20 +++++++++++++
 .../regression/RandomForestRegressorSuite.scala | 13 ++-------
 10 files changed, 126 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 8c4cec1..7f0397f 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -169,6 +169,25 @@ final class DecisionTreeClassificationModel private[ml] (
     s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with 
$numNodes nodes"
   }
 
+  /**
+   * Estimate of the importance of each feature.
+   *
+   * This generalizes the idea of "Gini" importance to other losses,
+   * following the explanation of Gini importance from "Random Forests" 
documentation
+   * by Leo Breiman and Adele Cutler, and following the implementation from 
scikit-learn.
+   *
+   * This feature importance is calculated as follows:
+   *   - importance(feature j) = sum (over nodes which split on feature j) of 
the gain,
+   *     where gain is scaled by the number of instances passing through node
+   *   - Normalize importances for tree to sum to 1.
+   *
+   * Note: Feature importance for single decision trees can have high variance 
due to
+   *       correlated predictor variables. Consider using a 
[[RandomForestClassifier]]
+   *       to determine feature importance instead.
+   */
+  @Since("2.0.0")
+  lazy val featureImportances: Vector = RandomForest.featureImportances(this, 
numFeatures)
+
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldDecisionTreeModel = {
     new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)

http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index f7d662d..5da04d3 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -230,10 +230,10 @@ final class RandomForestClassificationModel private[ml] (
    *  - Average over trees:
    *     - importance(feature j) = sum (over nodes which split on feature j) 
of the gain,
    *       where gain is scaled by the number of instances passing through node
-   *     - Normalize importances for tree based on total number of training 
instances used
-   *       to build tree.
+   *     - Normalize importances for tree to sum to 1.
    *  - Normalize feature importance vector to sum to 1.
    */
+  @Since("1.5.0")
   lazy val featureImportances: Vector = RandomForest.featureImportances(trees, 
numFeatures)
 
   /** (private[ml]) Convert to a model in the old API */

http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 18c94f3..897b233 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -169,6 +169,25 @@ final class DecisionTreeRegressionModel private[ml] (
     s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes 
nodes"
   }
 
+  /**
+   * Estimate of the importance of each feature.
+   *
+   * This generalizes the idea of "Gini" importance to other losses,
+   * following the explanation of Gini importance from "Random Forests" 
documentation
+   * by Leo Breiman and Adele Cutler, and following the implementation from 
scikit-learn.
+   *
+   * This feature importance is calculated as follows:
+   *   - importance(feature j) = sum (over nodes which split on feature j) of 
the gain,
+   *     where gain is scaled by the number of instances passing through node
+   *   - Normalize importances for tree to sum to 1.
+   *
+   * Note: Feature importance for single decision trees can have high variance 
due to
+   *       correlated predictor variables. Consider using a 
[[RandomForestRegressor]]
+   *       to determine feature importance instead.
+   */
+  @Since("2.0.0")
+  lazy val featureImportances: Vector = RandomForest.featureImportances(this, 
numFeatures)
+
   /** Convert to a model in the old API */
   private[ml] def toOld: OldDecisionTreeModel = {
     new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)

http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 71e40b5..798947b 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -189,10 +189,10 @@ final class RandomForestRegressionModel private[ml] (
    *  - Average over trees:
    *     - importance(feature j) = sum (over nodes which split on feature j) 
of the gain,
    *       where gain is scaled by the number of instances passing through node
-   *     - Normalize importances for tree based on total number of training 
instances used
-   *       to build tree.
+   *     - Normalize importances for tree to sum to 1.
    *  - Normalize feature importance vector to sum to 1.
    */
+  @Since("1.5.0")
   lazy val featureImportances: Vector = RandomForest.featureImportances(trees, 
numFeatures)
 
   /** (private[ml]) Convert to a model in the old API */

http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index ea733d5..f994c25 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -1089,12 +1089,9 @@ private[ml] object RandomForest extends Logging {
    *  - Average over trees:
    *     - importance(feature j) = sum (over nodes which split on feature j) 
of the gain,
    *       where gain is scaled by the number of instances passing through node
-   *     - Normalize importances for tree based on total number of training 
instances used
-   *       to build tree.
+   *     - Normalize importances for tree to sum to 1.
    *  - Normalize feature importance vector to sum to 1.
    *
-   * Note: This should not be used with Gradient-Boosted Trees.  It only makes 
sense for
-   *       independently trained trees.
    * @param trees  Unweighted forest of trees
    * @param numFeatures  Number of features in model (even if not all are 
explicitly used by
    *                     the model).
@@ -1128,14 +1125,35 @@ private[ml] object RandomForest extends Logging {
       maxFeatureIndex + 1
     }
     if (d == 0) {
-      assert(totalImportances.size == 0, s"Unknown error in computing 
RandomForest feature" +
-        s" importance: No splits in forest, but some non-zero importances.")
+      assert(totalImportances.size == 0, s"Unknown error in computing feature" 
+
+        s" importance: No splits found, but some non-zero importances.")
     }
     val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
     Vectors.sparse(d, indices.toArray, values.toArray)
   }
 
   /**
+   * Given a Decision Tree model, compute the importance of each feature.
+   * This generalizes the idea of "Gini" importance to other losses,
+   * following the explanation of Gini importance from "Random Forests" 
documentation
+   * by Leo Breiman and Adele Cutler, and following the implementation from 
scikit-learn.
+   *
+   * This feature importance is calculated as follows:
+   *  - importance(feature j) = sum (over nodes which split on feature j) of 
the gain,
+   *    where gain is scaled by the number of instances passing through node
+   *  - Normalize importances for tree to sum to 1.
+   *
+   * @param tree  Decision tree to compute importances for.
+   * @param numFeatures  Number of features in model (even if not all are 
explicitly used by
+   *                     the model).
+   *                     If -1, then numFeatures is set based on the max 
feature index in all trees.
+   * @return  Feature importance values, of length numFeatures.
+   */
+  private[ml] def featureImportances(tree: DecisionTreeModel, numFeatures: 
Int): Vector = {
+    featureImportances(Array(tree), numFeatures)
+  }
+
+  /**
    * Recursive method for computing feature importances for one tree.
    * This walks down the tree, adding to the importance of 1 feature at each 
node.
    * @param node  Current node in recursion

http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 9169bcd..6d68364 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -313,6 +313,27 @@ class DecisionTreeClassifierSuite extends SparkFunSuite 
with MLlibTestSparkConte
     }
   }
 
+  test("Feature importance with toy data") {
+    val dt = new DecisionTreeClassifier()
+      .setImpurity("gini")
+      .setMaxDepth(3)
+      .setSeed(123)
+
+    // In this data, feature 1 is very important.
+    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+    val numFeatures = data.first().features.size
+    val categoricalFeatures = (0 to numFeatures).map(i => (i, 2)).toMap
+    val df = TreeTests.setMetadata(data, categoricalFeatures, 2)
+
+    val model = dt.fit(df)
+
+    val importances = model.featureImportances
+    val mostImportantFeature = importances.argmax
+    assert(mostImportantFeature === 1)
+    assert(importances.toArray.sum === 1.0)
+    assert(importances.toArray.forall(_ >= 0.0))
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index deb8ec7..6b810ab 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -167,19 +167,15 @@ class RandomForestClassifierSuite extends SparkFunSuite 
with MLlibTestSparkConte
       .setSeed(123)
 
     // In this data, feature 1 is very important.
-    val data: RDD[LabeledPoint] = sc.parallelize(Seq(
-      new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
-      new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
-      new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
-      new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
-      new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
-    ))
+    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
     val categoricalFeatures = Map.empty[Int, Int]
     val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses)
 
     val importances = rf.fit(df).featureImportances
     val mostImportantFeature = importances.argmax
     assert(mostImportantFeature === 1)
+    assert(importances.toArray.sum === 1.0)
+    assert(importances.toArray.forall(_ >= 0.0))
   }
 
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala 
b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index a808177..5561f6f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -19,10 +19,12 @@ package org.apache.spark.ml.impl
 
 import scala.collection.JavaConverters._
 
+import org.apache.spark.SparkContext
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, 
NumericAttribute}
 import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, SQLContext}
@@ -141,4 +143,15 @@ private[ml] object TreeTests extends SparkFunSuite {
     val pred = parentImp.predict
     new InternalNode(pred, parentImp.calculate(), gain, left, right, split, 
parentImp)
   }
+
+  /**
+   * Create some toy data for testing feature importances.
+   */
+  def featureImportanceData(sc: SparkContext): RDD[LabeledPoint] = 
sc.parallelize(Seq(
+    new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
+    new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
+    new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
+    new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
+    new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
+  ))
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 13165f6..56b335a 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -96,6 +96,26 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with 
MLlibTestSparkContex
     }
   }
 
+  test("Feature importance with toy data") {
+    val dt = new DecisionTreeRegressor()
+      .setImpurity("variance")
+      .setMaxDepth(3)
+      .setSeed(123)
+
+    // In this data, feature 1 is very important.
+    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+    val categoricalFeatures = Map.empty[Int, Int]
+    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+    val model = dt.fit(df)
+
+    val importances = model.featureImportances
+    val mostImportantFeature = importances.argmax
+    assert(mostImportantFeature === 1)
+    assert(importances.toArray.sum === 1.0)
+    assert(importances.toArray.forall(_ >= 0.0))
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 7e751e4..efb117f 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.ml.regression
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.impl.TreeTests
 import org.apache.spark.ml.util.MLTestingUtils
-import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => 
OldRandomForest}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -82,23 +81,17 @@ class RandomForestRegressorSuite extends SparkFunSuite with 
MLlibTestSparkContex
       .setSeed(123)
 
     // In this data, feature 1 is very important.
-    val data: RDD[LabeledPoint] = sc.parallelize(Seq(
-      new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
-      new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
-      new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
-      new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
-      new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
-    ))
+    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
     val categoricalFeatures = Map.empty[Int, Int]
     val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
 
     val model = rf.fit(df)
 
-    // copied model must have the same parent.
-    MLTestingUtils.checkCopy(model)
     val importances = model.featureImportances
     val mostImportantFeature = importances.argmax
     assert(mostImportantFeature === 1)
+    assert(importances.toArray.sum === 1.0)
+    assert(importances.toArray.forall(_ >= 0.0))
   }
 
   /////////////////////////////////////////////////////////////////////////////


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to