Github user chiwanpark commented on a diff in the pull request: https://github.com/apache/flink/pull/1985#discussion_r65291552 --- Diff: flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/RegularizationPenalty.scala --- @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.optimization + +import org.apache.flink.ml.math.{Vector, BLAS} +import org.apache.flink.ml.math.Breeze._ +import breeze.linalg.{norm => BreezeNorm} + +/** Represents a type of regularization penalty + * + * Regularization penalties are used to restrict the optimization problem to solutions with + * certain desirable characteristics, such as sparsity for the L1 penalty, or penalizing large + * weights for the L2 penalty. + * + * The regularization term, `R(w)` is added to the objective function, `f(w) = L(w) + lambda*R(w)` + * where lambda is the regularization parameter used to tune the amount of regularization applied. + */ +trait RegularizationPenalty extends Serializable { + + /** Calculates the new weights based on the gradient and regularization penalty + * + * @param weightVector The weights to be updated + * @param gradient The gradient used to update the weights + * @param regularizationConstant The regularization parameter to be applied + * @param learningRate The effective step size for this iteration + * @return Updated weights + */ + def takeStep( + weightVector: Vector, + gradient: Vector, + regularizationConstant: Double, + learningRate: Double) + : Vector + + /** Adds regularization to the loss value + * + * @param oldLoss The loss to be updated + * @param weightVector The gradient used to update the loss + * @param regularizationConstant The regularization parameter to be applied + * @return Updated loss + */ + def regLoss(oldLoss: Double, weightVector: Vector, regularizationConstant: Double): Double + +} + + +/** `L_2` regularization penalty. + * + * The regularization function is the square of the L2 norm `1/2*||w||_2^2` + * with `w` being the weight vector. The function penalizes large weights, + * favoring solutions with more small weights rather than few large ones. + */ +object L2Regularization extends RegularizationPenalty { + + /** Calculates the new weights based on the gradient and L2 regularization penalty + * + * The updated weight is `w - learningRate *(gradient + lambda * w)` where + * `w` is the weight vector, and `lambda` is the regularization parameter. + * + * @param weightVector The weights to be updated + * @param gradient The gradient according to which we will update the weights + * @param regularizationConstant The regularization parameter to be applied + * @param learningRate The effective step size for this iteration + * @return Updated weights + */ + override def takeStep( + weightVector: Vector, + gradient: Vector, + regularizationConstant: Double, + learningRate: Double) + : Vector = { + // add the gradient of the L2 regularization + BLAS.axpy(regularizationConstant, weightVector, gradient) + + // update the weights according to the learning rate + BLAS.axpy(-learningRate, gradient, weightVector) + + weightVector + } + + /** Adds regularization to the loss value + * + * The updated loss is `oldLoss + lambda * 1/2*||w||_2^2` where + * `w` is the weight vector, and `lambda` is the regularization parameter + * + * @param oldLoss The loss to be updated + * @param weightVector The gradient used to update the loss + * @param regularizationConstant The regularization parameter to be applied + * @return Updated loss + */ + override def regLoss(oldLoss: Double, weightVector: Vector, regularizationConstant: Double) + : Double = { + val squareNorm = BLAS.dot(weightVector, weightVector) + oldLoss + regularizationConstant * 0.5 * squareNorm + } +} + +/** `L_1` regularization penalty. + * + * The regularization function is the `L1` norm `||w||_1` with `w` being the weight vector. + * The `L_1` penalty can be used to drive a number of the solution coefficients to 0, thereby + * producing sparse solutions. + * + */ +object L1Regularization extends RegularizationPenalty { + + /** Calculates the new weights based on the gradient and regularization penalty + * + * The updated weight `w - learningRate * gradient` is shrunk towards zero + * by applying the proximal operator `signum(w) * max(0.0, abs(w) - shrinkageVal)` + * where `w` is the weight vector, `lambda` is the regularization parameter, + * and `shrinkageVal` is `lambda*learningRate`. + * + * @param weightVector The weights to be updated + * @param gradient The gradient according to which we will update the weights + * @param regularizationConstant The regularization parameter to be applied + * @param learningRate The effective step size for this iteration + * @return Updated weights + */ + override def takeStep( + weightVector: Vector, + gradient: Vector, + regularizationConstant: Double, + learningRate: Double) + : Vector = { + // Update weight vector with gradient. + BLAS.axpy(-learningRate, gradient, weightVector) + + // Apply proximal operator (soft thresholding) + val shrinkageVal = regularizationConstant * learningRate + var i = 0 + while (i < weightVector.size) { + val wi = weightVector(i) + weightVector(i) = scala.math.signum(wi) * + scala.math.max(0.0, scala.math.abs(wi) - shrinkageVal) --- End diff -- We can change `scala.math` to `math`.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. ---