Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/23146#discussion_r238104839 --- Diff: mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala --- @@ -82,7 +82,72 @@ private[ml] class L2Regularization( } (0.5 * sum * regParam, Vectors.dense(gradient)) case _: SparseVector => - throw new IllegalArgumentException("Sparse coefficients are not currently supported.") + throw new IllegalArgumentException( + "Sparse coefficients are not currently supported.") + } + } +} + + +/** + * Implements regularization for Maximum A Posteriori (MAP) optimization + * based on prior means (coefficients) and precisions. + * + * @param priorMean Prior coefficients (multivariate mean). + * @param priorPrecisions Prior precisions. + * @param regParam The magnitude of the regularization. + * @param shouldApply A function (Int => Boolean) indicating whether a given index should have + * regularization applied to it. Usually we don't apply regularization to + * the intercept. + * @param applyFeaturesStd Option for a function which maps coefficient index (column major) to the + * feature standard deviation. Since we always standardize the data during + * training, if `standardization` is false, we have to reverse + * standardization by penalizing each component differently by this param. + * If `standardization` is true, this should be `None`. + */ +private[ml] class PriorRegularization( + priorMean: Array[Double], + priorPrecisions: Array[Double], + override val regParam: Double, + shouldApply: Int => Boolean, + applyFeaturesStd: Option[Int => Double]) + extends DifferentiableRegularization[Vector] { + + override def calculate(coefficients: Vector): (Double, Vector) = { + coefficients match { + case dv: DenseVector => + var sum = 0.0 + val gradient = new Array[Double](dv.size) + dv.values.indices.filter(shouldApply).foreach { j => + val coef = coefficients(j) + val priorCoef = priorMean(j) + val priorPrecision = priorPrecisions(j) + applyFeaturesStd match { + case Some(getStd) => + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + val std = getStd(j) + if (std != 0.0) { + val temp = (coef - priorCoef) / (std * std) + sum += (coef - priorCoef) * temp * priorPrecision + gradient(j) = regParam * priorPrecision * temp + } else { + 0.0 --- End diff -- Who consumes `0.0`?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org