Repository: spark Updated Branches: refs/heads/branch-2.0 1fba7595f -> dcbe85ff2
[SPARK-14844][ML] Add setFeaturesCol and setPredictionCol to KMeansM⦠## What changes were proposed in this pull request? Introduction of setFeaturesCol and setPredictionCol methods to KMeansModel in ML library. ## How was this patch tested? By running KMeansSuite. Author: Dominik JastrzÄbski <dominik.jastrzeb...@codilime.com> Closes #12609 from dominik-jastrzebski/master. (cherry picked from commit abecbcd5e9598471b705a2f701731af1adc9d48b) Signed-off-by: Nick Pentreath <ni...@za.ibm.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dcbe85ff Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dcbe85ff Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dcbe85ff Branch: refs/heads/branch-2.0 Commit: dcbe85ff20736915bba6c3269221c3367ee798c5 Parents: 1fba759 Author: Dominik JastrzÄbski <dominik.jastrzeb...@codilime.com> Authored: Wed May 4 14:25:51 2016 +0200 Committer: Nick Pentreath <ni...@za.ibm.com> Committed: Wed May 4 14:26:11 2016 +0200 ---------------------------------------------------------------------- .../org/apache/spark/ml/clustering/KMeans.scala | 8 ++++++++ .../org/apache/spark/ml/clustering/KMeansSuite.scala | 15 +++++++++++++++ 2 files changed, 23 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/dcbe85ff/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 7c9ac02..42a2539 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -105,6 +105,14 @@ class KMeansModel private[ml] ( copyValues(copied, extra) } + /** @group setParam */ + @Since("2.0.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val predictUDF = udf((vector: Vector) => predict(vector)) http://git-wip-us.apache.org/repos/asf/spark/blob/dcbe85ff/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 2ca386e..241d219 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -117,6 +117,21 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusterSizes.forall(_ >= 0)) } + test("KMeansModel transform with non-default feature and prediction cols") { + val featuresColName = "kmeans_model_features" + val predictionColName = "kmeans_model_prediction" + + val model = new KMeans().setK(k).setSeed(1).fit(dataset) + model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName) + + val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName)) + Seq(featuresColName, predictionColName).foreach { column => + assert(transformed.columns.contains(column)) + } + assert(model.getFeaturesCol == featuresColName) + assert(model.getPredictionCol == predictionColName) + } + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org