Repository: spark Updated Branches: refs/heads/master c0686668a -> 3c0d2e552
[SPARK-9246] [MLLIB] DistributedLDAModel predict top docs per topic Add topDocumentsPerTopic to DistributedLDAModel. Add ScalaDoc and unit tests. Author: Meihua Wu <meihu...@umich.edu> Closes #7769 from rotationsymmetry/SPARK-9246 and squashes the following commits: 1029e79c [Meihua Wu] clean up code comments a023b82 [Meihua Wu] Update tests to use Long for doc index. 91e5998 [Meihua Wu] Use Long for doc index. b9f70cf [Meihua Wu] Revise topDocumentsPerTopic 26ff3f6 [Meihua Wu] Add topDocumentsPerTopic, scala doc and unit tests Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3c0d2e55 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3c0d2e55 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3c0d2e55 Branch: refs/heads/master Commit: 3c0d2e55210735e0df2f8febb5f63c224af230e3 Parents: c068666 Author: Meihua Wu <meihu...@umich.edu> Authored: Fri Jul 31 13:01:10 2015 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Fri Jul 31 13:01:10 2015 -0700 ---------------------------------------------------------------------- .../spark/mllib/clustering/LDAModel.scala | 37 ++++++++++++++++++++ .../spark/mllib/clustering/LDASuite.scala | 22 ++++++++++++ 2 files changed, 59 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3c0d2e55/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index ff7035d..0cdac84 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -516,6 +516,43 @@ class DistributedLDAModel private[clustering] ( } } + /** + * Return the top documents for each topic + * + * This is approximate; it may not return exactly the top-weighted documents for each topic. + * To get a more precise set of top documents, increase maxDocumentsPerTopic. + * + * @param maxDocumentsPerTopic Maximum number of documents to collect for each topic. + * @return Array over topics. Each element represent as a pair of matching arrays: + * (IDs for the documents, weights of the topic in these documents). + * For each topic, documents are sorted in order of decreasing topic weights. + */ + def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = { + val numTopics = k + val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] = + topicDistributions.mapPartitions { docVertices => + // For this partition, collect the most common docs for each topic in queues: + // queues(topic) = queue of (doc topic, doc ID). + val queues = + Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Long)](maxDocumentsPerTopic)) + for ((docId, docTopics) <- docVertices) { + var topic = 0 + while (topic < numTopics) { + queues(topic) += (docTopics(topic) -> docId) + topic += 1 + } + } + Iterator(queues) + }.treeReduce { (q1, q2) => + q1.zip(q2).foreach { case (a, b) => a ++= b } + q1 + } + topicsInQueues.map { q => + val (docTopics, docs) = q.toArray.sortBy(-_._1).unzip + (docs.toArray, docTopics.toArray) + } + } + // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? http://git-wip-us.apache.org/repos/asf/spark/blob/3c0d2e55/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 79d2a1c..f2b9470 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -122,6 +122,28 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // Check: log probabilities assert(model.logLikelihood < 0.0) assert(model.logPrior < 0.0) + + // Check: topDocumentsPerTopic + // Compare it with top documents per topic derived from topicDistributions + val topDocsByTopicDistributions = { n: Int => + Range(0, k).map { topic => + val (doc, docWeights) = topicDistributions.sortBy(-_._2(topic)).take(n).unzip + (doc.toArray, docWeights.map(_(topic)).toArray) + }.toArray + } + + // Top 3 documents per topic + model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } + + // All documents per topic + val q = tinyCorpus.length + model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } } test("vertex indexing") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org