Github user akopich commented on a diff in the pull request: https://github.com/apache/spark/pull/18924#discussion_r142625490 --- Diff: mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala --- @@ -462,36 +462,55 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta) val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape + val optimizeDocConcentration = this.optimizeDocConcentration + // We calculate logphat in the same pass as other statistics, but we only need + // it if we are optimizing docConcentration + val logphatPartOptionBase = () => if (optimizeDocConcentration) Some(BDV.zeros[Double](k)) + else None - val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) val stat = BDM.zeros[Double](k, vocabSize) - var gammaPart = List[BDV[Double]]() + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount : Long = 0L nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids).toDenseMatrix + sstats - gammaPart = gammad :: gammaPart + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) } - Iterator((stat, gammaPart)) - }.persist(StorageLevel.MEMORY_AND_DISK) - val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))( - _ += _, _ += _) - val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( - stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) - stats.unpersist() - expElogbetaBc.destroy(false) - val batchResult = statsSum *:* expElogbeta.t + Iterator((stat, logphatPartOption, nonEmptyDocCount)) + } + + val elementWiseSum = (u : (BDM[Double], Option[BDV[Double]], Long), + v : (BDM[Double], Option[BDV[Double]], Long)) => { + u._1 += v._1 + u._2.foreach(_ += v._2.get) + (u._1, u._2, u._3 + v._3) + } + + val (statsSum: BDM[Double], logphatOption: Option[BDV[Double]], nonEmptyDocsN : Long) = stats + .treeAggregate((BDM.zeros[Double](k, vocabSize), logphatPartOptionBase(), 0L))( + elementWiseSum, elementWiseSum + ) + val batchResult = statsSum *:* expElogbeta.t // Note that this is an optimization to avoid batch.count - updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) - if (optimizeDocConcentration) updateAlpha(gammat) + val batchSize = (miniBatchFraction * corpusSize).ceil.toInt + updateLambda(batchResult, batchSize) + + logphatOption.foreach(_ /= batchSize.toDouble) + logphatOption.foreach(updateAlpha(_, nonEmptyDocsN)) + + expElogbetaBc.destroy(false) + this } /** - * Update lambda based on the batch submitted. batchSize can be different for each iteration. + * Update lambda based on the batch submitted. nonEmptyDocsN can be different for each iteration. --- End diff -- Thanks. Comment reverted.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org