mgaido91 commented on a change in pull request #25160: [SPARK-28399][ML] 
implement RobustScaler
URL: https://github.com/apache/spark/pull/25160#discussion_r305575336
 
 

 ##########
 File path: 
mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
 ##########
 @@ -134,77 +136,49 @@ class StandardScalerModel @Since("1.3.0") (
   @Since("1.1.0")
   override def transform(vector: Vector): Vector = {
     require(mean.size == vector.size)
-    if (withMean) {
-      // Must have a copy of the values since it will be modified in place
-      val values = vector match {
-        // specially handle DenseVector because its toArray does not clone 
already
-        case d: DenseVector => d.values.clone()
-        case v: Vector => v.toArray
-      }
-      val newValues = transformWithMean(values)
-      Vectors.dense(newValues)
-    } else if (withStd) {
-      vector match {
-        case DenseVector(values) =>
-          val newValues = transformDenseWithStd(values)
-          Vectors.dense(newValues)
-        case SparseVector(size, indices, values) =>
-          val (newIndices, newValues) = transformSparseWithStd(indices, values)
-          Vectors.sparse(size, newIndices, newValues)
-        case other =>
-          throw new UnsupportedOperationException(
-            s"Only sparse and dense vectors are supported but got 
${other.getClass}.")
-      }
-    } else {
-      // Note that it's safe since we always assume that the data in RDD 
should be immutable.
-      vector
-    }
-  }
-
-  private[spark] def transformWithMean(values: Array[Double]): Array[Double] = 
{
-    // By default, Scala generates Java methods for member variables. So every 
time when
-    // the member variables are accessed, `invokespecial` will be called which 
is expensive.
-    // This can be avoid by having a local reference of `shift`.
-    val localShift = shift
-    val size = values.length
-    if (withStd) {
-      var i = 0
-      while (i < size) {
-        values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / 
std(i)) else 0.0
-        i += 1
-      }
-    } else {
-      var i = 0
-      while (i < size) {
-        values(i) -= localShift(i)
-        i += 1
-      }
-    }
-    values
-  }
-
-  private[spark] def transformDenseWithStd(values: Array[Double]): 
Array[Double] = {
-    val size = values.length
-    val newValues = values.clone()
-    var i = 0
-    while(i < size) {
-      newValues(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0)
-      i += 1
-    }
-    newValues
-  }
 
-  private[spark] def transformSparseWithStd(indices: Array[Int],
-                                            values: Array[Double]): 
(Array[Int], Array[Double]) = {
-    // For sparse vector, the `index` array inside sparse vector object will 
not be changed,
-    // so we can re-use it to save memory.
-    val nnz = values.length
-    val newValues = values.clone()
-    var i = 0
-    while (i < nnz) {
-      newValues(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 
0.0)
-      i += 1
+    (withMean, withStd) match {
+      case (true, true) =>
+        // By default, Scala generates Java methods for member variables. So 
every time when
+        // the member variables are accessed, `invokespecial` will be called 
which is expensive.
+        // This can be avoid by having a local reference of `shift`.
+        val localShift = shift
+        val localScale = scale
+        val values = vector match {
+          // specially handle DenseVector because its toArray does not clone 
already
+          case d: DenseVector => d.values.clone()
+          case v: Vector => v.toArray
+        }
+        val newValues = NewStandardScalerModel
+          .transformWithBoth(localShift, localScale, values)
+        Vectors.dense(newValues)
+
+      case (true, false) =>
+        val localShift = shift
+        val values = vector match {
+          case d: DenseVector => d.values.clone()
+          case v: Vector => v.toArray
+        }
+        val newValues = NewStandardScalerModel
 
 Review comment:
   nit: this can go on one line

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to