Github user holdenk commented on a diff in the pull request: https://github.com/apache/spark/pull/21942#discussion_r216022250 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala --- @@ -160,15 +160,88 @@ class StandardScalerModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) - - // TODO: Make the transformer natively in ml framework to avoid extra conversion. - val transformer: Vector => Vector = v => scaler.transform(OldVectors.fromML(v)).asML + val transformer: Vector => Vector = v => transform(v) val scale = udf(transformer) dataset.withColumn($(outputCol), scale(col($(inputCol)))) } + /** + * Since `shift` will be only used in `withMean` branch, we have it as + * `lazy val` so it will be evaluated in that branch. Note that we don't + * want to create this array multiple times in `transform` function. + */ + private lazy val shift: Array[Double] = mean.toArray + + /** + * Applies standardization transformation on a vector. + * + * @param vector Vector to be standardized. + * @return Standardized vector. If the std of a column is zero, it will return default `0.0` + * for the column with zero std. + */ + private[spark] def transform(vector: Vector): Vector = { + require(mean.size == vector.size) + if ($(withMean)) { + /** + * By default, Scala generates Java methods for member variables. So every time + * member variables are accessed, `invokespecial` is called. This is an expensive + * operation, and can be avoided by having a local reference of `shift`. + */ + val localShift = shift + /** Must have a copy of the values since they will be modified in place. */ + val values = vector match { + /** Handle DenseVector specially because its `toArray` method does not clone values. */ + case d: DenseVector => d.values.clone() + case v: Vector => v.toArray + } + val size = values.length + if ($(withStd)) { + var i = 0 + while (i < size) { + values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0 + i += 1 + } + } else { + var i = 0 + while (i < size) { + values(i) -= localShift(i) + i += 1 + } + } + Vectors.dense(values) + } else if ($(withStd)) { --- End diff -- Maybe leave a comment withStd and not mean since when tracing the code by hand the nested if/else if can get a bit confusing flow wise.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org