Repository: spark Updated Branches: refs/heads/branch-1.6 d59a08f7c -> abe8f991a
[SPARK-12874][ML] ML StringIndexer does not protect itself from column name duplication ## What changes were proposed in this pull request? ML StringIndexer does not protect itself from column name duplication. We should still improve a way to validate a schema of `StringIndexer` and `StringIndexerModel`. However, it would be great to fix at another issue. ## How was this patch tested? unit test Author: Yu ISHIKAWA <yuu.ishik...@gmail.com> Closes #11370 from yu-iskw/SPARK-12874. (cherry picked from commit 14e2700de29d06460179a94cc9816bcd37344cf7) Signed-off-by: Xiangrui Meng <m...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/abe8f991 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/abe8f991 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/abe8f991 Branch: refs/heads/branch-1.6 Commit: abe8f991a32bef92fbbcd2911836bb7d8e61ca97 Parents: d59a08f Author: Yu ISHIKAWA <yuu.ishik...@gmail.com> Authored: Thu Feb 25 13:21:33 2016 -0800 Committer: Xiangrui Meng <m...@databricks.com> Committed: Thu Feb 25 13:23:44 2016 -0800 ---------------------------------------------------------------------- .../org/apache/spark/ml/feature/StringIndexer.scala | 1 + .../org/apache/spark/ml/feature/StringIndexerSuite.scala | 11 +++++++++++ 2 files changed, 12 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/abe8f991/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 5c40c35..b3413a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -149,6 +149,7 @@ class StringIndexerModel ( "Skip StringIndexerModel.") return dataset } + validateAndTransformSchema(dataset.schema) val indexer = udf { label: String => if (labelToIndex.contains(label)) { http://git-wip-us.apache.org/repos/asf/spark/blob/abe8f991/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 749bfac..26f4613 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -118,6 +118,17 @@ class StringIndexerSuite assert(indexerModel.transform(df).eq(df)) } + test("StringIndexerModel can't overwrite output column") { + val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + val indexer = new StringIndexer() + .setInputCol("input") + .setOutputCol("output") + .fit(df) + intercept[IllegalArgumentException] { + indexer.transform(df) + } + } + test("StringIndexer read/write") { val t = new StringIndexer() .setInputCol("myInputCol") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org