albertusk95 commented on a change in pull request #17673: [SPARK-20372] [ML] 
Word2Vec Continuous Bag of Words model
URL: https://github.com/apache/spark/pull/17673#discussion_r305571436
 
 

 ##########
 File path: 
mllib/src/main/scala/org/apache/spark/ml/feature/impl/Word2VecCBOWSolver.scala
 ##########
 @@ -0,0 +1,371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature.impl
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.Word2Vec
+import org.apache.spark.mllib.feature
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.random.XORShiftRandom
+
+private [feature] object Word2VecCBOWSolver extends Logging {
+  // learning rate is updated for every batch of size batchSize
+  private val batchSize = 10000
+
+  // power to raise the unigram distribution with
+  private val power = 0.75
+
+  private val MAX_EXP = 6
+
+  case class Vocabulary(
+    totalWordCount: Long,
+    vocabMap: Map[String, Int],
+    unigramTable: Array[Int],
+    samplingTable: Array[Float])
+
+  /**
+   * This method implements Word2Vec Continuous Bag Of Words based 
implementation using
+   * negative sampling optimization, using level 1 and level 2 BLAS for 
vectorizing operations
+   * where applicable.
+   * The algorithm is parallelized in the same way as the skip-gram based 
estimation.
+   * We divide input data into N equally sized random partitions.
+   * We then generate initial weights and broadcast them to the N partitions. 
This way
+   * all the partitions start with the same initial weights. We then run N 
independent
+   * estimations that each estimate a model on a partition. The weights learned
+   * from each of the N models are averaged and rebroadcast the weights.
+   * This process is repeated `maxIter` number of times.
+   *
+   * @param input A RDD of strings. Each string would be considered a sentence.
+   * @return Estimated word2vec model
+   */
+  def fitCBOW[S <: Iterable[String]](
+      word2Vec: Word2Vec,
+      input: RDD[S]): feature.Word2VecModel = {
+
+    val numNegativeSamples = word2Vec.getNumNegativeSamples
+    val samplingThreshold = word2Vec.getSamplingThreshold
+
+    val Vocabulary(totalWordCount, vocabMap, uniTable, sampleTable) =
+      generateVocab(input, word2Vec.getMinCount, samplingThreshold, 
word2Vec.getUnigramTableSize)
+    val vocabSize = sampleTable.length
+
+    assert(numNegativeSamples < vocabSize, s"Vocab size ($vocabSize) cannot be 
smaller" +
+      s" than negative samples($numNegativeSamples)")
+
+    val seed = word2Vec.getSeed
+    val initRandom = new XORShiftRandom(seed)
+
+    val vectorSize = word2Vec.getVectorSize
+    val syn0Global = Array.fill(vocabSize * vectorSize)(initRandom.nextFloat - 
0.5f)
+    val syn1Global = Array.fill(vocabSize * vectorSize)(0.0f)
+
+    val sc = input.context
+
+    val vocabMapBroadcast = sc.broadcast(vocabMap)
+    val unigramTableBroadcast = sc.broadcast(uniTable)
+    val sampleTableBroadcast = sc.broadcast(sampleTable)
+
+    val windowSize = word2Vec.getWindowSize
+    val maxSentenceLength = word2Vec.getMaxSentenceLength
+    val numPartitions = word2Vec.getNumPartitions
+
+    val digitSentences = input.flatMap { sentence =>
+      val wordIndexes = sentence.flatMap(vocabMapBroadcast.value.get)
+      wordIndexes.grouped(maxSentenceLength).map(_.toArray)
+    }.repartition(numPartitions).cache()
+
+    val learningRate = word2Vec.getStepSize
+
+    val wordsPerPartition = totalWordCount / numPartitions
+
+    logInfo(s"VocabSize: ${vocabMap.size}, TotalWordCount: $totalWordCount")
+
+    val maxIter = word2Vec.getMaxIter
+    for {iteration <- 1 to maxIter} {
+      logInfo(s"Starting iteration: $iteration")
+      val iterationStartTime = System.nanoTime()
+
+      val syn0bc = sc.broadcast(syn0Global)
+      val syn1bc = sc.broadcast(syn1Global)
+
+      val partialFits = digitSentences.mapPartitionsWithIndex { case 
(partIndex, sentenceIter) =>
+        logInfo(s"Iteration: $iteration, Partition: $partIndex")
+        val random = new XORShiftRandom(seed ^ ((partIndex + 1) << 16) ^ 
((-iteration - 1) << 8))
+        val contextWordPairs = sentenceIter.flatMap { s =>
+          val doSample = samplingThreshold > Double.MinPositiveValue
+          generateContextWordPairs(s, windowSize, doSample, 
sampleTableBroadcast.value, random)
+        }
+
+        val groupedBatches = contextWordPairs.grouped(batchSize)
+
+        val negLabels = 1.0f +: Array.fill(numNegativeSamples)(0.0f)
+        val syn0 = syn0bc.value
+        val syn1 = syn1bc.value
+        val unigramTable = unigramTableBroadcast.value
+
+        // initialize intermediate arrays
+        val contextVec = new Array[Float](vectorSize)
+        val layer2Vectors = new Array[Float](vectorSize * (numNegativeSamples 
+ 1))
+        val errGradients = new Array[Float](numNegativeSamples + 1)
+        val layer1Updates = new Array[Float](vectorSize)
+        val trainingWords = new Array[Int](numNegativeSamples + 1)
+
+        val time = System.nanoTime()
+        var batchTime = System.nanoTime()
+
+        for ((batch, idx) <- groupedBatches.zipWithIndex) {
+          val wordRatio =
+            idx.toFloat * batchSize /
+              (maxIter * (wordsPerPartition.toFloat + 1)) + ((iteration - 
1).toFloat / maxIter)
+          val alpha = math.max(learningRate * 0.0001, learningRate * (1 - 
wordRatio)).toFloat
+
+          if((idx + 1) % 10 == 0) {
 
 Review comment:
   I think it'd be better to store `10` in a variable with an intuitive name 
for the sake of clarity

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to