Repository: spark
Updated Branches:
  refs/heads/branch-1.5 560ec1268 -> 2beea65bf


[SPARK-9245] [MLLIB] LDA topic assignments

For each (document, term) pair, return top topic.  Note that instances of (doc, 
term) pairs within a document (a.k.a. "tokens") are exchangeable, so we should 
provide an estimate per document-term, rather than per token.

CC: rotationsymmetry mengxr

Author: Joseph K. Bradley <jos...@databricks.com>

Closes #8329 from jkbradley/lda-topic-assignments.

(cherry picked from commit eaafe139f881d6105996373c9b11f2ccd91b5b3e)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2beea65b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2beea65b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2beea65b

Branch: refs/heads/branch-1.5
Commit: 2beea65bfbbf4a94ad6b7ca5e4c24f59089f6099
Parents: 560ec12
Author: Joseph K. Bradley <jos...@databricks.com>
Authored: Thu Aug 20 15:01:31 2015 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Thu Aug 20 15:01:37 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/clustering/LDAModel.scala       | 51 ++++++++++++++++++--
 .../spark/mllib/clustering/LDAOptimizer.scala   |  2 +-
 .../spark/mllib/clustering/JavaLDASuite.java    |  7 +++
 .../spark/mllib/clustering/LDASuite.scala       | 21 +++++++-
 4 files changed, 74 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2beea65b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index b70e380..6bc68a4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.mllib.clustering
 
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, 
normalize, sum}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax, argtopk, 
normalize, sum}
 import breeze.numerics.{exp, lgamma}
 import org.apache.hadoop.fs.Path
 import org.json4s.DefaultFormats
@@ -438,7 +438,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
       Loader.checkSchema[Data](dataFrame.schema)
       val topics = dataFrame.collect()
       val vocabSize = topics(0).getAs[Vector](0).size
-      val k = topics.size
+      val k = topics.length
 
       val brzTopics = BDM.zeros[Double](vocabSize, k)
       topics.foreach { case Row(vec: Vector, ind: Int) =>
@@ -610,6 +610,50 @@ class DistributedLDAModel private[clustering] (
     }
   }
 
+  /**
+   * Return the top topic for each (doc, term) pair.  I.e., for each document, 
what is the most
+   * likely topic generating each term?
+   *
+   * @return RDD of (doc ID, assignment of top topic index for each term),
+   *         where the assignment is specified via a pair of zippable arrays
+   *         (term indices, topic indices).  Note that terms will be omitted 
if not present in
+   *         the document.
+   */
+  lazy val topicAssignments: RDD[(Long, Array[Int], Array[Int])] = {
+    // For reference, compare the below code with the core part of 
EMLDAOptimizer.next().
+    val eta = topicConcentration
+    val W = vocabSize
+    val alpha = docConcentration(0)
+    val N_k = globalTopicTotals
+    val sendMsg: EdgeContext[TopicCounts, TokenCount, (Array[Int], 
Array[Int])] => Unit =
+      (edgeContext) => {
+        // E-STEP: Compute gamma_{wjk} (smoothed topic distributions).
+        val scaledTopicDistribution: TopicCounts =
+          computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, 
alpha)
+        // For this (doc j, term w), send top topic k to doc vertex.
+        val topTopic: Int = argmax(scaledTopicDistribution)
+        val term: Int = index2term(edgeContext.dstId)
+        edgeContext.sendToSrc((Array(term), Array(topTopic)))
+      }
+    val mergeMsg: ((Array[Int], Array[Int]), (Array[Int], Array[Int])) => 
(Array[Int], Array[Int]) =
+      (terms_topics0, terms_topics1) => {
+        (terms_topics0._1 ++ terms_topics1._1, terms_topics0._2 ++ 
terms_topics1._2)
+      }
+    // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
+    val perDocAssignments =
+      graph.aggregateMessages[(Array[Int], Array[Int])](sendMsg, 
mergeMsg).filter(isDocumentVertex)
+    perDocAssignments.map { case (docID: Long, (terms: Array[Int], topics: 
Array[Int])) =>
+      // TODO: Avoid zip, which is inefficient.
+      val (sortedTerms, sortedTopics) = terms.zip(topics).sortBy(_._1).unzip
+      (docID, sortedTerms.toArray, sortedTopics.toArray)
+    }
+  }
+
+  /** Java-friendly version of [[topicAssignments]] */
+  lazy val javaTopicAssignments: JavaRDD[(java.lang.Long, Array[Int], 
Array[Int])] = {
+    topicAssignments.asInstanceOf[RDD[(java.lang.Long, Array[Int], 
Array[Int])]].toJavaRDD()
+  }
+
   // TODO
   // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
 
@@ -849,10 +893,9 @@ object DistributedLDAModel extends 
Loader[DistributedLDAModel] {
     val classNameV1_0 = SaveLoadV1_0.thisClassName
 
     val model = (loadedClassName, loadedVersion) match {
-      case (className, "1.0") if className == classNameV1_0 => {
+      case (className, "1.0") if className == classNameV1_0 =>
         DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, 
docConcentration,
           topicConcentration, iterationTimes.toArray, gammaShape)
-      }
       case _ => throw new Exception(
         s"DistributedLDAModel.load did not recognize model with (className, 
format version):" +
           s"($loadedClassName, $loadedVersion).  Supported: ($classNameV1_0, 
1.0)")

http://git-wip-us.apache.org/repos/asf/spark/blob/2beea65b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 360241c..cb517f9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -167,7 +167,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
         edgeContext.sendToDst((false, scaledTopicDistribution))
         edgeContext.sendToSrc((false, scaledTopicDistribution))
       }
-    // This is a hack to detect whether we could modify the values in-place.
+    // The Boolean is a hack to detect whether we could modify the values 
in-place.
     // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
     val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => 
(Boolean, TopicCounts) =
       (m0, m1) => {

http://git-wip-us.apache.org/repos/asf/spark/blob/2beea65b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index 6e91cde..3fea359 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -134,6 +134,13 @@ public class JavaLDASuite implements Serializable {
     double[] topicWeights = topTopics._3();
     assertEquals(3, topicIndices.length);
     assertEquals(3, topicWeights.length);
+
+    // Check: topTopicAssignments
+    Tuple3<Long, int[], int[]> topicAssignment = 
model.javaTopicAssignments().first();
+    Long docId2 = topicAssignment._1();
+    int[] termIndices2 = topicAssignment._2();
+    int[] topicIndices2 = topicAssignment._3();
+    assertEquals(termIndices2.length, topicIndices2.length);
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/spark/blob/2beea65b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 99e2849..8a714f9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -135,17 +135,34 @@ class LDASuite extends SparkFunSuite with 
MLlibTestSparkContext {
     }
 
     // Top 3 documents per topic
-    model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach 
{case (t1, t2) =>
+    model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach 
{ case (t1, t2) =>
       assert(t1._1 === t2._1)
       assert(t1._2 === t2._2)
     }
 
     // All documents per topic
     val q = tinyCorpus.length
-    model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach 
{case (t1, t2) =>
+    model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach 
{ case (t1, t2) =>
       assert(t1._1 === t2._1)
       assert(t1._2 === t2._2)
     }
+
+    // Check: topTopicAssignments
+    // Make sure it assigns a topic to each term appearing in each doc.
+    val topTopicAssignments: Map[Long, (Array[Int], Array[Int])] =
+      model.topicAssignments.collect().map(x => x._1 -> (x._2, x._3)).toMap
+    assert(topTopicAssignments.keys.max < tinyCorpus.length)
+    tinyCorpus.foreach { case (docID: Long, doc: Vector) =>
+      if (topTopicAssignments.contains(docID)) {
+        val (inds, vals) = topTopicAssignments(docID)
+        assert(inds.length === doc.numNonzeros)
+        // For "term" in actual doc,
+        // check that it has a topic assigned.
+        doc.foreachActive((term, wcnt) => assert(wcnt === 0 || 
inds.contains(term)))
+      } else {
+        assert(doc.numNonzeros === 0)
+      }
+    }
   }
 
   test("vertex indexing") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to