Repository: spark
Updated Branches:
  refs/heads/branch-2.0 7fbac48f0 -> d99d90982


[SPARK-16750][FOLLOW-UP][ML] Add transformSchema for 
StringIndexer/VectorAssembler and fix failed tests.

## What changes were proposed in this pull request?
This is follow-up for #14378. When we add ```transformSchema``` for all 
estimators and transformers, I found there are tests failed for 
```StringIndexer``` and ```VectorAssembler```. So I moved these parts of work 
separately in this PR, to make it more clear to review.
The corresponding tests should throw ```IllegalArgumentException``` at schema 
validation period after we add ```transformSchema```. It's efficient that to 
throw exception at the start of ```fit``` or ```transform``` rather than during 
the process.

## How was this patch tested?
Modified unit tests.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #14455 from yanboliang/transformSchema.

(cherry picked from commit 6cbde337a539e5bb170d0eb81f715a95ee9c9af3)
Signed-off-by: Sean Owen <so...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d99d9098
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d99d9098
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d99d9098

Branch: refs/heads/branch-2.0
Commit: d99d90982bb7bd2c9783dd007d5e16aa8703df6d
Parents: 7fbac48
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Fri Aug 5 22:07:59 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Fri Aug 5 22:08:08 2016 +0100

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/StringIndexer.scala     |  4 +++-
 .../org/apache/spark/ml/feature/VectorAssembler.scala   |  1 +
 .../apache/spark/ml/feature/StringIndexerSuite.scala    | 12 ++++++++++--
 .../apache/spark/ml/feature/VectorAssemblerSuite.scala  |  4 ++--
 4 files changed, 16 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d99d9098/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index fe79e2e..80fe467 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -85,6 +85,7 @@ class StringIndexer @Since("1.4.0") (
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): StringIndexerModel = {
+    transformSchema(dataset.schema, logging = true)
     val counts = dataset.select(col($(inputCol)).cast(StringType))
       .rdd
       .map(_.getString(0))
@@ -160,7 +161,7 @@ class StringIndexerModel (
         "Skip StringIndexerModel.")
       return dataset.toDF
     }
-    validateAndTransformSchema(dataset.schema)
+    transformSchema(dataset.schema, logging = true)
 
     val indexer = udf { label: String =>
       if (labelToIndex.contains(label)) {
@@ -305,6 +306,7 @@ class IndexToString private[ml] (@Since("1.5.0") override 
val uid: String)
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
     val inputColSchema = dataset.schema($(inputCol))
     // If the labels array is empty use column metadata
     val values = if (!isDefined(labels) || $(labels).isEmpty) {

http://git-wip-us.apache.org/repos/asf/spark/blob/d99d9098/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 142a2ae..ca90053 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -51,6 +51,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
     // Schema transformation.
     val schema = dataset.schema
     lazy val first = dataset.toDF.first()

http://git-wip-us.apache.org/repos/asf/spark/blob/d99d9098/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index c221d4a..b478fea 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -120,12 +120,20 @@ class StringIndexerSuite
 
   test("StringIndexerModel can't overwrite output column") {
     val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
+    intercept[IllegalArgumentException] {
+      new StringIndexer()
+        .setInputCol("input")
+        .setOutputCol("output")
+        .fit(df)
+    }
+
     val indexer = new StringIndexer()
       .setInputCol("input")
-      .setOutputCol("output")
+      .setOutputCol("indexedInput")
       .fit(df)
+
     intercept[IllegalArgumentException] {
-      indexer.transform(df)
+      indexer.setOutputCol("output").transform(df)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d99d9098/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index 14973e7..561493f 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -74,10 +74,10 @@ class VectorAssemblerSuite
     val assembler = new VectorAssembler()
       .setInputCols(Array("a", "b", "c"))
       .setOutputCol("features")
-    val thrown = intercept[SparkException] {
+    val thrown = intercept[IllegalArgumentException] {
       assembler.transform(df)
     }
-    assert(thrown.getMessage contains "VectorAssembler does not support the 
StringType type")
+    assert(thrown.getMessage contains "Data type StringType is not supported")
   }
 
   test("ML attributes") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to