Repository: spark
Updated Branches:
  refs/heads/branch-1.6 70febe224 -> af86c38db


[SPARK-11847][ML] Model export/import for spark.ml: LDA

Add read/write support to LDA, similar to ALS.

save/load for ml.LocalLDAModel is done.
For DistributedLDAModel, I'm not sure if we can invoke save on the 
mllib.DistributedLDAModel directly. I'll send update after some test.

Author: Yuhao Yang <hhb...@gmail.com>

Closes #9894 from hhbyyh/ldaMLsave.

(cherry picked from commit 52bc25c8e26d4be250d8ff7864067528f4f98592)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/af86c38d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/af86c38d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/af86c38d

Branch: refs/heads/branch-1.6
Commit: af86c38db7676c4dfc2724d5f86f0f5f3a22e349
Parents: 70febe2
Author: Yuhao Yang <hhb...@gmail.com>
Authored: Tue Nov 24 09:56:17 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Tue Nov 24 09:56:24 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/clustering/LDA.scala    | 110 ++++++++++++++++++-
 .../spark/mllib/clustering/LDAModel.scala       |   4 +-
 .../apache/spark/ml/clustering/LDASuite.scala   |  44 +++++++-
 3 files changed, 150 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/af86c38d/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 92e0581..830510b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -17,12 +17,13 @@
 
 package org.apache.spark.ml.clustering
 
+import org.apache.hadoop.fs.Path
 import org.apache.spark.Logging
 import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param.shared.{HasCheckpointInterval, 
HasFeaturesCol, HasSeed, HasMaxIter}
 import org.apache.spark.ml.param._
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.clustering.{DistributedLDAModel => 
OldDistributedLDAModel,
     EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => 
OldLDAModel,
     LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
@@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] (
     @Since("1.6.0") override val uid: String,
     @Since("1.6.0") val vocabSize: Int,
     @Since("1.6.0") @transient protected val sqlContext: SQLContext)
-  extends Model[LDAModel] with LDAParams with Logging {
+  extends Model[LDAModel] with LDAParams with Logging with MLWritable {
 
   // NOTE to developers:
   //  This abstraction should contain all important functionality for basic 
LDA usage.
@@ -486,6 +487,64 @@ class LocalLDAModel private[ml] (
 
   @Since("1.6.0")
   override def isDistributed: Boolean = false
+
+  @Since("1.6.0")
+  override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this)
+}
+
+
+@Since("1.6.0")
+object LocalLDAModel extends MLReadable[LocalLDAModel] {
+
+  private[LocalLDAModel]
+  class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter {
+
+    private case class Data(
+        vocabSize: Int,
+        topicsMatrix: Matrix,
+        docConcentration: Vector,
+        topicConcentration: Double,
+        gammaShape: Double)
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val oldModel = instance.oldLocalModel
+      val data = Data(instance.vocabSize, oldModel.topicsMatrix, 
oldModel.docConcentration,
+        oldModel.topicConcentration, oldModel.gammaShape)
+      val dataPath = new Path(path, "data").toString
+      
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class LocalLDAModelReader extends MLReader[LocalLDAModel] {
+
+    private val className = classOf[LocalLDAModel].getName
+
+    override def load(path: String): LocalLDAModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath)
+        .select("vocabSize", "topicsMatrix", "docConcentration", 
"topicConcentration",
+          "gammaShape")
+        .head()
+      val vocabSize = data.getAs[Int](0)
+      val topicsMatrix = data.getAs[Matrix](1)
+      val docConcentration = data.getAs[Vector](2)
+      val topicConcentration = data.getAs[Double](3)
+      val gammaShape = data.getAs[Double](4)
+      val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, 
topicConcentration,
+        gammaShape)
+      val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, 
sqlContext)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): LocalLDAModel = super.load(path)
 }
 
 
@@ -562,6 +621,45 @@ class DistributedLDAModel private[ml] (
    */
   @Since("1.6.0")
   lazy val logPrior: Double = oldDistributedModel.logPrior
+
+  @Since("1.6.0")
+  override def write: MLWriter = new 
DistributedLDAModel.DistributedWriter(this)
+}
+
+
+@Since("1.6.0")
+object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
+
+  private[DistributedLDAModel]
+  class DistributedWriter(instance: DistributedLDAModel) extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val modelPath = new Path(path, "oldModel").toString
+      instance.oldDistributedModel.save(sc, modelPath)
+    }
+  }
+
+  private class DistributedLDAModelReader extends 
MLReader[DistributedLDAModel] {
+
+    private val className = classOf[DistributedLDAModel].getName
+
+    override def load(path: String): DistributedLDAModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val modelPath = new Path(path, "oldModel").toString
+      val oldModel = OldDistributedLDAModel.load(sc, modelPath)
+      val model = new DistributedLDAModel(
+        metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: MLReader[DistributedLDAModel] = new 
DistributedLDAModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): DistributedLDAModel = super.load(path)
 }
 
 
@@ -593,7 +691,8 @@ class DistributedLDAModel private[ml] (
 @Since("1.6.0")
 @Experimental
 class LDA @Since("1.6.0") (
-    @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with 
LDAParams {
+    @Since("1.6.0") override val uid: String)
+  extends Estimator[LDAModel] with LDAParams with DefaultParamsWritable {
 
   @Since("1.6.0")
   def this() = this(Identifiable.randomUID("lda"))
@@ -695,7 +794,7 @@ class LDA @Since("1.6.0") (
 }
 
 
-private[clustering] object LDA {
+private[clustering] object LDA extends DefaultParamsReadable[LDA] {
 
   /** Get dataset for spark.mllib LDA */
   def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, 
Vector)] = {
@@ -706,4 +805,7 @@ private[clustering] object LDA {
         (docId, features)
       }
   }
+
+  @Since("1.6.0")
+  override def load(path: String): LDA = super.load(path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/af86c38d/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index cd520f0..7384d06 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -187,11 +187,11 @@ abstract class LDAModel private[clustering] extends 
Saveable {
  * @param topics Inferred topics (vocabSize x k matrix).
  */
 @Since("1.3.0")
-class LocalLDAModel private[clustering] (
+class LocalLDAModel private[spark] (
     @Since("1.3.0") val topics: Matrix,
     @Since("1.5.0") override val docConcentration: Vector,
     @Since("1.5.0") override val topicConcentration: Double,
-    override protected[clustering] val gammaShape: Double = 100)
+    override protected[spark] val gammaShape: Double = 100)
   extends LDAModel with Serializable {
 
   @Since("1.3.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/af86c38d/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index b634d31..97dbfd9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -18,9 +18,10 @@
 package org.apache.spark.ml.clustering
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.{DataFrame, Row, SQLContext}
 
 
@@ -39,10 +40,24 @@ object LDASuite {
     }.map(v => new TestRow(v))
     sql.createDataFrame(rdd)
   }
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as 
save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = Map(
+    "k" -> 3,
+    "maxIter" -> 2,
+    "checkpointInterval" -> 30,
+    "learningOffset" -> 1023.0,
+    "learningDecay" -> 0.52,
+    "subsamplingRate" -> 0.051
+  )
 }
 
 
-class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
+class LDASuite extends SparkFunSuite with MLlibTestSparkContext with 
DefaultReadWriteTest {
 
   val k: Int = 5
   val vocabSize: Int = 30
@@ -218,4 +233,29 @@ class LDASuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val lp = model.logPrior
     assert(lp <= 0.0 && lp != Double.NegativeInfinity)
   }
+
+  test("read/write LocalLDAModel") {
+    def checkModelData(model: LDAModel, model2: LDAModel): Unit = {
+      assert(model.vocabSize === model2.vocabSize)
+      assert(Vectors.dense(model.topicsMatrix.toArray) ~==
+        Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6)
+      assert(Vectors.dense(model.getDocConcentration) ~==
+        Vectors.dense(model2.getDocConcentration) absTol 1e-6)
+    }
+    val lda = new LDA()
+    testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, 
checkModelData)
+  }
+
+  test("read/write DistributedLDAModel") {
+    def checkModelData(model: LDAModel, model2: LDAModel): Unit = {
+      assert(model.vocabSize === model2.vocabSize)
+      assert(Vectors.dense(model.topicsMatrix.toArray) ~==
+        Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6)
+      assert(Vectors.dense(model.getDocConcentration) ~==
+        Vectors.dense(model2.getDocConcentration) absTol 1e-6)
+    }
+    val lda = new LDA()
+    testEstimatorAndModelReadWrite(lda, dataset,
+      LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to