Github user hhbyyh commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19020#discussion_r147326457
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---
    @@ -344,33 +449,58 @@ class LinearRegression @Since("1.3.0") 
(@Since("1.3.0") override val uid: String
         } else {
           None
         }
    -    val costFun = new RDDLossFunction(instances, getAggregatorFunc, 
regularization,
    -      $(aggregationDepth))
     
    -    val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 
0.0) {
    -      new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
    -    } else {
    -      val standardizationParam = $(standardization)
    -      def effectiveL1RegFun = (index: Int) => {
    -        if (standardizationParam) {
    -          effectiveL1RegParam
    +    val costFun = $(loss) match {
    +      case SquaredError =>
    +        val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, 
$(fitIntercept),
    +          bcFeaturesStd, bcFeaturesMean)(_)
    +        new RDDLossFunction(instances, getAggregatorFunc, regularization, 
$(aggregationDepth))
    +      case Huber =>
    +        val getAggregatorFunc = new HuberAggregator($(fitIntercept), 
$(epsilon), bcFeaturesStd)(_)
    +        new RDDLossFunction(instances, getAggregatorFunc, regularization, 
$(aggregationDepth))
    +    }
    +
    +    val optimizer = $(loss) match {
    +      case SquaredError =>
    +        if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
    +          new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
             } else {
    -          // If `standardization` is false, we still standardize the data
    -          // to improve the rate of convergence; as a result, we have to
    -          // perform this reverse standardization by penalizing each 
component
    -          // differently to get effectively the same objective function 
when
    -          // the training dataset is not standardized.
    -          if (featuresStd(index) != 0.0) effectiveL1RegParam / 
featuresStd(index) else 0.0
    +          val standardizationParam = $(standardization)
    +          def effectiveL1RegFun = (index: Int) => {
    +            if (standardizationParam) {
    +              effectiveL1RegParam
    +            } else {
    +              // If `standardization` is false, we still standardize the 
data
    +              // to improve the rate of convergence; as a result, we have 
to
    +              // perform this reverse standardization by penalizing each 
component
    +              // differently to get effectively the same objective 
function when
    +              // the training dataset is not standardized.
    +              if (featuresStd(index) != 0.0) effectiveL1RegParam / 
featuresStd(index) else 0.0
    +            }
    +          }
    +          new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, 
effectiveL1RegFun, $(tol))
             }
    -      }
    -      new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, 
$(tol))
    +      case Huber =>
    +        val dim = if ($(fitIntercept)) numFeatures + 2 else numFeatures + 1
    +        val lowerBounds = BDV[Double](Array.fill(dim)(Double.MinValue))
    +        // Optimize huber loss in space "\sigma > 0"
    +        lowerBounds(dim - 1) = Double.MinPositiveValue
    +        val upperBounds = BDV[Double](Array.fill(dim)(Double.MaxValue))
    +        new BreezeLBFGSB(lowerBounds, upperBounds, $(maxIter), 10, $(tol))
    +    }
    +
    +    val initialValues = $(loss) match {
    +      case SquaredError =>
    +        Vectors.zeros(numFeatures)
    +      case Huber =>
    +        val dim = if ($(fitIntercept)) numFeatures + 2 else numFeatures + 1
    +        Vectors.dense(Array.fill(dim)(1.0))
    --- End diff --
    
    out of curiosity, is there a reference for the default value?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to