This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-2.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.3 by this push:
     new 220f29a  [SPARK-28081][ML] Handle large vocab counts in word2vec
220f29a is described below

commit 220f29a6f5b681a67a7e9a9351f25389c303b956
Author: Sean Owen <sean.o...@databricks.com>
AuthorDate: Tue Jun 18 20:27:43 2019 -0500

    [SPARK-28081][ML] Handle large vocab counts in word2vec
    
    ## What changes were proposed in this pull request?
    
    The word2vec logic fails if a corpora has a word with count > 1e9. We 
should be able to handle very large counts generally better here by using longs 
to count.
    
    This takes over https://github.com/apache/spark/pull/24814
    
    ## How was this patch tested?
    
    Existing tests.
    
    Closes #24893 from srowen/SPARK-28081.
    
    Authored-by: Sean Owen <sean.o...@databricks.com>
    Signed-off-by: Sean Owen <sean.o...@databricks.com>
    (cherry picked from commit e96dd82f12f2b6d93860e23f4f98a86c3faf57c5)
    Signed-off-by: Sean Owen <sean.o...@databricks.com>
---
 .../src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala  | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index b8c306d..d5b91df 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -45,7 +45,7 @@ import org.apache.spark.util.random.XORShiftRandom
  */
 private case class VocabWord(
   var word: String,
-  var cn: Int,
+  var cn: Long,
   var point: Array[Int],
   var code: Array[Int],
   var codeLen: Int
@@ -194,7 +194,7 @@ class Word2Vec extends Serializable with Logging {
         new Array[Int](MAX_CODE_LENGTH),
         0))
       .collect()
-      .sortWith((a, b) => a.cn > b.cn)
+      .sortBy(_.cn)(Ordering[Long].reverse)
 
     vocabSize = vocab.length
     require(vocabSize > 0, "The vocabulary size should be > 0. You may need to 
check " +
@@ -232,7 +232,7 @@ class Word2Vec extends Serializable with Logging {
       a += 1
     }
     while (a < 2 * vocabSize) {
-      count(a) = 1e9.toInt
+      count(a) = Long.MaxValue
       a += 1
     }
     var pos1 = vocabSize - 1
@@ -267,6 +267,8 @@ class Word2Vec extends Serializable with Logging {
         min2i = pos2
         pos2 += 1
       }
+      assert(count(min1i) < Long.MaxValue)
+      assert(count(min2i) < Long.MaxValue)
       count(vocabSize + a) = count(min1i) + count(min2i)
       parentNode(min1i) = vocabSize + a
       parentNode(min2i) = vocabSize + a


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to