imatiach-msft commented on a change in pull request #21632: [SPARK-19591][ML][MLlib] Add sample weights to decision trees URL: https://github.com/apache/spark/pull/21632#discussion_r247175614
########## File path: mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ########## @@ -1002,19 +1019,20 @@ private[spark] object RandomForest extends Logging with Serializable { val numSplits = metadata.numSplits(featureIndex) // get count for each distinct value except zero value - val partNumSamples = featureSamples.size - val partValueCountMap = scala.collection.mutable.Map[Double, Int]() - featureSamples.foreach { x => - partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1 - } + val (partValueCountMap, partNumSamples) = + featureSamples.foldLeft((Map.empty[Double, Double], 0.0)) { + case ((m, cnt), (w, x)) => + (m + ((x, m.getOrElse(x, 0.0) + w)), cnt + w) + } // Calculate the expected number of samples for finding splits - val numSamples = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt + val weightedNumSamples = samplesFractionForFindSplits(metadata) * + metadata.weightedNumExamples // add expected zero value count and get complete statistics - val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples > 0) { - partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples)) + val valueCountMap: Map[Double, Double] = if (weightedNumSamples - partNumSamples > 1e-5) { Review comment: ok, so, the tests all pass when I do this but if I put a print here like this: ``` val valueCountMap = if (weightedNumSamples - partNumSamples > Utils.EPSILON) { println("adding zero weight: " + (weightedNumSamples - partNumSamples)) partValueCountMap + (0.0 -> (weightedNumSamples - partNumSamples)) } ``` and run the sample weights test: > testOnly org.apache.spark.ml.regression.DecisionTreeRegressorSuite -- -z "sample weights" I get the output: ``` adding zero weight: 4.440892098500626E-14 adding zero weight: 4.440892098500626E-14 adding zero weight: 1.432454155292362E-11 adding zero weight: 1.432454155292362E-11 adding zero weight: 4.440892098500626E-14 adding zero weight: 4.440892098500626E-14 adding zero weight: 1.432454155292362E-11 adding zero weight: 1.432454155292362E-11 ``` We really should be ignoring those - and as you can see for most of them the weight is around 1e-11 to 1e-14 The problem is that we are adding a lot of doubles for partNumSamples, and since we are adding so many of them together the precision isn't very good due to the limits of floating point representation. For example, if you print out weightedNumSamples and partNumSamples you will see something similar to this for the 1000 weight case: weightedNumSamples: 1000.1999999999784 partnumsamples: 1000.1999999999641 Utils.Epsilon would usually be fine but since we are adding so many of them together it seems to be too strict I think I've come up with a good formula though: val tolerance = Utils.EPSILON * unweightedNumSamples * unweightedNumSamples this seems to usually work well enough and customizes the tolerance based on the number of example summed ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org