Repository: spark
Updated Branches:
  refs/heads/branch-1.4 92ccc5ba3 -> 97d4cd074


[SPARK-8049] [MLLIB] drop tmp col from OneVsRest output

The temporary column should be dropped after we get the prediction column. 
harsha2010

Author: Xiangrui Meng <m...@databricks.com>

Closes #6592 from mengxr/SPARK-8049 and squashes the following commits:

1d89107 [Xiangrui Meng] use SparkFunSuite
6ee70de [Xiangrui Meng] drop tmp col from OneVsRest output

(cherry picked from commit 89f21f66b5549524d1a6e4fb576a4f80d9fef903)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/97d4cd07
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/97d4cd07
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/97d4cd07

Branch: refs/heads/branch-1.4
Commit: 97d4cd07406a3d2fd5be83b009988d8bc320b524
Parents: 92ccc5b
Author: Xiangrui Meng <m...@databricks.com>
Authored: Tue Jun 2 16:51:17 2015 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Tue Jun 2 16:53:26 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/classification/OneVsRest.scala      | 1 +
 .../org/apache/spark/ml/classification/OneVsRestSuite.scala | 9 +++++++++
 2 files changed, 10 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/97d4cd07/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 7b726da..825f9ed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] (
     // output label and label metadata as prediction
     val labelUdf = callUDF(label, DoubleType, col(accColName))
     aggregatedDataset.withColumn($(predictionCol), 
labelUdf.as($(predictionCol), labelMetadata))
+      .drop(accColName)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/97d4cd07/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 770b568..1b354d0 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -94,6 +94,15 @@ class OneVsRestSuite extends FunSuite with 
MLlibTestSparkContext {
     val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
     ova.fit(datasetWithLabelMetadata)
   }
+
+  test("SPARK-8049: OneVsRest shouldn't output temp columns") {
+    val logReg = new LogisticRegression()
+      .setMaxIter(1)
+    val ovr = new OneVsRest()
+      .setClassifier(logReg)
+    val output = ovr.fit(dataset).transform(dataset)
+    assert(output.schema.fieldNames.toSet === Set("label", "features", 
"prediction"))
+  }
 }
 
 private class MockLogisticRegression(uid: String) extends 
LogisticRegression(uid) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to