Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/20472#discussion_r169386551 --- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala --- @@ -1001,11 +996,18 @@ private[spark] object RandomForest extends Logging { } else { val numSplits = metadata.numSplits(featureIndex) - // get count for each distinct value - val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { + // get count for each distinct value except zero value + val (partValueCountMap, partNumSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { case ((m, cnt), x) => (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) } + + // Calculate the number of samples for finding splits + val numSamples: Int = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt + + // add zero value count and get complete statistics + val valueCountMap: Map[Double, Int] = partValueCountMap + (0.0 -> (numSamples - partNumSamples)) --- End diff -- There can be negative values right?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org