Github user WeichenXu123 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19666#discussion_r149561550
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
    @@ -741,17 +678,43 @@ private[spark] object RandomForest extends Logging {
               (splits(featureIndex)(bestFeatureSplitIndex), 
bestFeatureGainStats)
             } else if (binAggregates.metadata.isUnordered(featureIndex)) {
               // Unordered categorical feature
    -          val leftChildOffset = 
binAggregates.getFeatureOffset(featureIndexIdx)
    -          val (bestFeatureSplitIndex, bestFeatureGainStats) =
    -            Range(0, numSplits).map { splitIndex =>
    -              val leftChildStats = 
binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
    -              val rightChildStats = 
binAggregates.getParentImpurityCalculator()
    -                .subtract(leftChildStats)
    +          val numBins = binAggregates.metadata.numBins(featureIndex)
    +          val featureOffset = 
binAggregates.getFeatureOffset(featureIndexIdx)
    +
    +          val binStatsArray = Array.tabulate(numBins) { binIndex =>
    +            binAggregates.getImpurityCalculator(featureOffset, binIndex)
    +          }
    +          val parentStats = binAggregates.getParentImpurityCalculator()
    +
    +          var bestGain = Double.NegativeInfinity
    +          var bestSet: BitSet = null
    +          var bestLeftChildStats: ImpurityCalculator = null
    +          var bestRightChildStats: ImpurityCalculator = null
    +
    +          traverseUnorderedSplits[ImpurityCalculator](numBins, null,
    +            (stats, binIndex) => {
    +              val binStats = binStatsArray(binIndex)
    +              if (stats == null) {
    +                binStats
    +              } else {
    +                stats.copy.add(binStats)
    +              }
    +            },
    +            (set, leftChildStats) => {
    +              val rightChildStats = 
parentStats.copy.subtract(leftChildStats)
                   gainAndImpurityStats = 
calculateImpurityStats(gainAndImpurityStats,
                     leftChildStats, rightChildStats, binAggregates.metadata)
    -              (splitIndex, gainAndImpurityStats)
    -            }.maxBy(_._2.gain)
    -          (splits(featureIndex)(bestFeatureSplitIndex), 
bestFeatureGainStats)
    +              if (gainAndImpurityStats.gain > bestGain) {
    +                bestGain = gainAndImpurityStats.gain
    +                bestSet = set | new BitSet(numBins) // copy set
    --- End diff --
    
    The class do not support `copy` 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to