Repository: spark
Updated Branches:
  refs/heads/master 59c184e02 -> d9783380f


[SPARK-18036][ML][MLLIB] Fixing decision trees handling edge cases

## What changes were proposed in this pull request?

Decision trees/GBT/RF do not handle edge cases such as constant features or 
empty features.
In the case of constant features we choose any arbitrary split instead of 
failing with a cryptic error message.
In the case of empty features we fail with a better error message stating:
DecisionTree requires number of features > 0, but was given an empty features 
vector
Instead of the cryptic error message:
java.lang.UnsupportedOperationException: empty.max

## How was this patch tested?

Unit tests are added in the patch for:
DecisionTreeRegressor
GBTRegressor
Random Forest Regressor

Author: Ilya Matiach <il...@microsoft.com>

Closes #16377 from imatiach-msft/ilmat/fix-decision-tree.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d9783380
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d9783380
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d9783380

Branch: refs/heads/master
Commit: d9783380ff0a6440117348dee3205826d0f9687e
Parents: 59c184e
Author: Ilya Matiach <il...@microsoft.com>
Authored: Tue Jan 24 10:25:12 2017 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Jan 24 10:25:12 2017 -0800

----------------------------------------------------------------------
 .../ml/tree/impl/DecisionTreeMetadata.scala     |  2 ++
 .../spark/ml/tree/impl/RandomForest.scala       | 22 +++++++++++--
 .../spark/ml/tree/impl/RandomForestSuite.scala  | 33 +++++++++++++++++---
 3 files changed, 51 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d9783380/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index bc3c86a..8a9dcb4 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -113,6 +113,8 @@ private[spark] object DecisionTreeMetadata extends Logging {
       throw new IllegalArgumentException(s"DecisionTree requires size of input 
RDD > 0, " +
         s"but was given by empty one.")
     }
+    require(numFeatures > 0, s"DecisionTree requires number of features > 0, " 
+
+      s"but was given an empty features vector")
     val numExamples = input.count()
     val numClasses = strategy.algo match {
       case Classification => strategy.numClasses

http://git-wip-us.apache.org/repos/asf/spark/blob/d9783380/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index a61ea37..008dd19 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -714,7 +714,7 @@ private[spark] object RandomForest extends Logging {
       }
 
     // For each (feature, split), calculate the gain, and select the best 
(feature, split).
-    val (bestSplit, bestSplitStats) =
+    val splitsAndImpurityInfo =
       validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
         val numSplits = binAggregates.metadata.numSplits(featureIndex)
         if (binAggregates.metadata.isContinuous(featureIndex)) {
@@ -828,8 +828,26 @@ private[spark] object RandomForest extends Logging {
             new CategoricalSplit(featureIndex, categoriesForSplit.toArray, 
numCategories)
           (bestFeatureSplit, bestFeatureGainStats)
         }
-      }.maxBy(_._2.gain)
+      }
 
+    val (bestSplit, bestSplitStats) =
+      if (splitsAndImpurityInfo.isEmpty) {
+        // If no valid splits for features, then this split is invalid,
+        // return invalid information gain stats.  Take any split and continue.
+        // Splits is empty, so arbitrarily choose to split on any threshold
+        val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0)
+        val parentImpurityCalculator = 
binAggregates.getParentImpurityCalculator()
+        if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) {
+          (new ContinuousSplit(dummyFeatureIndex, 0),
+            ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+        } else {
+          val numCategories = 
binAggregates.metadata.featureArity(dummyFeatureIndex)
+          (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories),
+            ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+        }
+      } else {
+        splitsAndImpurityInfo.maxBy(_._2.gain)
+      }
     (bestSplit, bestSplitStats)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d9783380/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 3bded9c..e1ab7c2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -26,9 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.tree._
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, 
EnsembleTestHelper}
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, 
QuantileStrategy,
-  Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, 
QuantileStrategy, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator, 
Variance}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.util.collection.OpenHashMap
 
@@ -161,6 +160,21 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     }
   }
 
+  test("train with empty arrays") {
+    val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double]))
+    val data = Array.fill(5)(lp)
+    val rdd = sc.parallelize(data)
+
+    val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2,
+      maxBins = 5)
+    withClue("DecisionTree requires number of features > 0," +
+      " but was given an empty features vector") {
+      intercept[IllegalArgumentException] {
+        RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
+      }
+    }
+  }
+
   test("train with constant features") {
     val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
     val data = Array.fill(5)(lp)
@@ -170,12 +184,23 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
           Gini,
           maxDepth = 2,
           numClasses = 2,
-          maxBins = 100,
+          maxBins = 5,
           categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5))
     val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = 
None)
     assert(tree.rootNode.impurity === -1.0)
     assert(tree.depth === 0)
     assert(tree.rootNode.prediction === lp.label)
+
+    // Test with no categorical features
+    val strategy2 = new OldStrategy(
+      OldAlgo.Regression,
+      Variance,
+      maxDepth = 2,
+      maxBins = 5)
+    val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = 
None)
+    assert(tree2.rootNode.impurity === -1.0)
+    assert(tree2.depth === 0)
+    assert(tree2.rootNode.prediction === lp.label)
   }
 
   test("Multiclass classification with unordered categorical features: split 
calculations") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to