Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/11419#discussion_r57983012
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala ---
    @@ -0,0 +1,291 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.clustering
    +
    +import org.apache.hadoop.fs.Path
    +
    +import org.apache.spark.SparkException
    +import org.apache.spark.annotation.{Experimental, Since}
    +import org.apache.spark.ml.{Estimator, Model}
    +import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
    +import org.apache.spark.ml.param.shared._
    +import org.apache.spark.ml.util._
    +import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, 
GaussianMixtureModel => MLlibGMModel}
    +import org.apache.spark.mllib.linalg._
    +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
    +import org.apache.spark.sql.{DataFrame, Row}
    +import org.apache.spark.sql.functions.{col, udf}
    +import org.apache.spark.sql.types.{IntegerType, StructType}
    +
    +
    +/**
    + * Common params for GaussianMixture and GaussianMixtureModel
    + */
    +private[clustering] trait GaussianMixtureParams extends Params with 
HasMaxIter with HasFeaturesCol
    +  with HasSeed with HasPredictionCol with HasProbabilityCol with HasTol {
    +
    +  /**
    +   * Set the number of clusters to create (k). Must be > 1. Default: 2.
    +   * @group param
    +   */
    +  @Since("2.0.0")
    +  final val k = new IntParam(this, "k", "number of clusters to create", 
(x: Int) => x > 1)
    +
    +  /** @group getParam */
    +  @Since("2.0.0")
    +  def getK: Int = $(k)
    +
    +  /**
    +   * Validates and transforms the input schema.
    +   * @param schema input schema
    +   * @return output schema
    +   */
    +  protected def validateAndTransformSchema(schema: StructType): StructType 
= {
    +    validateParams()
    +    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
    +    SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
    +    SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
    +  }
    +}
    +
    +/**
    + * :: Experimental ::
    + * Model fitted by GaussianMixture.
    + * @param parentModel a model trained by 
spark.mllib.clustering.GaussianMixture.
    + */
    +@Since("2.0.0")
    +@Experimental
    +class GaussianMixtureModel private[ml] (
    +    @Since("2.0.0") override val uid: String,
    +    private val parentModel: MLlibGMModel)
    +  extends Model[GaussianMixtureModel] with GaussianMixtureParams with 
MLWritable {
    +
    +  @Since("2.0.0")
    +  override def copy(extra: ParamMap): GaussianMixtureModel = {
    +    val copied = new GaussianMixtureModel(uid, parentModel)
    +    copyValues(copied, extra)
    +  }
    +
    +  @Since("2.0.0")
    +  override def transform(dataset: DataFrame): DataFrame = {
    +    val predUDF = udf((vector: Vector) => predict(vector))
    +    val probUDF = udf((vector: Vector) => predictProbability(vector))
    +    dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
    +      .withColumn($(probabilityCol), probUDF(col($(featuresCol))))
    +  }
    +
    +  @Since("2.0.0")
    +  override def transformSchema(schema: StructType): StructType = {
    +    validateAndTransformSchema(schema)
    +  }
    +
    +  private[clustering] def predict(features: Vector): Int = 
parentModel.predict(features)
    +
    +  private[clustering] def predictProbability(features: Vector): Vector = {
    +    Vectors.dense(parentModel.predictSoft(features))
    +  }
    +
    +  @Since("2.0.0")
    +  def weights: Array[Double] = parentModel.weights
    +
    +  @Since("2.0.0")
    +  def gaussians: Array[MultivariateGaussian] = parentModel.gaussians
    +
    +  @Since("2.0.0")
    +  override def write: MLWriter = new 
GaussianMixtureModel.GaussianMixtureModelWriter(this)
    +
    +  private var trainingSummary: Option[GaussianMixtureSummary] = None
    +
    +  private[clustering] def setSummary(summary: GaussianMixtureSummary): 
this.type = {
    +    this.trainingSummary = Some(summary)
    +    this
    +  }
    +
    +  /**
    +   * Gets summary of model on training set. An exception is
    +   * thrown if `trainingSummary == None`.
    +   */
    +  @Since("2.0.0")
    +  def summary: GaussianMixtureSummary = trainingSummary match {
    +    case Some(summ) => summ
    +    case None =>
    +      throw new SparkException(
    +        s"No training summary available for the 
${this.getClass.getSimpleName}",
    +        new NullPointerException())
    +  }
    +}
    +
    +@Since("2.0.0")
    +object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
    +
    +  @Since("2.0.0")
    +  override def read: MLReader[GaussianMixtureModel] = new 
GaussianMixtureModelReader
    +
    +  @Since("2.0.0")
    +  override def load(path: String): GaussianMixtureModel = super.load(path)
    +
    +  /** [[MLWriter]] instance for [[GaussianMixtureModel]] */
    +  private[GaussianMixtureModel] class GaussianMixtureModelWriter(
    +      instance: GaussianMixtureModel) extends MLWriter {
    +
    +    private case class Data(weights: Array[Double], mus: Array[Vector], 
sigmas: Array[Matrix])
    +
    +    override protected def saveImpl(path: String): Unit = {
    +      // Save metadata and Params
    +      DefaultParamsWriter.saveMetadata(instance, path, sc)
    +      // Save model data: weights and gaussians
    +      val weights = instance.weights
    +      val gaussians = instance.gaussians
    +      val mus = gaussians.map(_.mu)
    +      val sigmas = gaussians.map(_.sigma)
    +      val data = Data(weights, mus, sigmas)
    +      val dataPath = new Path(path, "data").toString
    +      
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
    +    }
    +  }
    +
    +  private class GaussianMixtureModelReader extends 
MLReader[GaussianMixtureModel] {
    +
    +    /** Checked against metadata when loading model */
    +    private val className = classOf[GaussianMixtureModel].getName
    +
    +    override def load(path: String): GaussianMixtureModel = {
    +      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
    +
    +      val dataPath = new Path(path, "data").toString
    +      val df = sqlContext.read.parquet(dataPath)
    +      val weights = 
df.select("weights").head().get(0).asInstanceOf[Seq[Double]].toArray
    +      val mus = 
df.select("mus").head().get(0).asInstanceOf[Seq[Vector]].toArray
    +      val sigmas = 
df.select("sigmas").head().get(0).asInstanceOf[Seq[Matrix]].toArray
    +      require(mus.length == sigmas.length)
    +      val gaussians = (mus zip sigmas).map {
    +        case (mu, sigma) =>
    +          new MultivariateGaussian(mu, sigma)
    +      }
    +      val model = new GaussianMixtureModel(metadata.uid, new 
MLlibGMModel(weights, gaussians))
    +
    +      DefaultParamsReader.getAndSetParams(model, metadata)
    +      model
    +    }
    +  }
    +}
    +
    +/**
    + * :: Experimental ::
    + * GaussianMixture clustering.
    + */
    +@Since("2.0.0")
    +@Experimental
    +class GaussianMixture @Since("2.0.0") (
    +    @Since("2.0.0") override val uid: String)
    +  extends Estimator[GaussianMixtureModel] with GaussianMixtureParams with 
DefaultParamsWritable {
    +
    +  setDefault(
    +    k -> 2,
    +    maxIter -> 100,
    +    tol -> 0.01)
    +
    +  @Since("2.0.0")
    +  override def copy(extra: ParamMap): GaussianMixture = defaultCopy(extra)
    +
    +  @Since("2.0.0")
    +  def this() = this(Identifiable.randomUID("GaussianMixture"))
    +
    +  /** @group setParam */
    +  @Since("2.0.0")
    +  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
    +
    +  /** @group setParam */
    +  @Since("2.0.0")
    +  def setPredictionCol(value: String): this.type = set(predictionCol, 
value)
    +
    +  /** @group setParam */
    +  @Since("2.0.0")
    +  def setProbabilityCol(value: String): this.type = set(probabilityCol, 
value)
    +
    +  /** @group setParam */
    +  @Since("2.0.0")
    +  def setK(value: Int): this.type = set(k, value)
    +
    +  /** @group setParam */
    +  @Since("2.0.0")
    +  def setMaxIter(value: Int): this.type = set(maxIter, value)
    +
    +  /** @group setParam */
    +  @Since("2.0.0")
    +  def setTol(value: Double): this.type = set(tol, value)
    +
    +  /** @group setParam */
    +  @Since("2.0.0")
    +  def setSeed(value: Long): this.type = set(seed, value)
    +
    +  @Since("2.0.0")
    +  override def fit(dataset: DataFrame): GaussianMixtureModel = {
    +    val rdd = dataset.select(col($(featuresCol))).rdd.map { case 
Row(point: Vector) => point }
    +
    +    val algo = new MLlibGM()
    +      .setK($(k))
    +      .setMaxIterations($(maxIter))
    +      .setSeed($(seed))
    +      .setConvergenceTol($(tol))
    +    val parentModel = algo.run(rdd)
    +    val model = copyValues(new GaussianMixtureModel(uid, 
parentModel).setParent(this))
    +    val summary = new GaussianMixtureSummary(model.transform(dataset),
    +      $(predictionCol), $(probabilityCol), $(featuresCol))
    +    model.setSummary(summary)
    +  }
    +
    +  @Since("2.0.0")
    +  override def transformSchema(schema: StructType): StructType = {
    +    validateAndTransformSchema(schema)
    +  }
    +}
    +
    +@Since("2.0.0")
    +object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
    +
    +  @Since("2.0.0")
    +  override def load(path: String): GaussianMixture = super.load(path)
    +}
    +
    +class GaussianMixtureSummary private[clustering] (
    +    @Since("2.0.0") @transient val predictions: DataFrame,
    +    @Since("2.0.0") val predictionCol: String,
    +    @Since("2.0.0") val probabilityCol: String,
    +    @Since("2.0.0") val featuresCol: String) extends Serializable {
    +
    +  /**
    +   * Cluster centers of the transformed data.
    +   */
    +  @Since("2.0.0")
    +  @transient lazy val cluster: DataFrame = 
predictions.select(predictionCol)
    +
    +  /**
    +   * Probability of eacho cluster.
    +   */
    +  @Since("2.0.0")
    +  @transient lazy val probability: DataFrame = 
predictions.select(probabilityCol)
    +
    +  /**
    +   * Size of each cluster.
    +   */
    +  @Since("2.0.0")
    +  lazy val size: Array[Int] = cluster.rdd.map {
    +    case Row(clusterIdx: Int) => (clusterIdx, 1)
    --- End diff --
    
    Easier to do this within DataFrames:
    ```
    cluster.groupBy(predictionCol).count().select(predictionCol, 
"count").collect().map { case Row(cluster: Int, count: Long) => cluster -> 
count }.toMap
    ```


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to