Repository: spark
Updated Branches:
  refs/heads/master 2a29a60da -> 1e44dd004


[SPARK-3181][ML] Implement huber loss for LinearRegression.

## What changes were proposed in this pull request?
MLlib ```LinearRegression``` supports _huber_ loss addition to _leastSquares_ 
loss. The huber loss objective function is:
![image](https://user-images.githubusercontent.com/1962026/29554124-9544d198-8750-11e7-8afa-33579ec419d5.png)
Refer Eq.(6) and Eq.(8) in [A robust hybrid of lasso and ridge 
regression](http://statweb.stanford.edu/~owen/reports/hhu.pdf). This objective 
is jointly convex as a function of (w, σ) ∈ R × (0,∞), we can use 
L-BFGS-B to solve it.

The current implementation is a straight forward porting for Python 
scikit-learn 
[```HuberRegressor```](http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.HuberRegressor.html).
 There are some differences:
* We use mean loss (```lossSum/weightSum```), but sklearn uses total loss 
(```lossSum```).
* We multiply the loss function and L2 regularization by 1/2. It does not 
affect the result if we multiply the whole formula by a factor, we just keep 
consistent with _leastSquares_ loss.

So if fitting w/o regularization, MLlib and sklearn produce the same output. If 
fitting w/ regularization, MLlib should set ```regParam``` divide by the number 
of instances to match the output of sklearn.

## How was this patch tested?
Unit tests.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #19020 from yanboliang/spark-3181.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1e44dd00
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1e44dd00
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1e44dd00

Branch: refs/heads/master
Commit: 1e44dd004425040912f2cf16362d2c13f12e1689
Parents: 2a29a60
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Wed Dec 13 21:19:14 2017 -0800
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Wed Dec 13 21:19:14 2017 -0800

----------------------------------------------------------------------
 .../ml/optim/aggregator/HuberAggregator.scala   | 150 ++++++++++
 .../ml/param/shared/SharedParamsCodeGen.scala   |   3 +-
 .../spark/ml/param/shared/sharedParams.scala    |  17 ++
 .../spark/ml/regression/LinearRegression.scala  | 299 +++++++++++++++----
 .../optim/aggregator/HuberAggregatorSuite.scala | 170 +++++++++++
 .../ml/regression/LinearRegressionSuite.scala   | 244 ++++++++++++++-
 project/MimaExcludes.scala                      |   5 +
 7 files changed, 823 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1e44dd00/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HuberAggregator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HuberAggregator.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HuberAggregator.scala
new file mode 100644
index 0000000..13f64d2
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HuberAggregator.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.optim.aggregator
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.linalg.Vector
+
+/**
+ * HuberAggregator computes the gradient and loss for a huber loss function,
+ * as used in robust regression for samples in sparse or dense vector in an 
online fashion.
+ *
+ * The huber loss function based on:
+ * <a href="http://statweb.stanford.edu/~owen/reports/hhu.pdf";>Art B. Owen 
(2006),
+ * A robust hybrid of lasso and ridge regression</a>.
+ *
+ * Two HuberAggregator can be merged together to have a summary of loss and 
gradient of
+ * the corresponding joint dataset.
+ *
+ * The huber loss function is given by
+ *
+ * <blockquote>
+ *   $$
+ *   \begin{align}
+ *   \min_{w, \sigma}\frac{1}{2n}{\sum_{i=1}^n\left(\sigma +
+ *   H_m\left(\frac{X_{i}w - y_{i}}{\sigma}\right)\sigma\right) + 
\frac{1}{2}\lambda {||w||_2}^2}
+ *   \end{align}
+ *   $$
+ * </blockquote>
+ *
+ * where
+ *
+ * <blockquote>
+ *   $$
+ *   \begin{align}
+ *   H_m(z) = \begin{cases}
+ *            z^2, & \text {if } |z| &lt; \epsilon, \\
+ *            2\epsilon|z| - \epsilon^2, & \text{otherwise}
+ *            \end{cases}
+ *   \end{align}
+ *   $$
+ * </blockquote>
+ *
+ * It is advised to set the parameter $\epsilon$ to 1.35 to achieve 95% 
statistical efficiency
+ * for normally distributed data. Please refer to chapter 2 of
+ * <a href="http://statweb.stanford.edu/~owen/reports/hhu.pdf";>
+ * A robust hybrid of lasso and ridge regression</a> for more detail.
+ *
+ * @param fitIntercept Whether to fit an intercept term.
+ * @param epsilon The shape parameter to control the amount of robustness.
+ * @param bcFeaturesStd The broadcast standard deviation values of the 
features.
+ * @param bcParameters including three parts: the regression coefficients 
corresponding
+ *                     to the features, the intercept (if fitIntercept is ture)
+ *                     and the scale parameter (sigma).
+ */
+private[ml] class HuberAggregator(
+    fitIntercept: Boolean,
+    epsilon: Double,
+    bcFeaturesStd: Broadcast[Array[Double]])(bcParameters: Broadcast[Vector])
+  extends DifferentiableLossAggregator[Instance, HuberAggregator] {
+
+  protected override val dim: Int = bcParameters.value.size
+  private val numFeatures: Int = if (fitIntercept) dim - 2 else dim - 1
+  private val sigma: Double = bcParameters.value(dim - 1)
+  private val intercept: Double = if (fitIntercept) {
+    bcParameters.value(dim - 2)
+  } else {
+    0.0
+  }
+
+  /**
+   * Add a new training instance to this HuberAggregator, and update the loss 
and gradient
+   * of the objective function.
+   *
+   * @param instance The instance of data point to be added.
+   * @return This HuberAggregator object.
+   */
+  def add(instance: Instance): HuberAggregator = {
+    instance match { case Instance(label, weight, features) =>
+      require(numFeatures == features.size, s"Dimensions mismatch when adding 
new sample." +
+        s" Expecting $numFeatures but got ${features.size}.")
+      require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
+
+      if (weight == 0.0) return this
+      val localFeaturesStd = bcFeaturesStd.value
+      val localCoefficients = bcParameters.value.toArray.slice(0, numFeatures)
+      val localGradientSumArray = gradientSumArray
+
+      val margin = {
+        var sum = 0.0
+        features.foreachActive { (index, value) =>
+          if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+            sum += localCoefficients(index) * (value / localFeaturesStd(index))
+          }
+        }
+        if (fitIntercept) sum += intercept
+        sum
+      }
+      val linearLoss = label - margin
+
+      if (math.abs(linearLoss) <= sigma * epsilon) {
+        lossSum += 0.5 * weight * (sigma + math.pow(linearLoss, 2.0) / sigma)
+        val linearLossDivSigma = linearLoss / sigma
+
+        features.foreachActive { (index, value) =>
+          if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+            localGradientSumArray(index) +=
+              -1.0 * weight * linearLossDivSigma * (value / 
localFeaturesStd(index))
+          }
+        }
+        if (fitIntercept) {
+          localGradientSumArray(dim - 2) += -1.0 * weight * linearLossDivSigma
+        }
+        localGradientSumArray(dim - 1) += 0.5 * weight * (1.0 - 
math.pow(linearLossDivSigma, 2.0))
+      } else {
+        val sign = if (linearLoss >= 0) -1.0 else 1.0
+        lossSum += 0.5 * weight *
+          (sigma + 2.0 * epsilon * math.abs(linearLoss) - sigma * epsilon * 
epsilon)
+
+        features.foreachActive { (index, value) =>
+          if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+            localGradientSumArray(index) +=
+              weight * sign * epsilon * (value / localFeaturesStd(index))
+          }
+        }
+        if (fitIntercept) {
+          localGradientSumArray(dim - 2) += weight * sign * epsilon
+        }
+        localGradientSumArray(dim - 1) += 0.5 * weight * (1.0 - epsilon * 
epsilon)
+      }
+
+      weightSum += weight
+      this
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1e44dd00/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index a267bbc..a5d57a1 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -88,7 +88,8 @@ private[shared] object SharedParamsCodeGen {
         "during tuning. If set to false, then only the single best sub-model 
will be available " +
         "after fitting. If set to true, then all sub-models will be available. 
Warning: For " +
         "large models, collecting all sub-models can cause OOMs on the Spark 
driver",
-        Some("false"), isExpertParam = true)
+        Some("false"), isExpertParam = true),
+      ParamDesc[String]("loss", "the loss function to be optimized", 
finalFields = false)
     )
 
     val code = genSharedParams(params)

http://git-wip-us.apache.org/repos/asf/spark/blob/1e44dd00/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 0004f08..13425da 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -487,4 +487,21 @@ trait HasCollectSubModels extends Params {
   /** @group expertGetParam */
   final def getCollectSubModels: Boolean = $(collectSubModels)
 }
+
+/**
+ * Trait for shared param loss. This trait may be changed or
+ * removed between minor versions.
+ */
+@DeveloperApi
+trait HasLoss extends Params {
+
+  /**
+   * Param for the loss function to be optimized.
+   * @group param
+   */
+  val loss: Param[String] = new Param[String](this, "loss", "the loss function 
to be optimized")
+
+  /** @group getParam */
+  final def getLoss: String = $(loss)
+}
 // scalastyle:on

http://git-wip-us.apache.org/repos/asf/spark/blob/1e44dd00/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index da6bcf0..a5873d0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.regression
 import scala.collection.mutable
 
 import breeze.linalg.{DenseVector => BDV}
-import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, OWLQN => 
BreezeOWLQN}
+import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, LBFGSB => 
BreezeLBFGSB, OWLQN => BreezeOWLQN}
 import breeze.stats.distributions.StudentsT
 import org.apache.hadoop.fs.Path
 
@@ -32,9 +32,9 @@ import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.linalg.BLAS._
 import org.apache.spark.ml.optim.WeightedLeastSquares
-import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator
+import org.apache.spark.ml.optim.aggregator.{HuberAggregator, 
LeastSquaresAggregator}
 import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
-import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, 
ParamValidators}
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.evaluation.RegressionMetrics
@@ -44,8 +44,9 @@ import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.VersionUtils.majorMinorVersion
 
 /**
  * Params for linear regression.
@@ -53,7 +54,7 @@ import org.apache.spark.storage.StorageLevel
 private[regression] trait LinearRegressionParams extends PredictorParams
     with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
     with HasFitIntercept with HasStandardization with HasWeightCol with 
HasSolver
-    with HasAggregationDepth {
+    with HasAggregationDepth with HasLoss {
 
   import LinearRegression._
 
@@ -69,25 +70,105 @@ private[regression] trait LinearRegressionParams extends 
PredictorParams
     "The solver algorithm for optimization. Supported options: " +
       s"${supportedSolvers.mkString(", ")}. (Default auto)",
     ParamValidators.inArray[String](supportedSolvers))
+
+  /**
+   * The loss function to be optimized.
+   * Supported options: "squaredError" and "huber".
+   * Default: "squaredError"
+   *
+   * @group param
+   */
+  @Since("2.3.0")
+  final override val loss: Param[String] = new Param[String](this, "loss", 
"The loss function to" +
+    s" be optimized. Supported options: ${supportedLosses.mkString(", ")}. 
(Default squaredError)",
+    ParamValidators.inArray[String](supportedLosses))
+
+  /**
+   * The shape parameter to control the amount of robustness. Must be &gt; 1.0.
+   * At larger values of epsilon, the huber criterion becomes more similar to 
least squares
+   * regression; for small values of epsilon, the criterion is more similar to 
L1 regression.
+   * Default is 1.35 to get as much robustness as possible while retaining
+   * 95% statistical efficiency for normally distributed data. It matches 
sklearn
+   * HuberRegressor and is "M" from <a 
href="http://statweb.stanford.edu/~owen/reports/hhu.pdf";>
+   * A robust hybrid of lasso and ridge regression</a>.
+   * Only valid when "loss" is "huber".
+   *
+   * @group expertParam
+   */
+  @Since("2.3.0")
+  final val epsilon = new DoubleParam(this, "epsilon", "The shape parameter to 
control the " +
+    "amount of robustness. Must be > 1.0.", ParamValidators.gt(1.0))
+
+  /** @group getExpertParam */
+  @Since("2.3.0")
+  def getEpsilon: Double = $(epsilon)
+
+  override protected def validateAndTransformSchema(
+      schema: StructType,
+      fitting: Boolean,
+      featuresDataType: DataType): StructType = {
+    if ($(loss) == Huber) {
+      require($(solver)!= Normal, "LinearRegression with huber loss doesn't 
support " +
+        "normal solver, please change solver to auto or l-bfgs.")
+      require($(elasticNetParam) == 0.0, "LinearRegression with huber loss 
only supports " +
+        s"L2 regularization, but got elasticNetParam = $getElasticNetParam.")
+
+    }
+    super.validateAndTransformSchema(schema, fitting, featuresDataType)
+  }
 }
 
 /**
  * Linear regression.
  *
- * The learning objective is to minimize the squared error, with 
regularization.
- * The specific squared error loss function used is:
- *
- * <blockquote>
- *    $$
- *    L = 1/2n ||A coefficients - y||^2^
- *    $$
- * </blockquote>
+ * The learning objective is to minimize the specified loss function, with 
regularization.
+ * This supports two kinds of loss:
+ *  - squaredError (a.k.a squared loss)
+ *  - huber (a hybrid of squared error for relatively small errors and 
absolute error for
+ *  relatively large ones, and we estimate the scale parameter from training 
data)
  *
  * This supports multiple types of regularization:
  *  - none (a.k.a. ordinary least squares)
  *  - L2 (ridge regression)
  *  - L1 (Lasso)
  *  - L2 + L1 (elastic net)
+ *
+ * The squared error objective function is:
+ *
+ * <blockquote>
+ *   $$
+ *   \begin{align}
+ *   \min_{w}\frac{1}{2n}{\sum_{i=1}^n(X_{i}w - y_{i})^{2} +
+ *   \lambda\left[\frac{1-\alpha}{2}{||w||_{2}}^{2} + \alpha{||w||_{1}}\right]}
+ *   \end{align}
+ *   $$
+ * </blockquote>
+ *
+ * The huber objective function is:
+ *
+ * <blockquote>
+ *   $$
+ *   \begin{align}
+ *   \min_{w, \sigma}\frac{1}{2n}{\sum_{i=1}^n\left(\sigma +
+ *   H_m\left(\frac{X_{i}w - y_{i}}{\sigma}\right)\sigma\right) + 
\frac{1}{2}\lambda {||w||_2}^2}
+ *   \end{align}
+ *   $$
+ * </blockquote>
+ *
+ * where
+ *
+ * <blockquote>
+ *   $$
+ *   \begin{align}
+ *   H_m(z) = \begin{cases}
+ *            z^2, & \text {if } |z| &lt; \epsilon, \\
+ *            2\epsilon|z| - \epsilon^2, & \text{otherwise}
+ *            \end{cases}
+ *   \end{align}
+ *   $$
+ * </blockquote>
+ *
+ * Note: Fitting with huber loss only supports none and L2 regularization.
  */
 @Since("1.3.0")
 class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: 
String)
@@ -142,6 +223,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
    * For alpha in (0,1), the penalty is a combination of L1 and L2.
    * Default is 0.0 which is an L2 penalty.
    *
+   * Note: Fitting with huber loss only supports None and L2 regularization,
+   * so throws exception if this param is non-zero value.
+   *
    * @group setParam
    */
   @Since("1.4.0")
@@ -190,6 +274,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
    *    The Normal Equations solver will be used when possible, but this will 
automatically fall
    *    back to iterative optimization methods when needed.
    *
+   * Note: Fitting with huber loss doesn't support normal solver,
+   * so throws exception if this param was set with "normal".
    * @group setParam
    */
   @Since("1.6.0")
@@ -208,6 +294,26 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
   def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
   setDefault(aggregationDepth -> 2)
 
+  /**
+   * Sets the value of param [[loss]].
+   * Default is "squaredError".
+   *
+   * @group setParam
+   */
+  @Since("2.3.0")
+  def setLoss(value: String): this.type = set(loss, value)
+  setDefault(loss -> SquaredError)
+
+  /**
+   * Sets the value of param [[epsilon]].
+   * Default is 1.35.
+   *
+   * @group setExpertParam
+   */
+  @Since("2.3.0")
+  def setEpsilon(value: Double): this.type = set(epsilon, value)
+  setDefault(epsilon -> 1.35)
+
   override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
     // Extract the number of features before deciding optimization solver.
     val numFeatures = 
dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
@@ -220,12 +326,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
     }
 
     val instr = Instrumentation.create(this, dataset)
-    instr.logParams(labelCol, featuresCol, weightCol, predictionCol, solver, 
tol,
-      elasticNetParam, fitIntercept, maxIter, regParam, standardization, 
aggregationDepth)
+    instr.logParams(labelCol, featuresCol, weightCol, predictionCol, solver, 
tol, elasticNetParam,
+      fitIntercept, maxIter, regParam, standardization, aggregationDepth, 
loss, epsilon)
     instr.logNumFeatures(numFeatures)
 
-    if (($(solver) == Auto &&
-      numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == 
Normal) {
+    if ($(loss) == SquaredError && (($(solver) == Auto &&
+      numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == 
Normal)) {
       // For low dimensional data, WeightedLeastSquares is more efficient 
since the
       // training algorithm only requires one pass through the data. 
(SPARK-10668)
 
@@ -330,12 +436,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
 
     // Since we implicitly do the feature scaling when we compute the cost 
function
     // to improve the convergence, the effective regParam will be changed.
-    val effectiveRegParam = $(regParam) / yStd
+    val effectiveRegParam = $(loss) match {
+      case SquaredError => $(regParam) / yStd
+      case Huber => $(regParam)
+    }
     val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam
     val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
 
-    val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, 
$(fitIntercept),
-      bcFeaturesStd, bcFeaturesMean)(_)
     val getFeaturesStd = (j: Int) => if (j >= 0 && j < numFeatures) 
featuresStd(j) else 0.0
     val regularization = if (effectiveL2RegParam != 0.0) {
       val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures
@@ -344,33 +451,58 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
     } else {
       None
     }
-    val costFun = new RDDLossFunction(instances, getAggregatorFunc, 
regularization,
-      $(aggregationDepth))
 
-    val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) 
{
-      new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
-    } else {
-      val standardizationParam = $(standardization)
-      def effectiveL1RegFun = (index: Int) => {
-        if (standardizationParam) {
-          effectiveL1RegParam
+    val costFun = $(loss) match {
+      case SquaredError =>
+        val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, 
$(fitIntercept),
+          bcFeaturesStd, bcFeaturesMean)(_)
+        new RDDLossFunction(instances, getAggregatorFunc, regularization, 
$(aggregationDepth))
+      case Huber =>
+        val getAggregatorFunc = new HuberAggregator($(fitIntercept), 
$(epsilon), bcFeaturesStd)(_)
+        new RDDLossFunction(instances, getAggregatorFunc, regularization, 
$(aggregationDepth))
+    }
+
+    val optimizer = $(loss) match {
+      case SquaredError =>
+        if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
+          new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
         } else {
-          // If `standardization` is false, we still standardize the data
-          // to improve the rate of convergence; as a result, we have to
-          // perform this reverse standardization by penalizing each component
-          // differently to get effectively the same objective function when
-          // the training dataset is not standardized.
-          if (featuresStd(index) != 0.0) effectiveL1RegParam / 
featuresStd(index) else 0.0
+          val standardizationParam = $(standardization)
+          def effectiveL1RegFun = (index: Int) => {
+            if (standardizationParam) {
+              effectiveL1RegParam
+            } else {
+              // If `standardization` is false, we still standardize the data
+              // to improve the rate of convergence; as a result, we have to
+              // perform this reverse standardization by penalizing each 
component
+              // differently to get effectively the same objective function 
when
+              // the training dataset is not standardized.
+              if (featuresStd(index) != 0.0) effectiveL1RegParam / 
featuresStd(index) else 0.0
+            }
+          }
+          new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, 
$(tol))
         }
-      }
-      new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, 
$(tol))
+      case Huber =>
+        val dim = if ($(fitIntercept)) numFeatures + 2 else numFeatures + 1
+        val lowerBounds = BDV[Double](Array.fill(dim)(Double.MinValue))
+        // Optimize huber loss in space "\sigma > 0"
+        lowerBounds(dim - 1) = Double.MinPositiveValue
+        val upperBounds = BDV[Double](Array.fill(dim)(Double.MaxValue))
+        new BreezeLBFGSB(lowerBounds, upperBounds, $(maxIter), 10, $(tol))
+    }
+
+    val initialValues = $(loss) match {
+      case SquaredError =>
+        Vectors.zeros(numFeatures)
+      case Huber =>
+        val dim = if ($(fitIntercept)) numFeatures + 2 else numFeatures + 1
+        Vectors.dense(Array.fill(dim)(1.0))
     }
 
-    val initialCoefficients = Vectors.zeros(numFeatures)
     val states = optimizer.iterations(new CachedDiffFunction(costFun),
-      initialCoefficients.asBreeze.toDenseVector)
+      initialValues.asBreeze.toDenseVector)
 
-    val (coefficients, objectiveHistory) = {
+    val (coefficients, intercept, scale, objectiveHistory) = {
       /*
          Note that in Linear Regression, the objective history (loss + 
regularization) returned
          from optimizer is computed in the scaled space given by the following 
formula.
@@ -396,35 +528,54 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
       bcFeaturesMean.destroy(blocking = false)
       bcFeaturesStd.destroy(blocking = false)
 
+      val parameters = state.x.toArray.clone()
+
       /*
          The coefficients are trained in the scaled space; we're converting 
them back to
          the original space.
        */
-      val rawCoefficients = state.x.toArray.clone()
+      val rawCoefficients: Array[Double] = $(loss) match {
+        case SquaredError => parameters
+        case Huber => parameters.slice(0, numFeatures)
+      }
+
       var i = 0
       val len = rawCoefficients.length
+      val multiplier = $(loss) match {
+        case SquaredError => yStd
+        case Huber => 1.0
+      }
       while (i < len) {
-        rawCoefficients(i) *= { if (featuresStd(i) != 0.0) yStd / 
featuresStd(i) else 0.0 }
+        rawCoefficients(i) *= { if (featuresStd(i) != 0.0) multiplier / 
featuresStd(i) else 0.0 }
         i += 1
       }
 
-      (Vectors.dense(rawCoefficients).compressed, arrayBuilder.result())
-    }
+      val interceptValue: Double = if ($(fitIntercept)) {
+        $(loss) match {
+          case SquaredError =>
+            /*
+            The intercept of squared error in R's GLMNET is computed using 
closed form
+            after the coefficients are converged. See the following discussion 
for detail.
+            
http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
+            */
+            yMean - dot(Vectors.dense(rawCoefficients), 
Vectors.dense(featuresMean))
+          case Huber => parameters(numFeatures)
+        }
+      } else {
+        0.0
+      }
 
-    /*
-       The intercept in R's GLMNET is computed using closed form after the 
coefficients are
-       converged. See the following discussion for detail.
-       
http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
-     */
-    val intercept = if ($(fitIntercept)) {
-      yMean - dot(coefficients, Vectors.dense(featuresMean))
-    } else {
-      0.0
+      val scaleValue: Double = $(loss) match {
+        case SquaredError => 1.0
+        case Huber => parameters.last
+      }
+
+      (Vectors.dense(rawCoefficients).compressed, interceptValue, scaleValue, 
arrayBuilder.result())
     }
 
     if (handlePersistence) instances.unpersist()
 
-    val model = copyValues(new LinearRegressionModel(uid, coefficients, 
intercept))
+    val model = copyValues(new LinearRegressionModel(uid, coefficients, 
intercept, scale))
     // Handle possible missing or invalid prediction columns
     val (summaryModel, predictionColName) = 
model.findSummaryModelAndPredictionCol()
 
@@ -471,6 +622,15 @@ object LinearRegression extends 
DefaultParamsReadable[LinearRegression] {
 
   /** Set of solvers that LinearRegression supports. */
   private[regression] val supportedSolvers = Array(Auto, Normal, LBFGS)
+
+  /** String name for "squaredError". */
+  private[regression] val SquaredError = "squaredError"
+
+  /** String name for "huber". */
+  private[regression] val Huber = "huber"
+
+  /** Set of loss function names that LinearRegression supports. */
+  private[regression] val supportedLosses = Array(SquaredError, Huber)
 }
 
 /**
@@ -480,10 +640,14 @@ object LinearRegression extends 
DefaultParamsReadable[LinearRegression] {
 class LinearRegressionModel private[ml] (
     @Since("1.4.0") override val uid: String,
     @Since("2.0.0") val coefficients: Vector,
-    @Since("1.3.0") val intercept: Double)
+    @Since("1.3.0") val intercept: Double,
+    @Since("2.3.0") val scale: Double)
   extends RegressionModel[Vector, LinearRegressionModel]
   with LinearRegressionParams with MLWritable {
 
+  def this(uid: String, coefficients: Vector, intercept: Double) =
+    this(uid, coefficients, intercept, 1.0)
+
   private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
 
   override val numFeatures: Int = coefficients.size
@@ -570,13 +734,13 @@ object LinearRegressionModel extends 
MLReadable[LinearRegressionModel] {
   private[LinearRegressionModel] class LinearRegressionModelWriter(instance: 
LinearRegressionModel)
     extends MLWriter with Logging {
 
-    private case class Data(intercept: Double, coefficients: Vector)
+    private case class Data(intercept: Double, coefficients: Vector, scale: 
Double)
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
       DefaultParamsWriter.saveMetadata(instance, path, sc)
-      // Save model data: intercept, coefficients
-      val data = Data(instance.intercept, instance.coefficients)
+      // Save model data: intercept, coefficients, scale
+      val data = Data(instance.intercept, instance.coefficients, 
instance.scale)
       val dataPath = new Path(path, "data").toString
       
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
     }
@@ -592,11 +756,20 @@ object LinearRegressionModel extends 
MLReadable[LinearRegressionModel] {
 
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.format("parquet").load(dataPath)
-      val Row(intercept: Double, coefficients: Vector) =
-        MLUtils.convertVectorColumnsToML(data, "coefficients")
-          .select("intercept", "coefficients")
-          .head()
-      val model = new LinearRegressionModel(metadata.uid, coefficients, 
intercept)
+      val (majorVersion, minorVersion) = 
majorMinorVersion(metadata.sparkVersion)
+      val model = if (majorVersion < 2 || (majorVersion == 2 && minorVersion 
<= 2)) {
+        // Spark 2.2 and before
+        val Row(intercept: Double, coefficients: Vector) =
+          MLUtils.convertVectorColumnsToML(data, "coefficients")
+            .select("intercept", "coefficients")
+            .head()
+        new LinearRegressionModel(metadata.uid, coefficients, intercept)
+      } else {
+        // Spark 2.3 and later
+        val Row(intercept: Double, coefficients: Vector, scale: Double) =
+          data.select("intercept", "coefficients", "scale").head()
+        new LinearRegressionModel(metadata.uid, coefficients, intercept, scale)
+      }
 
       DefaultParamsReader.getAndSetParams(model, metadata)
       model

http://git-wip-us.apache.org/repos/asf/spark/blob/1e44dd00/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberAggregatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberAggregatorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberAggregatorSuite.scala
new file mode 100644
index 0000000..718ffa2
--- /dev/null
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HuberAggregatorSuite.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.optim.aggregator
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class HuberAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  import DifferentiableLossAggregatorSuite.getRegressionSummarizers
+
+  @transient var instances: Array[Instance] = _
+  @transient var instancesConstantFeature: Array[Instance] = _
+  @transient var instancesConstantFeatureFiltered: Array[Instance] = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    instances = Array(
+      Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)),
+      Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)),
+      Instance(2.0, 0.3, Vectors.dense(4.0, 0.5))
+    )
+    instancesConstantFeature = Array(
+      Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)),
+      Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)),
+      Instance(2.0, 0.3, Vectors.dense(1.0, 0.5))
+    )
+    instancesConstantFeatureFiltered = Array(
+      Instance(0.0, 0.1, Vectors.dense(2.0)),
+      Instance(1.0, 0.5, Vectors.dense(1.0)),
+      Instance(2.0, 0.3, Vectors.dense(0.5))
+    )
+  }
+
+  /** Get summary statistics for some data and create a new HuberAggregator. */
+  private def getNewAggregator(
+      instances: Array[Instance],
+      parameters: Vector,
+      fitIntercept: Boolean,
+      epsilon: Double): HuberAggregator = {
+    val (featuresSummarizer, _) = getRegressionSummarizers(instances)
+    val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+    val bcFeaturesStd = spark.sparkContext.broadcast(featuresStd)
+    val bcParameters = spark.sparkContext.broadcast(parameters)
+    new HuberAggregator(fitIntercept, epsilon, bcFeaturesStd)(bcParameters)
+  }
+
+  test("aggregator add method should check input size") {
+    val parameters = Vectors.dense(1.0, 2.0, 3.0, 4.0)
+    val agg = getNewAggregator(instances, parameters, fitIntercept = true, 
epsilon = 1.35)
+    withClue("HuberAggregator features dimension must match parameters 
dimension") {
+      intercept[IllegalArgumentException] {
+        agg.add(Instance(1.0, 1.0, Vectors.dense(2.0)))
+      }
+    }
+  }
+
+  test("negative weight") {
+    val parameters = Vectors.dense(1.0, 2.0, 3.0, 4.0)
+    val agg = getNewAggregator(instances, parameters, fitIntercept = true, 
epsilon = 1.35)
+    withClue("HuberAggregator does not support negative instance weights.") {
+      intercept[IllegalArgumentException] {
+        agg.add(Instance(1.0, -1.0, Vectors.dense(2.0, 1.0)))
+      }
+    }
+  }
+
+  test("check sizes") {
+    val paramWithIntercept = Vectors.dense(1.0, 2.0, 3.0, 4.0)
+    val paramWithoutIntercept = Vectors.dense(1.0, 2.0, 4.0)
+    val aggIntercept = getNewAggregator(instances, paramWithIntercept,
+      fitIntercept = true, epsilon = 1.35)
+    val aggNoIntercept = getNewAggregator(instances, paramWithoutIntercept,
+      fitIntercept = false, epsilon = 1.35)
+    instances.foreach(aggIntercept.add)
+    instances.foreach(aggNoIntercept.add)
+
+    assert(aggIntercept.gradient.size === 4)
+    assert(aggNoIntercept.gradient.size === 3)
+  }
+
+  test("check correctness") {
+    val parameters = Vectors.dense(1.0, 2.0, 3.0, 4.0)
+    val numFeatures = 2
+    val (featuresSummarizer, _) = getRegressionSummarizers(instances)
+    val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+    val epsilon = 1.35
+    val weightSum = instances.map(_.weight).sum
+
+    val agg = getNewAggregator(instances, parameters, fitIntercept = true, 
epsilon)
+    instances.foreach(agg.add)
+
+    // compute expected loss sum
+    val coefficients = parameters.toArray.slice(0, 2)
+    val intercept = parameters(2)
+    val sigma = parameters(3)
+    val stdCoef = coefficients.indices.map(i => coefficients(i) / 
featuresStd(i)).toArray
+    val lossSum = instances.map { case Instance(label, weight, features) =>
+      val margin = BLAS.dot(Vectors.dense(stdCoef), features) + intercept
+      val linearLoss = label - margin
+      if (math.abs(linearLoss) <= sigma * epsilon) {
+        0.5 * weight * (sigma +  math.pow(linearLoss, 2.0) / sigma)
+      } else {
+        0.5 * weight * (sigma + 2.0 * epsilon * math.abs(linearLoss) - sigma * 
epsilon * epsilon)
+      }
+    }.sum
+    val loss = lossSum / weightSum
+
+    // compute expected gradients
+    val gradientCoef = new Array[Double](numFeatures + 2)
+    instances.foreach { case Instance(label, weight, features) =>
+      val margin = BLAS.dot(Vectors.dense(stdCoef), features) + intercept
+      val linearLoss = label - margin
+      if (math.abs(linearLoss) <= sigma * epsilon) {
+        features.toArray.indices.foreach { i =>
+          gradientCoef(i) +=
+            -1.0 * weight * (linearLoss / sigma) * (features(i) / 
featuresStd(i))
+        }
+        gradientCoef(2) += -1.0 * weight * (linearLoss / sigma)
+        gradientCoef(3) += 0.5 * weight * (1.0 - math.pow(linearLoss / sigma, 
2.0))
+      } else {
+        val sign = if (linearLoss >= 0) -1.0 else 1.0
+        features.toArray.indices.foreach { i =>
+          gradientCoef(i) += weight * sign * epsilon * (features(i) / 
featuresStd(i))
+        }
+        gradientCoef(2) += weight * sign * epsilon
+        gradientCoef(3) += 0.5 * weight * (1.0 - epsilon * epsilon)
+      }
+    }
+    val gradient = Vectors.dense(gradientCoef.map(_ / weightSum))
+
+    assert(loss ~== agg.loss relTol 0.01)
+    assert(gradient ~== agg.gradient relTol 0.01)
+  }
+
+  test("check with zero standard deviation") {
+    val parameters = Vectors.dense(1.0, 2.0, 3.0, 4.0)
+    val parametersFiltered = Vectors.dense(2.0, 3.0, 4.0)
+    val aggConstantFeature = getNewAggregator(instancesConstantFeature, 
parameters,
+      fitIntercept = true, epsilon = 1.35)
+    val aggConstantFeatureFiltered = 
getNewAggregator(instancesConstantFeatureFiltered,
+      parametersFiltered, fitIntercept = true, epsilon = 1.35)
+    instances.foreach(aggConstantFeature.add)
+    instancesConstantFeatureFiltered.foreach(aggConstantFeatureFiltered.add)
+    // constant features should not affect gradient
+    def validateGradient(grad: Vector, gradFiltered: Vector): Unit = {
+      assert(grad(0) === 0.0)
+      assert(grad(1) ~== gradFiltered(0) relTol 0.01)
+    }
+
+    validateGradient(aggConstantFeature.gradient, 
aggConstantFeatureFiltered.gradient)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1e44dd00/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index aec5ac0..9bb2895 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -41,6 +41,7 @@ class LinearRegressionSuite extends MLTest with 
DefaultReadWriteTest {
   @transient var datasetWithWeight: DataFrame = _
   @transient var datasetWithWeightConstantLabel: DataFrame = _
   @transient var datasetWithWeightZeroLabel: DataFrame = _
+  @transient var datasetWithOutlier: DataFrame = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()
@@ -107,6 +108,16 @@ class LinearRegressionSuite extends MLTest with 
DefaultReadWriteTest {
       Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)),
       Instance(0.0, 4.0, Vectors.dense(3.0, 13.0))
     ), 2).toDF()
+
+    datasetWithOutlier = {
+      val inlierData = LinearDataGenerator.generateLinearInput(
+        intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
+        xVariance = Array(0.7, 1.2), nPoints = 900, seed, eps = 0.1)
+      val outlierData = LinearDataGenerator.generateLinearInput(
+        intercept = -2.1, weights = Array(0.6, -1.2), xMean = Array(0.9, -1.3),
+        xVariance = Array(1.5, 0.8), nPoints = 100, seed, eps = 0.1)
+      sc.parallelize(inlierData ++ outlierData, 2).map(_.asML).toDF()
+    }
   }
 
   /**
@@ -127,6 +138,10 @@ class LinearRegressionSuite extends MLTest with 
DefaultReadWriteTest {
     datasetWithSparseFeature.rdd.map { case Row(label: Double, features: 
Vector) =>
       label + "," + features.toArray.mkString(",")
     
}.repartition(1).saveAsTextFile("target/tmp/LinearRegressionSuite/datasetWithSparseFeature")
+
+    datasetWithOutlier.rdd.map { case Row(label: Double, features: Vector) =>
+      label + "," + features.toArray.mkString(",")
+    
}.repartition(1).saveAsTextFile("target/tmp/LinearRegressionSuite/datasetWithOutlier")
   }
 
   test("params") {
@@ -144,7 +159,9 @@ class LinearRegressionSuite extends MLTest with 
DefaultReadWriteTest {
     assert(lir.getElasticNetParam === 0.0)
     assert(lir.getFitIntercept)
     assert(lir.getStandardization)
-    assert(lir.getSolver == "auto")
+    assert(lir.getSolver === "auto")
+    assert(lir.getLoss === "squaredError")
+    assert(lir.getEpsilon === 1.35)
     val model = lir.fit(datasetWithDenseFeature)
 
     MLTestingUtils.checkCopyAndUids(lir, model)
@@ -160,11 +177,27 @@ class LinearRegressionSuite extends MLTest with 
DefaultReadWriteTest {
     assert(model.getFeaturesCol === "features")
     assert(model.getPredictionCol === "prediction")
     assert(model.intercept !== 0.0)
+    assert(model.scale === 1.0)
     assert(model.hasParent)
     val numFeatures = 
datasetWithDenseFeature.select("features").first().getAs[Vector](0).size
     assert(model.numFeatures === numFeatures)
   }
 
+  test("linear regression: illegal params") {
+    withClue("LinearRegression with huber loss only supports L2 
regularization") {
+      intercept[IllegalArgumentException] {
+        new LinearRegression().setLoss("huber").setElasticNetParam(0.5)
+          .fit(datasetWithDenseFeature)
+      }
+    }
+
+    withClue("LinearRegression with huber loss doesn't support normal solver") 
{
+      intercept[IllegalArgumentException] {
+        new 
LinearRegression().setLoss("huber").setSolver("normal").fit(datasetWithDenseFeature)
+      }
+    }
+  }
+
   test("linear regression handles singular matrices") {
     // check for both constant columns with intercept (zero std) and collinear
     val singularDataConstantColumn = sc.parallelize(Seq(
@@ -837,6 +870,7 @@ class LinearRegressionSuite extends MLTest with 
DefaultReadWriteTest {
       (1.0, 0.21, true, true)
     )
 
+    // For squaredError loss
     for (solver <- Seq("auto", "l-bfgs", "normal");
          (elasticNetParam, regParam, fitIntercept, standardization) <- 
testParams) {
       val estimator = new LinearRegression()
@@ -852,6 +886,22 @@ class LinearRegressionSuite extends MLTest with 
DefaultReadWriteTest {
       MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel, 
LinearRegression](
         datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals, seed)
     }
+
+    // For huber loss
+    for ((_, regParam, fitIntercept, standardization) <- testParams) {
+      val estimator = new LinearRegression()
+        .setLoss("huber")
+        .setFitIntercept(fitIntercept)
+        .setStandardization(standardization)
+        .setRegParam(regParam)
+      MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, 
LinearRegression](
+        datasetWithOutlier.as[LabeledPoint], estimator, modelEquals)
+      MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, 
LinearRegression](
+        datasetWithOutlier.as[LabeledPoint], estimator, numClasses, 
modelEquals,
+        outlierRatio = 3)
+      MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel, 
LinearRegression](
+        datasetWithOutlier.as[LabeledPoint], estimator, modelEquals, seed)
+    }
   }
 
   test("linear regression model with l-bfgs with big feature datasets") {
@@ -1004,6 +1054,198 @@ class LinearRegressionSuite extends MLTest with 
DefaultReadWriteTest {
       }
     }
   }
+
+  test("linear regression (huber loss) with intercept without regularization") 
{
+    val trainer1 = (new LinearRegression).setLoss("huber")
+      .setFitIntercept(true).setStandardization(true)
+    val trainer2 = (new LinearRegression).setLoss("huber")
+      .setFitIntercept(true).setStandardization(false)
+
+    val model1 = trainer1.fit(datasetWithOutlier)
+    val model2 = trainer2.fit(datasetWithOutlier)
+
+    /*
+      Using the following Python code to load the data and train the model 
using
+      scikit-learn package.
+
+      import pandas as pd
+      import numpy as np
+      from sklearn.linear_model import HuberRegressor
+      df = pd.read_csv("path", header = None)
+      X = df[df.columns[1:3]]
+      y = np.array(df[df.columns[0]])
+      huber = HuberRegressor(fit_intercept=True, alpha=0.0, max_iter=100, 
epsilon=1.35)
+      huber.fit(X, y)
+
+      >>> huber.coef_
+      array([ 4.68998007,  7.19429011])
+      >>> huber.intercept_
+      6.3002404351083037
+      >>> huber.scale_
+      0.077810159205220747
+     */
+    val coefficientsPy = Vectors.dense(4.68998007, 7.19429011)
+    val interceptPy = 6.30024044
+    val scalePy = 0.07781016
+
+    assert(model1.coefficients ~= coefficientsPy relTol 1E-3)
+    assert(model1.intercept ~== interceptPy relTol 1E-3)
+    assert(model1.scale ~== scalePy relTol 1E-3)
+
+    // Without regularization, with or without standardization will converge 
to the same solution.
+    assert(model2.coefficients ~= coefficientsPy relTol 1E-3)
+    assert(model2.intercept ~== interceptPy relTol 1E-3)
+    assert(model2.scale ~== scalePy relTol 1E-3)
+  }
+
+  test("linear regression (huber loss) without intercept without 
regularization") {
+    val trainer1 = (new LinearRegression).setLoss("huber")
+      .setFitIntercept(false).setStandardization(true)
+    val trainer2 = (new LinearRegression).setLoss("huber")
+      .setFitIntercept(false).setStandardization(false)
+
+    val model1 = trainer1.fit(datasetWithOutlier)
+    val model2 = trainer2.fit(datasetWithOutlier)
+
+    /*
+      huber = HuberRegressor(fit_intercept=False, alpha=0.0, max_iter=100, 
epsilon=1.35)
+      huber.fit(X, y)
+
+      >>> huber.coef_
+      array([ 6.71756703,  5.08873222])
+      >>> huber.intercept_
+      0.0
+      >>> huber.scale_
+      2.5560209922722317
+     */
+    val coefficientsPy = Vectors.dense(6.71756703, 5.08873222)
+    val interceptPy = 0.0
+    val scalePy = 2.55602099
+
+    assert(model1.coefficients ~= coefficientsPy relTol 1E-3)
+    assert(model1.intercept === interceptPy)
+    assert(model1.scale ~== scalePy relTol 1E-3)
+
+    // Without regularization, with or without standardization will converge 
to the same solution.
+    assert(model2.coefficients ~= coefficientsPy relTol 1E-3)
+    assert(model2.intercept === interceptPy)
+    assert(model2.scale ~== scalePy relTol 1E-3)
+  }
+
+  test("linear regression (huber loss) with intercept with L2 regularization") 
{
+    val trainer1 = (new LinearRegression).setLoss("huber")
+      .setFitIntercept(true).setRegParam(0.21).setStandardization(true)
+    val trainer2 = (new LinearRegression).setLoss("huber")
+      .setFitIntercept(true).setRegParam(0.21).setStandardization(false)
+
+    val model1 = trainer1.fit(datasetWithOutlier)
+    val model2 = trainer2.fit(datasetWithOutlier)
+
+    /*
+      Since scikit-learn HuberRegressor does not support standardization,
+      we do it manually out of the estimator.
+
+      xStd = np.std(X, axis=0)
+      scaledX = X / xStd
+      huber = HuberRegressor(fit_intercept=True, alpha=210, max_iter=100, 
epsilon=1.35)
+      huber.fit(scaledX, y)
+
+      >>> np.array(huber.coef_ / xStd)
+      array([ 1.97732633,  3.38816722])
+      >>> huber.intercept_
+      3.7527581430531227
+      >>> huber.scale_
+      3.787363673371801
+     */
+    val coefficientsPy1 = Vectors.dense(1.97732633, 3.38816722)
+    val interceptPy1 = 3.75275814
+    val scalePy1 = 3.78736367
+
+    assert(model1.coefficients ~= coefficientsPy1 relTol 1E-2)
+    assert(model1.intercept ~== interceptPy1 relTol 1E-2)
+    assert(model1.scale ~== scalePy1 relTol 1E-2)
+
+    /*
+      huber = HuberRegressor(fit_intercept=True, alpha=210, max_iter=100, 
epsilon=1.35)
+      huber.fit(X, y)
+
+      >>> huber.coef_
+      array([ 1.73346444,  3.63746999])
+      >>> huber.intercept_
+      4.3017134790781739
+      >>> huber.scale_
+      3.6472742809286793
+     */
+    val coefficientsPy2 = Vectors.dense(1.73346444, 3.63746999)
+    val interceptPy2 = 4.30171347
+    val scalePy2 = 3.64727428
+
+    assert(model2.coefficients ~= coefficientsPy2 relTol 1E-3)
+    assert(model2.intercept ~== interceptPy2 relTol 1E-3)
+    assert(model2.scale ~== scalePy2 relTol 1E-3)
+  }
+
+  test("linear regression (huber loss) without intercept with L2 
regularization") {
+    val trainer1 = (new LinearRegression).setLoss("huber")
+      .setFitIntercept(false).setRegParam(0.21).setStandardization(true)
+    val trainer2 = (new LinearRegression).setLoss("huber")
+      .setFitIntercept(false).setRegParam(0.21).setStandardization(false)
+
+    val model1 = trainer1.fit(datasetWithOutlier)
+    val model2 = trainer2.fit(datasetWithOutlier)
+
+    /*
+      Since scikit-learn HuberRegressor does not support standardization,
+      we do it manually out of the estimator.
+
+      xStd = np.std(X, axis=0)
+      scaledX = X / xStd
+      huber = HuberRegressor(fit_intercept=False, alpha=210, max_iter=100, 
epsilon=1.35)
+      huber.fit(scaledX, y)
+
+      >>> np.array(huber.coef_ / xStd)
+      array([ 2.59679008,  2.26973102])
+      >>> huber.intercept_
+      0.0
+      >>> huber.scale_
+      4.5766311924091791
+     */
+    val coefficientsPy1 = Vectors.dense(2.59679008, 2.26973102)
+    val interceptPy1 = 0.0
+    val scalePy1 = 4.57663119
+
+    assert(model1.coefficients ~= coefficientsPy1 relTol 1E-2)
+    assert(model1.intercept === interceptPy1)
+    assert(model1.scale ~== scalePy1 relTol 1E-2)
+
+    /*
+      huber = HuberRegressor(fit_intercept=False, alpha=210, max_iter=100, 
epsilon=1.35)
+      huber.fit(X, y)
+
+      >>> huber.coef_
+      array([ 2.28423908,  2.25196887])
+      >>> huber.intercept_
+      0.0
+      >>> huber.scale_
+      4.5979643506051753
+     */
+    val coefficientsPy2 = Vectors.dense(2.28423908, 2.25196887)
+    val interceptPy2 = 0.0
+    val scalePy2 = 4.59796435
+
+    assert(model2.coefficients ~= coefficientsPy2 relTol 1E-3)
+    assert(model2.intercept === interceptPy2)
+    assert(model2.scale ~== scalePy2 relTol 1E-3)
+  }
+
+  test("huber loss model match squared error for large epsilon") {
+    val trainer1 = new LinearRegression().setLoss("huber").setEpsilon(1E5)
+    val model1 = trainer1.fit(datasetWithOutlier)
+    val trainer2 = new LinearRegression()
+    val model2 = trainer2.fit(datasetWithOutlier)
+    assert(model1.coefficients ~== model2.coefficients relTol 1E-3)
+    assert(model1.intercept ~== model2.intercept relTol 1E-3)
+  }
 }
 
 object LinearRegressionSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/1e44dd00/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 9be01f6..9902fed 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -1065,6 +1065,11 @@ object MimaExcludes {
       // [SPARK-21680][ML][MLLIB]optimzie Vector coompress
       
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.toSparseWithSize"),
       
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Vector.toSparseWithSize")
+    ) ++ Seq(
+      // [SPARK-3181][ML]Implement huber loss for LinearRegression.
+      
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.org$apache$spark$ml$param$shared$HasLoss$_setter_$loss_="),
+      
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.getLoss"),
+      
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.loss")
     )
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to