Github user akopich commented on a diff in the pull request:

    https://github.com/apache/spark/pull/18924#discussion_r142625490
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala ---
    @@ -462,36 +462,55 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
         val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta)
         val alpha = this.alpha.asBreeze
         val gammaShape = this.gammaShape
    +    val optimizeDocConcentration = this.optimizeDocConcentration
    +    // We calculate logphat in the same pass as other statistics, but we 
only need
    +    // it if we are optimizing docConcentration
    +    val logphatPartOptionBase = () => if (optimizeDocConcentration) 
Some(BDV.zeros[Double](k))
    +                                      else None
     
    -    val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions 
{ docs =>
    +    val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = 
batch.mapPartitions { docs =>
           val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0)
     
           val stat = BDM.zeros[Double](k, vocabSize)
    -      var gammaPart = List[BDV[Double]]()
    +      val logphatPartOption = logphatPartOptionBase()
    +      var nonEmptyDocCount : Long = 0L
           nonEmptyDocs.foreach { case (_, termCounts: Vector) =>
    +        nonEmptyDocCount += 1
             val (gammad, sstats, ids) = 
OnlineLDAOptimizer.variationalTopicInference(
               termCounts, expElogbetaBc.value, alpha, gammaShape, k)
    -        stat(::, ids) := stat(::, ids).toDenseMatrix + sstats
    -        gammaPart = gammad :: gammaPart
    +        stat(::, ids) := stat(::, ids) + sstats
    +        logphatPartOption.foreach(_ += 
LDAUtils.dirichletExpectation(gammad))
           }
    -      Iterator((stat, gammaPart))
    -    }.persist(StorageLevel.MEMORY_AND_DISK)
    -    val statsSum: BDM[Double] = 
stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))(
    -      _ += _, _ += _)
    -    val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
    -      stats.map(_._2).flatMap(list => 
list).collect().map(_.toDenseMatrix): _*)
    -    stats.unpersist()
    -    expElogbetaBc.destroy(false)
    -    val batchResult = statsSum *:* expElogbeta.t
    +      Iterator((stat, logphatPartOption, nonEmptyDocCount))
    +    }
    +
    +    val elementWiseSum = (u : (BDM[Double], Option[BDV[Double]], Long),
    +                                 v : (BDM[Double], Option[BDV[Double]], 
Long)) => {
    +      u._1 += v._1
    +      u._2.foreach(_ += v._2.get)
    +      (u._1, u._2, u._3 + v._3)
    +    }
    +
    +    val (statsSum: BDM[Double], logphatOption: Option[BDV[Double]], 
nonEmptyDocsN : Long) = stats
    +      .treeAggregate((BDM.zeros[Double](k, vocabSize), 
logphatPartOptionBase(), 0L))(
    +        elementWiseSum, elementWiseSum
    +      )
     
    +    val batchResult = statsSum *:* expElogbeta.t
         // Note that this is an optimization to avoid batch.count
    -    updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)
    -    if (optimizeDocConcentration) updateAlpha(gammat)
    +    val batchSize = (miniBatchFraction * corpusSize).ceil.toInt
    +    updateLambda(batchResult, batchSize)
    +
    +    logphatOption.foreach(_ /= batchSize.toDouble)
    +    logphatOption.foreach(updateAlpha(_, nonEmptyDocsN))
    +
    +    expElogbetaBc.destroy(false)
    +
         this
       }
     
       /**
    -   * Update lambda based on the batch submitted. batchSize can be 
different for each iteration.
    +   * Update lambda based on the batch submitted. nonEmptyDocsN can be 
different for each iteration.
    --- End diff --
    
    Thanks. Comment reverted. 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to