Github user smurching commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19433#discussion_r151011913
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
    @@ -627,221 +621,37 @@ private[spark] object RandomForest extends Logging {
       }
     
       /**
    -   * Calculate the impurity statistics for a given (feature, split) based 
upon left/right
    -   * aggregates.
    -   *
    -   * @param stats the recycle impurity statistics for this feature's all 
splits,
    -   *              only 'impurity' and 'impurityCalculator' are valid 
between each iteration
    -   * @param leftImpurityCalculator left node aggregates for this (feature, 
split)
    -   * @param rightImpurityCalculator right node aggregate for this 
(feature, split)
    -   * @param metadata learning and dataset metadata for DecisionTree
    -   * @return Impurity statistics for this (feature, split)
    +   * Return a list of pairs (featureIndexIdx, featureIndex) where 
featureIndex is the global
    +   * (across all trees) index of a feature and featureIndexIdx is the 
index of a feature within the
    +   * list of features for a given node. Filters out constant features 
(features with 0 splits)
        */
    -  private def calculateImpurityStats(
    -      stats: ImpurityStats,
    -      leftImpurityCalculator: ImpurityCalculator,
    -      rightImpurityCalculator: ImpurityCalculator,
    -      metadata: DecisionTreeMetadata): ImpurityStats = {
    -
    -    val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
    -      leftImpurityCalculator.copy.add(rightImpurityCalculator)
    -    } else {
    -      stats.impurityCalculator
    -    }
    -
    -    val impurity: Double = if (stats == null) {
    -      parentImpurityCalculator.calculate()
    -    } else {
    -      stats.impurity
    -    }
    -
    -    val leftCount = leftImpurityCalculator.count
    -    val rightCount = rightImpurityCalculator.count
    -
    -    val totalCount = leftCount + rightCount
    -
    -    // If left child or right child doesn't satisfy minimum instances per 
node,
    -    // then this split is invalid, return invalid information gain stats.
    -    if ((leftCount < metadata.minInstancesPerNode) ||
    -      (rightCount < metadata.minInstancesPerNode)) {
    -      return 
ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
    -    }
    -
    -    val leftImpurity = leftImpurityCalculator.calculate() // Note: This 
equals 0 if count = 0
    -    val rightImpurity = rightImpurityCalculator.calculate()
    -
    -    val leftWeight = leftCount / totalCount.toDouble
    -    val rightWeight = rightCount / totalCount.toDouble
    -
    -    val gain = impurity - leftWeight * leftImpurity - rightWeight * 
rightImpurity
    -
    -    // if information gain doesn't satisfy minimum information gain,
    -    // then this split is invalid, return invalid information gain stats.
    -    if (gain < metadata.minInfoGain) {
    -      return 
ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
    +  private[impl] def getNonConstantFeatures(
    +      metadata: DecisionTreeMetadata,
    +      featuresForNode: Option[Array[Int]]): Seq[(Int, Int)] = {
    +    Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx =>
    --- End diff --
    
    At some point when refactoring I was hitting errors caused by a stateful 
operation within a `map` over the output of this method (IIRC the result of the 
`map` was accessed repeatedly, causing the stateful operation to inadvertently 
be run multiple times).
    
    However using `withFilter` and `view` now seems to work, I'll change it 
back :)


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to