Repository: spark Updated Branches: refs/heads/branch-2.0 7fbac48f0 -> d99d90982
[SPARK-16750][FOLLOW-UP][ML] Add transformSchema for StringIndexer/VectorAssembler and fix failed tests. ## What changes were proposed in this pull request? This is follow-up for #14378. When we add ```transformSchema``` for all estimators and transformers, I found there are tests failed for ```StringIndexer``` and ```VectorAssembler```. So I moved these parts of work separately in this PR, to make it more clear to review. The corresponding tests should throw ```IllegalArgumentException``` at schema validation period after we add ```transformSchema```. It's efficient that to throw exception at the start of ```fit``` or ```transform``` rather than during the process. ## How was this patch tested? Modified unit tests. Author: Yanbo Liang <yblia...@gmail.com> Closes #14455 from yanboliang/transformSchema. (cherry picked from commit 6cbde337a539e5bb170d0eb81f715a95ee9c9af3) Signed-off-by: Sean Owen <so...@cloudera.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d99d9098 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d99d9098 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d99d9098 Branch: refs/heads/branch-2.0 Commit: d99d90982bb7bd2c9783dd007d5e16aa8703df6d Parents: 7fbac48 Author: Yanbo Liang <yblia...@gmail.com> Authored: Fri Aug 5 22:07:59 2016 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Fri Aug 5 22:08:08 2016 +0100 ---------------------------------------------------------------------- .../org/apache/spark/ml/feature/StringIndexer.scala | 4 +++- .../org/apache/spark/ml/feature/VectorAssembler.scala | 1 + .../apache/spark/ml/feature/StringIndexerSuite.scala | 12 ++++++++++-- .../apache/spark/ml/feature/VectorAssemblerSuite.scala | 4 ++-- 4 files changed, 16 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d99d9098/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index fe79e2e..80fe467 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -85,6 +85,7 @@ class StringIndexer @Since("1.4.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { + transformSchema(dataset.schema, logging = true) val counts = dataset.select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) @@ -160,7 +161,7 @@ class StringIndexerModel ( "Skip StringIndexerModel.") return dataset.toDF } - validateAndTransformSchema(dataset.schema) + transformSchema(dataset.schema, logging = true) val indexer = udf { label: String => if (labelToIndex.contains(label)) { @@ -305,6 +306,7 @@ class IndexToString private[ml] (@Since("1.5.0") override val uid: String) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata val values = if (!isDefined(labels) || $(labels).isEmpty) { http://git-wip-us.apache.org/repos/asf/spark/blob/d99d9098/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 142a2ae..ca90053 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -51,6 +51,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) // Schema transformation. val schema = dataset.schema lazy val first = dataset.toDF.first() http://git-wip-us.apache.org/repos/asf/spark/blob/d99d9098/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index c221d4a..b478fea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -120,12 +120,20 @@ class StringIndexerSuite test("StringIndexerModel can't overwrite output column") { val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + intercept[IllegalArgumentException] { + new StringIndexer() + .setInputCol("input") + .setOutputCol("output") + .fit(df) + } + val indexer = new StringIndexer() .setInputCol("input") - .setOutputCol("output") + .setOutputCol("indexedInput") .fit(df) + intercept[IllegalArgumentException] { - indexer.transform(df) + indexer.setOutputCol("output").transform(df) } } http://git-wip-us.apache.org/repos/asf/spark/blob/d99d9098/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 14973e7..561493f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -74,10 +74,10 @@ class VectorAssemblerSuite val assembler = new VectorAssembler() .setInputCols(Array("a", "b", "c")) .setOutputCol("features") - val thrown = intercept[SparkException] { + val thrown = intercept[IllegalArgumentException] { assembler.transform(df) } - assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + assert(thrown.getMessage contains "Data type StringType is not supported") } test("ML attributes") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org