Repository: spark Updated Branches: refs/heads/master 7f741905b -> 2ff0e79a8
[SPARK-8467] [MLLIB] [PYSPARK] Add LDAModel.describeTopics() in Python Could jkbradley and davies review it? - Create a wrapper class: `LDAModelWrapper` for `LDAModel`. Because we can't deal with the return value of`describeTopics` in Scala from pyspark directly. `Array[(Array[Int], Array[Double])]` is too complicated to convert it. - Add `loadLDAModel` in `PythonMLlibAPI`. Since `LDAModel` in Scala is an abstract class and we need to call `load` of `DistributedLDAModel`. [[SPARK-8467] Add LDAModel.describeTopics() in Python - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-8467) Author: Yu ISHIKAWA <yuu.ishik...@gmail.com> Closes #8643 from yu-iskw/SPARK-8467-2. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2ff0e79a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2ff0e79a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2ff0e79a Branch: refs/heads/master Commit: 2ff0e79a8647cca5c9c57f613a07e739ac4f677e Parents: 7f74190 Author: Yu ISHIKAWA <yuu.ishik...@gmail.com> Authored: Fri Nov 6 22:56:29 2015 -0800 Committer: Davies Liu <davies....@gmail.com> Committed: Fri Nov 6 22:56:29 2015 -0800 ---------------------------------------------------------------------- .../mllib/api/python/LDAModelWrapper.scala | 46 ++++++++++++++++++++ .../spark/mllib/api/python/PythonMLLibAPI.scala | 13 +++++- python/pyspark/mllib/clustering.py | 33 +++++++------- 3 files changed, 75 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2ff0e79a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala new file mode 100644 index 0000000..63282ee --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.api.python + +import scala.collection.JavaConverters + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.clustering.LDAModel +import org.apache.spark.mllib.linalg.Matrix + +/** + * Wrapper around LDAModel to provide helper methods in Python + */ +private[python] class LDAModelWrapper(model: LDAModel) { + + def topicsMatrix(): Matrix = model.topicsMatrix + + def vocabSize(): Int = model.vocabSize + + def describeTopics(): Array[Byte] = describeTopics(this.model.vocabSize) + + def describeTopics(maxTermsPerTopic: Int): Array[Byte] = { + val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) => + val jTerms = JavaConverters.seqAsJavaListConverter(terms).asJava + val jTermWeights = JavaConverters.seqAsJavaListConverter(termWeights).asJava + Array[Any](jTerms, jTermWeights) + } + SerDe.dumps(JavaConverters.seqAsJavaListConverter(topics).asJava) + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} http://git-wip-us.apache.org/repos/asf/spark/blob/2ff0e79a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 40c4180..54b03a9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -517,7 +517,7 @@ private[python] class PythonMLLibAPI extends Serializable { topicConcentration: Double, seed: java.lang.Long, checkpointInterval: Int, - optimizer: String): LDAModel = { + optimizer: String): LDAModelWrapper = { val algo = new LDA() .setK(k) .setMaxIterations(maxIterations) @@ -535,7 +535,16 @@ private[python] class PythonMLLibAPI extends Serializable { case _ => throw new IllegalArgumentException("input values contains invalid type value.") } } - algo.run(documents) + val model = algo.run(documents) + new LDAModelWrapper(model) + } + + /** + * Load a LDA model + */ + def loadLDAModel(jsc: JavaSparkContext, path: String): LDAModelWrapper = { + val model = DistributedLDAModel.load(jsc.sc, path) + new LDAModelWrapper(model) } http://git-wip-us.apache.org/repos/asf/spark/blob/2ff0e79a/python/pyspark/mllib/clustering.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 8629aa5..12081f8 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -671,7 +671,7 @@ class StreamingKMeans(object): return dstream.mapValues(lambda x: self._model.predict(x)) -class LDAModel(JavaModelWrapper): +class LDAModel(JavaModelWrapper, JavaSaveable, Loader): """ A clustering model derived from the LDA method. @@ -691,9 +691,14 @@ class LDAModel(JavaModelWrapper): ... [2, SparseVector(2, {0: 1.0})], ... ] >>> rdd = sc.parallelize(data) - >>> model = LDA.train(rdd, k=2) + >>> model = LDA.train(rdd, k=2, seed=1) >>> model.vocabSize() 2 + >>> model.describeTopics() + [([1, 0], [0.5..., 0.49...]), ([0, 1], [0.5..., 0.49...])] + >>> model.describeTopics(1) + [([1], [0.5...]), ([0], [0.5...])] + >>> topics = model.topicsMatrix() >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]]) >>> assert_almost_equal(topics, topics_expect, 1) @@ -724,18 +729,17 @@ class LDAModel(JavaModelWrapper): """Vocabulary size (number of terms or terms in the vocabulary)""" return self.call("vocabSize") - @since('1.5.0') - def save(self, sc, path): - """Save the LDAModel on to disk. + @since('1.6.0') + def describeTopics(self, maxTermsPerTopic=None): + """Return the topics described by weighted terms. - :param sc: SparkContext - :param path: str, path to where the model needs to be stored. + WARNING: If vocabSize and k are large, this can return a large object! """ - if not isinstance(sc, SparkContext): - raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - self._java_model.save(sc._jsc.sc(), path) + if maxTermsPerTopic is None: + topics = self.call("describeTopics") + else: + topics = self.call("describeTopics", maxTermsPerTopic) + return topics @classmethod @since('1.5.0') @@ -749,9 +753,8 @@ class LDAModel(JavaModelWrapper): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) if not isinstance(path, basestring): raise TypeError("path should be a basestring, got type %s" % type(path)) - java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load( - sc._jsc.sc(), path) - return cls(java_model) + model = callMLlibFunc("loadLDAModel", sc, path) + return LDAModel(model) class LDA(object): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org