Github user sethah commented on a diff in the pull request:

    https://github.com/apache/spark/pull/17673#discussion_r143048261
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/Word2VecCBOWSolver.scala ---
    @@ -0,0 +1,344 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.feature
    +
    +import com.github.fommil.netlib.BLAS.{getInstance => blas}
    +
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.mllib.feature
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.util.random.XORShiftRandom
    +
    +object Word2VecCBOWSolver extends Logging {
    +  // learning rate is updated for every batch of size batchSize
    +  private val batchSize = 10000
    +
    +  // power to raise the unigram distribution with
    +  private val power = 0.75
    +
    +  private val MAX_EXP = 6
    +
    +  case class Vocabulary(
    +    totalWordCount: Long,
    +    vocabMap: Map[String, Int],
    +    unigramTable: Array[Int],
    +    samplingTable: Array[Float])
    +
    +  /**
    +   * This method implements Word2Vec Continuous Bag Of Words based 
implementation using
    +   * negative sampling optimization, using BLAS for vectorizing operations 
where applicable.
    +   * The algorithm is parallelized in the same way as the skip-gram based 
estimation.
    +   * We divide input data into N equally sized random partitions.
    +   * We then generate initial weights and broadcast them to the N 
partitions. This way
    +   * all the partitions start with the same initial weights. We then run N 
independent
    +   * estimations that each estimate a model on a partition. The weights 
learned
    +   * from each of the N models are averaged and rebroadcast the weights.
    +   * This process is repeated `maxIter` number of times.
    +   *
    +   * @param input A RDD of strings. Each string would be considered a 
sentence.
    +   * @return Estimated word2vec model
    +   */
    +  def fitCBOW[S <: Iterable[String]](
    +      word2Vec: Word2Vec,
    +      input: RDD[S]): feature.Word2VecModel = {
    +
    +    val negativeSamples = word2Vec.getNegativeSamples
    +    val sample = word2Vec.getSample
    +
    +    val Vocabulary(totalWordCount, vocabMap, uniTable, sampleTable) =
    +      generateVocab(input, word2Vec.getMinCount, sample, 
word2Vec.getUnigramTableSize)
    +    val vocabSize = vocabMap.size
    +
    +    assert(negativeSamples < vocabSize, s"Vocab size ($vocabSize) cannot 
be smaller" +
    +      s" than negative samples($negativeSamples)")
    +
    +    val seed = word2Vec.getSeed
    +    val initRandom = new XORShiftRandom(seed)
    +
    +    val vectorSize = word2Vec.getVectorSize
    +    val syn0Global = Array.fill(vocabSize * 
vectorSize)(initRandom.nextFloat - 0.5f)
    +    val syn1Global = Array.fill(vocabSize * vectorSize)(0.0f)
    +
    +    val sc = input.context
    +
    +    val vocabMapBroadcast = sc.broadcast(vocabMap)
    +    val unigramTableBroadcast = sc.broadcast(uniTable)
    +    val sampleTableBroadcast = sc.broadcast(sampleTable)
    +
    +    val windowSize = word2Vec.getWindowSize
    +    val maxSentenceLength = word2Vec.getMaxSentenceLength
    +    val numPartitions = word2Vec.getNumPartitions
    +
    +    val digitSentences = input.flatMap { sentence =>
    +      val wordIndexes = sentence.flatMap(vocabMapBroadcast.value.get)
    +      wordIndexes.grouped(maxSentenceLength).map(_.toArray)
    +    }.repartition(numPartitions).cache()
    +
    +    val learningRate = word2Vec.getStepSize
    +
    +    val wordsPerPartition = totalWordCount / numPartitions
    +
    +    logInfo(s"VocabSize: ${vocabMap.size}, TotalWordCount: 
$totalWordCount")
    +
    +    val maxIter = word2Vec.getMaxIter
    +    for {iteration <- 1 to maxIter} {
    +      logInfo(s"Starting iteration: $iteration")
    +      val iterationStartTime = System.nanoTime()
    +
    +      val syn0bc = sc.broadcast(syn0Global)
    +      val syn1bc = sc.broadcast(syn1Global)
    +
    +      val partialFits = digitSentences.mapPartitionsWithIndex { case (i_, 
iter) =>
    +        logInfo(s"Iteration: $iteration, Partition: $i_")
    +        val random = new XORShiftRandom(seed ^ ((i_ + 1) << 16) ^ 
((-iteration - 1) << 8))
    +        val contextWordPairs = iter.flatMap { s =>
    +          val doSample = sample > Double.MinPositiveValue
    +          generateContextWordPairs(s, windowSize, doSample, 
sampleTableBroadcast.value, random)
    +        }
    +
    +        val groupedBatches = contextWordPairs.grouped(batchSize)
    +
    +        val negLabels = 1.0f +: Array.fill(negativeSamples)(0.0f)
    +        val syn0 = syn0bc.value
    +        val syn1 = syn1bc.value
    +        val unigramTable = unigramTableBroadcast.value
    +
    +        // initialize intermediate arrays
    +        val contextVec = new Array[Float](vectorSize)
    +        val l2Vectors = new Array[Float](vectorSize * (negativeSamples + 
1))
    +        val gb = new Array[Float](negativeSamples + 1)
    +        val neu1e = new Array[Float](vectorSize)
    +        val wordIndices = new Array[Int](negativeSamples + 1)
    +
    +        val time = System.nanoTime
    +        var batchTime = System.nanoTime
    +        var idx = -1L
    +        for (batch <- groupedBatches) {
    +          idx = idx + 1
    +
    +          val wordRatio =
    +            idx.toFloat * batchSize /
    +              (maxIter * (wordsPerPartition.toFloat + 1)) + ((iteration - 
1).toFloat / maxIter)
    +          val alpha = math.max(learningRate * 0.0001, learningRate * (1 - 
wordRatio)).toFloat
    +
    +          if(idx % 10 == 0 && idx > 0) {
    +            logInfo(s"Partition: $i_, wordRatio = $wordRatio, alpha = 
$alpha")
    +            val wordCount = batchSize * idx
    +            val timeTaken = (System.nanoTime - time) / 1e6
    +            val batchWordCount = 10 * batchSize
    +            val currentBatchTime = (System.nanoTime - batchTime) / 1e6
    +            batchTime = System.nanoTime
    +            logDebug(s"Partition: $i_, Batch time: $currentBatchTime ms, 
batch speed: " +
    +              s"${batchWordCount / currentBatchTime * 1000} words/s")
    +            logDebug(s"Partition: $i_, Cumulative time: $timeTaken ms, 
cumulative speed: " +
    +              s"${wordCount / timeTaken * 1000} words/s")
    +          }
    +
    +          val errors = for ((contextIds, word) <- batch) yield {
    +            // initialize vectors to 0
    +            zeroVector(contextVec)
    +            zeroVector(l2Vectors)
    +            zeroVector(gb)
    +            zeroVector(neu1e)
    +
    +            val scale = 1.0f / contextIds.length
    +
    +            // feed forward
    +            contextIds.foreach { c =>
    +              blas.saxpy(vectorSize, scale, syn0, c * vectorSize, 1, 
contextVec, 0, 1)
    +            }
    +
    +            generateNegativeSamples(random, word, unigramTable, 
negativeSamples, wordIndices)
    +
    +            Iterator.range(0, wordIndices.length).foreach { i =>
    +              Array.copy(syn1, vectorSize * wordIndices(i), l2Vectors, 
vectorSize * i, vectorSize)
    +            }
    +
    +            // propagating hidden to output in batch
    +            val rows = negativeSamples + 1
    +            val cols = vectorSize
    +            blas.sgemv("T", cols, rows, 1.0f, l2Vectors, 0, cols, 
contextVec, 0, 1, 0.0f, gb, 0, 1)
    +
    +            Iterator.range(0, negativeSamples + 1).foreach { i =>
    +              if (gb(i) > -MAX_EXP && gb(i) < MAX_EXP) {
    +                val v = 1.0f / (1 + math.exp(-gb(i)).toFloat)
    +                // computing error gradient
    +                val err = (negLabels(i) - v) * alpha
    +                // update hidden -> output layer, syn1
    +                blas.saxpy(vectorSize, err, contextVec, 0, 1, syn1, 
wordIndices(i) * vectorSize, 1)
    +                // update for word vectors
    --- End diff --
    
    `// accumulate gradients for the cumulative context vector`?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to