Repository: spark Updated Branches: refs/heads/master 59c184e02 -> d9783380f
[SPARK-18036][ML][MLLIB] Fixing decision trees handling edge cases ## What changes were proposed in this pull request? Decision trees/GBT/RF do not handle edge cases such as constant features or empty features. In the case of constant features we choose any arbitrary split instead of failing with a cryptic error message. In the case of empty features we fail with a better error message stating: DecisionTree requires number of features > 0, but was given an empty features vector Instead of the cryptic error message: java.lang.UnsupportedOperationException: empty.max ## How was this patch tested? Unit tests are added in the patch for: DecisionTreeRegressor GBTRegressor Random Forest Regressor Author: Ilya Matiach <il...@microsoft.com> Closes #16377 from imatiach-msft/ilmat/fix-decision-tree. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d9783380 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d9783380 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d9783380 Branch: refs/heads/master Commit: d9783380ff0a6440117348dee3205826d0f9687e Parents: 59c184e Author: Ilya Matiach <il...@microsoft.com> Authored: Tue Jan 24 10:25:12 2017 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Tue Jan 24 10:25:12 2017 -0800 ---------------------------------------------------------------------- .../ml/tree/impl/DecisionTreeMetadata.scala | 2 ++ .../spark/ml/tree/impl/RandomForest.scala | 22 +++++++++++-- .../spark/ml/tree/impl/RandomForestSuite.scala | 33 +++++++++++++++++--- 3 files changed, 51 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d9783380/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index bc3c86a..8a9dcb4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -113,6 +113,8 @@ private[spark] object DecisionTreeMetadata extends Logging { throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " + s"but was given by empty one.") } + require(numFeatures > 0, s"DecisionTree requires number of features > 0, " + + s"but was given an empty features vector") val numExamples = input.count() val numClasses = strategy.algo match { case Classification => strategy.numClasses http://git-wip-us.apache.org/repos/asf/spark/blob/d9783380/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index a61ea37..008dd19 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -714,7 +714,7 @@ private[spark] object RandomForest extends Logging { } // For each (feature, split), calculate the gain, and select the best (feature, split). - val (bestSplit, bestSplitStats) = + val splitsAndImpurityInfo = validFeatureSplits.map { case (featureIndexIdx, featureIndex) => val numSplits = binAggregates.metadata.numSplits(featureIndex) if (binAggregates.metadata.isContinuous(featureIndex)) { @@ -828,8 +828,26 @@ private[spark] object RandomForest extends Logging { new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) (bestFeatureSplit, bestFeatureGainStats) } - }.maxBy(_._2.gain) + } + val (bestSplit, bestSplitStats) = + if (splitsAndImpurityInfo.isEmpty) { + // If no valid splits for features, then this split is invalid, + // return invalid information gain stats. Take any split and continue. + // Splits is empty, so arbitrarily choose to split on any threshold + val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0) + val parentImpurityCalculator = binAggregates.getParentImpurityCalculator() + if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) { + (new ContinuousSplit(dummyFeatureIndex, 0), + ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) + } else { + val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex) + (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories), + ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) + } + } else { + splitsAndImpurityInfo.maxBy(_._2.gain) + } (bestSplit, bestSplitStats) } http://git-wip-us.apache.org/repos/asf/spark/blob/d9783380/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 3bded9c..e1ab7c2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -26,9 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.tree._ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper} -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, - Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator, Variance} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.collection.OpenHashMap @@ -161,6 +160,21 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("train with empty arrays") { + val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double])) + val data = Array.fill(5)(lp) + val rdd = sc.parallelize(data) + + val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2, + maxBins = 5) + withClue("DecisionTree requires number of features > 0," + + " but was given an empty features vector") { + intercept[IllegalArgumentException] { + RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) + } + } + } + test("train with constant features") { val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)) val data = Array.fill(5)(lp) @@ -170,12 +184,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Gini, maxDepth = 2, numClasses = 2, - maxBins = 100, + maxBins = 5, categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) assert(tree.rootNode.impurity === -1.0) assert(tree.depth === 0) assert(tree.rootNode.prediction === lp.label) + + // Test with no categorical features + val strategy2 = new OldStrategy( + OldAlgo.Regression, + Variance, + maxDepth = 2, + maxBins = 5) + val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None) + assert(tree2.rootNode.impurity === -1.0) + assert(tree2.depth === 0) + assert(tree2.rootNode.prediction === lp.label) } test("Multiclass classification with unordered categorical features: split calculations") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org