Github user smurching commented on a diff in the pull request: https://github.com/apache/spark/pull/19433#discussion_r151011913 --- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala --- @@ -627,221 +621,37 @@ private[spark] object RandomForest extends Logging { } /** - * Calculate the impurity statistics for a given (feature, split) based upon left/right - * aggregates. - * - * @param stats the recycle impurity statistics for this feature's all splits, - * only 'impurity' and 'impurityCalculator' are valid between each iteration - * @param leftImpurityCalculator left node aggregates for this (feature, split) - * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @param metadata learning and dataset metadata for DecisionTree - * @return Impurity statistics for this (feature, split) + * Return a list of pairs (featureIndexIdx, featureIndex) where featureIndex is the global + * (across all trees) index of a feature and featureIndexIdx is the index of a feature within the + * list of features for a given node. Filters out constant features (features with 0 splits) */ - private def calculateImpurityStats( - stats: ImpurityStats, - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata): ImpurityStats = { - - val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { - leftImpurityCalculator.copy.add(rightImpurityCalculator) - } else { - stats.impurityCalculator - } - - val impurity: Double = if (stats == null) { - parentImpurityCalculator.calculate() - } else { - stats.impurity - } - - val leftCount = leftImpurityCalculator.count - val rightCount = rightImpurityCalculator.count - - val totalCount = leftCount + rightCount - - // If left child or right child doesn't satisfy minimum instances per node, - // then this split is invalid, return invalid information gain stats. - if ((leftCount < metadata.minInstancesPerNode) || - (rightCount < metadata.minInstancesPerNode)) { - return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) - } - - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 - val rightImpurity = rightImpurityCalculator.calculate() - - val leftWeight = leftCount / totalCount.toDouble - val rightWeight = rightCount / totalCount.toDouble - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - // if information gain doesn't satisfy minimum information gain, - // then this split is invalid, return invalid information gain stats. - if (gain < metadata.minInfoGain) { - return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + private[impl] def getNonConstantFeatures( + metadata: DecisionTreeMetadata, + featuresForNode: Option[Array[Int]]): Seq[(Int, Int)] = { + Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx => --- End diff -- At some point when refactoring I was hitting errors caused by a stateful operation within a `map` over the output of this method (IIRC the result of the `map` was accessed repeatedly, causing the stateful operation to inadvertently be run multiple times). However using `withFilter` and `view` now seems to work, I'll change it back :)
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org