This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f465a3d943e [SPARK-39446][MLLIB][FOLLOWUP] Modify constructor of RankingMetrics class f465a3d943e is described below commit f465a3d943ea692b9ba377fcfcf17012c3bea29f Author: uchiiii <uchiku...@gmail.com> AuthorDate: Sat Jun 25 14:16:45 2022 -0500 [SPARK-39446][MLLIB][FOLLOWUP] Modify constructor of RankingMetrics class ### What changes were proposed in this pull request? - Merged the two constructor into one using `RDD[_ <: Product]`. ### Why are the changes needed? - To make code simpler. - To support even more inputs. - ~~The previous code treats `rel` as an empty array when `rel` is not provided, which is not that beautiful. This change removes this.~~ ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? Closes #36920 from uchiiii/modify_ranking_metrics. Authored-by: uchiiii <uchiku...@gmail.com> Signed-off-by: Sean Owen <sro...@gmail.com> --- .../spark/mllib/evaluation/RankingMetrics.scala | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 7fccff9a24e..87a17f57caf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -38,16 +38,14 @@ import org.apache.spark.rdd.RDD * Since 3.4.0, it supports ndcg evaluation with relevance value. */ @Since("1.2.0") -class RankingMetrics[T: ClassTag] @Since("3.4.0") ( - predictionAndLabels: RDD[(Array[T], Array[T], Array[Double])]) +class RankingMetrics[T: ClassTag] @Since("1.2.0") (predictionAndLabels: RDD[_ <: Product]) extends Logging with Serializable { - @Since("1.2.0") - def this(predictionAndLabelsWithoutRelevance: => RDD[(Array[T], Array[T])]) = { - this(predictionAndLabelsWithoutRelevance.map { - case (pred, lab) => (pred, lab, Array.empty[Double]) - }) + private val rdd = predictionAndLabels.map { + case (pred: Array[T], lab: Array[T]) => (pred, lab, Array.empty[Double]) + case (pred: Array[T], lab: Array[T], rel: Array[Double]) => (pred, lab, rel) + case _ => throw new IllegalArgumentException(s"Expected RDD of tuples or triplets") } /** @@ -70,7 +68,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") ( @Since("1.2.0") def precisionAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") - predictionAndLabels.map { case (pred, lab, _) => + rdd.map { case (pred, lab, _) => countRelevantItemRatio(pred, lab, k, k) }.mean() } @@ -82,7 +80,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") ( */ @Since("1.2.0") lazy val meanAveragePrecision: Double = { - predictionAndLabels.map { case (pred, lab, _) => + rdd.map { case (pred, lab, _) => val labSet = lab.toSet val k = math.max(pred.length, labSet.size) averagePrecision(pred, labSet, k) @@ -99,7 +97,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") ( @Since("3.0.0") def meanAveragePrecisionAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") - predictionAndLabels.map { case (pred, lab, _) => + rdd.map { case (pred, lab, _) => averagePrecision(pred, lab.toSet, k) }.mean() } @@ -154,7 +152,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") ( @Since("1.2.0") def ndcgAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") - predictionAndLabels.map { case (pred, lab, rel) => + rdd.map { case (pred, lab, rel) => val useBinary = rel.isEmpty val labSet = lab.toSet val relMap = lab.zip(rel).toMap @@ -224,7 +222,7 @@ class RankingMetrics[T: ClassTag] @Since("3.4.0") ( @Since("3.0.0") def recallAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") - predictionAndLabels.map { case (pred, lab, _) => + rdd.map { case (pred, lab, _) => countRelevantItemRatio(pred, lab, k, lab.toSet.size) }.mean() } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org