Github user smurching commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19433#discussion_r151019591
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
    @@ -852,6 +662,41 @@ private[spark] object RandomForest extends Logging {
       }
     
       /**
    +   * Find the best split for a node.
    +   *
    +   * @param binAggregates Bin statistics.
    +   * @return tuple for best split: (Split, information gain, prediction at 
node)
    +   */
    +  private[tree] def binsToBestSplit(
    +      binAggregates: DTStatsAggregator,
    +      splits: Array[Array[Split]],
    +      featuresForNode: Option[Array[Int]],
    +      node: LearningNode): (Split, ImpurityStats) = {
    +    val validFeatureSplits = 
getNonConstantFeatures(binAggregates.metadata, featuresForNode)
    +    // For each (feature, split), calculate the gain, and select the best 
(feature, split).
    +    val parentImpurityCalc = if (node.stats == null) None else 
Some(node.stats.impurityCalculator)
    --- End diff --
    
    I believe so, the nodes at the top level are created 
([RandomForest.scala:178](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala#L178))
 with 
[`LearningNode.emptyNode`](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala#L341),
 which sets `node.stats = null`.
    
    I could change this to check node depth (via node index), but if we're 
planning on deprecating node indices in the future it might be best not to.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to