Repository: spark Updated Branches: refs/heads/master ca00cc70d -> 0fa5b7cac
[SPARK-21690][ML] one-pass imputer ## What changes were proposed in this pull request? parallelize the computation of all columns performance tests: |numColums| Mean(Old) | Median(Old) | Mean(RDD) | Median(RDD) | Mean(DF) | Median(DF) | |------|----------|------------|----------|------------|----------|------------| |1|0.0771394713|0.0658712813|0.080779802|0.048165981499999996|0.10525509870000001|0.0499620203| |10|0.7234340630999999|0.5954440414|0.0867935197|0.13263428659999998|0.09255724889999999|0.1573943635| |100|7.3756451568|6.2196631259|0.1911931552|0.8625376817000001|0.5557462431|1.7216837982000002| ## How was this patch tested? existing tests Author: Zheng RuiFeng <ruife...@foxmail.com> Closes #18902 from zhengruifeng/parallelize_imputer. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0fa5b7ca Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0fa5b7ca Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0fa5b7ca Branch: refs/heads/master Commit: 0fa5b7cacca4e867dd9f787cc2801616967932a4 Parents: ca00cc7 Author: Zheng RuiFeng <ruife...@foxmail.com> Authored: Wed Sep 13 20:12:21 2017 +0800 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Wed Sep 13 20:12:21 2017 +0800 ---------------------------------------------------------------------- .../org/apache/spark/ml/feature/Imputer.scala | 56 ++++++++++++++------ 1 file changed, 41 insertions(+), 15 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0fa5b7ca/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 9e023b9..1f36ece 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -133,23 +133,49 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) override def fit(dataset: Dataset[_]): ImputerModel = { transformSchema(dataset.schema, logging = true) val spark = dataset.sparkSession - import spark.implicits._ - val surrogates = $(inputCols).map { inputCol => - val ic = col(inputCol) - val filtered = dataset.select(ic.cast(DoubleType)) - .filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN) - if(filtered.take(1).length == 0) { - throw new SparkException(s"surrogate cannot be computed. " + - s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})") - } - val surrogate = $(strategy) match { - case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first() - case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head - } - surrogate + + val cols = $(inputCols).map { inputCol => + when(col(inputCol).equalTo($(missingValue)), null) + .when(col(inputCol).isNaN, null) + .otherwise(col(inputCol)) + .cast("double") + .as(inputCol) + } + + val results = $(strategy) match { + case Imputer.mean => + // Function avg will ignore null automatically. + // For a column only containing null, avg will return null. + val row = dataset.select(cols.map(avg): _*).head() + Array.range(0, $(inputCols).length).map { i => + if (row.isNullAt(i)) { + Double.NaN + } else { + row.getDouble(i) + } + } + + case Imputer.median => + // Function approxQuantile will ignore null automatically. + // For a column only containing null, approxQuantile will return an empty array. + dataset.select(cols: _*).stat.approxQuantile($(inputCols), Array(0.5), 0.001) + .map { array => + if (array.isEmpty) { + Double.NaN + } else { + array.head + } + } + } + + val emptyCols = $(inputCols).zip(results).filter(_._2.isNaN).map(_._1) + if (emptyCols.nonEmpty) { + throw new SparkException(s"surrogate cannot be computed. " + + s"All the values in ${emptyCols.mkString(",")} are Null, Nan or " + + s"missingValue(${$(missingValue)})") } - val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates))) + val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(results))) val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false))) val surrogateDF = spark.createDataFrame(rows, schema) copyValues(new ImputerModel(uid, surrogateDF).setParent(this)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org