Repository: spark
Updated Branches:
  refs/heads/master dcbb22943 -> 8d8641f12


[SPARK-21854] Added LogisticRegressionTrainingSummary for 
MultinomialLogisticRegression in Python API

## What changes were proposed in this pull request?

Added LogisticRegressionTrainingSummary for MultinomialLogisticRegression in 
Python API

## How was this patch tested?

Added unit test

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Author: Ming Jiang <mji...@fanatics.com>
Author: Ming Jiang <jmw...@gmail.com>
Author: jmwdpk <jmw...@gmail.com>

Closes #19185 from jmwdpk/SPARK-21854.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8d8641f1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8d8641f1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8d8641f1

Branch: refs/heads/master
Commit: 8d8641f12250b0a9d370ff9354407c27af7cfcf4
Parents: dcbb229
Author: Ming Jiang <mji...@fanatics.com>
Authored: Thu Sep 14 13:53:28 2017 +0800
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Thu Sep 14 13:53:28 2017 +0800

----------------------------------------------------------------------
 .../LogisticRegressionSuite.scala               |  12 ++
 python/pyspark/ml/classification.py             | 120 ++++++++++++++++++-
 python/pyspark/ml/tests.py                      |  55 ++++++++-
 3 files changed, 183 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8d8641f1/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index d43c7cd..14f5508 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -2416,6 +2416,18 @@ class LogisticRegressionSuite
       blorSummary.recallByThreshold.collect() === 
sameBlorSummary.recallByThreshold.collect())
     assert(
       blorSummary.precisionByThreshold.collect() === 
sameBlorSummary.precisionByThreshold.collect())
+    assert(blorSummary.labels === sameBlorSummary.labels)
+    assert(blorSummary.truePositiveRateByLabel === 
sameBlorSummary.truePositiveRateByLabel)
+    assert(blorSummary.falsePositiveRateByLabel === 
sameBlorSummary.falsePositiveRateByLabel)
+    assert(blorSummary.precisionByLabel === sameBlorSummary.precisionByLabel)
+    assert(blorSummary.recallByLabel === sameBlorSummary.recallByLabel)
+    assert(blorSummary.fMeasureByLabel === sameBlorSummary.fMeasureByLabel)
+    assert(blorSummary.accuracy === sameBlorSummary.accuracy)
+    assert(blorSummary.weightedTruePositiveRate === 
sameBlorSummary.weightedTruePositiveRate)
+    assert(blorSummary.weightedFalsePositiveRate === 
sameBlorSummary.weightedFalsePositiveRate)
+    assert(blorSummary.weightedRecall === sameBlorSummary.weightedRecall)
+    assert(blorSummary.weightedPrecision === sameBlorSummary.weightedPrecision)
+    assert(blorSummary.weightedFMeasure === sameBlorSummary.weightedFMeasure)
 
     lr.setFamily("multinomial")
     val mlorModel = lr.fit(smallMultinomialDataset)

http://git-wip-us.apache.org/repos/asf/spark/blob/8d8641f1/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index fbb9e7f..0caafa6 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -529,9 +529,11 @@ class LogisticRegressionModel(JavaModel, 
JavaClassificationModel, JavaMLWritable
         trained on the training set. An exception is thrown if 
`trainingSummary is None`.
         """
         if self.hasSummary:
-            java_blrt_summary = self._call_java("summary")
-            # Note: Once multiclass is added, update this to return correct 
summary
-            return BinaryLogisticRegressionTrainingSummary(java_blrt_summary)
+            java_lrt_summary = self._call_java("summary")
+            if self.numClasses <= 2:
+                return 
BinaryLogisticRegressionTrainingSummary(java_lrt_summary)
+            else:
+                return LogisticRegressionTrainingSummary(java_lrt_summary)
         else:
             raise RuntimeError("No training summary available for this %s" %
                                self.__class__.__name__)
@@ -587,6 +589,14 @@ class LogisticRegressionSummary(JavaWrapper):
         return self._call_java("probabilityCol")
 
     @property
+    @since("2.3.0")
+    def predictionCol(self):
+        """
+        Field in "predictions" which gives the prediction of each class.
+        """
+        return self._call_java("predictionCol")
+
+    @property
     @since("2.0.0")
     def labelCol(self):
         """
@@ -604,6 +614,110 @@ class LogisticRegressionSummary(JavaWrapper):
         """
         return self._call_java("featuresCol")
 
+    @property
+    @since("2.3.0")
+    def labels(self):
+        """
+        Returns the sequence of labels in ascending order. This order matches 
the order used
+        in metrics which are specified as arrays over labels, e.g., 
truePositiveRateByLabel.
+
+        Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, 
However, if the
+        training set is missing a label, then all of the arrays over labels
+        (e.g., from truePositiveRateByLabel) will be of length numClasses-1 
instead of the
+        expected numClasses.
+        """
+        return self._call_java("labels")
+
+    @property
+    @since("2.3.0")
+    def truePositiveRateByLabel(self):
+        """
+        Returns true positive rate for each label (category).
+        """
+        return self._call_java("truePositiveRateByLabel")
+
+    @property
+    @since("2.3.0")
+    def falsePositiveRateByLabel(self):
+        """
+        Returns false positive rate for each label (category).
+        """
+        return self._call_java("falsePositiveRateByLabel")
+
+    @property
+    @since("2.3.0")
+    def precisionByLabel(self):
+        """
+        Returns precision for each label (category).
+        """
+        return self._call_java("precisionByLabel")
+
+    @property
+    @since("2.3.0")
+    def recallByLabel(self):
+        """
+        Returns recall for each label (category).
+        """
+        return self._call_java("recallByLabel")
+
+    @since("2.3.0")
+    def fMeasureByLabel(self, beta=1.0):
+        """
+        Returns f-measure for each label (category).
+        """
+        return self._call_java("fMeasureByLabel", beta)
+
+    @property
+    @since("2.3.0")
+    def accuracy(self):
+        """
+        Returns accuracy.
+        (equals to the total number of correctly classified instances
+        out of the total number of instances.)
+        """
+        return self._call_java("accuracy")
+
+    @property
+    @since("2.3.0")
+    def weightedTruePositiveRate(self):
+        """
+        Returns weighted true positive rate.
+        (equals to precision, recall and f-measure)
+        """
+        return self._call_java("weightedTruePositiveRate")
+
+    @property
+    @since("2.3.0")
+    def weightedFalsePositiveRate(self):
+        """
+        Returns weighted false positive rate.
+        """
+        return self._call_java("weightedFalsePositiveRate")
+
+    @property
+    @since("2.3.0")
+    def weightedRecall(self):
+        """
+        Returns weighted averaged recall.
+        (equals to precision, recall and f-measure)
+        """
+        return self._call_java("weightedRecall")
+
+    @property
+    @since("2.3.0")
+    def weightedPrecision(self):
+        """
+        Returns weighted averaged precision.
+        """
+        return self._call_java("weightedPrecision")
+
+    @since("2.3.0")
+    def weightedFMeasure(self, beta=1.0):
+        """
+        Returns weighted averaged f-measure.
+        """
+        return self._call_java("weightedFMeasure", beta)
+
 
 @inherit_doc
 class LogisticRegressionTrainingSummary(LogisticRegressionSummary):

http://git-wip-us.apache.org/repos/asf/spark/blob/8d8641f1/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index c66cd76..8b8bcc7 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1451,7 +1451,7 @@ class TrainingSummaryTest(SparkSessionTestCase):
         sameSummary = model.evaluate(df)
         self.assertAlmostEqual(sameSummary.deviance, s.deviance)
 
-    def test_logistic_regression_summary(self):
+    def test_binary_logistic_regression_summary(self):
         df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
                                          (0.0, 2.0, Vectors.sparse(1, [], 
[]))],
                                         ["label", "weight", "features"])
@@ -1464,20 +1464,73 @@ class TrainingSummaryTest(SparkSessionTestCase):
         self.assertEqual(s.probabilityCol, "probability")
         self.assertEqual(s.labelCol, "label")
         self.assertEqual(s.featuresCol, "features")
+        self.assertEqual(s.predictionCol, "prediction")
         objHist = s.objectiveHistory
         self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], 
float))
         self.assertGreater(s.totalIterations, 0)
+        self.assertTrue(isinstance(s.labels, list))
+        self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
+        self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
+        self.assertTrue(isinstance(s.precisionByLabel, list))
+        self.assertTrue(isinstance(s.recallByLabel, list))
+        self.assertTrue(isinstance(s.fMeasureByLabel(), list))
+        self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
         self.assertTrue(isinstance(s.roc, DataFrame))
         self.assertAlmostEqual(s.areaUnderROC, 1.0, 2)
         self.assertTrue(isinstance(s.pr, DataFrame))
         self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame))
         self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
         self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
+        self.assertAlmostEqual(s.accuracy, 1.0, 2)
+        self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2)
+        self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2)
+        self.assertAlmostEqual(s.weightedRecall, 1.0, 2)
+        self.assertAlmostEqual(s.weightedPrecision, 1.0, 2)
+        self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2)
+        self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2)
         # test evaluation (with training dataset) produces a summary with same 
values
         # one check is enough to verify a summary is returned, Scala version 
runs full test
         sameSummary = model.evaluate(df)
         self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
 
+    def test_multiclass_logistic_regression_summary(self):
+        df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+                                         (0.0, 2.0, Vectors.sparse(1, [], [])),
+                                         (2.0, 2.0, Vectors.dense(2.0)),
+                                         (2.0, 2.0, Vectors.dense(1.9))],
+                                        ["label", "weight", "features"])
+        lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", 
fitIntercept=False)
+        model = lr.fit(df)
+        self.assertTrue(model.hasSummary)
+        s = model.summary
+        # test that api is callable and returns expected types
+        self.assertTrue(isinstance(s.predictions, DataFrame))
+        self.assertEqual(s.probabilityCol, "probability")
+        self.assertEqual(s.labelCol, "label")
+        self.assertEqual(s.featuresCol, "features")
+        self.assertEqual(s.predictionCol, "prediction")
+        objHist = s.objectiveHistory
+        self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], 
float))
+        self.assertGreater(s.totalIterations, 0)
+        self.assertTrue(isinstance(s.labels, list))
+        self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
+        self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
+        self.assertTrue(isinstance(s.precisionByLabel, list))
+        self.assertTrue(isinstance(s.recallByLabel, list))
+        self.assertTrue(isinstance(s.fMeasureByLabel(), list))
+        self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
+        self.assertAlmostEqual(s.accuracy, 0.75, 2)
+        self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2)
+        self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2)
+        self.assertAlmostEqual(s.weightedRecall, 0.75, 2)
+        self.assertAlmostEqual(s.weightedPrecision, 0.583, 2)
+        self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2)
+        self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2)
+        # test evaluation (with training dataset) produces a summary with same 
values
+        # one check is enough to verify a summary is returned, Scala version 
runs full test
+        sameSummary = model.evaluate(df)
+        self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)
+
     def test_gaussian_mixture_summary(self):
         data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), 
(Vectors.dense(10.0),),
                 (Vectors.sparse(1, [], []),)]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to