Github user felixcheung commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20907#discussion_r178425007
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
---
    @@ -185,6 +187,47 @@ class KMeansModel private[ml] (
       }
     }
     
    +/** Helper class for storing model data */
    +private case class ClusterData(clusterIdx: Int, clusterCenter: Vector)
    +
    +
    +/** A writer for KMeans that handles the "internal" (or default) format */
    +private class InternalKMeansModelWriter extends MLWriterFormat with 
MLFormatRegister {
    +
    +  override def format(): String = "internal"
    +  override def stageName(): String = 
"org.apache.spark.ml.clustering.KMeansModel"
    +
    +  override def write(path: String, sparkSession: SparkSession,
    +    optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
    +    val instance = stage.asInstanceOf[KMeansModel]
    +    val sc = sparkSession.sparkContext
    +    // Save metadata and Params
    +    DefaultParamsWriter.saveMetadata(instance, path, sc)
    +    // Save model data: cluster centers
    +    val data: Array[ClusterData] = 
instance.clusterCenters.zipWithIndex.map {
    +      case (center, idx) =>
    +        ClusterData(idx, center)
    --- End diff --
    
    👍 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to