Github user WeichenXu123 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/17086#discussion_r183647533
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
 ---
    @@ -95,4 +95,95 @@ class MulticlassMetricsSuite extends SparkFunSuite with 
MLlibTestSparkContext {
           ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * 
f2measure2)) < delta)
         assert(metrics.labels.sameElements(labels))
       }
    +
    +  test("Multiclass evaluation metrics with weights") {
    +    /*
    +     * Confusion matrix for 3-class classification with total 9 instances 
with 2 weights:
    +     * |2 * w1|1 * w2         |1 * w1| true class0 (4 instances)
    +     * |1 * w2|2 * w1 + 1 * w2|0     | true class1 (4 instances)
    +     * |0     |0              |1 * w2| true class2 (1 instance)
    +     */
    +    val w1 = 2.2
    +    val w2 = 1.5
    +    val tw = 2.0 * w1 + 1.0 * w2 + 1.0 * w1 + 1.0 * w2 + 2.0 * w1 + 1.0 * 
w2 + 1.0 * w2
    +    val confusionMatrix = Matrices.dense(3, 3,
    +      Array(2 * w1, 1 * w2, 0, 1 * w2, 2 * w1 + 1 * w2, 0, 1 * w1, 0, 1 * 
w2))
    +    val labels = Array(0.0, 1.0, 2.0)
    +    val predictionAndLabelsWithWeights = sc.parallelize(
    +      Seq((0.0, 0.0, w1), (0.0, 1.0, w2), (0.0, 0.0, w1), (1.0, 0.0, w2),
    +        (1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2),
    +        (2.0, 0.0, w1)), 2)
    +    val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights)
    +    val delta = 0.0000001
    +    val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1)
    +    val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
    +    val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0)
    +    val fpRate0 = (1.0 * w2) / (tw - (2.0 * w1 + 1.0 * w2 + 1.0 * w1))
    +    val fpRate1 = (1.0 * w2) / (tw - (1.0 * w2 + 2.0 * w1 + 1.0 * w2))
    +    val fpRate2 = (1.0 * w1) / (tw - (1.0 * w2))
    +    val precision0 = (2.0 * w1) / (2 * w1 + 1 * w2)
    +    val precision1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * 
w2)
    +    val precision2 = (1.0 * w2) / (1 * w1 + 1 * w2)
    +    val recall0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1)
    +    val recall1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
    +    val recall2 = (1.0 * w2) / (1.0 * w2 + 0)
    +    val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
    +    val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
    +    val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
    +    val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * 
precision0 + recall0)
    +    val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * 
precision1 + recall1)
    +    val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * 
precision2 + recall2)
    +
    +    
assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray))
    --- End diff --
    
    don't `toArray`, use `assert(metrics.confusionMatrix ~== confusionMatrix 
relTol e)`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to