zhengruifeng commented on a change in pull request #32822:
URL: https://github.com/apache/spark/pull/32822#discussion_r649624193



##########
File path: mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala
##########
@@ -94,4 +95,43 @@ private[spark] object Utils {
       math.log1p(math.exp(x))
     }
   }
+
+  /**
+   * Perform in-place softmax conversion.
+   */
+  def softmax(values: Array[Double]): Unit = {
+    var maxValue = Double.MinValue
+    var i = 0
+    while (i < values.length) {
+      val value = values(i)
+      if (value.isPosInfinity) {
+        java.util.Arrays.fill(values, 0)
+        values(i) = 1.0
+        return
+      } else if (value > maxValue) {
+        maxValue = value
+      }
+      i += 1
+    }
+
+    var sum = 0.0
+    i = 0
+    if (maxValue > 0) {

Review comment:
       sounds reasonable

##########
File path: mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala
##########
@@ -94,4 +95,43 @@ private[spark] object Utils {
       math.log1p(math.exp(x))
     }
   }
+
+  /**
+   * Perform in-place softmax conversion.
+   */
+  def softmax(values: Array[Double]): Unit = {
+    var maxValue = Double.MinValue
+    var i = 0
+    while (i < values.length) {
+      val value = values(i)
+      if (value.isPosInfinity) {
+        java.util.Arrays.fill(values, 0)
+        values(i) = 1.0
+        return
+      } else if (value > maxValue) {
+        maxValue = value
+      }
+      i += 1
+    }
+
+    var sum = 0.0
+    i = 0
+    if (maxValue > 0) {

Review comment:
       ```py
   def softmax(X, copy=True):
       """
       Calculate the softmax function.
   
       The softmax function is calculated by
       np.exp(X) / np.sum(np.exp(X), axis=1)
   
       This will cause overflow when large values are exponentiated.
       Hence the largest value in each row is subtracted from each data
       point to prevent this.
   
       Parameters
       ----------
       X : array-like of float of shape (M, N)
           Argument to the logistic function.
   
       copy : bool, default=True
           Copy X or not.
   
       Returns
       -------
       out : ndarray of shape (M, N)
           Softmax function evaluated at every point in x.
       """
       if copy:
           X = np.copy(X)
       max_prob = np.max(X, axis=1).reshape((-1, 1))
       X -= max_prob
       np.exp(X, X)
       sum_prob = np.sum(X, axis=1).reshape((-1, 1))
       X /= sum_prob
       return X
   ```
   
   softmax in scikit-learn does not check whether the maxvalue is positive or 
not




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to