Github user imatiach-msft commented on a diff in the pull request: https://github.com/apache/spark/pull/17086#discussion_r231273689 --- Diff: mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala --- @@ -39,21 +46,28 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl private[mllib] def this(predictionAndLabels: DataFrame) = this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) - private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() - private lazy val labelCount: Long = labelCountByClass.values.sum - private lazy val tpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (label, if (label == prediction) 1 else 0) + private lazy val labelCountByClass: Map[Double, Double] = + predLabelsWeight.map { + case (prediction: Double, label: Double, weight: Double) => + (label, weight) + }.reduceByKey(_ + _).collect().toMap + private lazy val labelCount: Double = labelCountByClass.values.sum + private lazy val tpByClass: Map[Double, Double] = predLabelsWeight + .map { --- End diff -- done, not sure if more efficient, it seemed to take same time for me, but I didn't test extensively
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org