This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch branch-2.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.3 by this push: new 220f29a [SPARK-28081][ML] Handle large vocab counts in word2vec 220f29a is described below commit 220f29a6f5b681a67a7e9a9351f25389c303b956 Author: Sean Owen <sean.o...@databricks.com> AuthorDate: Tue Jun 18 20:27:43 2019 -0500 [SPARK-28081][ML] Handle large vocab counts in word2vec ## What changes were proposed in this pull request? The word2vec logic fails if a corpora has a word with count > 1e9. We should be able to handle very large counts generally better here by using longs to count. This takes over https://github.com/apache/spark/pull/24814 ## How was this patch tested? Existing tests. Closes #24893 from srowen/SPARK-28081. Authored-by: Sean Owen <sean.o...@databricks.com> Signed-off-by: Sean Owen <sean.o...@databricks.com> (cherry picked from commit e96dd82f12f2b6d93860e23f4f98a86c3faf57c5) Signed-off-by: Sean Owen <sean.o...@databricks.com> --- .../src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index b8c306d..d5b91df 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.random.XORShiftRandom */ private case class VocabWord( var word: String, - var cn: Int, + var cn: Long, var point: Array[Int], var code: Array[Int], var codeLen: Int @@ -194,7 +194,7 @@ class Word2Vec extends Serializable with Logging { new Array[Int](MAX_CODE_LENGTH), 0)) .collect() - .sortWith((a, b) => a.cn > b.cn) + .sortBy(_.cn)(Ordering[Long].reverse) vocabSize = vocab.length require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " + @@ -232,7 +232,7 @@ class Word2Vec extends Serializable with Logging { a += 1 } while (a < 2 * vocabSize) { - count(a) = 1e9.toInt + count(a) = Long.MaxValue a += 1 } var pos1 = vocabSize - 1 @@ -267,6 +267,8 @@ class Word2Vec extends Serializable with Logging { min2i = pos2 pos2 += 1 } + assert(count(min1i) < Long.MaxValue) + assert(count(min2i) < Long.MaxValue) count(vocabSize + a) = count(min1i) + count(min2i) parentNode(min1i) = vocabSize + a parentNode(min2i) = vocabSize + a --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org