Repository: spark
Updated Branches:
  refs/heads/branch-2.0 b1a9c41e8 -> 594a2cf6f


[SPARK-17792][ML] L-BFGS solver for linear regression does not accept general 
numeric label column types

## What changes were proposed in this pull request?

Before, we computed `instances` in LinearRegression in two spots, even though 
they did the same thing. One of them did not cast the label column to 
`DoubleType`. This patch consolidates the computation and always casts the 
label column to `DoubleType`.

## How was this patch tested?

Added a unit test to check all solvers. This test failed before this patch.

Author: sethah <seth.hendrickso...@gmail.com>

Closes #15364 from sethah/linreg_numeric_type.

(cherry picked from commit 3713bb199142c5e06e2e527c99650f02f41f47b1)
Signed-off-by: Yanbo Liang <yblia...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/594a2cf6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/594a2cf6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/594a2cf6

Branch: refs/heads/branch-2.0
Commit: 594a2cf6f7c74c54127b8c3947aadbe0052b404c
Parents: b1a9c41
Author: sethah <seth.hendrickso...@gmail.com>
Authored: Thu Oct 6 21:10:17 2016 -0700
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Thu Oct 6 21:14:44 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/regression/LinearRegression.scala     | 17 ++++++-----------
 .../ml/regression/LinearRegressionSuite.scala      |  8 +++++---
 2 files changed, 11 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/594a2cf6/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index f82f2c3..600bbcb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -163,17 +163,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
     val numFeatures = 
dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
     val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else 
col($(weightCol))
 
+    val instances: RDD[Instance] = dataset.select(
+      col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+      case Row(label: Double, weight: Double, features: Vector) =>
+        Instance(label, weight, features)
+    }
+
     if (($(solver) == "auto" && $(elasticNetParam) == 0.0 &&
       numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == 
"normal") {
       require($(elasticNetParam) == 0.0, "Only L2 regularization can be used 
when normal " +
         "solver is used.'")
       // For low dimensional data, WeightedLeastSquares is more efficiently 
since the
       // training algorithm only requires one pass through the data. 
(SPARK-10668)
-      val instances: RDD[Instance] = dataset.select(
-        col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
-          case Row(label: Double, weight: Double, features: Vector) =>
-            Instance(label, weight, features)
-      }
 
       val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
         $(standardization), true)
@@ -196,12 +197,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
       return lrModel.setSummary(trainingSummary)
     }
 
-    val instances: RDD[Instance] =
-      dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
-        case Row(label: Double, weight: Double, features: Vector) =>
-          Instance(label, weight, features)
-      }
-
     val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
     if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/594a2cf6/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 265f2f4..df67a3a 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -1019,12 +1019,14 @@ class LinearRegressionSuite
   }
 
   test("should support all NumericType labels and not support other types") {
-    val lr = new LinearRegression().setMaxIter(1)
-    MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
-      lr, spark, isClassification = false) { (expected, actual) =>
+    for (solver <- Seq("auto", "l-bfgs", "normal")) {
+      val lr = new LinearRegression().setMaxIter(1).setSolver(solver)
+      MLTestingUtils.checkNumericTypes[LinearRegressionModel, 
LinearRegression](
+        lr, spark, isClassification = false) { (expected, actual) =>
         assert(expected.intercept === actual.intercept)
         assert(expected.coefficients === actual.coefficients)
       }
+    }
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to