Repository: spark Updated Branches: refs/heads/master 5b0596648 -> 524827f06
[SPARK-14712][ML] LogisticRegressionModel.toString should summarize model ## What changes were proposed in this pull request? [SPARK-14712](https://issues.apache.org/jira/browse/SPARK-14712) spark.mllib LogisticRegressionModel overrides toString to print a little model info. We should do the same in spark.ml and override repr in pyspark. ## How was this patch tested? LogisticRegressionSuite.scala Python doctest in pyspark.ml.classification.py Author: bravo-zhang <mzhang1...@gmail.com> Closes #18826 from bravo-zhang/spark-14712. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/524827f0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/524827f0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/524827f0 Branch: refs/heads/master Commit: 524827f0626281847582ec3056982db7eb83f8b1 Parents: 5b05966 Author: bravo-zhang <mzhang1...@gmail.com> Authored: Thu Jun 28 12:40:39 2018 -0700 Committer: Holden Karau <hol...@pigscanfly.ca> Committed: Thu Jun 28 12:40:39 2018 -0700 ---------------------------------------------------------------------- .../apache/spark/ml/classification/LogisticRegression.scala | 5 +++++ .../spark/ml/classification/LogisticRegressionSuite.scala | 6 ++++++ python/pyspark/ml/classification.py | 5 +++++ python/pyspark/mllib/classification.py | 3 +++ 4 files changed, 19 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/524827f0/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 06ca37b..92e342e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1202,6 +1202,11 @@ class LogisticRegressionModel private[spark] ( */ @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + + override def toString: String = { + s"LogisticRegressionModel: " + + s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures" + } } http://git-wip-us.apache.org/repos/asf/spark/blob/524827f0/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 36b7e51..75c2aeb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2751,6 +2751,12 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model.getFamily === family) } } + + test("toString") { + val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0) + val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures = 3" + assert(model.toString === expected) + } } object LogisticRegressionSuite { http://git-wip-us.apache.org/repos/asf/spark/blob/524827f0/python/pyspark/ml/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 1754c48..d5963f4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -239,6 +239,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti True >>> blorModel.intercept == model2.intercept True + >>> model2 + LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2 .. versionadded:: 1.3.0 """ @@ -562,6 +564,9 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable java_blr_summary = self._call_java("evaluate", dataset) return BinaryLogisticRegressionSummary(java_blr_summary) + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionSummary(JavaWrapper): """ http://git-wip-us.apache.org/repos/asf/spark/blob/524827f0/python/pyspark/mllib/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index bb28198..e00ed95 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -258,6 +258,9 @@ class LogisticRegressionModel(LinearClassificationModel): model.setThreshold(threshold) return model + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionWithSGD(object): """ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org