zhengruifeng commented on a change in pull request #32822: URL: https://github.com/apache/spark/pull/32822#discussion_r649624193
########## File path: mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala ########## @@ -94,4 +95,43 @@ private[spark] object Utils { math.log1p(math.exp(x)) } } + + /** + * Perform in-place softmax conversion. + */ + def softmax(values: Array[Double]): Unit = { + var maxValue = Double.MinValue + var i = 0 + while (i < values.length) { + val value = values(i) + if (value.isPosInfinity) { + java.util.Arrays.fill(values, 0) + values(i) = 1.0 + return + } else if (value > maxValue) { + maxValue = value + } + i += 1 + } + + var sum = 0.0 + i = 0 + if (maxValue > 0) { Review comment: sounds reasonable ########## File path: mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala ########## @@ -94,4 +95,43 @@ private[spark] object Utils { math.log1p(math.exp(x)) } } + + /** + * Perform in-place softmax conversion. + */ + def softmax(values: Array[Double]): Unit = { + var maxValue = Double.MinValue + var i = 0 + while (i < values.length) { + val value = values(i) + if (value.isPosInfinity) { + java.util.Arrays.fill(values, 0) + values(i) = 1.0 + return + } else if (value > maxValue) { + maxValue = value + } + i += 1 + } + + var sum = 0.0 + i = 0 + if (maxValue > 0) { Review comment: ```py def softmax(X, copy=True): """ Calculate the softmax function. The softmax function is calculated by np.exp(X) / np.sum(np.exp(X), axis=1) This will cause overflow when large values are exponentiated. Hence the largest value in each row is subtracted from each data point to prevent this. Parameters ---------- X : array-like of float of shape (M, N) Argument to the logistic function. copy : bool, default=True Copy X or not. Returns ------- out : ndarray of shape (M, N) Softmax function evaluated at every point in x. """ if copy: X = np.copy(X) max_prob = np.max(X, axis=1).reshape((-1, 1)) X -= max_prob np.exp(X, X) sum_prob = np.sum(X, axis=1).reshape((-1, 1)) X /= sum_prob return X ``` softmax in scikit-learn does not check whether the maxvalue is positive or not -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org