Repository: spark
Updated Branches:
  refs/heads/branch-1.6 838f956c9 -> f061d2539


[SPARK-11712][ML] Make spark.ml LDAModel be abstract

Per discussion in the initial Pipelines LDA PR 
[https://github.com/apache/spark/pull/9513], we should make LDAModel abstract 
and create a LocalLDAModel. This code simplification should be done before the 
1.6 release to ensure API compatibility in future releases.

CC feynmanliang mengxr

Author: Joseph K. Bradley <jos...@databricks.com>

Closes #9678 from jkbradley/lda-pipelines-2.

(cherry picked from commit dcb896fd8cec83483f700ee985c352be61cdf233)
Signed-off-by: Joseph K. Bradley <jos...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f061d253
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f061d253
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f061d253

Branch: refs/heads/branch-1.6
Commit: f061d2539931a2feab7229b95d3a38671ec4c0d3
Parents: 838f956
Author: Joseph K. Bradley <jos...@databricks.com>
Authored: Thu Nov 12 17:03:19 2015 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Thu Nov 12 17:03:30 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/clustering/LDA.scala    | 180 ++++++++++---------
 .../apache/spark/ml/clustering/LDASuite.scala   |   4 +-
 2 files changed, 96 insertions(+), 88 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f061d253/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index f66233e..92e0581 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -314,31 +314,31 @@ private[clustering] trait LDAParams extends Params with 
HasFeaturesCol with HasM
  * Model fitted by [[LDA]].
  *
  * @param vocabSize  Vocabulary size (number of terms or terms in the 
vocabulary)
- * @param oldLocalModel  Underlying spark.mllib model.
- *                       If this model was produced by Online LDA, then this 
is the
- *                       only model representation.
- *                       If this model was produced by EM, then this local
- *                       representation may be built lazily.
  * @param sqlContext  Used to construct local DataFrames for returning query 
results
  */
 @Since("1.6.0")
 @Experimental
-class LDAModel private[ml] (
+sealed abstract class LDAModel private[ml] (
     @Since("1.6.0") override val uid: String,
     @Since("1.6.0") val vocabSize: Int,
-    @Since("1.6.0") protected var oldLocalModel: Option[OldLocalLDAModel],
     @Since("1.6.0") @transient protected val sqlContext: SQLContext)
   extends Model[LDAModel] with LDAParams with Logging {
 
-  /** Returns underlying spark.mllib model */
+  // NOTE to developers:
+  //  This abstraction should contain all important functionality for basic 
LDA usage.
+  //  Specializations of this class can contain expert-only functionality.
+
+  /**
+   * Underlying spark.mllib model.
+   * If this model was produced by Online LDA, then this is the only model 
representation.
+   * If this model was produced by EM, then this local representation may be 
built lazily.
+   */
   @Since("1.6.0")
-  protected def getModel: OldLDAModel = oldLocalModel match {
-    case Some(m) => m
-    case None =>
-      // Should never happen.
-      throw new RuntimeException("LDAModel required local model format," +
-        " but the underlying model is missing.")
-  }
+  protected def oldLocalModel: OldLocalLDAModel
+
+  /** Returns underlying spark.mllib model, which may be local or distributed 
*/
+  @Since("1.6.0")
+  protected def getModel: OldLDAModel
 
   /**
    * The features for LDA should be a [[Vector]] representing the word counts 
in a document.
@@ -352,16 +352,17 @@ class LDAModel private[ml] (
   @Since("1.6.0")
   def setSeed(value: Long): this.type = set(seed, value)
 
-  @Since("1.6.0")
-  override def copy(extra: ParamMap): LDAModel = {
-    val copied = new LDAModel(uid, vocabSize, oldLocalModel, sqlContext)
-    copyValues(copied, extra).setParent(parent)
-  }
-
+  /**
+   * Transforms the input dataset.
+   *
+   * WARNING: If this model is an instance of [[DistributedLDAModel]] 
(produced when [[optimizer]]
+   *          is set to "em"), this involves collecting a large 
[[topicsMatrix]] to the driver.
+   *          This implementation may be changed in the future.
+   */
   @Since("1.6.0")
   override def transform(dataset: DataFrame): DataFrame = {
     if ($(topicDistributionCol).nonEmpty) {
-      val t = 
udf(oldLocalModel.get.getTopicDistributionMethod(sqlContext.sparkContext))
+      val t = 
udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext))
       dataset.withColumn($(topicDistributionCol), t(col($(featuresCol))))
     } else {
       logWarning("LDAModel.transform was called without any output columns. 
Set an output column" +
@@ -388,56 +389,50 @@ class LDAModel private[ml] (
    * This is a matrix of size vocabSize x k, where each column is a topic.
    * No guarantees are given about the ordering of the topics.
    *
-   * WARNING: If this model is actually a [[DistributedLDAModel]] instance 
from EM,
-   *          then this method could involve collecting a large amount of data 
to the driver
-   *          (on the order of vocabSize x k).
+   * WARNING: If this model is actually a [[DistributedLDAModel]] instance 
produced by
+   *          the Expectation-Maximization ("em") [[optimizer]], then this 
method could involve
+   *          collecting a large amount of data to the driver (on the order of 
vocabSize x k).
    */
   @Since("1.6.0")
-  def topicsMatrix: Matrix = getModel.topicsMatrix
+  def topicsMatrix: Matrix = oldLocalModel.topicsMatrix
 
   /** Indicates whether this instance is of type [[DistributedLDAModel]] */
   @Since("1.6.0")
-  def isDistributed: Boolean = false
+  def isDistributed: Boolean
 
   /**
    * Calculates a lower bound on the log likelihood of the entire corpus.
    *
    * See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
    *
-   * WARNING: If this model was learned via a [[DistributedLDAModel]], this 
involves collecting
-   *          a large [[topicsMatrix]] to the driver.  This implementation may 
be changed in the
-   *          future.
+   * WARNING: If this model is an instance of [[DistributedLDAModel]] 
(produced when [[optimizer]]
+   *          is set to "em"), this involves collecting a large 
[[topicsMatrix]] to the driver.
+   *          This implementation may be changed in the future.
    *
    * @param dataset  test corpus to use for calculating log likelihood
    * @return variational lower bound on the log likelihood of the entire corpus
    */
   @Since("1.6.0")
-  def logLikelihood(dataset: DataFrame): Double = oldLocalModel match {
-    case Some(m) =>
-      val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
-      m.logLikelihood(oldDataset)
-    case None =>
-      // Should never happen.
-      throw new RuntimeException("LocalLDAModel.logLikelihood was called," +
-        " but the underlying model is missing.")
+  def logLikelihood(dataset: DataFrame): Double = {
+    val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
+    oldLocalModel.logLikelihood(oldDataset)
   }
 
   /**
    * Calculate an upper bound bound on perplexity.  (Lower is better.)
    * See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
    *
+   * WARNING: If this model is an instance of [[DistributedLDAModel]] 
(produced when [[optimizer]]
+   *          is set to "em"), this involves collecting a large 
[[topicsMatrix]] to the driver.
+   *          This implementation may be changed in the future.
+   *
    * @param dataset test corpus to use for calculating perplexity
    * @return Variational upper bound on log perplexity per token.
    */
   @Since("1.6.0")
-  def logPerplexity(dataset: DataFrame): Double = oldLocalModel match {
-    case Some(m) =>
-      val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
-      m.logPerplexity(oldDataset)
-    case None =>
-      // Should never happen.
-      throw new RuntimeException("LocalLDAModel.logPerplexity was called," +
-        " but the underlying model is missing.")
+  def logPerplexity(dataset: DataFrame): Double = {
+    val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
+    oldLocalModel.logPerplexity(oldDataset)
   }
 
   /**
@@ -468,10 +463,43 @@ class LDAModel private[ml] (
 /**
  * :: Experimental ::
  *
- * Distributed model fitted by [[LDA]] using Expectation-Maximization (EM).
+ * Local (non-distributed) model fitted by [[LDA]].
+ *
+ * This model stores the inferred topics only; it does not store info about 
the training dataset.
+ */
+@Since("1.6.0")
+@Experimental
+class LocalLDAModel private[ml] (
+    uid: String,
+    vocabSize: Int,
+    @Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel,
+    sqlContext: SQLContext)
+  extends LDAModel(uid, vocabSize, sqlContext) {
+
+  @Since("1.6.0")
+  override def copy(extra: ParamMap): LocalLDAModel = {
+    val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)
+    copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel]
+  }
+
+  override protected def getModel: OldLDAModel = oldLocalModel
+
+  @Since("1.6.0")
+  override def isDistributed: Boolean = false
+}
+
+
+/**
+ * :: Experimental ::
+ *
+ * Distributed model fitted by [[LDA]].
+ * This type of model is currently only produced by Expectation-Maximization 
(EM).
  *
  * This model stores the inferred topics, the full training dataset, and the 
topic distribution
  * for each training document.
+ *
+ * @param oldLocalModelOption  Used to implement [[oldLocalModel]] as a lazy 
val, but keeping
+ *                             [[copy()]] cheap.
  */
 @Since("1.6.0")
 @Experimental
@@ -479,59 +507,39 @@ class DistributedLDAModel private[ml] (
     uid: String,
     vocabSize: Int,
     private val oldDistributedModel: OldDistributedLDAModel,
-    sqlContext: SQLContext)
-  extends LDAModel(uid, vocabSize, None, sqlContext) {
+    sqlContext: SQLContext,
+    private var oldLocalModelOption: Option[OldLocalLDAModel])
+  extends LDAModel(uid, vocabSize, sqlContext) {
+
+  override protected def oldLocalModel: OldLocalLDAModel = {
+    if (oldLocalModelOption.isEmpty) {
+      oldLocalModelOption = Some(oldDistributedModel.toLocal)
+    }
+    oldLocalModelOption.get
+  }
+
+  override protected def getModel: OldLDAModel = oldDistributedModel
 
   /**
    * Convert this distributed model to a local representation.  This discards 
info about the
    * training dataset.
+   *
+   * WARNING: This involves collecting a large [[topicsMatrix]] to the driver.
    */
   @Since("1.6.0")
-  def toLocal: LDAModel = {
-    if (oldLocalModel.isEmpty) {
-      oldLocalModel = Some(oldDistributedModel.toLocal)
-    }
-    new LDAModel(uid, vocabSize, oldLocalModel, sqlContext)
-  }
-
-  @Since("1.6.0")
-  override protected def getModel: OldLDAModel = oldDistributedModel
+  def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, 
oldLocalModel, sqlContext)
 
   @Since("1.6.0")
   override def copy(extra: ParamMap): DistributedLDAModel = {
-    val copied = new DistributedLDAModel(uid, vocabSize, oldDistributedModel, 
sqlContext)
-    if (oldLocalModel.nonEmpty) copied.oldLocalModel = oldLocalModel
+    val copied =
+      new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, 
oldLocalModelOption)
     copyValues(copied, extra).setParent(parent)
     copied
   }
 
   @Since("1.6.0")
-  override def topicsMatrix: Matrix = {
-    if (oldLocalModel.isEmpty) {
-      oldLocalModel = Some(oldDistributedModel.toLocal)
-    }
-    super.topicsMatrix
-  }
-
-  @Since("1.6.0")
   override def isDistributed: Boolean = true
 
-  @Since("1.6.0")
-  override def logLikelihood(dataset: DataFrame): Double = {
-    if (oldLocalModel.isEmpty) {
-      oldLocalModel = Some(oldDistributedModel.toLocal)
-    }
-    super.logLikelihood(dataset)
-  }
-
-  @Since("1.6.0")
-  override def logPerplexity(dataset: DataFrame): Double = {
-    if (oldLocalModel.isEmpty) {
-      oldLocalModel = Some(oldDistributedModel.toLocal)
-    }
-    super.logPerplexity(dataset)
-  }
-
   /**
    * Log likelihood of the observed tokens in the training set,
    * given the current parameter estimates:
@@ -673,9 +681,9 @@ class LDA @Since("1.6.0") (
     val oldModel = oldLDA.run(oldData)
     val newModel = oldModel match {
       case m: OldLocalLDAModel =>
-        new LDAModel(uid, m.vocabSize, Some(m), dataset.sqlContext)
+        new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
       case m: OldDistributedLDAModel =>
-        new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
+        new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None)
     }
     copyValues(newModel).setParent(this)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/f061d253/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index edb9274..b634d31 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -156,7 +156,7 @@ class LDASuite extends SparkFunSuite with 
MLlibTestSparkContext {
 
     MLTestingUtils.checkCopy(model)
 
-    assert(!model.isInstanceOf[DistributedLDAModel])
+    assert(model.isInstanceOf[LocalLDAModel])
     assert(model.vocabSize === vocabSize)
     assert(model.estimatedDocConcentration.size === k)
     assert(model.topicsMatrix.numRows === vocabSize)
@@ -210,7 +210,7 @@ class LDASuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(model.isDistributed)
 
     val localModel = model.toLocal
-    assert(!localModel.isInstanceOf[DistributedLDAModel])
+    assert(localModel.isInstanceOf[LocalLDAModel])
 
     // training logLikelihood, logPrior
     val ll = model.trainingLogLikelihood


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to