Repository: spark
Updated Branches:
  refs/heads/master 16fab6b0e -> 7f96f2d7f


[SPARK-16957][MLLIB] Use midpoints for split values.

## What changes were proposed in this pull request?

Use midpoints for split values now, and maybe later to make it weighted.

## How was this patch tested?

+ [x] add unit test.
+ [x] revise Split's unit test.

Author: Yan Facai (颜发才) <facai....@gmail.com>
Author: 颜发才(Yan Facai) <facai....@gmail.com>

Closes #17556 from 
facaiy/ENH/decision_tree_overflow_and_precision_in_aggregation.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7f96f2d7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7f96f2d7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7f96f2d7

Branch: refs/heads/master
Commit: 7f96f2d7f2d5abf81dd7f8ca27fea35cf798fd65
Parents: 16fab6b
Author: Yan Facai (颜发才) <facai....@gmail.com>
Authored: Wed May 3 10:54:40 2017 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Wed May 3 10:54:40 2017 +0100

----------------------------------------------------------------------
 .../spark/ml/tree/impl/RandomForest.scala       | 15 ++++---
 .../spark/ml/tree/impl/RandomForestSuite.scala  | 41 +++++++++++++++++---
 python/pyspark/mllib/tree.py                    | 12 +++---
 3 files changed, 51 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7f96f2d7/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 008dd19..82e1ed8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -996,7 +996,7 @@ private[spark] object RandomForest extends Logging {
     require(metadata.isContinuous(featureIndex),
       "findSplitsForContinuousFeature can only be used to find splits for a 
continuous feature.")
 
-    val splits = if (featureSamples.isEmpty) {
+    val splits: Array[Double] = if (featureSamples.isEmpty) {
       Array.empty[Double]
     } else {
       val numSplits = metadata.numSplits(featureIndex)
@@ -1009,10 +1009,15 @@ private[spark] object RandomForest extends Logging {
       // sort distinct values
       val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
 
-      // if possible splits is not enough or just enough, just return all 
possible splits
       val possibleSplits = valueCounts.length - 1
-      if (possibleSplits <= numSplits) {
-        valueCounts.map(_._1).init
+      if (possibleSplits == 0) {
+        // constant feature
+        Array.empty[Double]
+      } else if (possibleSplits <= numSplits) {
+        // if possible splits is not enough or just enough, just return all 
possible splits
+        (1 to possibleSplits)
+          .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 
2.0)
+          .toArray
       } else {
         // stride between splits
         val stride: Double = numSamples.toDouble / (numSplits + 1)
@@ -1037,7 +1042,7 @@ private[spark] object RandomForest extends Logging {
           // makes the gap between currentCount and targetCount smaller,
           // previous value is a split threshold.
           if (previousGap < currentGap) {
-            splitsBuilder += valueCounts(index - 1)._1
+            splitsBuilder += (valueCounts(index - 1)._1 + 
valueCounts(index)._1) / 2.0
             targetCount += stride
           }
           index += 1

http://git-wip-us.apache.org/repos/asf/spark/blob/7f96f2d7/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index e1ab7c2..df155b4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -104,6 +104,31 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       assert(splits.distinct.length === splits.length)
     }
 
+    // SPARK-16957: Use midpoints for split values.
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(3), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+
+      // possibleSplits <= numSplits
+      {
+        val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble)
+        val splits = 
RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+        val expectedSplits = Array((0.0 + 1.0) / 2)
+        assert(splits === expectedSplits)
+      }
+
+      // possibleSplits > numSplits
+      {
+        val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble)
+        val splits = 
RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+        val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2)
+        assert(splits === expectedSplits)
+      }
+    }
+
     // find splits should not return identical splits
     // when there are not enough split candidates, reduce the number of splits 
in metadata
     {
@@ -112,9 +137,10 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
         Array(5), Gini, QuantileStrategy.Sort,
         0, 0, 0.0, 0, 0
       )
-      val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 
3).map(_.toDouble)
+      val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 
3).map(_.toDouble)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
-      assert(splits === Array(1.0, 2.0))
+      val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2)
+      assert(splits === expectedSplits)
       // check returned splits are distinct
       assert(splits.distinct.length === splits.length)
     }
@@ -126,9 +152,11 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
         Array(3), Gini, QuantileStrategy.Sort,
         0, 0, 0.0, 0, 0
       )
-      val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 
5).map(_.toDouble)
+      val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
3, 4, 5)
+        .map(_.toDouble)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
-      assert(splits === Array(2.0, 3.0))
+      val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2)
+      assert(splits === expectedSplits)
     }
 
     // find splits when most samples close to the maximum
@@ -138,9 +166,10 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
         Array(2), Gini, QuantileStrategy.Sort,
         0, 0, 0.0, 0, 0
       )
-      val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2).map(_.toDouble)
+      val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2, 2).map(_.toDouble)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, 
fakeMetadata, 0)
-      assert(splits === Array(1.0))
+      val expectedSplits = Array((1.0 + 2.0) / 2)
+      assert(splits === expectedSplits)
     }
 
     // find splits for constant feature

http://git-wip-us.apache.org/repos/asf/spark/blob/7f96f2d7/python/pyspark/mllib/tree.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index a6089fc..619fa16 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -199,9 +199,9 @@ class DecisionTree(object):
 
         >>> print(model.toDebugString())
         DecisionTreeModel classifier of depth 1 with 3 nodes
-          If (feature 0 <= 0.0)
+          If (feature 0 <= 0.5)
            Predict: 0.0
-          Else (feature 0 > 0.0)
+          Else (feature 0 > 0.5)
            Predict: 1.0
         <BLANKLINE>
         >>> model.predict(array([1.0]))
@@ -383,14 +383,14 @@ class RandomForest(object):
           Tree 0:
             Predict: 1.0
           Tree 1:
-            If (feature 0 <= 1.0)
+            If (feature 0 <= 1.5)
              Predict: 0.0
-            Else (feature 0 > 1.0)
+            Else (feature 0 > 1.5)
              Predict: 1.0
           Tree 2:
-            If (feature 0 <= 1.0)
+            If (feature 0 <= 1.5)
              Predict: 0.0
-            Else (feature 0 > 1.0)
+            Else (feature 0 > 1.5)
              Predict: 1.0
         <BLANKLINE>
         >>> model.predict([2.0])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to