Github user yanboliang commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19185#discussion_r138048004
  
    --- Diff: python/pyspark/ml/tests.py ---
    @@ -1478,6 +1478,40 @@ def test_logistic_regression_summary(self):
             sameSummary = model.evaluate(df)
             self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
     
    +    def test_multiclass_logistic_regression_summary(self):
    +        df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
    +                                         (0.0, 2.0, Vectors.sparse(1, [], 
[])),
    +                                         (2.0, 2.0, Vectors.dense(2.0)),
    +                                         (2.0, 2.0, Vectors.dense(1.9))],
    +                                        ["label", "weight", "features"])
    +        lr = LogisticRegression(maxIter=5, regParam=0.01, 
weightCol="weight", fitIntercept=False)
    +        model = lr.fit(df)
    +        self.assertTrue(model.hasSummary)
    +        s = model.summary
    +        # test that api is callable and returns expected types
    +        self.assertTrue(isinstance(s.predictions, DataFrame))
    +        self.assertEqual(s.probabilityCol, "probability")
    +        self.assertEqual(s.labelCol, "label")
    +        self.assertEqual(s.featuresCol, "features")
    +        self.assertEqual(s.predictionCol, "prediction")
    +        objHist = s.objectiveHistory
    +        self.assertTrue(isinstance(objHist, list) and 
isinstance(objHist[0], float))
    +        self.assertGreater(s.totalIterations, 0)
    +        self.assertTrue(isinstance(s.labels, list))
    +        self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
    +        self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
    +        self.assertTrue(isinstance(s.precisionByLabel, list))
    +        self.assertTrue(isinstance(s.recallByLabel, list))
    +        self.assertTrue(isinstance(s.fMeasureByLabel, list))
    +        self.assertAlmostEqual(s.accuracy, 0.75, 2)
    +        self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2)
    +        self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2)
    +        self.assertAlmostEqual(s.weightedRecall, 0.75, 2)
    +        self.assertAlmostEqual(s.weightedPrecision, 0.583, 2)
    +        self.assertAlmostEqual(s.weightedFMeasure, 0.65, 2)
    --- End diff --
    
    We need to add these check for the above 
```test_logistic_regression_summary``` and rename it to 
```test_binary_logistic_regression_summary```, since binary logistic regression 
summary has these variables as well.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to