Repository: spark Updated Branches: refs/heads/master 1f25d8683 -> 52facb006
[SPARK-14371][MLLIB] OnlineLDAOptimizer should not collect stats for each doc in mini-batch to driver Hi, # What changes were proposed in this pull request? as it was proposed by jkbradley , ```gammat``` are not collected to the driver anymore. # How was this patch tested? existing test suite. Author: Valeriy Avanesov <avane...@wias-berlin.de> Author: Valeriy Avanesov <acop...@gmail.com> Closes #18924 from akopich/master. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/52facb00 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/52facb00 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/52facb00 Branch: refs/heads/master Commit: 52facb0062a4253fa45ac0c633d0510a9b684a62 Parents: 1f25d86 Author: Valeriy Avanesov <avane...@wias-berlin.de> Authored: Wed Oct 18 10:46:46 2017 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Wed Oct 18 10:46:46 2017 -0700 ---------------------------------------------------------------------- .../spark/mllib/clustering/LDAOptimizer.scala | 82 ++++++++++++++------ 1 file changed, 57 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/52facb00/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index d633893..693a2a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -26,6 +26,7 @@ import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.PeriodicGraphCheckpointer +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -259,7 +260,7 @@ final class EMLDAOptimizer extends LDAOptimizer { */ @Since("1.4.0") @DeveloperApi -final class OnlineLDAOptimizer extends LDAOptimizer { +final class OnlineLDAOptimizer extends LDAOptimizer with Logging { // LDA common parameters private var k: Int = 0 @@ -462,31 +463,61 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta) val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape - - val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => + val optimizeDocConcentration = this.optimizeDocConcentration + // If and only if optimizeDocConcentration is set true, + // we calculate logphat in the same pass as other statistics. + // No calculation of loghat happens otherwise. + val logphatPartOptionBase = () => if (optimizeDocConcentration) { + Some(BDV.zeros[Double](k)) + } else { + None + } + + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) val stat = BDM.zeros[Double](k, vocabSize) - var gammaPart = List[BDV[Double]]() + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount: Long = 0L nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids).toDenseMatrix + sstats - gammaPart = gammad :: gammaPart + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) } - Iterator((stat, gammaPart)) - }.persist(StorageLevel.MEMORY_AND_DISK) - val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))( - _ += _, _ += _) - val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( - stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) - stats.unpersist() + Iterator((stat, logphatPartOption, nonEmptyDocCount)) + } + + val elementWiseSum = ( + u: (BDM[Double], Option[BDV[Double]], Long), + v: (BDM[Double], Option[BDV[Double]], Long)) => { + u._1 += v._1 + u._2.foreach(_ += v._2.get) + (u._1, u._2, u._3 + v._3) + } + + val (statsSum: BDM[Double], logphatOption: Option[BDV[Double]], nonEmptyDocsN: Long) = stats + .treeAggregate((BDM.zeros[Double](k, vocabSize), logphatPartOptionBase(), 0L))( + elementWiseSum, elementWiseSum + ) + expElogbetaBc.destroy(false) - val batchResult = statsSum *:* expElogbeta.t + if (nonEmptyDocsN == 0) { + logWarning("No non-empty documents were submitted in the batch.") + // Therefore, there is no need to update any of the model parameters + return this + } + + val batchResult = statsSum *:* expElogbeta.t // Note that this is an optimization to avoid batch.count - updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) - if (optimizeDocConcentration) updateAlpha(gammat) + val batchSize = (miniBatchFraction * corpusSize).ceil.toInt + updateLambda(batchResult, batchSize) + + logphatOption.foreach(_ /= nonEmptyDocsN.toDouble) + logphatOption.foreach(updateAlpha(_, nonEmptyDocsN)) + this } @@ -503,21 +534,22 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } /** - * Update alpha based on `gammat`, the inferred topic distributions for documents in the - * current mini-batch. Uses Newton-Rhapson method. + * Update alpha based on `logphat`. + * Uses Newton-Rhapson method. * @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters * (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf) + * @param logphat Expectation of estimated log-posterior distribution of + * topics in a document averaged over the batch. + * @param nonEmptyDocsN number of non-empty documents */ - private def updateAlpha(gammat: BDM[Double]): Unit = { + private def updateAlpha(logphat: BDV[Double], nonEmptyDocsN: Double): Unit = { val weight = rho() - val N = gammat.rows.toDouble val alpha = this.alpha.asBreeze.toDenseVector - val logphat: BDV[Double] = - sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)).t / N - val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat) - val c = N * trigamma(sum(alpha)) - val q = -N * trigamma(alpha) + val gradf = nonEmptyDocsN * (-LDAUtils.dirichletExpectation(alpha) + logphat) + + val c = nonEmptyDocsN * trigamma(sum(alpha)) + val q = -nonEmptyDocsN * trigamma(alpha) val b = sum(gradf / q) / (1D / c + sum(1D / q)) val dalpha = -(gradf - b) / q --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org