Repository: spark
Updated Branches:
  refs/heads/master fc4b792d2 -> d9cf9c21f


[SPARK-11912][ML] ml.feature.PCA minor refactor

Like [SPARK-11852](https://issues.apache.org/jira/browse/SPARK-11852), ```k``` 
is params and we should save it under ```metadata/``` rather than both under 
```data/``` and ```metadata/```. Refactor the constructor of 
```ml.feature.PCAModel```  to take only ```pc``` but construct 
```mllib.feature.PCAModel``` inside ```transform```.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #9897 from yanboliang/spark-11912.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d9cf9c21
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d9cf9c21
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d9cf9c21

Branch: refs/heads/master
Commit: d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8
Parents: fc4b792
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Sun Nov 22 21:56:07 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Sun Nov 22 21:56:07 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/feature/PCA.scala | 23 +++++++--------
 .../org/apache/spark/ml/feature/PCASuite.scala  | 31 ++++++++------------
 2 files changed, 24 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d9cf9c21/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 32d7afe..aa88cb0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -73,7 +73,7 @@ class PCA (override val uid: String) extends 
Estimator[PCAModel] with PCAParams
     val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v}
     val pca = new feature.PCA(k = $(k))
     val pcaModel = pca.fit(input)
-    copyValues(new PCAModel(uid, pcaModel).setParent(this))
+    copyValues(new PCAModel(uid, pcaModel.pc).setParent(this))
   }
 
   override def transformSchema(schema: StructType): StructType = {
@@ -99,18 +99,17 @@ object PCA extends DefaultParamsReadable[PCA] {
 /**
  * :: Experimental ::
  * Model fitted by [[PCA]].
+ *
+ * @param pc A principal components Matrix. Each column is one principal 
component.
  */
 @Experimental
 class PCAModel private[ml] (
     override val uid: String,
-    pcaModel: feature.PCAModel)
+    val pc: DenseMatrix)
   extends Model[PCAModel] with PCAParams with MLWritable {
 
   import PCAModel._
 
-  /** a principal components Matrix. Each column is one principal component. */
-  val pc: DenseMatrix = pcaModel.pc
-
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
 
@@ -124,6 +123,7 @@ class PCAModel private[ml] (
    */
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
+    val pcaModel = new feature.PCAModel($(k), pc)
     val pcaOp = udf { pcaModel.transform _ }
     dataset.withColumn($(outputCol), pcaOp(col($(inputCol))))
   }
@@ -139,7 +139,7 @@ class PCAModel private[ml] (
   }
 
   override def copy(extra: ParamMap): PCAModel = {
-    val copied = new PCAModel(uid, pcaModel)
+    val copied = new PCAModel(uid, pc)
     copyValues(copied, extra).setParent(parent)
   }
 
@@ -152,11 +152,11 @@ object PCAModel extends MLReadable[PCAModel] {
 
   private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {
 
-    private case class Data(k: Int, pc: DenseMatrix)
+    private case class Data(pc: DenseMatrix)
 
     override protected def saveImpl(path: String): Unit = {
       DefaultParamsWriter.saveMetadata(instance, path, sc)
-      val data = Data(instance.getK, instance.pc)
+      val data = Data(instance.pc)
       val dataPath = new Path(path, "data").toString
       
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
     }
@@ -169,11 +169,10 @@ object PCAModel extends MLReadable[PCAModel] {
     override def load(path: String): PCAModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
-      val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
-        .select("k", "pc")
+      val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
+        .select("pc")
         .head()
-      val oldModel = new feature.PCAModel(k, pc)
-      val model = new PCAModel(metadata.uid, oldModel)
+      val model = new PCAModel(metadata.uid, pc)
       DefaultParamsReader.getAndSetParams(model, metadata)
       model
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/d9cf9c21/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index 5a21cd2..edab21e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -32,7 +32,7 @@ class PCASuite extends SparkFunSuite with 
MLlibTestSparkContext with DefaultRead
   test("params") {
     ParamsSuite.checkParams(new PCA)
     val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 
3.0)).asInstanceOf[DenseMatrix]
-    val model = new PCAModel("pca", new OldPCAModel(2, mat))
+    val model = new PCAModel("pca", mat)
     ParamsSuite.checkParams(model)
   }
 
@@ -66,23 +66,18 @@ class PCASuite extends SparkFunSuite with 
MLlibTestSparkContext with DefaultRead
     }
   }
 
-  test("read/write") {
+  test("PCA read/write") {
+    val t = new PCA()
+      .setInputCol("myInputCol")
+      .setOutputCol("myOutputCol")
+      .setK(3)
+    testDefaultReadWrite(t)
+  }
 
-    def checkModelData(model1: PCAModel, model2: PCAModel): Unit = {
-      assert(model1.pc === model2.pc)
-    }
-    val allParams: Map[String, Any] = Map(
-      "k" -> 3,
-      "inputCol" -> "features",
-      "outputCol" -> "pca_features"
-    )
-    val data = Seq(
-      (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))),
-      (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)),
-      (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
-    )
-    val df = sqlContext.createDataFrame(data).toDF("id", "features")
-    val pca = new PCA().setK(3)
-    testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData)
+  test("PCAModel read/write") {
+    val instance = new PCAModel("myPCAModel",
+      Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 
3.0)).asInstanceOf[DenseMatrix])
+    val newInstance = testDefaultReadWrite(instance)
+    assert(newInstance.pc === instance.pc)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to