Github user gaborgsomogyi commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20362#discussion_r170072578
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala ---
    @@ -413,34 +411,36 @@ class ALSSuite
           .setSeed(0)
         val alpha = als.getAlpha
         val model = als.fit(training.toDF())
    -    val predictions = model.transform(test.toDF()).select("rating", 
"prediction").rdd.map {
    -      case Row(rating: Float, prediction: Float) =>
    -        (rating.toDouble, prediction.toDouble)
    +    testTransformerByGlobalCheckFunc[Rating[Int]](test.toDF(), model, 
"rating", "prediction") {
    +        case rows: Seq[Row] =>
    +          val predictions = rows.map(row => (row.getFloat(0).toDouble, 
row.getFloat(1).toDouble))
    +
    +          val rmse =
    +            if (implicitPrefs) {
    +              // TODO: Use a better (rank-based?) evaluation metric for 
implicit feedback.
    +              // We limit the ratings and the predictions to interval [0, 
1] and compute the
    +              // weighted RMSE with the confidence scores as weights.
    +              val (totalWeight, weightedSumSq) = predictions.map { case 
(rating, prediction) =>
    +                val confidence = 1.0 + alpha * math.abs(rating)
    +                val rating01 = math.max(math.min(rating, 1.0), 0.0)
    +                val prediction01 = math.max(math.min(prediction, 1.0), 0.0)
    +                val err = prediction01 - rating01
    +                (confidence, confidence * err * err)
    +              }.reduce[(Double, Double)] { case ((c0, e0), (c1, e1)) =>
    +                (c0 + c1, e0 + e1)
    +              }
    +              math.sqrt(weightedSumSq / totalWeight)
    +            } else {
    +              val errorSquares = predictions.map { case (rating, 
prediction) =>
    +                val err = rating - prediction
    +                err * err
    +              }
    +              val mse = errorSquares.sum / errorSquares.length
    +              math.sqrt(mse)
    +            }
    +          logInfo(s"Test RMSE is $rmse.")
    +          assert(rmse < targetRMSE)
         }
    -    val rmse =
    --- End diff --
    
    Mainly move but there was no mean function so implemented.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to