Repository: spark Updated Branches: refs/heads/master 22014e6fb -> f4344582b
[SPARK-14497][ML] Use top instead of sortBy() to get top N frequent words as dict in ConutVectorizer ## What changes were proposed in this pull request? Replace sortBy() with top() to calculate the top N frequent words as dictionary. ## How was this patch tested? existing unit tests. The terms with same TF would be sorted in descending order. The test would fail if hardcode the terms with same TF the dictionary like "c", "d"... Author: fwang1 <desperado...@gmail.com> Closes #12265 from lionelfeng/master. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f4344582 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f4344582 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f4344582 Branch: refs/heads/master Commit: f4344582ba28983bf3892d08e11236f090f5bf92 Parents: 22014e6 Author: fwang1 <desperado...@gmail.com> Authored: Sun Apr 10 01:13:25 2016 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Sun Apr 10 01:13:25 2016 -0700 ---------------------------------------------------------------------- .../org/apache/spark/ml/feature/CountVectorizer.scala | 14 ++++---------- .../spark/ml/feature/CountVectorizerSuite.scala | 7 ++++--- 2 files changed, 8 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f4344582/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index f1be971..00abbbe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -170,16 +170,10 @@ class CountVectorizer(override val uid: String) (word, count) }.cache() val fullVocabSize = wordCounts.count() - val vocab: Array[String] = { - val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { - // Use all terms - wordCounts.collect().sortBy(-_._2) - } else { - // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocSize) - } - tmpSortedWC.map(_._1) - } + + val vocab = wordCounts + .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2)) + .map(_._1) require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.") copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) http://git-wip-us.apache.org/repos/asf/spark/blob/f4344582/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index ff0de06..7641e3b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -59,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext (0, split("a b c d e"), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), - (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))), - (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) + (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))), + (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))), + (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) ).toDF("id", "words", "expected") val cv = new CountVectorizer() .setInputCol("words") .setOutputCol("features") .fit(df) - assert(cv.vocabulary === Array("a", "b", "c", "d", "e")) + assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e")) cv.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org