This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5fa4ba0  [SPARK-26981][MLLIB] Add 'Recall_at_k' metric to 
RankingMetrics
5fa4ba0 is described below

commit 5fa4ba0cfb126bfadee7451fe9a46cee3d60b67c
Author: masa3141 <masah...@kazama.tv>
AuthorDate: Wed Mar 6 08:28:53 2019 -0600

    [SPARK-26981][MLLIB] Add 'Recall_at_k' metric to RankingMetrics
    
    ## What changes were proposed in this pull request?
    
    Add 'Recall_at_k' metric to RankingMetrics
    
    ## How was this patch tested?
    
    Add test to RankingMetricsSuite.
    
    Closes #23881 from masa3141/SPARK-26981.
    
    Authored-by: masa3141 <masah...@kazama.tv>
    Signed-off-by: Sean Owen <sean.o...@databricks.com>
---
 .../examples/mllib/JavaRankingMetricsExample.java  |  3 +-
 .../examples/mllib/RankingMetricsExample.scala     |  5 ++
 .../spark/mllib/evaluation/RankingMetrics.scala    | 75 +++++++++++++++++-----
 .../mllib/evaluation/RankingMetricsSuite.scala     | 14 +++-
 python/pyspark/mllib/evaluation.py                 | 20 ++++++
 5 files changed, 97 insertions(+), 20 deletions(-)

diff --git 
a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java
 
b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java
index dc9970d..414d376 100644
--- 
a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java
+++ 
b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java
@@ -99,11 +99,12 @@ public class JavaRankingMetricsExample {
     // Instantiate the metrics object
     RankingMetrics<Integer> metrics = RankingMetrics.of(relevantDocs);
 
-    // Precision and NDCG at k
+    // Precision, NDCG and Recall at k
     Integer[] kVector = {1, 3, 5};
     for (Integer k : kVector) {
       System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k));
       System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k));
+      System.out.format("Recall at %d = %f\n", k, metrics.recallAt(k));
     }
 
     // Mean average precision
diff --git 
a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala
 
b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala
index d514891..34fbe08 100644
--- 
a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala
+++ 
b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala
@@ -89,6 +89,11 @@ object RankingMetricsExample {
       println(s"NDCG at $k = ${metrics.ndcgAt(k)}")
     }
 
+    // Recall at K
+    Array(1, 3, 5).foreach { k =>
+      println(s"Recall at $k = ${metrics.recallAt(k)}")
+    }
+
     // Get predictions for each data point
     val allPredictions = model.predict(ratings.map(r => (r.user, 
r.product))).map(r => ((r.user,
       r.product), r.rating))
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
index 4935d11..ff9663a 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -59,23 +59,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: 
RDD[(Array[T], Array[T])]
   def precisionAt(k: Int): Double = {
     require(k > 0, "ranking position k should be positive")
     predictionAndLabels.map { case (pred, lab) =>
-      val labSet = lab.toSet
-
-      if (labSet.nonEmpty) {
-        val n = math.min(pred.length, k)
-        var i = 0
-        var cnt = 0
-        while (i < n) {
-          if (labSet.contains(pred(i))) {
-            cnt += 1
-          }
-          i += 1
-        }
-        cnt.toDouble / k
-      } else {
-        logWarning("Empty ground truth set, check input data")
-        0.0
-      }
+      countRelevantItemRatio(pred, lab, k, k)
     }.mean()
   }
 
@@ -157,6 +141,63 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: 
RDD[(Array[T], Array[T])]
     }.mean()
   }
 
+  /**
+   * Compute the average recall of all the queries, truncated at ranking 
position k.
+   *
+   * If for a query, the ranking algorithm returns n results, the recall value 
will be
+   * computed as #(relevant items retrieved) / #(ground truth set). This 
formula
+   * also applies when the size of the ground truth set is less than k.
+   *
+   * If a query has an empty ground truth set, zero will be used as recall 
together with
+   * a log warning.
+   *
+   * See the following paper for detail:
+   *
+   * IR evaluation methods for retrieving highly relevant documents. K. 
Jarvelin and J. Kekalainen
+   *
+   * @param k the position to compute the truncated recall, must be positive
+   * @return the average recall at the first k ranking positions
+   */
+  @Since("3.0.0")
+  def recallAt(k: Int): Double = {
+    require(k > 0, "ranking position k should be positive")
+    predictionAndLabels.map { case (pred, lab) =>
+      countRelevantItemRatio(pred, lab, k, lab.toSet.size)
+    }.mean()
+  }
+
+  /**
+   * Returns the relevant item ratio computed as #(relevant items retrieved) / 
denominator.
+   * If a query has an empty ground truth set, the value will be zero and a log
+   * warning is generated.
+   *
+   * @param pred predicted ranking
+   * @param lab ground truth
+   * @param k use the top k predicted ranking, must be positive
+   * @param denominator the denominator of ratio
+   * @return relevant item ratio at the first k ranking positions
+   */
+  private def countRelevantItemRatio(pred: Array[T],
+                                     lab: Array[T],
+                                     k: Int,
+                                     denominator: Int): Double = {
+    val labSet = lab.toSet
+    if (labSet.nonEmpty) {
+      val n = math.min(pred.length, k)
+      var i = 0
+      var cnt = 0
+      while (i < n) {
+        if (labSet.contains(pred(i))) {
+          cnt += 1
+        }
+        i += 1
+      }
+      cnt.toDouble / denominator
+    } else {
+      logWarning("Empty ground truth set, check input data")
+      0.0
+    }
+  }
 }
 
 object RankingMetrics {
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
index f334be2..1969098 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.mllib.util.TestingUtils._
 
 class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
 
-  test("Ranking metrics: MAP, NDCG") {
+  test("Ranking metrics: MAP, NDCG, Recall") {
     val predictionAndLabels = sc.parallelize(
       Seq(
         (Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)),
@@ -49,9 +49,17 @@ class RankingMetricsSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
     assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
     assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)
+
+    assert(metrics.recallAt(1) ~== 1.0/15 absTol eps)
+    assert(metrics.recallAt(2) ~== 8.0/45 absTol eps)
+    assert(metrics.recallAt(3) ~== 11.0/45 absTol eps)
+    assert(metrics.recallAt(4) ~== 11.0/45 absTol eps)
+    assert(metrics.recallAt(5) ~== 16.0/45 absTol eps)
+    assert(metrics.recallAt(10) ~== 2.0/3 absTol eps)
+    assert(metrics.recallAt(15) ~== 2.0/3 absTol eps)
   }
 
-  test("MAP, NDCG with few predictions (SPARK-14886)") {
+  test("MAP, NDCG, Recall with few predictions (SPARK-14886)") {
     val predictionAndLabels = sc.parallelize(
       Seq(
         (Array(1, 6, 2), Array(1, 2, 3, 4, 5)),
@@ -64,6 +72,8 @@ class RankingMetricsSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(metrics.precisionAt(2) ~== 0.25 absTol eps)
     assert(metrics.ndcgAt(1) ~== 0.5 absTol eps)
     assert(metrics.ndcgAt(2) ~== 0.30657 absTol eps)
+    assert(metrics.recallAt(1) ~== 0.1 absTol eps)
+    assert(metrics.recallAt(2) ~== 0.1 absTol eps)
   }
 
 }
diff --git a/python/pyspark/mllib/evaluation.py 
b/python/pyspark/mllib/evaluation.py
index 5d8d20d..171c62c 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -373,6 +373,12 @@ class RankingMetrics(JavaModelWrapper):
     0.33...
     >>> metrics.ndcgAt(10)
     0.48...
+    >>> metrics.recallAt(1)
+    0.06...
+    >>> metrics.recallAt(5)
+    0.35...
+    >>> metrics.recallAt(15)
+    0.66...
 
     .. versionadded:: 1.4.0
     """
@@ -422,6 +428,20 @@ class RankingMetrics(JavaModelWrapper):
         """
         return self.call("ndcgAt", int(k))
 
+    @since('3.0.0')
+    def recallAt(self, k):
+        """
+        Compute the average recall of all the queries, truncated at ranking 
position k.
+
+        If for a query, the ranking algorithm returns n results, the recall 
value
+        will be computed as #(relevant items retrieved) / #(ground truth set).
+        This formula also applies when the size of the ground truth set is 
less than k.
+
+        If a query has an empty ground truth set, zero will be used as recall 
together
+        with a log warning.
+        """
+        return self.call("recallAt", int(k))
+
 
 class MultilabelMetrics(JavaModelWrapper):
     """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to