Repository: spark
Updated Branches:
  refs/heads/master b271c265b -> 9b746f380


[SPARK-3381] [MLlib] Eliminate bins for unordered features in DecisionTrees

For unordered features, it is sufficient to use splits since the threshold of 
the split corresponds the threshold of the HighSplit of the bin and there is no 
use of the LowSplit.

Author: MechCoder <manojkumarsivaraj...@gmail.com>

Closes #4231 from MechCoder/spark-3381 and squashes the following commits:

58c19a5 [MechCoder] COSMIT
c274b74 [MechCoder] Remove unordered feature calculation in 
labeledPointToTreePoint
b2b9b89 [MechCoder] COSMIT
d3ee042 [MechCoder] [SPARK-3381] [MLlib] Eliminate bins for unordered features


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9b746f38
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9b746f38
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9b746f38

Branch: refs/heads/master
Commit: 9b746f380869b54d673e3758ca5e4475f76c864a
Parents: b271c26
Author: MechCoder <manojkumarsivaraj...@gmail.com>
Authored: Tue Feb 17 11:19:23 2015 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Feb 17 11:19:23 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/DecisionTree.scala  | 37 ++++++--------------
 .../spark/mllib/tree/impl/TreePoint.scala       | 14 +++-----
 .../spark/mllib/tree/DecisionTreeSuite.scala    | 37 +-------------------
 3 files changed, 15 insertions(+), 73 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9b746f38/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index f1f8599..b9d0c56 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -327,14 +327,14 @@ object DecisionTree extends Serializable with Logging {
    * @param agg  Array storing aggregate calculation, with a set of sufficient 
statistics for
    *             each (feature, bin).
    * @param treePoint  Data point being aggregated.
-   * @param bins possible bins for all features, indexed (numFeatures)(numBins)
+   * @param splits possible splits indexed (numFeatures)(numSplits)
    * @param unorderedFeatures  Set of indices of unordered features.
    * @param instanceWeight  Weight (importance) of instance in dataset.
    */
   private def mixedBinSeqOp(
       agg: DTStatsAggregator,
       treePoint: TreePoint,
-      bins: Array[Array[Bin]],
+      splits: Array[Array[Split]],
       unorderedFeatures: Set[Int],
       instanceWeight: Double,
       featuresForNode: Option[Array[Int]]): Unit = {
@@ -362,7 +362,7 @@ object DecisionTree extends Serializable with Logging {
         val numSplits = agg.metadata.numSplits(featureIndex)
         var splitIndex = 0
         while (splitIndex < numSplits) {
-          if 
(bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
+          if 
(splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
             agg.featureUpdate(leftNodeFeatureOffset, splitIndex, 
treePoint.label,
               instanceWeight)
           } else {
@@ -506,8 +506,8 @@ object DecisionTree extends Serializable with Logging {
         if (metadata.unorderedFeatures.isEmpty) {
           orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, 
instanceWeight, featuresForNode)
         } else {
-          mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, 
metadata.unorderedFeatures,
-            instanceWeight, featuresForNode)
+          mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
+            metadata.unorderedFeatures, instanceWeight, featuresForNode)
         }
       }
     }
@@ -1024,35 +1024,15 @@ object DecisionTree extends Serializable with Logging {
             // Categorical feature
             val featureArity = metadata.featureArity(featureIndex)
             if (metadata.isUnordered(featureIndex)) {
-              // TODO: The second half of the bins are unused.  Actually, we 
could just use
-              //       splits and not build bins for unordered features.  That 
should be part of
-              //       a later PR since it will require changing other code 
(using splits instead
-              //       of bins in a few places).
               // Unordered features
-              //   2^(maxFeatureValue - 1) - 1 combinations
+              // 2^(maxFeatureValue - 1) - 1 combinations
               splits(featureIndex) = new Array[Split](numSplits)
-              bins(featureIndex) = new Array[Bin](numBins)
               var splitIndex = 0
               while (splitIndex < numSplits) {
                 val categories: List[Double] =
                   extractMultiClassCategories(splitIndex + 1, featureArity)
                 splits(featureIndex)(splitIndex) =
                   new Split(featureIndex, Double.MinValue, Categorical, 
categories)
-                bins(featureIndex)(splitIndex) = {
-                  if (splitIndex == 0) {
-                    new Bin(
-                      new DummyCategoricalSplit(featureIndex, Categorical),
-                      splits(featureIndex)(0),
-                      Categorical,
-                      Double.MinValue)
-                  } else {
-                    new Bin(
-                      splits(featureIndex)(splitIndex - 1),
-                      splits(featureIndex)(splitIndex),
-                      Categorical,
-                      Double.MinValue)
-                  }
-                }
                 splitIndex += 1
               }
             } else {
@@ -1060,8 +1040,11 @@ object DecisionTree extends Serializable with Logging {
               //   Bins correspond to feature values, so we do not need to 
compute splits or bins
               //   beforehand.  Splits are constructed as needed during 
training.
               splits(featureIndex) = new Array[Split](0)
-              bins(featureIndex) = new Array[Bin](0)
             }
+            // For ordered features, bins correspond to feature values.
+            // For unordered categorical features, there is no need to 
construct the bins.
+            // since there is a one-to-one correspondence between the splits 
and the bins.
+            bins(featureIndex) = new Array[Bin](0)
           }
           featureIndex += 1
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/9b746f38/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
index 35e361a..50b292e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -55,17 +55,15 @@ private[tree] object TreePoint {
       input: RDD[LabeledPoint],
       bins: Array[Array[Bin]],
       metadata: DecisionTreeMetadata): RDD[TreePoint] = {
-    // Construct arrays for featureArity and isUnordered for efficiency in the 
inner loop.
+    // Construct arrays for featureArity for efficiency in the inner loop.
     val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
-    val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures)
     var featureIndex = 0
     while (featureIndex < metadata.numFeatures) {
       featureArity(featureIndex) = 
metadata.featureArity.getOrElse(featureIndex, 0)
-      isUnordered(featureIndex) = metadata.isUnordered(featureIndex)
       featureIndex += 1
     }
     input.map { x =>
-      TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered)
+      TreePoint.labeledPointToTreePoint(x, bins, featureArity)
     }
   }
 
@@ -74,19 +72,17 @@ private[tree] object TreePoint {
    * @param bins      Bins for features, of size (numFeatures, numBins).
    * @param featureArity  Array indexed by feature, with value 0 for 
continuous and numCategories
    *                      for categorical features.
-   * @param isUnordered  Array index by feature, with value true for unordered 
categorical features.
    */
   private def labeledPointToTreePoint(
       labeledPoint: LabeledPoint,
       bins: Array[Array[Bin]],
-      featureArity: Array[Int],
-      isUnordered: Array[Boolean]): TreePoint = {
+      featureArity: Array[Int]): TreePoint = {
     val numFeatures = labeledPoint.features.size
     val arr = new Array[Int](numFeatures)
     var featureIndex = 0
     while (featureIndex < numFeatures) {
       arr(featureIndex) = findBin(featureIndex, labeledPoint, 
featureArity(featureIndex),
-        isUnordered(featureIndex), bins)
+        bins)
       featureIndex += 1
     }
     new TreePoint(labeledPoint.label, arr)
@@ -96,14 +92,12 @@ private[tree] object TreePoint {
    * Find bin for one (labeledPoint, feature).
    *
    * @param featureArity  0 for continuous features; number of categories for 
categorical features.
-   * @param isUnorderedFeature  (only applies if feature is categorical)
    * @param bins   Bins for features, of size (numFeatures, numBins).
    */
   private def findBin(
       featureIndex: Int,
       labeledPoint: LabeledPoint,
       featureArity: Int,
-      isUnorderedFeature: Boolean,
       bins: Array[Array[Bin]]): Int = {
 
     /**

http://git-wip-us.apache.org/repos/asf/spark/blob/9b746f38/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 7b1aed5..4c162df 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -190,7 +190,7 @@ class DecisionTreeSuite extends FunSuite with 
MLlibTestSparkContext {
     assert(splits.length === 2)
     assert(bins.length === 2)
     assert(splits(0).length === 3)
-    assert(bins(0).length === 6)
+    assert(bins(0).length === 0)
 
     // Expecting 2^2 - 1 = 3 bins/splits
     assert(splits(0)(0).feature === 0)
@@ -228,41 +228,6 @@ class DecisionTreeSuite extends FunSuite with 
MLlibTestSparkContext {
     assert(splits(1)(2).categories.contains(0.0))
     assert(splits(1)(2).categories.contains(1.0))
 
-    // Check bins.
-
-    assert(bins(0)(0).category === Double.MinValue)
-    assert(bins(0)(0).lowSplit.categories.length === 0)
-    assert(bins(0)(0).highSplit.categories.length === 1)
-    assert(bins(0)(0).highSplit.categories.contains(0.0))
-    assert(bins(1)(0).category === Double.MinValue)
-    assert(bins(1)(0).lowSplit.categories.length === 0)
-    assert(bins(1)(0).highSplit.categories.length === 1)
-    assert(bins(1)(0).highSplit.categories.contains(0.0))
-
-    assert(bins(0)(1).category === Double.MinValue)
-    assert(bins(0)(1).lowSplit.categories.length === 1)
-    assert(bins(0)(1).lowSplit.categories.contains(0.0))
-    assert(bins(0)(1).highSplit.categories.length === 1)
-    assert(bins(0)(1).highSplit.categories.contains(1.0))
-    assert(bins(1)(1).category === Double.MinValue)
-    assert(bins(1)(1).lowSplit.categories.length === 1)
-    assert(bins(1)(1).lowSplit.categories.contains(0.0))
-    assert(bins(1)(1).highSplit.categories.length === 1)
-    assert(bins(1)(1).highSplit.categories.contains(1.0))
-
-    assert(bins(0)(2).category === Double.MinValue)
-    assert(bins(0)(2).lowSplit.categories.length === 1)
-    assert(bins(0)(2).lowSplit.categories.contains(1.0))
-    assert(bins(0)(2).highSplit.categories.length === 2)
-    assert(bins(0)(2).highSplit.categories.contains(1.0))
-    assert(bins(0)(2).highSplit.categories.contains(0.0))
-    assert(bins(1)(2).category === Double.MinValue)
-    assert(bins(1)(2).lowSplit.categories.length === 1)
-    assert(bins(1)(2).lowSplit.categories.contains(1.0))
-    assert(bins(1)(2).highSplit.categories.length === 2)
-    assert(bins(1)(2).highSplit.categories.contains(1.0))
-    assert(bins(1)(2).highSplit.categories.contains(0.0))
-
   }
 
   test("Multiclass classification with ordered categorical features: split and 
bin calculations") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to