Repository: spark
Updated Branches:
  refs/heads/master d88e69561 -> 7058a5393


[SPARK-2796] [mllib] DecisionTree bug fix: ordered categorical features

Bug: In DecisionTree, the method 
sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins 
from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is 
the bound for unordered categorical features, not ordered ones. The upper bound 
should be the arity (i.e., max value) of the feature.

Added new test to DecisionTreeSuite to catch this: "regression stump with 
categorical variables of arity 2"

Bug fix: Modified upper bound discussed above.

Also: Small improvements to coding style in DecisionTree.

CC mengxr manishamde

Author: Joseph K. Bradley <joseph.kurata.brad...@gmail.com>

Closes #1720 from jkbradley/decisiontree-bugfix2 and squashes the following 
commits:

225822f [Joseph K. Bradley] Bug: In DecisionTree, the method 
sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins 
from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is 
the bound for unordered categorical features, not ordered ones. The upper bound 
should be the arity (i.e., max value) of the feature.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7058a539
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7058a539
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7058a539

Branch: refs/heads/master
Commit: 7058a5393bccc2f917189fa9b4cf7f314410b0de
Parents: d88e695
Author: Joseph K. Bradley <joseph.kurata.brad...@gmail.com>
Authored: Fri Aug 1 15:52:21 2014 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Fri Aug 1 15:52:21 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/DecisionTree.scala  | 45 ++++++++++++--------
 .../spark/mllib/tree/DecisionTreeSuite.scala    | 29 +++++++++++++
 2 files changed, 56 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7058a539/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 7d123dd..382e76a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -498,7 +498,7 @@ object DecisionTree extends Serializable with Logging {
           val bin = binForFeatures(mid)
           val lowThreshold = bin.lowSplit.threshold
           val highThreshold = bin.highSplit.threshold
-          if ((lowThreshold < feature) && (highThreshold >= feature)){
+          if ((lowThreshold < feature) && (highThreshold >= feature)) {
             return mid
           }
           else if (lowThreshold >= feature) {
@@ -522,28 +522,36 @@ object DecisionTree extends Serializable with Logging {
       }
 
       /**
-       * Sequential search helper method to find bin for categorical feature.
+       * Sequential search helper method to find bin for categorical feature
+       * (for classification and regression).
        */
-      def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): 
Int = {
+      def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
         val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-        val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
+        val featureValue = labeledPoint.features(featureIndex)
         var binIndex = 0
-        while (binIndex < numCategoricalBins) {
+        while (binIndex < featureCategories) {
           val bin = bins(featureIndex)(binIndex)
           val categories = bin.highSplit.categories
-          val features = labeledPoint.features
-          if (categories.contains(features(featureIndex))) {
+          if (categories.contains(featureValue)) {
             return binIndex
           }
           binIndex += 1
         }
+        if (featureValue < 0 || featureValue >= featureCategories) {
+          throw new IllegalArgumentException(
+            s"DecisionTree given invalid data:" +
+            s" Feature $featureIndex is categorical with values in" +
+            s" {0,...,${featureCategories - 1}," +
+            s" but a data point gives it value $featureValue.\n" +
+            "  Bad data point: " + labeledPoint.toString)
+        }
         -1
       }
 
       if (isFeatureContinuous) {
         // Perform binary search for finding bin for continuous features.
         val binIndex = binarySearchForBins()
-        if (binIndex == -1){
+        if (binIndex == -1) {
           throw new UnknownError("no bin was found for continuous variable.")
         }
         binIndex
@@ -555,10 +563,10 @@ object DecisionTree extends Serializable with Logging {
           if (isUnorderedFeature) {
             sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
           } else {
-            sequentialBinSearchForOrderedCategoricalFeatureInClassification()
+            sequentialBinSearchForOrderedCategoricalFeature()
           }
         }
-        if (binIndex == -1){
+        if (binIndex == -1) {
           throw new UnknownError("no bin was found for categorical variable.")
         }
         binIndex
@@ -642,11 +650,12 @@ object DecisionTree extends Serializable with Logging {
       val arrShift = 1 + numFeatures * nodeIndex
       val arrIndex = arrShift + featureIndex
       // Update the left or right count for one bin.
-      val aggShift = numClasses * numBins * numFeatures * nodeIndex
-      val aggIndex
-        = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt 
* numClasses
-      val labelInt = label.toInt
-      agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
+      val aggIndex =
+        numClasses * numBins * numFeatures * nodeIndex +
+        numClasses * numBins * featureIndex +
+        numClasses * arr(arrIndex).toInt +
+        label.toInt
+      agg(aggIndex) += 1
     }
 
     /**
@@ -1127,7 +1136,7 @@ object DecisionTree extends Serializable with Logging {
           val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 
numClasses)
           var featureIndex = 0
           while (featureIndex < numFeatures) {
-            if (isMulticlassClassificationWithCategoricalFeatures){
+            if (isMulticlassClassificationWithCategoricalFeatures) {
               val isFeatureContinuous = 
strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
               if (isFeatureContinuous) {
                 findAggForOrderedFeatureClassification(leftNodeAgg, 
rightNodeAgg, featureIndex)
@@ -1393,7 +1402,7 @@ object DecisionTree extends Serializable with Logging {
 
         // Iterate over all features.
         var featureIndex = 0
-        while (featureIndex < numFeatures){
+        while (featureIndex < numFeatures) {
           // Check whether the feature is continuous.
           val isFeatureContinuous = 
strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
           if (isFeatureContinuous) {
@@ -1513,7 +1522,7 @@ object DecisionTree extends Serializable with Logging {
           if (isFeatureContinuous) { // Bins for categorical variables are 
already assigned.
             bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, 
Continuous),
               splits(featureIndex)(0), Continuous, Double.MinValue)
-            for (index <- 1 until numBins - 1){
+            for (index <- 1 until numBins - 1) {
               val bin = new Bin(splits(featureIndex)(index-1), 
splits(featureIndex)(index),
                 Continuous, Double.MinValue)
               bins(featureIndex)(index) = bin

http://git-wip-us.apache.org/repos/asf/spark/blob/7058a539/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 10462db..546a132 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -42,6 +42,18 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(accuracy >= requiredAccuracy)
   }
 
+  def validateRegressor(
+      model: DecisionTreeModel,
+      input: Seq[LabeledPoint],
+      requiredMSE: Double) {
+    val predictions = input.map(x => model.predict(x.features))
+    val squaredError = predictions.zip(input).map { case (prediction, 
expected) =>
+      (prediction - expected.label) * (prediction - expected.label)
+    }.sum
+    val mse = squaredError / input.length
+    assert(mse <= requiredMSE)
+  }
+
   test("split and bin calculation") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
@@ -454,6 +466,23 @@ class DecisionTreeSuite extends FunSuite with 
LocalSparkContext {
     assert(stats.impurity > 0.2)
   }
 
+  test("regression stump with categorical variables of arity 2") {
+    val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+    assert(arr.length === 1000)
+    val rdd = sc.parallelize(arr)
+    val strategy = new Strategy(
+      Regression,
+      Variance,
+      maxDepth = 2,
+      maxBins = 100,
+      categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+
+    val model = DecisionTree.train(rdd, strategy)
+    validateRegressor(model, arr, 0.0)
+    assert(model.numNodes === 3)
+    assert(model.depth === 1)
+  }
+
   test("stump with fixed label 0 for Gini") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
     assert(arr.length === 1000)

Reply via email to