Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/2435#discussion_r17943420
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---
    @@ -649,71 +542,65 @@ object DecisionTree extends Serializable with Logging 
{
         // Calculate bin aggregates.
         timer.start("aggregation")
         val binAggregates: DTStatsAggregator = {
    -      val initAgg = new DTStatsAggregator(metadata, numNodes)
    +      val initAgg = if (metadata.subsamplingFeatures) {
    +        assert(featuresForNodes.nonEmpty)
    +        new DTStatsAggregatorSubsampledFeatures(metadata, groupNodeIndex, 
featuresForNodes.get)
    +      } else {
    +        new DTStatsAggregatorFixedFeatures(metadata, numNodes)
    +      }
           input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
         }
         timer.stop("aggregation")
     
    -    // Calculate best splits for all nodes at a given level
    +    // Calculate best splits for all nodes in the group
         timer.start("chooseSplits")
    -    // On the first iteration, we need to get and return the newly created 
root node.
    -    var newTopNode: Node = topNode
    -
    -    // Iterate over all nodes at this level
    -    var nodeIndex = 0
    -    var internalNodeCount = 0
    -    while (nodeIndex < numNodes) {
    -      val (split: Split, stats: InformationGainStats, predict: Predict) =
    -        binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits)
    -      logDebug("best split = " + split)
     
    -      val globalNodeIndex = globalNodeIndexOffset + nodeIndex
    -
    -      // Extract info for this node at the current level.
    -      val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth)
    -      val node =
    -        new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), 
None, None, Some(stats))
    -      logDebug("Node = " + node)
    -
    -      if (!isLeaf) {
    -        internalNodeCount += 1
    -      }
    -      if (level == 0) {
    -        newTopNode = node
    -      } else {
    -        // Set parent.
    -        val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), 
topNode)
    -        if (Node.isLeftChild(globalNodeIndex)) {
    -          parentNode.leftNode = Some(node)
    +    // Iterate over all nodes in this group.
    +    groupNodeIndex.foreach{ case (treeIndex, nodeIndexToAggIndex) =>
    +      nodeIndexToAggIndex.foreach{ case (nodeIndex, aggNodeIndex) =>
    --- End diff --
    
    space before `{`


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to