Repository: spark Updated Branches: refs/heads/master af6ece33d -> b88cb63da
[SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement. ## What changes were proposed in this pull request? Partial revert of #15277 to instead sort and store input to model rather than require sorted input ## How was this patch tested? Existing tests. Author: Sean Owen <so...@cloudera.com> Closes #15299 from srowen/SPARK-17704.2. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b88cb63d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b88cb63d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b88cb63d Branch: refs/heads/master Commit: b88cb63da39786c07cb4bfa70afed32ec5eb3286 Parents: af6ece3 Author: Sean Owen <so...@cloudera.com> Authored: Sat Oct 1 16:10:39 2016 -0400 Committer: Sean Owen <so...@cloudera.com> Committed: Sat Oct 1 16:10:39 2016 -0400 ---------------------------------------------------------------------- .../apache/spark/ml/feature/ChiSqSelector.scala | 2 +- .../spark/mllib/feature/ChiSqSelector.scala | 22 ++++++++++---------- python/pyspark/ml/feature.py | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b88cb63d/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 9c131a4..d0385e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -193,7 +193,7 @@ final class ChiSqSelectorModel private[ml] ( import ChiSqSelectorModel._ - /** list of indices to select (filter). Must be ordered asc */ + /** list of indices to select (filter). */ @Since("1.6.0") val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures http://git-wip-us.apache.org/repos/asf/spark/blob/b88cb63d/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 706ce78..c305b36 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -35,14 +35,15 @@ import org.apache.spark.sql.{Row, SparkSession} /** * Chi Squared selector model. * - * @param selectedFeatures list of indices to select (filter). Must be ordered asc + * @param selectedFeatures list of indices to select (filter). */ @Since("1.3.0") class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { - require(isSorted(selectedFeatures), "Array has to be sorted asc") + private val filterIndices = selectedFeatures.sorted + @deprecated("not intended for subclasses to use", "2.1.0") protected def isSorted(array: Array[Int]): Boolean = { var i = 1 val len = array.length @@ -61,7 +62,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( */ @Since("1.3.0") override def transform(vector: Vector): Vector = { - compress(vector, selectedFeatures) + compress(vector) } /** @@ -69,9 +70,8 @@ class ChiSqSelectorModel @Since("1.3.0") ( * Preserves the order of filtered features the same as their indices are stored. * Might be moved to Vector as .slice * @param features vector - * @param filterIndices indices of features to filter, must be ordered asc */ - private def compress(features: Vector, filterIndices: Array[Int]): Vector = { + private def compress(features: Vector): Vector = { features match { case SparseVector(size, indices, values) => val newSize = filterIndices.length @@ -230,23 +230,23 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { */ @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { - val chiSqTestResult = Statistics.chiSqTest(data) + val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex val features = selectorType match { case ChiSqSelector.KBest => - chiSqTestResult.zipWithIndex + chiSqTestResult .sortBy { case (res, _) => -res.statistic } .take(numTopFeatures) case ChiSqSelector.Percentile => - chiSqTestResult.zipWithIndex + chiSqTestResult .sortBy { case (res, _) => -res.statistic } .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => - chiSqTestResult.zipWithIndex - .filter{ case (res, _) => res.pValue < alpha } + chiSqTestResult + .filter { case (res, _) => res.pValue < alpha } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } - val indices = features.map { case (_, indices) => indices }.sorted + val indices = features.map { case (_, index) => index } new ChiSqSelectorModel(indices) } } http://git-wip-us.apache.org/repos/asf/spark/blob/b88cb63d/python/pyspark/ml/feature.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 12a1384..64b21ca 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2705,7 +2705,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable): @since("2.0.0") def selectedFeatures(self): """ - List of indices to select (filter). Must be ordered asc. + List of indices to select (filter). """ return self._call_java("selectedFeatures") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org