Repository: spark Updated Branches: refs/heads/branch-1.5 253e3eb71 -> 7fdd7cf09
[SPARK-12685][MLLIB][BACKPORT TO 1.4] word2vec trainWordsCount gets overflow jira: https://issues.apache.org/jira/browse/SPARK-12685 master PR: https://github.com/apache/spark/pull/10627 the log of word2vec reports trainWordsCount = -785727483 during computation over a large dataset. Update the priority as it will affect the computation process. alpha = learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) Author: Yuhao Yang <hhb...@gmail.com> Closes #10721 from hhbyyh/branch-1.4. (cherry picked from commit 7bd2564192f51f6229cf759a2bafc22134479955) Signed-off-by: Joseph K. Bradley <jos...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7fdd7cf0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7fdd7cf0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7fdd7cf0 Branch: refs/heads/branch-1.5 Commit: 7fdd7cf09a48b7cbe0c0de11482a6aa6a574d9a7 Parents: 253e3eb Author: Yuhao Yang <hhb...@gmail.com> Authored: Wed Jan 13 11:53:25 2016 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Wed Jan 13 11:53:45 2016 -0800 ---------------------------------------------------------------------- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7fdd7cf0/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 131a862..1a9ac47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -146,7 +146,7 @@ class Word2Vec extends Serializable with Logging { /** context words from [-window, window] */ private val window = 5 - private var trainWordsCount = 0 + private var trainWordsCount = 0L private var vocabSize = 0 @transient private var vocab: Array[VocabWord] = null @transient private var vocabHash = mutable.HashMap.empty[String, Int] @@ -154,13 +154,13 @@ class Word2Vec extends Serializable with Logging { private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) + .filter(_._2 >= minCount) .map(x => VocabWord( x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) - .filter(_.cn >= minCount) .collect() .sortWith((a, b) => a.cn > b.cn) @@ -174,7 +174,7 @@ class Word2Vec extends Serializable with Logging { trainWordsCount += vocab(a).cn a += 1 } - logInfo("trainWordsCount = " + trainWordsCount) + logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount") } private def createExpTable(): Array[Float] = { @@ -324,7 +324,7 @@ class Word2Vec extends Serializable with Logging { val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) val syn1Modify = new Array[Int](vocabSize) - val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { + val model = iter.foldLeft((syn0Global, syn1Global, 0L, 0L)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org