Github user imatiach-msft commented on a diff in the pull request: https://github.com/apache/spark/pull/17086#discussion_r231294624 --- Diff: mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala --- @@ -39,21 +46,28 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl private[mllib] def this(predictionAndLabels: DataFrame) = this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) - private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() - private lazy val labelCount: Long = labelCountByClass.values.sum - private lazy val tpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (label, if (label == prediction) 1 else 0) + private lazy val labelCountByClass: Map[Double, Double] = + predLabelsWeight.map { + case (prediction: Double, label: Double, weight: Double) => + (label, weight) + }.reduceByKey(_ + _).collect().toMap + private lazy val labelCount: Double = labelCountByClass.values.sum + private lazy val tpByClass: Map[Double, Double] = predLabelsWeight + .map { --- End diff -- it looks like this will actually cause tests to fail, because the key may become missing if we filter everything out first, whereas we would want it to be present otherwise but have a 0 value
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org