Github user jkbradley commented on a diff in the pull request: https://github.com/apache/spark/pull/11419#discussion_r57983007 --- Diff: mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala --- @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{IntegerType, StructType} + + +/** + * Common params for GaussianMixture and GaussianMixtureModel + */ +private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter with HasFeaturesCol + with HasSeed with HasPredictionCol with HasProbabilityCol with HasTol { + + /** + * Set the number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + @Since("2.0.0") + final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) + + /** @group getParam */ + @Since("2.0.0") + def getK: Int = $(k) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT) + } +} + +/** + * :: Experimental :: + * Model fitted by GaussianMixture. + * @param parentModel a model trained by spark.mllib.clustering.GaussianMixture. + */ +@Since("2.0.0") +@Experimental +class GaussianMixtureModel private[ml] ( + @Since("2.0.0") override val uid: String, + private val parentModel: MLlibGMModel) + extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable { + + @Since("2.0.0") + override def copy(extra: ParamMap): GaussianMixtureModel = { + val copied = new GaussianMixtureModel(uid, parentModel) + copyValues(copied, extra) + } + + @Since("2.0.0") + override def transform(dataset: DataFrame): DataFrame = { + val predUDF = udf((vector: Vector) => predict(vector)) + val probUDF = udf((vector: Vector) => predictProbability(vector)) + dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) + .withColumn($(probabilityCol), probUDF(col($(featuresCol)))) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + + private[clustering] def predictProbability(features: Vector): Vector = { + Vectors.dense(parentModel.predictSoft(features)) + } + + @Since("2.0.0") + def weights: Array[Double] = parentModel.weights + + @Since("2.0.0") + def gaussians: Array[MultivariateGaussian] = parentModel.gaussians + + @Since("2.0.0") + override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this) + + private var trainingSummary: Option[GaussianMixtureSummary] = None + + private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: GaussianMixtureSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + s"No training summary available for the ${this.getClass.getSimpleName}", + new NullPointerException()) + } +} + +@Since("2.0.0") +object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { + + @Since("2.0.0") + override def read: MLReader[GaussianMixtureModel] = new GaussianMixtureModelReader + + @Since("2.0.0") + override def load(path: String): GaussianMixtureModel = super.load(path) + + /** [[MLWriter]] instance for [[GaussianMixtureModel]] */ + private[GaussianMixtureModel] class GaussianMixtureModelWriter( + instance: GaussianMixtureModel) extends MLWriter { + + private case class Data(weights: Array[Double], mus: Array[Vector], sigmas: Array[Matrix]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: weights and gaussians + val weights = instance.weights + val gaussians = instance.gaussians + val mus = gaussians.map(_.mu) + val sigmas = gaussians.map(_.sigma) + val data = Data(weights, mus, sigmas) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class GaussianMixtureModelReader extends MLReader[GaussianMixtureModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GaussianMixtureModel].getName + + override def load(path: String): GaussianMixtureModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val df = sqlContext.read.parquet(dataPath) + val weights = df.select("weights").head().get(0).asInstanceOf[Seq[Double]].toArray + val mus = df.select("mus").head().get(0).asInstanceOf[Seq[Vector]].toArray + val sigmas = df.select("sigmas").head().get(0).asInstanceOf[Seq[Matrix]].toArray + require(mus.length == sigmas.length) + val gaussians = (mus zip sigmas).map { + case (mu, sigma) => + new MultivariateGaussian(mu, sigma) + } + val model = new GaussianMixtureModel(metadata.uid, new MLlibGMModel(weights, gaussians)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + +/** + * :: Experimental :: + * GaussianMixture clustering. + */ +@Since("2.0.0") +@Experimental +class GaussianMixture @Since("2.0.0") ( + @Since("2.0.0") override val uid: String) + extends Estimator[GaussianMixtureModel] with GaussianMixtureParams with DefaultParamsWritable { + + setDefault( + k -> 2, + maxIter -> 100, + tol -> 0.01) + + @Since("2.0.0") + override def copy(extra: ParamMap): GaussianMixture = defaultCopy(extra) + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("GaussianMixture")) + + /** @group setParam */ + @Since("2.0.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setProbabilityCol(value: String): this.type = set(probabilityCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setK(value: Int): this.type = set(k, value) + + /** @group setParam */ + @Since("2.0.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("2.0.0") + def setTol(value: Double): this.type = set(tol, value) + + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("2.0.0") + override def fit(dataset: DataFrame): GaussianMixtureModel = { + val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + + val algo = new MLlibGM() + .setK($(k)) + .setMaxIterations($(maxIter)) + .setSeed($(seed)) + .setConvergenceTol($(tol)) + val parentModel = algo.run(rdd) + val model = copyValues(new GaussianMixtureModel(uid, parentModel).setParent(this)) + val summary = new GaussianMixtureSummary(model.transform(dataset), + $(predictionCol), $(probabilityCol), $(featuresCol)) + model.setSummary(summary) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + +@Since("2.0.0") +object GaussianMixture extends DefaultParamsReadable[GaussianMixture] { + + @Since("2.0.0") + override def load(path: String): GaussianMixture = super.load(path) +} + +class GaussianMixtureSummary private[clustering] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val probabilityCol: String, + @Since("2.0.0") val featuresCol: String) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @Since("2.0.0") + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Probability of eacho cluster. + */ + @Since("2.0.0") + @transient lazy val probability: DataFrame = predictions.select(probabilityCol) + + /** + * Size of each cluster. + */ + @Since("2.0.0") + lazy val size: Array[Int] = cluster.rdd.map { --- End diff -- I'd prefer to call this clusterSizes. I'll make a JIRA for updating KMeansSummary too.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org