This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new b3d30a8 [SPARK-27577][MLLIB] Correct thresholds downsampled in BinaryClassificationMetrics b3d30a8 is described below commit b3d30a8660de4b8d1a0d704ce18f37703b2f5cc2 Author: Shaochen Shi <shishaoc...@bytedance.com> AuthorDate: Tue May 7 08:41:58 2019 -0500 [SPARK-27577][MLLIB] Correct thresholds downsampled in BinaryClassificationMetrics ## What changes were proposed in this pull request? Choose the last record in chunks when calculating metrics with downsampling in `BinaryClassificationMetrics`. ## How was this patch tested? A new unit test is added to verify thresholds from downsampled records. Closes #24470 from shishaochen/spark-mllib-binary-metrics. Authored-by: Shaochen Shi <shishaoc...@bytedance.com> Signed-off-by: Sean Owen <sean.o...@databricks.com> (cherry picked from commit d5308cd86fff1e4bf9c24e0dd73d8d2c92737c4f) Signed-off-by: Sean Owen <sean.o...@databricks.com> --- .../spark/mllib/evaluation/BinaryClassificationMetrics.scala | 11 +++++++---- .../mllib/evaluation/BinaryClassificationMetricsSuite.scala | 11 +++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 2cfcf38..764806b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -175,12 +175,15 @@ class BinaryClassificationMetrics @Since("1.3.0") ( grouping = Int.MaxValue } counts.mapPartitions(_.grouped(grouping.toInt).map { pairs => - // The score of the combined point will be just the first one's score - val firstScore = pairs.head._1 - // The point will contain all counts in this chunk + // The score of the combined point will be just the last one's score, which is also + // the minimal in each chunk since all scores are already sorted in descending. + val lastScore = pairs.last._1 + // The combined point will contain all counts in this chunk. Thus, calculated + // metrics (like precision, recall, etc.) on its score (or so-called threshold) are + // the same as those without sampling. val agg = new BinaryLabelCounter() pairs.foreach(pair => agg += pair._2) - (firstScore, agg) + (lastScore, agg) }) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index a08917a..4cc9ee5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -155,6 +155,17 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark (1.0, 1.0), (1.0, 1.0) ) == downsampledROC) + + val downsampledRecall = downsampled.recallByThreshold().collect().sorted.toList + assert( + // May have to add 1 if the sample factor didn't divide evenly + numBins + (if (scoreAndLabels.size % numBins == 0) 0 else 1) == + downsampledRecall.size) + assert( + List( + (0.1, 1.0), (0.2, 1.0), (0.4, 0.75), (0.6, 0.75), (0.8, 0.25) + ) == + downsampledRecall) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org