This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new feae21c6445 [SPARK-39446][MLLIB][FOLLOWUP] Modify ranking metrics for java and python feae21c6445 is described below commit feae21c6445f8767bf5f62bb54f6c61a8df4e0c1 Author: uchiiii <uchiku...@gmail.com> AuthorDate: Wed Jun 29 17:36:48 2022 +0800 [SPARK-39446][MLLIB][FOLLOWUP] Modify ranking metrics for java and python ### What changes were proposed in this pull request? - Updated `RankingMetrics` for Java and Python - Modified the interface for Java and Python - Added test for Java ### Why are the changes needed? - To expose the change in https://github.com/apache/spark/pull/36843 to Java and Python. - To update the document for Java and Python. ### Does this PR introduce _any_ user-facing change? - Java users can use a JavaRDD of (predicted ranking, ground truth set, relevance value of ground truth set) for `RankingMetrics` ### How was this patch tested? - Added test for Java Closes #37019 from uchiiii/modify_ranking_metrics_for_java_and_python. Authored-by: uchiiii <uchiku...@gmail.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../spark/mllib/evaluation/RankingMetrics.scala | 16 ++++++++++--- .../mllib/evaluation/JavaRankingMetricsSuite.java | 27 ++++++++++++++++++++++ python/pyspark/mllib/evaluation.py | 14 ++++++++--- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 87a17f57caf..6ff8262c498 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -267,12 +267,22 @@ object RankingMetrics { /** * Creates a [[RankingMetrics]] instance (for Java users). * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs + * or (predicted ranking, ground truth set, + * relevance value of ground truth set). + * Since 3.4.0, it supports ndcg evaluation with relevance value. */ @Since("1.4.0") - def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = { + def of[E, T <: jl.Iterable[E], A <: jl.Iterable[Double]]( + predictionAndLabels: JavaRDD[_ <: Product]): RankingMetrics[E] = { implicit val tag = JavaSparkContext.fakeClassTag[E] - val rdd = predictionAndLabels.rdd.map { case (predictions, labels) => - (predictions.asScala.toArray, labels.asScala.toArray) + val rdd = predictionAndLabels.rdd.map { + case (predictions, labels) => + (predictions.asInstanceOf[T].asScala.toArray, labels.asInstanceOf[T].asScala.toArray) + case (predictions, labels, rels) => + ( + predictions.asInstanceOf[T].asScala.toArray, + labels.asInstanceOf[T].asScala.toArray, + rels.asInstanceOf[A].asScala.toArray) } new RankingMetrics(rdd) } diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java index 50822c61fdc..4dcb2920610 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java @@ -22,7 +22,9 @@ import java.util.Arrays; import java.util.List; import scala.Tuple2; +import scala.Tuple3; import scala.Tuple2$; +import scala.Tuple3$; import org.junit.Assert; import org.junit.Test; @@ -32,6 +34,8 @@ import org.apache.spark.api.java.JavaRDD; public class JavaRankingMetricsSuite extends SharedSparkSession { private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> predictionAndLabels; + private transient JavaRDD<Tuple3<List<Integer>, List<Integer>, List<Double>>> + predictionLabelsAndRelevance; @Override public void setUp() throws IOException { @@ -43,6 +47,22 @@ public class JavaRankingMetricsSuite extends SharedSparkSession { Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)), Tuple2$.MODULE$.apply( Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2); + predictionLabelsAndRelevance = jsc.parallelize(Arrays.asList( + Tuple3$.MODULE$.apply( + Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), + Arrays.asList(1, 2, 3, 4, 5), + Arrays.asList(3.0, 2.0, 1.0, 1.0, 1.0) + ), + Tuple3$.MODULE$.apply( + Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), + Arrays.asList(1, 2, 3), + Arrays.asList(2.0, 0.0, 0.0) + ), + Tuple3$.MODULE$.apply( + Arrays.asList(1, 2, 3, 4, 5), + Arrays.<Integer>asList(), + Arrays.<Double>asList() + )), 3); } @Test @@ -51,4 +71,11 @@ public class JavaRankingMetricsSuite extends SharedSparkSession { Assert.assertEquals(0.355026, metrics.meanAveragePrecision(), 1e-5); Assert.assertEquals(0.75 / 3.0, metrics.precisionAt(4), 1e-5); } + + @Test + public void rankingMetricsWithRelevance() { + RankingMetrics<?> metrics = RankingMetrics.of(predictionLabelsAndRelevance); + Assert.assertEquals(0.355026, metrics.meanAveragePrecision(), 1e-5); + Assert.assertEquals(0.511959, metrics.ndcgAt(3), 1e-5); + } } diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 1003ba68c5f..cee61a1b241 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Generic, List, Optional, Tuple, TypeVar +from typing import Generic, List, Optional, Tuple, TypeVar, Union import sys @@ -418,7 +418,10 @@ class RankingMetrics(JavaModelWrapper, Generic[T]): Parameters ---------- predictionAndLabels : :py:class:`pyspark.RDD` - an RDD of (predicted ranking, ground truth set) pairs. + an RDD of (predicted ranking, ground truth set) pairs + or (predicted ranking, ground truth set, + relevance value of ground truth set). + Since 3.4.0, it supports ndcg evaluation with relevance value. Examples -------- @@ -451,7 +454,12 @@ class RankingMetrics(JavaModelWrapper, Generic[T]): 0.66... """ - def __init__(self, predictionAndLabels: RDD[Tuple[List[T], List[T]]]): + def __init__( + self, + predictionAndLabels: Union[ + RDD[Tuple[List[T], List[T]]], RDD[Tuple[List[T], List[T], List[float]]] + ], + ): sc = predictionAndLabels.ctx sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org