Repository: spark Updated Branches: refs/heads/branch-1.4 92ccc5ba3 -> 97d4cd074
[SPARK-8049] [MLLIB] drop tmp col from OneVsRest output The temporary column should be dropped after we get the prediction column. harsha2010 Author: Xiangrui Meng <m...@databricks.com> Closes #6592 from mengxr/SPARK-8049 and squashes the following commits: 1d89107 [Xiangrui Meng] use SparkFunSuite 6ee70de [Xiangrui Meng] drop tmp col from OneVsRest output (cherry picked from commit 89f21f66b5549524d1a6e4fb576a4f80d9fef903) Signed-off-by: Xiangrui Meng <m...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/97d4cd07 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/97d4cd07 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/97d4cd07 Branch: refs/heads/branch-1.4 Commit: 97d4cd07406a3d2fd5be83b009988d8bc320b524 Parents: 92ccc5b Author: Xiangrui Meng <m...@databricks.com> Authored: Tue Jun 2 16:51:17 2015 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Tue Jun 2 16:53:26 2015 -0700 ---------------------------------------------------------------------- .../org/apache/spark/ml/classification/OneVsRest.scala | 1 + .../org/apache/spark/ml/classification/OneVsRestSuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/97d4cd07/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 7b726da..825f9ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] ( // output label and label metadata as prediction val labelUdf = callUDF(label, DoubleType, col(accColName)) aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) + .drop(accColName) } } http://git-wip-us.apache.org/repos/asf/spark/blob/97d4cd07/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 770b568..1b354d0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -94,6 +94,15 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) ova.fit(datasetWithLabelMetadata) } + + test("SPARK-8049: OneVsRest shouldn't output temp columns") { + val logReg = new LogisticRegression() + .setMaxIter(1) + val ovr = new OneVsRest() + .setClassifier(logReg) + val output = ovr.fit(dataset).transform(dataset) + assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org