Repository: spark
Updated Branches:
  refs/heads/branch-2.0 1fba7595f -> dcbe85ff2


[SPARK-14844][ML] Add setFeaturesCol and setPredictionCol to KMeansM…

## What changes were proposed in this pull request?

Introduction of setFeaturesCol and setPredictionCol methods to KMeansModel in 
ML library.

## How was this patch tested?

By running KMeansSuite.

Author: Dominik Jastrzębski <dominik.jastrzeb...@codilime.com>

Closes #12609 from dominik-jastrzebski/master.

(cherry picked from commit abecbcd5e9598471b705a2f701731af1adc9d48b)
Signed-off-by: Nick Pentreath <ni...@za.ibm.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dcbe85ff
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dcbe85ff
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dcbe85ff

Branch: refs/heads/branch-2.0
Commit: dcbe85ff20736915bba6c3269221c3367ee798c5
Parents: 1fba759
Author: Dominik Jastrzębski <dominik.jastrzeb...@codilime.com>
Authored: Wed May 4 14:25:51 2016 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Wed May 4 14:26:11 2016 +0200

----------------------------------------------------------------------
 .../org/apache/spark/ml/clustering/KMeans.scala      |  8 ++++++++
 .../org/apache/spark/ml/clustering/KMeansSuite.scala | 15 +++++++++++++++
 2 files changed, 23 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/dcbe85ff/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 7c9ac02..42a2539 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -105,6 +105,14 @@ class KMeansModel private[ml] (
     copyValues(copied, extra)
   }
 
+  /** @group setParam */
+  @Since("2.0.0")
+  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+  /** @group setParam */
+  @Since("2.0.0")
+  def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     val predictUDF = udf((vector: Vector) => predict(vector))

http://git-wip-us.apache.org/repos/asf/spark/blob/dcbe85ff/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 2ca386e..241d219 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -117,6 +117,21 @@ class KMeansSuite extends SparkFunSuite with 
MLlibTestSparkContext with DefaultR
     assert(clusterSizes.forall(_ >= 0))
   }
 
+  test("KMeansModel transform with non-default feature and prediction cols") {
+    val featuresColName = "kmeans_model_features"
+    val predictionColName = "kmeans_model_prediction"
+
+    val model = new KMeans().setK(k).setSeed(1).fit(dataset)
+    model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName)
+
+    val transformed = model.transform(dataset.withColumnRenamed("features", 
featuresColName))
+    Seq(featuresColName, predictionColName).foreach { column =>
+      assert(transformed.columns.contains(column))
+    }
+    assert(model.getFeaturesCol == featuresColName)
+    assert(model.getPredictionCol == predictionColName)
+  }
+
   test("read/write") {
     def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
       assert(model.clusterCenters === model2.clusterCenters)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to