Repository: spark
Updated Branches:
  refs/heads/master a19a1bb59 -> f7082ac12


[SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement.

## What changes were proposed in this pull request?
Several performance improvement for ```ChiSqSelector```:
1, Keep ```selectedFeatures``` ordered ascendent.
```ChiSqSelectorModel.transform``` need ```selectedFeatures``` ordered to make 
prediction. We should sort it when training model rather than making 
prediction, since users usually train model once and use the model to do 
prediction multiple times.
2, When training ```fpr``` type ```ChiSqSelectorModel```, it's not necessary to 
sort the ChiSq test result by statistic.

## How was this patch tested?
Existing unit tests.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #15277 from yanboliang/spark-17704.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f7082ac1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f7082ac1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f7082ac1

Branch: refs/heads/master
Commit: f7082ac12518ae84d6d1d4b7330a9f12cf95e7c1
Parents: a19a1bb
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Thu Sep 29 04:30:42 2016 -0700
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Thu Sep 29 04:30:42 2016 -0700

----------------------------------------------------------------------
 .../spark/mllib/feature/ChiSqSelector.scala     | 45 +++++++++++++-------
 project/MimaExcludes.scala                      |  3 --
 2 files changed, 30 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f7082ac1/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 0f7c6e8..706ce78 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -35,12 +35,24 @@ import org.apache.spark.sql.{Row, SparkSession}
 /**
  * Chi Squared selector model.
  *
- * @param selectedFeatures list of indices to select (filter).
+ * @param selectedFeatures list of indices to select (filter). Must be ordered 
asc
  */
 @Since("1.3.0")
 class ChiSqSelectorModel @Since("1.3.0") (
   @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer 
with Saveable {
 
+  require(isSorted(selectedFeatures), "Array has to be sorted asc")
+
+  protected def isSorted(array: Array[Int]): Boolean = {
+    var i = 1
+    val len = array.length
+    while (i < len) {
+      if (array(i) < array(i-1)) return false
+      i += 1
+    }
+    true
+  }
+
   /**
    * Applies transformation on a vector.
    *
@@ -57,22 +69,21 @@ class ChiSqSelectorModel @Since("1.3.0") (
    * Preserves the order of filtered features the same as their indices are 
stored.
    * Might be moved to Vector as .slice
    * @param features vector
-   * @param filterIndices indices of features to filter
+   * @param filterIndices indices of features to filter, must be ordered asc
    */
   private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
-    val orderedIndices = filterIndices.sorted
     features match {
       case SparseVector(size, indices, values) =>
-        val newSize = orderedIndices.length
+        val newSize = filterIndices.length
         val newValues = new ArrayBuilder.ofDouble
         val newIndices = new ArrayBuilder.ofInt
         var i = 0
         var j = 0
         var indicesIdx = 0
         var filterIndicesIdx = 0
-        while (i < indices.length && j < orderedIndices.length) {
+        while (i < indices.length && j < filterIndices.length) {
           indicesIdx = indices(i)
-          filterIndicesIdx = orderedIndices(j)
+          filterIndicesIdx = filterIndices(j)
           if (indicesIdx == filterIndicesIdx) {
             newIndices += j
             newValues += values(i)
@@ -90,7 +101,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
         Vectors.sparse(newSize, newIndices.result(), newValues.result())
       case DenseVector(values) =>
         val values = features.toArray
-        Vectors.dense(orderedIndices.map(i => values(i)))
+        Vectors.dense(filterIndices.map(i => values(i)))
       case other =>
         throw new UnsupportedOperationException(
           s"Only sparse and dense vectors are supported but got 
${other.getClass}.")
@@ -220,18 +231,22 @@ class ChiSqSelector @Since("2.1.0") () extends 
Serializable {
   @Since("1.3.0")
   def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
     val chiSqTestResult = Statistics.chiSqTest(data)
-      .zipWithIndex.sortBy { case (res, _) => -res.statistic }
     val features = selectorType match {
-      case ChiSqSelector.KBest => chiSqTestResult
-        .take(numTopFeatures)
-      case ChiSqSelector.Percentile => chiSqTestResult
-        .take((chiSqTestResult.length * percentile).toInt)
-      case ChiSqSelector.FPR => chiSqTestResult
-        .filter{ case (res, _) => res.pValue < alpha }
+      case ChiSqSelector.KBest =>
+        chiSqTestResult.zipWithIndex
+          .sortBy { case (res, _) => -res.statistic }
+          .take(numTopFeatures)
+      case ChiSqSelector.Percentile =>
+        chiSqTestResult.zipWithIndex
+          .sortBy { case (res, _) => -res.statistic }
+          .take((chiSqTestResult.length * percentile).toInt)
+      case ChiSqSelector.FPR =>
+        chiSqTestResult.zipWithIndex
+          .filter{ case (res, _) => res.pValue < alpha }
       case errorType =>
         throw new IllegalStateException(s"Unknown ChiSqSelector Type: 
$errorType")
     }
-    val indices = features.map { case (_, indices) => indices }
+    val indices = features.map { case (_, indices) => indices }.sorted
     new ChiSqSelectorModel(indices)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f7082ac1/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 8024fbd..4db3edb 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -818,9 +818,6 @@ object MimaExcludes {
       // [SPARK-17163] Unify logistic regression interface. Private 
constructor has new signature.
       
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this")
     ) ++ Seq(
-      // [SPARK-17017] Add chiSquare selector based on False Positive Rate 
(FPR) test
-      
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted")
-    ) ++ Seq(
       // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce 
RPC call time
       
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext")
     ) ++ Seq(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to