mgaido91 commented on a change in pull request #25160: [SPARK-28399][ML] implement RobustScaler URL: https://github.com/apache/spark/pull/25160#discussion_r305575336
########## File path: mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala ########## @@ -134,77 +136,49 @@ class StandardScalerModel @Since("1.3.0") ( @Since("1.1.0") override def transform(vector: Vector): Vector = { require(mean.size == vector.size) - if (withMean) { - // Must have a copy of the values since it will be modified in place - val values = vector match { - // specially handle DenseVector because its toArray does not clone already - case d: DenseVector => d.values.clone() - case v: Vector => v.toArray - } - val newValues = transformWithMean(values) - Vectors.dense(newValues) - } else if (withStd) { - vector match { - case DenseVector(values) => - val newValues = transformDenseWithStd(values) - Vectors.dense(newValues) - case SparseVector(size, indices, values) => - val (newIndices, newValues) = transformSparseWithStd(indices, values) - Vectors.sparse(size, newIndices, newValues) - case other => - throw new UnsupportedOperationException( - s"Only sparse and dense vectors are supported but got ${other.getClass}.") - } - } else { - // Note that it's safe since we always assume that the data in RDD should be immutable. - vector - } - } - - private[spark] def transformWithMean(values: Array[Double]): Array[Double] = { - // By default, Scala generates Java methods for member variables. So every time when - // the member variables are accessed, `invokespecial` will be called which is expensive. - // This can be avoid by having a local reference of `shift`. - val localShift = shift - val size = values.length - if (withStd) { - var i = 0 - while (i < size) { - values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0 - i += 1 - } - } else { - var i = 0 - while (i < size) { - values(i) -= localShift(i) - i += 1 - } - } - values - } - - private[spark] def transformDenseWithStd(values: Array[Double]): Array[Double] = { - val size = values.length - val newValues = values.clone() - var i = 0 - while(i < size) { - newValues(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0) - i += 1 - } - newValues - } - private[spark] def transformSparseWithStd(indices: Array[Int], - values: Array[Double]): (Array[Int], Array[Double]) = { - // For sparse vector, the `index` array inside sparse vector object will not be changed, - // so we can re-use it to save memory. - val nnz = values.length - val newValues = values.clone() - var i = 0 - while (i < nnz) { - newValues(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0) - i += 1 + (withMean, withStd) match { + case (true, true) => + // By default, Scala generates Java methods for member variables. So every time when + // the member variables are accessed, `invokespecial` will be called which is expensive. + // This can be avoid by having a local reference of `shift`. + val localShift = shift + val localScale = scale + val values = vector match { + // specially handle DenseVector because its toArray does not clone already + case d: DenseVector => d.values.clone() + case v: Vector => v.toArray + } + val newValues = NewStandardScalerModel + .transformWithBoth(localShift, localScale, values) + Vectors.dense(newValues) + + case (true, false) => + val localShift = shift + val values = vector match { + case d: DenseVector => d.values.clone() + case v: Vector => v.toArray + } + val newValues = NewStandardScalerModel Review comment: nit: this can go on one line ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org