Repository: spark Updated Branches: refs/heads/master 8f0df6bc1 -> 734ed7a7b
[SPARK-21806][MLLIB] BinaryClassificationMetrics pr(): first point (0.0, 1.0) is misleading ## What changes were proposed in this pull request? Prepend (0,p) to precision-recall curve not (0,1) where p matches lowest recall point ## How was this patch tested? Updated tests. Author: Sean Owen <so...@cloudera.com> Closes #19038 from srowen/SPARK-21806. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/734ed7a7 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/734ed7a7 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/734ed7a7 Branch: refs/heads/master Commit: 734ed7a7b397578f16549070f350215bde369b3c Parents: 8f0df6b Author: Sean Owen <so...@cloudera.com> Authored: Wed Aug 30 11:36:00 2017 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Wed Aug 30 11:36:00 2017 +0100 ---------------------------------------------------------------------- .../BinaryClassificationMetrics.scala | 8 +++---- .../BinaryClassificationMetricsSuite.scala | 22 +++++++++----------- 2 files changed, 14 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/734ed7a7/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 9b7cd04..2cfcf38 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -98,16 +98,16 @@ class BinaryClassificationMetrics @Since("1.3.0") ( /** * Returns the precision-recall curve, which is an RDD of (recall, precision), - * NOT (precision, recall), with (0.0, 1.0) prepended to it. + * NOT (precision, recall), with (0.0, p) prepended to it, where p is the precision + * associated with the lowest recall on the curve. * @see <a href="http://en.wikipedia.org/wiki/Precision_and_recall"> * Precision and recall (Wikipedia)</a> */ @Since("1.0.0") def pr(): RDD[(Double, Double)] = { val prCurve = createCurve(Recall, Precision) - val sc = confusions.context - val first = sc.makeRDD(Seq((0.0, 1.0)), 1) - first.union(prCurve) + val (_, firstPrecision) = prCurve.first() + confusions.context.parallelize(Seq((0.0, firstPrecision)), 1).union(prCurve) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/734ed7a7/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 99d52fa..a08917a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -23,18 +23,16 @@ import org.apache.spark.mllib.util.TestingUtils._ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { - private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 - - private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean = - (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5) - - private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = { - assert(left.zip(right).forall(areWithinEpsilon)) + private def assertSequencesMatch(actual: Seq[Double], expected: Seq[Double]): Unit = { + actual.zip(expected).foreach { case (a, e) => assert(a ~== e absTol 1.0e-5) } } - private def assertTupleSequencesMatch(left: Seq[(Double, Double)], - right: Seq[(Double, Double)]): Unit = { - assert(left.zip(right).forall(pairsWithinEpsilon)) + private def assertTupleSequencesMatch(actual: Seq[(Double, Double)], + expected: Seq[(Double, Double)]): Unit = { + actual.zip(expected).foreach { case ((ax, ay), (ex, ey)) => + assert(ax ~== ex absTol 1.0e-5) + assert(ay ~== ey absTol 1.0e-5) + } } private def validateMetrics(metrics: BinaryClassificationMetrics, @@ -44,7 +42,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark expectedFMeasures1: Seq[Double], expectedFmeasures2: Seq[Double], expectedPrecisions: Seq[Double], - expectedRecalls: Seq[Double]) = { + expectedRecalls: Seq[Double]): Unit = { assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds) assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve) @@ -111,7 +109,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark val fpr = Seq(1.0) val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) val pr = recalls.zip(precisions) - val prCurve = Seq((0.0, 1.0)) ++ pr + val prCurve = Seq((0.0, 0.0)) ++ pr val f1 = pr.map { case (0, 0) => 0.0 case (r, p) => 2.0 * (p * r) / (p + r) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org