Repository: spark Updated Branches: refs/heads/master a19a1bb59 -> f7082ac12
[SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement. ## What changes were proposed in this pull request? Several performance improvement for ```ChiSqSelector```: 1, Keep ```selectedFeatures``` ordered ascendent. ```ChiSqSelectorModel.transform``` need ```selectedFeatures``` ordered to make prediction. We should sort it when training model rather than making prediction, since users usually train model once and use the model to do prediction multiple times. 2, When training ```fpr``` type ```ChiSqSelectorModel```, it's not necessary to sort the ChiSq test result by statistic. ## How was this patch tested? Existing unit tests. Author: Yanbo Liang <yblia...@gmail.com> Closes #15277 from yanboliang/spark-17704. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f7082ac1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f7082ac1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f7082ac1 Branch: refs/heads/master Commit: f7082ac12518ae84d6d1d4b7330a9f12cf95e7c1 Parents: a19a1bb Author: Yanbo Liang <yblia...@gmail.com> Authored: Thu Sep 29 04:30:42 2016 -0700 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Thu Sep 29 04:30:42 2016 -0700 ---------------------------------------------------------------------- .../spark/mllib/feature/ChiSqSelector.scala | 45 +++++++++++++------- project/MimaExcludes.scala | 3 -- 2 files changed, 30 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f7082ac1/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 0f7c6e8..706ce78 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -35,12 +35,24 @@ import org.apache.spark.sql.{Row, SparkSession} /** * Chi Squared selector model. * - * @param selectedFeatures list of indices to select (filter). + * @param selectedFeatures list of indices to select (filter). Must be ordered asc */ @Since("1.3.0") class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { + require(isSorted(selectedFeatures), "Array has to be sorted asc") + + protected def isSorted(array: Array[Int]): Boolean = { + var i = 1 + val len = array.length + while (i < len) { + if (array(i) < array(i-1)) return false + i += 1 + } + true + } + /** * Applies transformation on a vector. * @@ -57,22 +69,21 @@ class ChiSqSelectorModel @Since("1.3.0") ( * Preserves the order of filtered features the same as their indices are stored. * Might be moved to Vector as .slice * @param features vector - * @param filterIndices indices of features to filter + * @param filterIndices indices of features to filter, must be ordered asc */ private def compress(features: Vector, filterIndices: Array[Int]): Vector = { - val orderedIndices = filterIndices.sorted features match { case SparseVector(size, indices, values) => - val newSize = orderedIndices.length + val newSize = filterIndices.length val newValues = new ArrayBuilder.ofDouble val newIndices = new ArrayBuilder.ofInt var i = 0 var j = 0 var indicesIdx = 0 var filterIndicesIdx = 0 - while (i < indices.length && j < orderedIndices.length) { + while (i < indices.length && j < filterIndices.length) { indicesIdx = indices(i) - filterIndicesIdx = orderedIndices(j) + filterIndicesIdx = filterIndices(j) if (indicesIdx == filterIndicesIdx) { newIndices += j newValues += values(i) @@ -90,7 +101,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( Vectors.sparse(newSize, newIndices.result(), newValues.result()) case DenseVector(values) => val values = features.toArray - Vectors.dense(orderedIndices.map(i => values(i))) + Vectors.dense(filterIndices.map(i => values(i))) case other => throw new UnsupportedOperationException( s"Only sparse and dense vectors are supported but got ${other.getClass}.") @@ -220,18 +231,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val chiSqTestResult = Statistics.chiSqTest(data) - .zipWithIndex.sortBy { case (res, _) => -res.statistic } val features = selectorType match { - case ChiSqSelector.KBest => chiSqTestResult - .take(numTopFeatures) - case ChiSqSelector.Percentile => chiSqTestResult - .take((chiSqTestResult.length * percentile).toInt) - case ChiSqSelector.FPR => chiSqTestResult - .filter{ case (res, _) => res.pValue < alpha } + case ChiSqSelector.KBest => + chiSqTestResult.zipWithIndex + .sortBy { case (res, _) => -res.statistic } + .take(numTopFeatures) + case ChiSqSelector.Percentile => + chiSqTestResult.zipWithIndex + .sortBy { case (res, _) => -res.statistic } + .take((chiSqTestResult.length * percentile).toInt) + case ChiSqSelector.FPR => + chiSqTestResult.zipWithIndex + .filter{ case (res, _) => res.pValue < alpha } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } - val indices = features.map { case (_, indices) => indices } + val indices = features.map { case (_, indices) => indices }.sorted new ChiSqSelectorModel(indices) } } http://git-wip-us.apache.org/repos/asf/spark/blob/f7082ac1/project/MimaExcludes.scala ---------------------------------------------------------------------- diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8024fbd..4db3edb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -818,9 +818,6 @@ object MimaExcludes { // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") ) ++ Seq( - // [SPARK-17017] Add chiSquare selector based on False Positive Rate (FPR) test - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted") - ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") ) ++ Seq( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org