Repository: spark Updated Branches: refs/heads/master dcbb22943 -> 8d8641f12
[SPARK-21854] Added LogisticRegressionTrainingSummary for MultinomialLogisticRegression in Python API ## What changes were proposed in this pull request? Added LogisticRegressionTrainingSummary for MultinomialLogisticRegression in Python API ## How was this patch tested? Added unit test Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ming Jiang <mji...@fanatics.com> Author: Ming Jiang <jmw...@gmail.com> Author: jmwdpk <jmw...@gmail.com> Closes #19185 from jmwdpk/SPARK-21854. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8d8641f1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8d8641f1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8d8641f1 Branch: refs/heads/master Commit: 8d8641f12250b0a9d370ff9354407c27af7cfcf4 Parents: dcbb229 Author: Ming Jiang <mji...@fanatics.com> Authored: Thu Sep 14 13:53:28 2017 +0800 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Thu Sep 14 13:53:28 2017 +0800 ---------------------------------------------------------------------- .../LogisticRegressionSuite.scala | 12 ++ python/pyspark/ml/classification.py | 120 ++++++++++++++++++- python/pyspark/ml/tests.py | 55 ++++++++- 3 files changed, 183 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8d8641f1/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index d43c7cd..14f5508 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2416,6 +2416,18 @@ class LogisticRegressionSuite blorSummary.recallByThreshold.collect() === sameBlorSummary.recallByThreshold.collect()) assert( blorSummary.precisionByThreshold.collect() === sameBlorSummary.precisionByThreshold.collect()) + assert(blorSummary.labels === sameBlorSummary.labels) + assert(blorSummary.truePositiveRateByLabel === sameBlorSummary.truePositiveRateByLabel) + assert(blorSummary.falsePositiveRateByLabel === sameBlorSummary.falsePositiveRateByLabel) + assert(blorSummary.precisionByLabel === sameBlorSummary.precisionByLabel) + assert(blorSummary.recallByLabel === sameBlorSummary.recallByLabel) + assert(blorSummary.fMeasureByLabel === sameBlorSummary.fMeasureByLabel) + assert(blorSummary.accuracy === sameBlorSummary.accuracy) + assert(blorSummary.weightedTruePositiveRate === sameBlorSummary.weightedTruePositiveRate) + assert(blorSummary.weightedFalsePositiveRate === sameBlorSummary.weightedFalsePositiveRate) + assert(blorSummary.weightedRecall === sameBlorSummary.weightedRecall) + assert(blorSummary.weightedPrecision === sameBlorSummary.weightedPrecision) + assert(blorSummary.weightedFMeasure === sameBlorSummary.weightedFMeasure) lr.setFamily("multinomial") val mlorModel = lr.fit(smallMultinomialDataset) http://git-wip-us.apache.org/repos/asf/spark/blob/8d8641f1/python/pyspark/ml/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index fbb9e7f..0caafa6 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -529,9 +529,11 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable trained on the training set. An exception is thrown if `trainingSummary is None`. """ if self.hasSummary: - java_blrt_summary = self._call_java("summary") - # Note: Once multiclass is added, update this to return correct summary - return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + java_lrt_summary = self._call_java("summary") + if self.numClasses <= 2: + return BinaryLogisticRegressionTrainingSummary(java_lrt_summary) + else: + return LogisticRegressionTrainingSummary(java_lrt_summary) else: raise RuntimeError("No training summary available for this %s" % self.__class__.__name__) @@ -587,6 +589,14 @@ class LogisticRegressionSummary(JavaWrapper): return self._call_java("probabilityCol") @property + @since("2.3.0") + def predictionCol(self): + """ + Field in "predictions" which gives the prediction of each class. + """ + return self._call_java("predictionCol") + + @property @since("2.0.0") def labelCol(self): """ @@ -604,6 +614,110 @@ class LogisticRegressionSummary(JavaWrapper): """ return self._call_java("featuresCol") + @property + @since("2.3.0") + def labels(self): + """ + Returns the sequence of labels in ascending order. This order matches the order used + in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel. + + Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the + training set is missing a label, then all of the arrays over labels + (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the + expected numClasses. + """ + return self._call_java("labels") + + @property + @since("2.3.0") + def truePositiveRateByLabel(self): + """ + Returns true positive rate for each label (category). + """ + return self._call_java("truePositiveRateByLabel") + + @property + @since("2.3.0") + def falsePositiveRateByLabel(self): + """ + Returns false positive rate for each label (category). + """ + return self._call_java("falsePositiveRateByLabel") + + @property + @since("2.3.0") + def precisionByLabel(self): + """ + Returns precision for each label (category). + """ + return self._call_java("precisionByLabel") + + @property + @since("2.3.0") + def recallByLabel(self): + """ + Returns recall for each label (category). + """ + return self._call_java("recallByLabel") + + @since("2.3.0") + def fMeasureByLabel(self, beta=1.0): + """ + Returns f-measure for each label (category). + """ + return self._call_java("fMeasureByLabel", beta) + + @property + @since("2.3.0") + def accuracy(self): + """ + Returns accuracy. + (equals to the total number of correctly classified instances + out of the total number of instances.) + """ + return self._call_java("accuracy") + + @property + @since("2.3.0") + def weightedTruePositiveRate(self): + """ + Returns weighted true positive rate. + (equals to precision, recall and f-measure) + """ + return self._call_java("weightedTruePositiveRate") + + @property + @since("2.3.0") + def weightedFalsePositiveRate(self): + """ + Returns weighted false positive rate. + """ + return self._call_java("weightedFalsePositiveRate") + + @property + @since("2.3.0") + def weightedRecall(self): + """ + Returns weighted averaged recall. + (equals to precision, recall and f-measure) + """ + return self._call_java("weightedRecall") + + @property + @since("2.3.0") + def weightedPrecision(self): + """ + Returns weighted averaged precision. + """ + return self._call_java("weightedPrecision") + + @since("2.3.0") + def weightedFMeasure(self, beta=1.0): + """ + Returns weighted averaged f-measure. + """ + return self._call_java("weightedFMeasure", beta) + @inherit_doc class LogisticRegressionTrainingSummary(LogisticRegressionSummary): http://git-wip-us.apache.org/repos/asf/spark/blob/8d8641f1/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c66cd76..8b8bcc7 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1451,7 +1451,7 @@ class TrainingSummaryTest(SparkSessionTestCase): sameSummary = model.evaluate(df) self.assertAlmostEqual(sameSummary.deviance, s.deviance) - def test_logistic_regression_summary(self): + def test_binary_logistic_regression_summary(self): df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) @@ -1464,20 +1464,73 @@ class TrainingSummaryTest(SparkSessionTestCase): self.assertEqual(s.probabilityCol, "probability") self.assertEqual(s.labelCol, "label") self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") objHist = s.objectiveHistory self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) self.assertTrue(isinstance(s.roc, DataFrame)) self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) self.assertTrue(isinstance(s.pr, DataFrame)) self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) + self.assertAlmostEqual(s.accuracy, 1.0, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) + self.assertAlmostEqual(s.weightedRecall, 1.0, 2) + self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + def test_multiclass_logistic_regression_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], [])), + (2.0, 2.0, Vectors.dense(2.0)), + (2.0, 2.0, Vectors.dense(1.9))], + ["label", "weight", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertAlmostEqual(s.accuracy, 0.75, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) + self.assertAlmostEqual(s.weightedRecall, 0.75, 2) + self.assertAlmostEqual(s.weightedPrecision, 0.583, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) + def test_gaussian_mixture_summary(self): data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), (Vectors.sparse(1, [], []),)] --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org