Repository: spark
Updated Branches:
  refs/heads/master 0fff8eb3e -> 3e1d120ce


[SPARK-11867] Add save/load for kmeans and naive bayes

https://issues.apache.org/jira/browse/SPARK-11867

Author: Xusen Yin <yinxu...@gmail.com>

Closes #9849 from yinxusen/SPARK-11867.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3e1d120c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3e1d120c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3e1d120c

Branch: refs/heads/master
Commit: 3e1d120cedb4bd9e1595e95d4d531cf61da6684d
Parents: 0fff8eb
Author: Xusen Yin <yinxu...@gmail.com>
Authored: Thu Nov 19 23:43:18 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Thu Nov 19 23:43:18 2015 -0800

----------------------------------------------------------------------
 .../spark/ml/classification/NaiveBayes.scala    | 68 ++++++++++++++++++--
 .../org/apache/spark/ml/clustering/KMeans.scala | 67 +++++++++++++++++--
 .../ml/classification/NaiveBayesSuite.scala     | 47 ++++++++++++--
 .../spark/ml/clustering/KMeansSuite.scala       | 41 +++++++++---
 4 files changed, 195 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3e1d120c/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index a14dcec..c512a2c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -17,12 +17,15 @@
 
 package org.apache.spark.ml.classification
 
+import org.apache.hadoop.fs.Path
+
 import org.apache.spark.SparkException
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, 
ParamValidators}
-import org.apache.spark.ml.util.Identifiable
-import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, 
NaiveBayesModel => OldNaiveBayesModel}
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
+import org.apache.spark.mllib.classification.{NaiveBayesModel => 
OldNaiveBayesModel}
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.rdd.RDD
@@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
 @Experimental
 class NaiveBayes(override val uid: String)
   extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
-  with NaiveBayesParams {
+  with NaiveBayesParams with DefaultParamsWritable {
 
   def this() = this(Identifiable.randomUID("nb"))
 
@@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String)
   override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
 }
 
+@Since("1.6.0")
+object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
+
+  @Since("1.6.0")
+  override def load(path: String): NaiveBayes = super.load(path)
+}
+
 /**
  * :: Experimental ::
  * Model produced by [[NaiveBayes]]
@@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] (
     override val uid: String,
     val pi: Vector,
     val theta: Matrix)
-  extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with 
NaiveBayesParams {
+  extends ProbabilisticClassificationModel[Vector, NaiveBayesModel]
+  with NaiveBayesParams with MLWritable {
 
   import OldNaiveBayes.{Bernoulli, Multinomial}
 
@@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] (
     s"NaiveBayesModel (uid=$uid) with ${pi.size} classes"
   }
 
+  @Since("1.6.0")
+  override def write: MLWriter = new 
NaiveBayesModel.NaiveBayesModelWriter(this)
 }
 
-private[ml] object NaiveBayesModel {
+@Since("1.6.0")
+object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
 
   /** Convert a model from the old API */
-  def fromOld(
+  private[ml] def fromOld(
       oldModel: OldNaiveBayesModel,
       parent: NaiveBayes): NaiveBayesModel = {
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
@@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel {
       oldModel.theta.flatten, true)
     new NaiveBayesModel(uid, pi, theta)
   }
+
+  @Since("1.6.0")
+  override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): NaiveBayesModel = super.load(path)
+
+  /** [[MLWriter]] instance for [[NaiveBayesModel]] */
+  private[NaiveBayesModel] class NaiveBayesModelWriter(instance: 
NaiveBayesModel) extends MLWriter {
+
+    private case class Data(pi: Vector, theta: Matrix)
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: pi, theta
+      val data = Data(instance.pi, instance.theta)
+      val dataPath = new Path(path, "data").toString
+      
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[NaiveBayesModel].getName
+
+    override def load(path: String): NaiveBayesModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head()
+      val pi = data.getAs[Vector](0)
+      val theta = data.getAs[Matrix](1)
+      val model = new NaiveBayesModel(metadata.uid, pi, theta)
+
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3e1d120c/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 509be63..71e9684 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -17,10 +17,12 @@
 
 package org.apache.spark.ml.clustering
 
-import org.apache.spark.annotation.{Since, Experimental}
-import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap}
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.ml.util._
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel 
=> MLlibKMeansModel}
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
@@ -28,7 +30,6 @@ import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{IntegerType, StructType}
 import org.apache.spark.sql.{DataFrame, Row}
 
-
 /**
  * Common params for KMeans and KMeansModel
  */
@@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with 
HasMaxIter with HasFe
 @Experimental
 class KMeansModel private[ml] (
     @Since("1.5.0") override val uid: String,
-    private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with 
KMeansParams {
+    private val parentModel: MLlibKMeansModel)
+  extends Model[KMeansModel] with KMeansParams with MLWritable {
 
   @Since("1.5.0")
   override def copy(extra: ParamMap): KMeansModel = {
@@ -129,6 +131,52 @@ class KMeansModel private[ml] (
     val data = dataset.select(col($(featuresCol))).map { case Row(point: 
Vector) => point }
     parentModel.computeCost(data)
   }
+
+  @Since("1.6.0")
+  override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
+}
+
+@Since("1.6.0")
+object KMeansModel extends MLReadable[KMeansModel] {
+
+  @Since("1.6.0")
+  override def read: MLReader[KMeansModel] = new KMeansModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): KMeansModel = super.load(path)
+
+  /** [[MLWriter]] instance for [[KMeansModel]] */
+  private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends 
MLWriter {
+
+    private case class Data(clusterCenters: Array[Vector])
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: cluster centers
+      val data = Data(instance.clusterCenters)
+      val dataPath = new Path(path, "data").toString
+      
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class KMeansModelReader extends MLReader[KMeansModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[KMeansModel].getName
+
+    override def load(path: String): KMeansModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = 
sqlContext.read.parquet(dataPath).select("clusterCenters").head()
+      val clusterCenters = data.getAs[Seq[Vector]](0).toArray
+      val model = new KMeansModel(metadata.uid, new 
MLlibKMeansModel(clusterCenters))
+
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
 }
 
 /**
@@ -141,7 +189,7 @@ class KMeansModel private[ml] (
 @Experimental
 class KMeans @Since("1.5.0") (
     @Since("1.5.0") override val uid: String)
-  extends Estimator[KMeansModel] with KMeansParams {
+  extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable {
 
   setDefault(
     k -> 2,
@@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") (
   }
 }
 
+@Since("1.6.0")
+object KMeans extends DefaultParamsReadable[KMeans] {
+
+  @Since("1.6.0")
+  override def load(path: String): KMeans = super.load(path)
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/3e1d120c/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 98bc951..082a6bc 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV}
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, 
Bernoulli}
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, 
Multinomial}
+import org.apache.spark.mllib.classification.NaiveBayesSuite._
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.mllib.classification.NaiveBayesSuite._
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{DataFrame, Row}
+
+class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with 
DefaultReadWriteTest {
+
+  @transient var dataset: DataFrame = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+
+    val pi = Array(0.5, 0.1, 0.4).map(math.log)
+    val theta = Array(
+      Array(0.70, 0.10, 0.10, 0.10), // label 0
+      Array(0.10, 0.70, 0.10, 0.10), // label 1
+      Array(0.10, 0.10, 0.70, 0.10)  // label 2
+    ).map(_.map(math.log))
 
-class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
+    dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 
100, 42))
+  }
 
   def validatePrediction(predictionAndLabels: DataFrame): Unit = {
     val numOfErrorPredictions = predictionAndLabels.collect().count {
@@ -161,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       .select("features", "probability")
     validateProbabilities(featureAndProbabilities, model, "bernoulli")
   }
+
+  test("read/write") {
+    def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit 
= {
+      assert(model.pi === model2.pi)
+      assert(model.theta === model2.theta)
+    }
+    val nb = new NaiveBayes()
+    testEstimatorAndModelReadWrite(nb, dataset, 
NaiveBayesSuite.allParamSettings, checkModelData)
+  }
+}
+
+object NaiveBayesSuite {
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as 
save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = Map(
+    "predictionCol" -> "myPrediction",
+    "smoothing" -> 0.1
+  )
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3e1d120c/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index c05f905..2724e51 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.clustering
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -25,16 +26,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
 
 private[clustering] case class TestRow(features: Vector)
 
-object KMeansSuite {
-  def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): 
DataFrame = {
-    val sc = sql.sparkContext
-    val rdd = sc.parallelize(1 to rows).map(i => 
Vectors.dense(Array.fill(dim)((i % k).toDouble)))
-      .map(v => new TestRow(v))
-    sql.createDataFrame(rdd)
-  }
-}
-
-class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
+class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with 
DefaultReadWriteTest {
 
   final val k = 5
   @transient var dataset: DataFrame = _
@@ -106,4 +98,33 @@ class KMeansSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(clusters === Set(0, 1, 2, 3, 4))
     assert(model.computeCost(dataset) < 0.1)
   }
+
+  test("read/write") {
+    def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
+      assert(model.clusterCenters === model2.clusterCenters)
+    }
+    val kmeans = new KMeans()
+    testEstimatorAndModelReadWrite(kmeans, dataset, 
KMeansSuite.allParamSettings, checkModelData)
+  }
+}
+
+object KMeansSuite {
+  def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): 
DataFrame = {
+    val sc = sql.sparkContext
+    val rdd = sc.parallelize(1 to rows).map(i => 
Vectors.dense(Array.fill(dim)((i % k).toDouble)))
+      .map(v => new TestRow(v))
+    sql.createDataFrame(rdd)
+  }
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as 
save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = Map(
+    "predictionCol" -> "myPrediction",
+    "k" -> 3,
+    "maxIter" -> 2,
+    "tol" -> 0.01
+  )
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to