Github user smurching commented on a diff in the pull request: https://github.com/apache/spark/pull/19433#discussion_r151019591 --- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala --- @@ -852,6 +662,41 @@ private[spark] object RandomForest extends Logging { } /** + * Find the best split for a node. + * + * @param binAggregates Bin statistics. + * @return tuple for best split: (Split, information gain, prediction at node) + */ + private[tree] def binsToBestSplit( + binAggregates: DTStatsAggregator, + splits: Array[Array[Split]], + featuresForNode: Option[Array[Int]], + node: LearningNode): (Split, ImpurityStats) = { + val validFeatureSplits = getNonConstantFeatures(binAggregates.metadata, featuresForNode) + // For each (feature, split), calculate the gain, and select the best (feature, split). + val parentImpurityCalc = if (node.stats == null) None else Some(node.stats.impurityCalculator) --- End diff -- I believe so, the nodes at the top level are created ([RandomForest.scala:178](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala#L178)) with [`LearningNode.emptyNode`](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala#L341), which sets `node.stats = null`. I could change this to check node depth (via node index), but if we're planning on deprecating node indices in the future it might be best not to.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org