Repository: spark
Updated Branches:
  refs/heads/master 8f0e88df0 -> 7f99a05e6


[SPARK-22422][ML] Add Adjusted R2 to RegressionMetrics

## What changes were proposed in this pull request?

I added adjusted R2 as a regression metric which was implemented in all major 
statistical analysis tools.

In practice, no one looks at R2 alone. The reason is R2 itself is misleading. 
If we add more parameters, R2 will not decrease but only increase (or stay the 
same). This leads to overfitting. Adjusted R2 addressed this issue by using 
number of parameters as "weight" for the sum of errors.

## How was this patch tested?

- Added a new unit test and passed.
- ./dev/run-tests all passed.

Author: test <joseph.p...@quetica.com>
Author: tengpeng <tengp...@users.noreply.github.com>

Closes #19638 from tengpeng/master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7f99a05e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7f99a05e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7f99a05e

Branch: refs/heads/master
Commit: 7f99a05e6ff258fc2192130451aa8aa1304bfe93
Parents: 8f0e88d
Author: test <joseph.p...@quetica.com>
Authored: Wed Nov 15 10:13:01 2017 -0600
Committer: Sean Owen <so...@cloudera.com>
Committed: Wed Nov 15 10:13:01 2017 -0600

----------------------------------------------------------------------
 .../spark/ml/regression/LinearRegression.scala       | 15 +++++++++++++++
 .../spark/ml/regression/LinearRegressionSuite.scala  |  6 ++++++
 2 files changed, 21 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7f99a05e/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index df1aa60..da6bcf0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -722,6 +722,21 @@ class LinearRegressionSummary private[regression] (
   @Since("1.5.0")
   val r2: Double = metrics.r2
 
+  /**
+   * Returns Adjusted R^2^, the adjusted coefficient of determination.
+   * Reference: <a 
href="https://en.wikipedia.org/wiki/Coefficient_of_determination#Adjusted_R2";>
+   * Wikipedia coefficient of determination</a>
+   *
+   * @note This ignores instance weights (setting all to 1.0) from 
`LinearRegression.weightCol`.
+   * This will change in later Spark versions.
+   */
+  @Since("2.3.0")
+  val r2adj: Double = {
+    val interceptDOF = if (privateModel.getFitIntercept) 1 else 0
+    1 - (1 - r2) * (numInstances - interceptDOF) /
+      (numInstances - privateModel.coefficients.size - interceptDOF)
+  }
+
   /** Residuals (label - predicted value) */
   @Since("1.5.0")
   @transient lazy val residuals: DataFrame = {

http://git-wip-us.apache.org/repos/asf/spark/blob/7f99a05e/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index f470dca..0e0be58 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -764,6 +764,11 @@ class LinearRegressionSuite
           (Intercept) 6.3022157  0.0018600    3388   <2e-16 ***
           V2          4.6982442  0.0011805    3980   <2e-16 ***
           V3          7.1994344  0.0009044    7961   <2e-16 ***
+
+          # R code for r2adj
+          lm_fit <- lm(V1 ~ V2 + V3, data = d1)
+          summary(lm_fit)$adj.r.squared
+          [1] 0.9998736
           ---
 
           ....
@@ -771,6 +776,7 @@ class LinearRegressionSuite
       assert(model.summary.meanSquaredError ~== 0.00985449 relTol 1E-4)
       assert(model.summary.meanAbsoluteError ~== 0.07961668 relTol 1E-4)
       assert(model.summary.r2 ~== 0.9998737 relTol 1E-4)
+      assert(model.summary.r2adj ~== 0.9998736  relTol 1E-4)
 
       // Normal solver uses "WeightedLeastSquares". If no regularization is 
applied or only L2
       // regularization is applied, this algorithm uses a direct solver and 
does not generate an


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to