Github user srowen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/14299#discussion_r71703375
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala ---
    @@ -313,133 +313,139 @@ class Word2Vec extends Serializable with Logging {
         val expTable = sc.broadcast(createExpTable())
         val bcVocab = sc.broadcast(vocab)
         val bcVocabHash = sc.broadcast(vocabHash)
    -    // each partition is a collection of sentences,
    -    // will be translated into arrays of Index integer
    -    val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter 
=>
    -      // Each sentence will map to 0 or more Array[Int]
    -      sentenceIter.flatMap { sentence =>
    -        // Sentence of words, some of which map to a word index
    -        val wordIndexes = sentence.flatMap(bcVocabHash.value.get)
    -        // break wordIndexes into trunks of maxSentenceLength when has more
    -        wordIndexes.grouped(maxSentenceLength).map(_.toArray)
    +
    +    try {
    +      // each partition is a collection of sentences,
    +      // will be translated into arrays of Index integer
    +      val sentences: RDD[Array[Int]] = dataset.mapPartitions { 
sentenceIter =>
    +        // Each sentence will map to 0 or more Array[Int]
    +        sentenceIter.flatMap { sentence =>
    +          // Sentence of words, some of which map to a word index
    +          val wordIndexes = sentence.flatMap(bcVocabHash.value.get)
    +          // break wordIndexes into trunks of maxSentenceLength when has 
more
    +          wordIndexes.grouped(maxSentenceLength).map(_.toArray)
    +        }
           }
    -    }
     
    -    val newSentences = sentences.repartition(numPartitions).cache()
    -    val initRandom = new XORShiftRandom(seed)
    +      val newSentences = sentences.repartition(numPartitions).cache()
    +      val initRandom = new XORShiftRandom(seed)
     
    -    if (vocabSize.toLong * vectorSize >= Int.MaxValue) {
    -      throw new RuntimeException("Please increase minCount or decrease 
vectorSize in Word2Vec" +
    -        " to avoid an OOM. You are highly recommended to make your 
vocabSize*vectorSize, " +
    -        "which is " + vocabSize + "*" + vectorSize + " for now, less than 
`Int.MaxValue`.")
    -    }
    +      if (vocabSize.toLong * vectorSize >= Int.MaxValue) {
    +        throw new RuntimeException("Please increase minCount or decrease 
vectorSize in Word2Vec" +
    +          " to avoid an OOM. You are highly recommended to make your 
vocabSize*vectorSize, " +
    +          "which is " + vocabSize + "*" + vectorSize + " for now, less 
than `Int.MaxValue`.")
    +      }
     
    -    val syn0Global =
    -      Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 
0.5f) / vectorSize)
    -    val syn1Global = new Array[Float](vocabSize * vectorSize)
    -    var alpha = learningRate
    -
    -    for (k <- 1 to numIterations) {
    -      val bcSyn0Global = sc.broadcast(syn0Global)
    -      val bcSyn1Global = sc.broadcast(syn1Global)
    -      val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) 
=>
    -        val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 
1) << 8))
    -        val syn0Modify = new Array[Int](vocabSize)
    -        val syn1Modify = new Array[Int](vocabSize)
    -        val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 
0L, 0L)) {
    -          case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
    -            var lwc = lastWordCount
    -            var wc = wordCount
    -            if (wordCount - lastWordCount > 10000) {
    -              lwc = wordCount
    -              // TODO: discount by iteration?
    -              alpha =
    -                learningRate * (1 - numPartitions * wordCount.toDouble / 
(trainWordsCount + 1))
    -              if (alpha < learningRate * 0.0001) alpha = learningRate * 
0.0001
    -              logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
    -            }
    -            wc += sentence.length
    -            var pos = 0
    -            while (pos < sentence.length) {
    -              val word = sentence(pos)
    -              val b = random.nextInt(window)
    -              // Train Skip-gram
    -              var a = b
    -              while (a < window * 2 + 1 - b) {
    -                if (a != window) {
    -                  val c = pos - window + a
    -                  if (c >= 0 && c < sentence.length) {
    -                    val lastWord = sentence(c)
    -                    val l1 = lastWord * vectorSize
    -                    val neu1e = new Array[Float](vectorSize)
    -                    // Hierarchical softmax
    -                    var d = 0
    -                    while (d < bcVocab.value(word).codeLen) {
    -                      val inner = bcVocab.value(word).point(d)
    -                      val l2 = inner * vectorSize
    -                      // Propagate hidden -> output
    -                      var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 
1)
    -                      if (f > -MAX_EXP && f < MAX_EXP) {
    -                        val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / 
MAX_EXP / 2.0)).toInt
    -                        f = expTable.value(ind)
    -                        val g = ((1 - bcVocab.value(word).code(d) - f) * 
alpha).toFloat
    -                        blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
    -                        blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
    -                        syn1Modify(inner) += 1
    +      val syn0Global =
    +        Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() 
- 0.5f) / vectorSize)
    +      val syn1Global = new Array[Float](vocabSize * vectorSize)
    +      var alpha = learningRate
    +
    +      for (k <- 1 to numIterations) {
    +        val bcSyn0Global = sc.broadcast(syn0Global)
    +        val bcSyn1Global = sc.broadcast(syn1Global)
    +        val partial = newSentences.mapPartitionsWithIndex { case (idx, 
iter) =>
    +          val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k 
- 1) << 8))
    +          val syn0Modify = new Array[Int](vocabSize)
    +          val syn1Modify = new Array[Int](vocabSize)
    +          val model = iter.foldLeft((bcSyn0Global.value, 
bcSyn1Global.value, 0L, 0L)) {
    +            case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
    +              var lwc = lastWordCount
    +              var wc = wordCount
    +              if (wordCount - lastWordCount > 10000) {
    +                lwc = wordCount
    +                // TODO: discount by iteration?
    +                alpha =
    +                  learningRate * (1 - numPartitions * wordCount.toDouble / 
(trainWordsCount + 1))
    +                if (alpha < learningRate * 0.0001) alpha = learningRate * 
0.0001
    +                logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
    +              }
    +              wc += sentence.length
    +              var pos = 0
    +              while (pos < sentence.length) {
    +                val word = sentence(pos)
    +                val b = random.nextInt(window)
    +                // Train Skip-gram
    +                var a = b
    +                while (a < window * 2 + 1 - b) {
    +                  if (a != window) {
    +                    val c = pos - window + a
    +                    if (c >= 0 && c < sentence.length) {
    +                      val lastWord = sentence(c)
    +                      val l1 = lastWord * vectorSize
    +                      val neu1e = new Array[Float](vectorSize)
    +                      // Hierarchical softmax
    +                      var d = 0
    +                      while (d < bcVocab.value(word).codeLen) {
    +                        val inner = bcVocab.value(word).point(d)
    +                        val l2 = inner * vectorSize
    +                        // Propagate hidden -> output
    +                        var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, 
l2, 1)
    +                        if (f > -MAX_EXP && f < MAX_EXP) {
    +                          val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / 
MAX_EXP / 2.0)).toInt
    +                          f = expTable.value(ind)
    +                          val g = ((1 - bcVocab.value(word).code(d) - f) * 
alpha).toFloat
    +                          blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 
1)
    +                          blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 
1)
    +                          syn1Modify(inner) += 1
    +                        }
    +                        d += 1
                           }
    -                      d += 1
    +                      blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 
1)
    +                      syn0Modify(lastWord) += 1
                         }
    -                    blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
    -                    syn0Modify(lastWord) += 1
                       }
    +                  a += 1
                     }
    -                a += 1
    +                pos += 1
                   }
    -              pos += 1
    +              (syn0, syn1, lwc, wc)
    +          }
    +          val syn0Local = model._1
    +          val syn1Local = model._2
    +          // Only output modified vectors.
    +          Iterator.tabulate(vocabSize) { index =>
    +            if (syn0Modify(index) > 0) {
    +              Some((index, syn0Local.slice(index * vectorSize, (index + 1) 
* vectorSize)))
    +            } else {
    +              None
    +            }
    +          }.flatten ++ Iterator.tabulate(vocabSize) { index =>
    +            if (syn1Modify(index) > 0) {
    +              Some((index + vocabSize, syn1Local.slice(index * vectorSize, 
(index + 1) * vectorSize)))
    +            } else {
    +              None
                 }
    -            (syn0, syn1, lwc, wc)
    +          }.flatten
             }
    -        val syn0Local = model._1
    -        val syn1Local = model._2
    -        // Only output modified vectors.
    -        Iterator.tabulate(vocabSize) { index =>
    -          if (syn0Modify(index) > 0) {
    -            Some((index, syn0Local.slice(index * vectorSize, (index + 1) * 
vectorSize)))
    -          } else {
    -            None
    -          }
    -        }.flatten ++ Iterator.tabulate(vocabSize) { index =>
    -          if (syn1Modify(index) > 0) {
    -            Some((index + vocabSize, syn1Local.slice(index * vectorSize, 
(index + 1) * vectorSize)))
    -          } else {
    -            None
    -          }
    -        }.flatten
    -      }
    -      val synAgg = partial.reduceByKey { case (v1, v2) =>
    +        val synAgg = partial.reduceByKey { case (v1, v2) =>
               blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
               v1
    -      }.collect()
    -      var i = 0
    -      while (i < synAgg.length) {
    -        val index = synAgg(i)._1
    -        if (index < vocabSize) {
    -          Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, 
vectorSize)
    -        } else {
    -          Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * 
vectorSize, vectorSize)
    +        }.collect()
    +        var i = 0
    +        while (i < synAgg.length) {
    +          val index = synAgg(i)._1
    +          if (index < vocabSize) {
    +            Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, 
vectorSize)
    +          } else {
    +            Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * 
vectorSize, vectorSize)
    +          }
    +          i += 1
             }
    -        i += 1
    +        bcSyn0Global.unpersist(false)
    +        bcSyn1Global.unpersist(false)
           }
    -      bcSyn0Global.unpersist(false)
    -      bcSyn1Global.unpersist(false)
    -    }
    -    newSentences.unpersist()
    -    expTable.destroy()
    -    bcVocab.destroy()
    -    bcVocabHash.destroy()
    +      newSentences.unpersist()
     
    -    val wordArray = vocab.map(_.word)
    -    new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
    +      val wordArray = vocab.map(_.word)
    +      new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
    +    }
    +    finally
    --- End diff --
    
    Not sure this will pass the style checker. My only hesitation is that 
really if we do this one place we should do it in 100 places, and the exception 
path here is not a usual one.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to