Github user manishamde commented on a diff in the pull request:

    https://github.com/apache/spark/pull/886#discussion_r14865144
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---
    @@ -768,104 +973,157 @@ object DecisionTree extends Serializable with 
Logging {
         /**
          * Extracts left and right split aggregates.
          * @param binData Array[Double] of size 2*numFeatures*numSplits
    -     * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double],
    -     *         Array[Double]) where each array is of 
size(numFeature,2*(numSplits-1))
    +     * @return (leftNodeAgg, rightNodeAgg) tuple of type 
(Array[Array[Array[Double\]\]\],
    +     *         Array[Array[Array[Double\]\]\]) where each array is of 
size(numFeature,
    +     *         (numBins - 1), numClasses)
          */
         def extractLeftRightNodeAggregates(
    -        binData: Array[Double]): (Array[Array[Double]], 
Array[Array[Double]]) = {
    +        binData: Array[Double]): (Array[Array[Array[Double]]], 
Array[Array[Array[Double]]]) = {
    +
    +
    +      def findAggForOrderedFeatureClassification(
    +          leftNodeAgg: Array[Array[Array[Double]]],
    +          rightNodeAgg: Array[Array[Array[Double]]],
    +          featureIndex: Int) {
    +
    +        // shift for this featureIndex
    +        val shift = numClasses * featureIndex * numBins
    +
    +        var classIndex = 0
    +        while (classIndex < numClasses) {
    +          // left node aggregate for the lowest split
    +          leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + 
classIndex)
    +          // right node aggregate for the highest split
    +          rightNodeAgg(featureIndex)(numBins - 2)(classIndex)
    +            = binData(shift + (numClasses * (numBins - 1)) + classIndex)
    +          classIndex += 1
    +        }
    +
    +        // Iterate over all splits.
    +        var splitIndex = 1
    +        while (splitIndex < numBins - 1) {
    +          // calculating left node aggregate for a split as a sum of left 
node aggregate of a
    +          // lower split and the left bin aggregate of a bin where the 
split is a high split
    +          var innerClassIndex = 0
    +          while (innerClassIndex < numClasses) {
    +            leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex)
    +              = binData(shift + numClasses * splitIndex + innerClassIndex) 
+
    +                leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex)
    +            rightNodeAgg(featureIndex)(numBins - 2 - 
splitIndex)(innerClassIndex) =
    +              binData(shift + (numClasses * (numBins - 1 - splitIndex) + 
innerClassIndex)) +
    +                rightNodeAgg(featureIndex)(numBins - 1 - 
splitIndex)(innerClassIndex)
    +            innerClassIndex += 1
    +          }
    +          splitIndex += 1
    +        }
    +      }
    +
    +      def findAggForUnorderedFeatureClassification(
    +          leftNodeAgg: Array[Array[Array[Double]]],
    +          rightNodeAgg: Array[Array[Array[Double]]],
    +          featureIndex: Int) {
    +
    +        val rightChildShift = numClasses * numBins * numFeatures
    +        var splitIndex = 0
    +        while (splitIndex < numBins - 1) {
    +          var classIndex = 0
    +          while (classIndex < numClasses) {
    +            // shift for this featureIndex
    +            val shift = numClasses * featureIndex * numBins + splitIndex * 
numClasses
    +            val leftBinValue = binData(shift + classIndex)
    +            val rightBinValue = binData(rightChildShift + shift + 
classIndex)
    +            leftNodeAgg(featureIndex)(splitIndex)(classIndex) = 
leftBinValue
    +            rightNodeAgg(featureIndex)(splitIndex)(classIndex) = 
rightBinValue
    +            classIndex += 1
    +          }
    +          splitIndex += 1
    +        }
    +      }
    +
    +      def findAggForRegression(
    +          leftNodeAgg: Array[Array[Array[Double]]],
    +          rightNodeAgg: Array[Array[Array[Double]]],
    +          featureIndex: Int) {
    +
    +        // shift for this featureIndex
    +        val shift = 3 * featureIndex * numBins
    +        // left node aggregate for the lowest split
    +        leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0)
    +        leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1)
    +        leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2)
    +
    +        // right node aggregate for the highest split
    +        rightNodeAgg(featureIndex)(numBins - 2)(0) =
    +          binData(shift + (3 * (numBins - 1)))
    +        rightNodeAgg(featureIndex)(numBins - 2)(1) =
    +          binData(shift + (3 * (numBins - 1)) + 1)
    +        rightNodeAgg(featureIndex)(numBins - 2)(2) =
    +          binData(shift + (3 * (numBins - 1)) + 2)
    +
    +        // Iterate over all splits.
    +        var splitIndex = 1
    +        while (splitIndex < numBins - 1) {
    +          // calculating left node aggregate for a split as a sum of left 
node aggregate of a
    +          // lower split and the left bin aggregate of a bin where the 
split is a high split
    +          leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 3 * 
splitIndex) +
    --- End diff --
    
    Will fix this.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

Reply via email to