[ https://issues.apache.org/jira/browse/SPARK-14409?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=15824774#comment-15824774 ]
Roberto Mirizzi commented on SPARK-14409: ----------------------------------------- I implemented the RankingEvaluator to be used with ALS. Here's the code {code:java} package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Experimental import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{Params, Param, ParamMap, ParamValidators} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RankingMetrics import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{IntegerType, DoubleType, FloatType} /** * Created by Roberto Mirizzi on 12/5/16. */ /** * :: Experimental :: * Evaluator for ranking, which expects two input columns: prediction and label. */ @Experimental final class RankingEvaluator(override val uid: String) extends Evaluator with HasUserCol with HasItemCol with HasPredictionCol with HasLabelCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("rankEval")) /** * Param for metric name in evaluation. Supports: * - `"map"` (default): mean average precision * - `"p@k"`: precision@k (1 <= k <= 10) * - `"ndcg@k"`: normalized discounted cumulative gain@k (1 <= k <= 10) * * @group param */ val metricName: Param[String] = { val allowedParams = ParamValidators.inArray(Array("map", "p@1", "p@2", "p@3", "p@4", "p@5", "p@6", "p@7", "p@8", "p@9", "p@10", "ndcg@1", "ndcg@2", "ndcg@3", "ndcg@4", "ndcg@5", "ndcg@6", "ndcg@7", "ndcg@8", "ndcg@9", "ndcg@10")) new Param(this, "metricName", "metric name in evaluation (map|p@1|p@2|p@3|p@4|p@5|p@6|p@7|p@8|p@9|p@10|" + "ndcg@1|ndcg@2|ndcg@3|ndcg@4|ndcg@5|ndcg@6|ndcg@7|ndcg@8|ndcg@9|ndcg@10)", allowedParams) } val goodThreshold: Param[String] = { new Param(this, "goodThreshold", "threshold for good labels") } /** @group getParam */ def getMetricName: String = $(metricName) /** @group setParam */ def setMetricName(value: String): this.type = set(metricName, value) /** @group getParam */ def getGoodThreshold: Double = $(goodThreshold).toDouble /** @group setParam */ def setGoodThreshold(value: Double): this.type = set(goodThreshold, value.toString) /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) /** @group setParam */ def setItemCol(value: String): this.type = set(itemCol, value) /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) setDefault(metricName -> "map") setDefault(goodThreshold -> "0") override def evaluate(dataset: Dataset[_]): Double = { val spark = dataset.sparkSession import spark.implicits._ val schema = dataset.schema SchemaUtils.checkNumericType(schema, $(userCol)) SchemaUtils.checkNumericType(schema, $(itemCol)) SchemaUtils.checkColumnTypes(schema, $(labelCol), Seq(DoubleType, FloatType)) SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType)) val windowByUserRankByPrediction = Window.partitionBy(col($(userCol))).orderBy(col($(predictionCol)).desc) val windowByUserRankByRating = Window.partitionBy(col($(userCol))).orderBy(col($(labelCol)).desc) val predictionDataset = dataset.select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), col($(predictionCol)).cast(FloatType), row_number().over(windowByUserRankByPrediction).as("rank")) .where(s"rank <= 10") .groupBy(col($(userCol))) .agg(collect_list(col($(itemCol))).as("prediction_list")) .withColumnRenamed($(userCol), "predicted_userId") .as[(Int, Array[Int])] predictionDataset.show() // // alternative to the above query // dataset.createOrReplaceTempView("sortedRanking") // spark.sql("SELECT _1 AS predicted_userId, collect_list(_2) AS prediction_list FROM " + // "(SELECT *, row_number() OVER (PARTITION BY _1 ORDER BY _4 DESC) AS rank FROM sortedRanking) x " + // "WHERE rank <= 10 " + // "GROUP BY predicted_userId").as[(Int, Array[Int])] val actualDataset = dataset.select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), row_number().over(windowByUserRankByRating)) .where(col($(labelCol)).cast(DoubleType) > $(goodThreshold)) .groupBy(col($(userCol))) .agg(collect_list(col($(itemCol))).as("actual_list")) .withColumnRenamed($(userCol), "actual_userId") .as[(Int, Array[Int])] actualDataset.show() val predictionAndLabels = actualDataset .join(predictionDataset, actualDataset("actual_userId") === predictionDataset("predicted_userId")) .select("prediction_list", "actual_list") .as[(Array[Int], Array[Int])] .rdd val metrics = new RankingMetrics[Int](predictionAndLabels) val metric = $(metricName) match { case "map" => metrics.meanAveragePrecision case "p@1" => metrics.precisionAt(1) case "p@2" => metrics.precisionAt(2) case "p@3" => metrics.precisionAt(3) case "p@4" => metrics.precisionAt(4) case "p@5" => metrics.precisionAt(5) case "p@6" => metrics.precisionAt(6) case "p@7" => metrics.precisionAt(7) case "p@8" => metrics.precisionAt(8) case "p@9" => metrics.precisionAt(9) case "p@10" => metrics.precisionAt(10) case "ndcg@1" => metrics.ndcgAt(1) case "ndcg@2" => metrics.ndcgAt(2) case "ndcg@3" => metrics.ndcgAt(3) case "ndcg@4" => metrics.ndcgAt(4) case "ndcg@5" => metrics.ndcgAt(5) case "ndcg@6" => metrics.ndcgAt(6) case "ndcg@7" => metrics.ndcgAt(7) case "ndcg@8" => metrics.ndcgAt(8) case "ndcg@9" => metrics.ndcgAt(9) case "ndcg@10" => metrics.ndcgAt(10) } metric } override def isLargerBetter: Boolean = $(metricName) match { case "map" => true case "p@1" => true case "p@2" => true case "p@3" => true case "p@4" => true case "p@5" => true case "p@6" => true case "p@7" => true case "p@8" => true case "p@9" => true case "p@10" => true case "ndcg@1" => true case "ndcg@2" => true case "ndcg@3" => true case "ndcg@4" => true case "ndcg@5" => true case "ndcg@6" => true case "ndcg@7" => true case "ndcg@8" => true case "ndcg@9" => true case "ndcg@10" => true } override def copy(extra: ParamMap): RankingEvaluator = defaultCopy(extra) } object RankingEvaluator extends DefaultParamsReadable[RankingEvaluator] { override def load(path: String): RankingEvaluator = super.load(path) } /** * Trait for shared param userCol (default: "user"). */ private[evaluator] trait HasUserCol extends Params { /** * Param for label column name. * * @group param */ final val userCol: Param[String] = new Param[String](this, "userCol", "user column name") setDefault(userCol, "user") /** @group getParam */ final def getUserCol: String = $(userCol) } /** * Trait for shared param itemCol (default: "item"). */ private[evaluator] trait HasItemCol extends Params { /** * Param for label column name. * * @group param */ final val itemCol: Param[String] = new Param[String](this, "itemCol", "item column name") setDefault(itemCol, "item") /** @group getParam */ final def getItemCol: String = $(itemCol) } /** * Trait for shared param labelCol (default: "label"). */ private[evaluator] trait HasLabelCol extends Params { /** * Param for prediction column name. * * @group param */ final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name") setDefault(labelCol, "label") /** @group getParam */ final def getLabelCol: String = $(labelCol) } /** * Trait for shared param predictionCol (default: "prediction"). */ private[evaluator] trait HasPredictionCol extends Params { /** * Param for prediction column name. * * @group param */ final val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name") setDefault(predictionCol, "prediction") /** @group getParam */ final def getPredictionCol: String = $(predictionCol) } {code} > Investigate adding a RankingEvaluator to ML > ------------------------------------------- > > Key: SPARK-14409 > URL: https://issues.apache.org/jira/browse/SPARK-14409 > Project: Spark > Issue Type: New Feature > Components: ML > Reporter: Nick Pentreath > Priority: Minor > > {{mllib.evaluation}} contains a {{RankingMetrics}} class, while there is no > {{RankingEvaluator}} in {{ml.evaluation}}. Such an evaluator can be useful > for recommendation evaluation (and can be useful in other settings > potentially). > Should be thought about in conjunction with adding the "recommendAll" methods > in SPARK-13857, so that top-k ranking metrics can be used in cross-validators. -- This message was sent by Atlassian JIRA (v6.3.4#6332) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org