Repository: spark
Updated Branches:
  refs/heads/master 9693b0d5a -> a0af0e351


[SPARK-11898][MLLIB] Use broadcast for the global tables in Word2Vec

jira: https://issues.apache.org/jira/browse/SPARK-11898
syn0Global and sync1Global in word2vec are quite large objects with size (vocab 
* vectorSize * 8), yet they are passed to worker using basic task serialization.

Use broadcast can greatly improve the performance. My benchmark shows that, for 
1M vocabulary and default vectorSize 100, changing to broadcast can help,

1. decrease the worker memory consumption by 45%.
2. decrease running time by 40%.

This will also help extend the upper limit for Word2Vec.

Author: Yuhao Yang <hhb...@gmail.com>

Closes #9878 from hhbyyh/w2vBC.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a0af0e35
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a0af0e35
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a0af0e35

Branch: refs/heads/master
Commit: a0af0e351e45a8be47a6f65efd132eaa4a00c9e4
Parents: 9693b0d
Author: Yuhao Yang <hhb...@gmail.com>
Authored: Tue Dec 1 09:26:58 2015 +0000
Committer: Sean Owen <so...@cloudera.com>
Committed: Tue Dec 1 09:26:58 2015 +0000

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala  | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a0af0e35/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index a47f27b..655ac0b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -316,12 +316,15 @@ class Word2Vec extends Serializable with Logging {
       Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 
0.5f) / vectorSize)
     val syn1Global = new Array[Float](vocabSize * vectorSize)
     var alpha = learningRate
+
     for (k <- 1 to numIterations) {
+      val bcSyn0Global = sc.broadcast(syn0Global)
+      val bcSyn1Global = sc.broadcast(syn1Global)
       val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
         val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) 
<< 8))
         val syn0Modify = new Array[Int](vocabSize)
         val syn1Modify = new Array[Int](vocabSize)
-        val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
+        val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 
0)) {
           case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
             var lwc = lastWordCount
             var wc = wordCount
@@ -405,6 +408,8 @@ class Word2Vec extends Serializable with Logging {
         }
         i += 1
       }
+      bcSyn0Global.unpersist(false)
+      bcSyn1Global.unpersist(false)
     }
     newSentences.unpersist()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to