Github user srowen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/3833#discussion_r22932041
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala ---
    @@ -81,14 +136,111 @@ class LogisticGradient extends Gradient {
           label: Double,
           weights: Vector,
           cumGradient: Vector): Double = {
    -    val margin = -1.0 * dot(data, weights)
    -    val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
    -    axpy(gradientMultiplier, data, cumGradient)
    -    if (label > 0) {
    -      // The following is equivalent to log(1 + exp(margin)) but more 
numerically stable.
    -      MLUtils.log1pExp(margin)
    -    } else {
    -      MLUtils.log1pExp(margin) - margin
    +    assert((weights.size % data.size) == 0)
    +    val dataSize = data.size
    +
    +    // (n + 1) is number of classes
    +    val n = weights.size / dataSize
    +    n match {
    +      case 1 =>
    +        /**
    +         * For Binary Logistic Regression.
    +         *
    +         * Although the loss and gradient calculation for multinomial one 
is more generalized,
    +         * and multinomial one can also be used in binary case, we still 
implement a specialized
    +         * binary version for performance reason.
    +         */
    +        val margin = -1.0 * dot(data, weights)
    +        val multiplier = (1.0 / (1.0 + math.exp(margin))) - label
    +        axpy(multiplier, data, cumGradient)
    +        if (label > 0) {
    +          // The following is equivalent to log(1 + exp(margin)) but more 
numerically stable.
    +          MLUtils.log1pExp(margin)
    +        } else {
    +          MLUtils.log1pExp(margin) - margin
    +        }
    +      case _ =>
    +        /**
    +         * For Multinomial Logistic Regression.
    +         */
    +        val weightsArray = weights match {
    +          case dv: DenseVector => dv.values
    +          case _ =>
    +            throw new IllegalArgumentException(
    +              s"weights only supports dense vector but got type 
${weights.getClass}.")
    +        }
    +        val cumGradientArray = cumGradient match {
    +          case dv: DenseVector => dv.values
    +          case _ =>
    +            throw new IllegalArgumentException(
    +              s"cumGradient only supports dense vector but got type 
${cumGradient.getClass}.")
    +        }
    +        val margins = Array.ofDim[Double](n)
    +
    +        // marginY is margins(label - 1) in the formula.
    +        var marginY = 0.0
    +        var maxMargin = Double.NegativeInfinity
    +        var maxMarginIndex = 0
    +        var sum = 0.0
    +
    +        var i = 0
    +        while (i < n) {
    --- End diff --
    
    Same comment about making this more compact with Scala syntax.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to