Repository: spark
Updated Branches:
  refs/heads/branch-2.3 f3efbfa4b -> 0663b6119


http://git-wip-us.apache.org/repos/asf/spark/blob/0663b611/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 775a04d..df24367 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,17 +17,14 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.types.{DoubleType, StringType, StructField, 
StructType}
 
-class StringIndexerSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -46,19 +43,23 @@ class StringIndexerSuite
       .setInputCol("label")
       .setOutputCol("labelIndex")
     val indexerModel = indexer.fit(df)
-
     MLTestingUtils.checkCopyAndUids(indexer, indexerModel)
-
-    val transformed = indexerModel.transform(df)
-    val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
-      .asInstanceOf[NominalAttribute]
-    assert(attr.values.get === Array("a", "c", "b"))
-    val output = transformed.select("id", "labelIndex").rdd.map { r =>
-      (r.getInt(0), r.getDouble(1))
-    }.collect().toSet
     // a -> 0, b -> 2, c -> 1
-    val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 
1.0))
-    assert(output === expected)
+    val expected = Seq(
+      (0, 0.0),
+      (1, 2.0),
+      (2, 1.0),
+      (3, 0.0),
+      (4, 0.0),
+       (5, 1.0)
+    ).toDF("id", "labelIndex")
+
+    testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", 
"labelIndex") { rows =>
+      val attr = Attribute.fromStructField(rows.head.schema("labelIndex"))
+        .asInstanceOf[NominalAttribute]
+      assert(attr.values.get === Array("a", "c", "b"))
+      assert(rows.seq === expected.collect().toSeq)
+    }
   }
 
   test("StringIndexerUnseen") {
@@ -70,36 +71,38 @@ class StringIndexerSuite
       .setInputCol("label")
       .setOutputCol("labelIndex")
       .fit(df)
+
     // Verify we throw by default with unseen values
-    intercept[SparkException] {
-      indexer.transform(df2).collect()
-    }
+    testTransformerByInterceptingException[(Int, String)](
+      df2,
+      indexer,
+      "Unseen label:",
+      "labelIndex")
 
-    indexer.setHandleInvalid("skip")
     // Verify that we skip the c record
-    val transformedSkip = indexer.transform(df2)
-    val attrSkip = 
Attribute.fromStructField(transformedSkip.schema("labelIndex"))
-      .asInstanceOf[NominalAttribute]
-    assert(attrSkip.values.get === Array("b", "a"))
-    val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
-      (r.getInt(0), r.getDouble(1))
-    }.collect().toSet
     // a -> 1, b -> 0
-    val expectedSkip = Set((0, 1.0), (1, 0.0))
-    assert(outputSkip === expectedSkip)
+    indexer.setHandleInvalid("skip")
+
+    val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF()
+    testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", 
"labelIndex") { rows =>
+      val attrSkip = Attribute.fromStructField(rows.head.schema("labelIndex"))
+        .asInstanceOf[NominalAttribute]
+      assert(attrSkip.values.get === Array("b", "a"))
+      assert(rows.seq === expectedSkip.collect().toSeq)
+    }
 
     indexer.setHandleInvalid("keep")
-    // Verify that we keep the unseen records
-    val transformedKeep = indexer.transform(df2)
-    val attrKeep = 
Attribute.fromStructField(transformedKeep.schema("labelIndex"))
-      .asInstanceOf[NominalAttribute]
-    assert(attrKeep.values.get === Array("b", "a", "__unknown"))
-    val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
-      (r.getInt(0), r.getDouble(1))
-    }.collect().toSet
+
     // a -> 1, b -> 0, c -> 2, d -> 3
-    val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0))
-    assert(outputKeep === expectedKeep)
+    val expectedKeep = Seq((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)).toDF()
+
+    // Verify that we keep the unseen records
+    testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", 
"labelIndex") { rows =>
+      val attrKeep = Attribute.fromStructField(rows.head.schema("labelIndex"))
+        .asInstanceOf[NominalAttribute]
+      assert(attrKeep.values.get === Array("b", "a", "__unknown"))
+      assert(rows === expectedKeep.collect().toSeq)
+    }
   }
 
   test("StringIndexer with a numeric input column") {
@@ -109,16 +112,14 @@ class StringIndexerSuite
       .setInputCol("label")
       .setOutputCol("labelIndex")
       .fit(df)
-    val transformed = indexer.transform(df)
-    val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
-      .asInstanceOf[NominalAttribute]
-    assert(attr.values.get === Array("100", "300", "200"))
-    val output = transformed.select("id", "labelIndex").rdd.map { r =>
-      (r.getInt(0), r.getDouble(1))
-    }.collect().toSet
     // 100 -> 0, 200 -> 2, 300 -> 1
-    val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 
1.0))
-    assert(output === expected)
+    val expected = Seq((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 
1.0)).toDF()
+    testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "id", 
"labelIndex") { rows =>
+      val attr = Attribute.fromStructField(rows.head.schema("labelIndex"))
+        .asInstanceOf[NominalAttribute]
+      assert(attr.values.get === Array("100", "300", "200"))
+      assert(rows === expected.collect().toSeq)
+    }
   }
 
   test("StringIndexer with NULLs") {
@@ -133,37 +134,36 @@ class StringIndexerSuite
 
     withClue("StringIndexer should throw error when setHandleInvalid=error " +
       "when given NULL values") {
-      intercept[SparkException] {
-        indexer.setHandleInvalid("error")
-        indexer.fit(df).transform(df2).collect()
-      }
+      indexer.setHandleInvalid("error")
+      testTransformerByInterceptingException[(Int, String)](
+        df2,
+        indexer.fit(df),
+        "StringIndexer encountered NULL value.",
+        "labelIndex")
     }
 
     indexer.setHandleInvalid("skip")
-    val transformedSkip = indexer.fit(df).transform(df2)
-    val attrSkip = Attribute
-      .fromStructField(transformedSkip.schema("labelIndex"))
-      .asInstanceOf[NominalAttribute]
-    assert(attrSkip.values.get === Array("b", "a"))
-    val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
-      (r.getInt(0), r.getDouble(1))
-    }.collect().toSet
+    val modelSkip = indexer.fit(df)
     // a -> 1, b -> 0
-    val expectedSkip = Set((0, 1.0), (1, 0.0))
-    assert(outputSkip === expectedSkip)
+    val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF()
+    testTransformerByGlobalCheckFunc[(Int, String)](df2, modelSkip, "id", 
"labelIndex") { rows =>
+      val attrSkip =
+        
Attribute.fromStructField(rows.head.schema("labelIndex")).asInstanceOf[NominalAttribute]
+      assert(attrSkip.values.get === Array("b", "a"))
+      assert(rows === expectedSkip.collect().toSeq)
+    }
 
     indexer.setHandleInvalid("keep")
-    val transformedKeep = indexer.fit(df).transform(df2)
-    val attrKeep = Attribute
-      .fromStructField(transformedKeep.schema("labelIndex"))
-      .asInstanceOf[NominalAttribute]
-    assert(attrKeep.values.get === Array("b", "a", "__unknown"))
-    val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
-      (r.getInt(0), r.getDouble(1))
-    }.collect().toSet
     // a -> 1, b -> 0, null -> 2
-    val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0))
-    assert(outputKeep === expectedKeep)
+    val expectedKeep = Seq((0, 1.0), (1, 0.0), (3, 2.0)).toDF()
+    val modelKeep = indexer.fit(df)
+    testTransformerByGlobalCheckFunc[(Int, String)](df2, modelKeep, "id", 
"labelIndex") { rows =>
+      val attrKeep = Attribute
+        .fromStructField(rows.head.schema("labelIndex"))
+        .asInstanceOf[NominalAttribute]
+      assert(attrKeep.values.get === Array("b", "a", "__unknown"))
+      assert(rows === expectedKeep.collect().toSeq)
+    }
   }
 
   test("StringIndexerModel should keep silent if the input column does not 
exist.") {
@@ -171,7 +171,9 @@ class StringIndexerSuite
       .setInputCol("label")
       .setOutputCol("labelIndex")
     val df = spark.range(0L, 10L).toDF()
-    assert(indexerModel.transform(df).collect().toSet === df.collect().toSet)
+    testTransformerByGlobalCheckFunc[Long](df, indexerModel, "id") { rows =>
+      assert(rows.toSet === df.collect().toSet)
+    }
   }
 
   test("StringIndexerModel can't overwrite output column") {
@@ -188,9 +190,12 @@ class StringIndexerSuite
       .setOutputCol("indexedInput")
       .fit(df)
 
-    intercept[IllegalArgumentException] {
-      indexer.setOutputCol("output").transform(df)
-    }
+    testTransformerByInterceptingException[(Int, String)](
+      df,
+      indexer.setOutputCol("output"),
+      "Output column output already exists.",
+      "labelIndex")
+
   }
 
   test("StringIndexer read/write") {
@@ -223,7 +228,8 @@ class StringIndexerSuite
       .setInputCol("index")
       .setOutputCol("actual")
       .setLabels(labels)
-    idxToStr0.transform(df0).select("actual", "expected").collect().foreach {
+
+    testTransformer[(Int, String)](df0, idxToStr0, "actual", "expected") {
       case Row(actual, expected) =>
         assert(actual === expected)
     }
@@ -234,7 +240,8 @@ class StringIndexerSuite
     val idxToStr1 = new IndexToString()
       .setInputCol("indexWithAttr")
       .setOutputCol("actual")
-    idxToStr1.transform(df1).select("actual", "expected").collect().foreach {
+
+    testTransformer[(Int, String)](df1, idxToStr1, "actual", "expected") {
       case Row(actual, expected) =>
         assert(actual === expected)
     }
@@ -252,9 +259,10 @@ class StringIndexerSuite
       .setInputCol("labelIndex")
       .setOutputCol("sameLabel")
       .setLabels(indexer.labels)
-    idx2str.transform(transformed).select("label", 
"sameLabel").collect().foreach {
-      case Row(a: String, b: String) =>
-        assert(a === b)
+
+    testTransformer[(Int, String, Double)](transformed, idx2str, "sameLabel", 
"label") {
+      case Row(sameLabel, label) =>
+        assert(sameLabel === label)
     }
   }
 
@@ -286,10 +294,11 @@ class StringIndexerSuite
       .setInputCol("label")
       .setOutputCol("labelIndex")
       .fit(df)
-    val transformed = indexer.transform(df)
-    val attrs =
-      NominalAttribute.decodeStructField(transformed.schema("labelIndex"), 
preserveName = true)
-    assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
+    testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "labelIndex") 
{ rows =>
+      val attrs =
+        NominalAttribute.decodeStructField(rows.head.schema("labelIndex"), 
preserveName = true)
+      assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
+    }
   }
 
   test("StringIndexer order types") {
@@ -299,18 +308,17 @@ class StringIndexerSuite
       .setInputCol("label")
       .setOutputCol("labelIndex")
 
-    val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), 
(5, 0.0)),
-      Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)),
-      Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)),
-      Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0)))
+    val expected = Seq(Seq((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), 
(5, 0.0)),
+      Seq((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)),
+      Seq((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)),
+      Seq((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0)))
 
     var idx = 0
     for (orderType <- StringIndexer.supportedStringOrderType) {
-      val transformed = 
indexer.setStringOrderType(orderType).fit(df).transform(df)
-      val output = transformed.select("id", "labelIndex").rdd.map { r =>
-        (r.getInt(0), r.getDouble(1))
-      }.collect().toSet
-      assert(output === expected(idx))
+      val model = indexer.setStringOrderType(orderType).fit(df)
+      testTransformerByGlobalCheckFunc[(Int, String)](df, model, "id", 
"labelIndex") { rows =>
+        assert(rows === expected(idx).toDF().collect().toSeq)
+      }
       idx += 1
     }
   }
@@ -328,7 +336,11 @@ class StringIndexerSuite
       .setOutputCol("CITYIndexed")
       .fit(dfNoBristol)
 
-    val dfWithIndex = model.transform(dfNoBristol)
-    assert(dfWithIndex.filter($"CITYIndexed" === 1.0).count == 1)
+    testTransformerByGlobalCheckFunc[(String, String, String)](
+      dfNoBristol,
+      model,
+      "CITYIndexed") { rows =>
+      assert(rows.toList.count(_.getDouble(0) == 1.0) === 1)
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0663b611/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index c895659..be59b0a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -19,16 +19,14 @@ package org.apache.spark.ml.feature
 
 import scala.beans.BeanInfo
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.{DataFrame, Row}
 
 @BeanInfo
 case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
 
-class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with 
DefaultReadWriteTest {
+class TokenizerSuite extends MLTest with DefaultReadWriteTest {
 
   test("params") {
     ParamsSuite.checkParams(new Tokenizer)
@@ -42,12 +40,17 @@ class TokenizerSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defau
   }
 }
 
-class RegexTokenizerSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class RegexTokenizerSuite extends MLTest with DefaultReadWriteTest {
 
-  import org.apache.spark.ml.feature.RegexTokenizerSuite._
   import testImplicits._
 
+  def testRegexTokenizer(t: RegexTokenizer, dataframe: DataFrame): Unit = {
+    testTransformer[(String, Seq[String])](dataframe, t, "tokens", 
"wantedTokens") {
+      case Row(tokens, wantedTokens) =>
+        assert(tokens === wantedTokens)
+    }
+  }
+
   test("params") {
     ParamsSuite.checkParams(new RegexTokenizer)
   }
@@ -105,14 +108,3 @@ class RegexTokenizerSuite
   }
 }
 
-object RegexTokenizerSuite extends SparkFunSuite {
-
-  def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = {
-    t.transform(dataset)
-      .select("tokens", "wantedTokens")
-      .collect()
-      .foreach { case Row(tokens, wantedTokens) =>
-        assert(tokens === wantedTokens)
-      }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/0663b611/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 69a7b75..e5675e3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -19,18 +19,16 @@ package org.apache.spark.ml.feature
 
 import scala.beans.{BeanInfo, BeanProperty}
 
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.SparkException
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
 
-class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
-  with DefaultReadWriteTest with Logging {
+class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging 
{
 
   import testImplicits._
   import VectorIndexerSuite.FeatureData
@@ -128,18 +126,27 @@ class VectorIndexerSuite extends SparkFunSuite with 
MLlibTestSparkContext
 
     MLTestingUtils.checkCopyAndUids(vectorIndexer, model)
 
-    model.transform(densePoints1) // should work
-    model.transform(sparsePoints1) // should work
+    testTransformer[FeatureData](densePoints1, model, "indexed") { _ => }
+    testTransformer[FeatureData](sparsePoints1, model, "indexed") { _ => }
+
     // If the data is local Dataset, it throws AssertionError directly.
-    intercept[AssertionError] {
-      model.transform(densePoints2).collect()
-      logInfo("Did not throw error when fit, transform were called on vectors 
of different lengths")
+    withClue("Did not throw error when fit, transform were called on " +
+      "vectors of different lengths") {
+      testTransformerByInterceptingException[FeatureData](
+        densePoints2,
+        model,
+        "VectorIndexerModel expected vector of length 3 but found length 4",
+        "indexed")
     }
     // If the data is distributed Dataset, it throws SparkException
     // which is the wrapper of AssertionError.
-    intercept[SparkException] {
-      model.transform(densePoints2.repartition(2)).collect()
-      logInfo("Did not throw error when fit, transform were called on vectors 
of different lengths")
+    withClue("Did not throw error when fit, transform were called " +
+      "on vectors of different lengths") {
+      testTransformerByInterceptingException[FeatureData](
+        densePoints2.repartition(2),
+        model,
+        "VectorIndexerModel expected vector of length 3 but found length 4",
+        "indexed")
     }
     intercept[SparkException] {
       vectorIndexer.fit(badPoints)
@@ -178,46 +185,48 @@ class VectorIndexerSuite extends SparkFunSuite with 
MLlibTestSparkContext
         val categoryMaps = model.categoryMaps
         // Chose correct categorical features
         assert(categoryMaps.keys.toSet === categoricalFeatures)
-        val transformed = model.transform(data).select("indexed")
-        val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0))
-        val featureAttrs = 
AttributeGroup.fromStructField(transformed.schema("indexed"))
-        assert(featureAttrs.name === "indexed")
-        assert(featureAttrs.attributes.get.length === model.numFeatures)
-        categoricalFeatures.foreach { feature: Int =>
-          val origValueSet = collectedData.map(_(feature)).toSet
-          val targetValueIndexSet = Range(0, origValueSet.size).toSet
-          val catMap = categoryMaps(feature)
-          assert(catMap.keys.toSet === origValueSet) // Correct categories
-          assert(catMap.values.toSet === targetValueIndexSet) // Correct 
category indices
-          if (origValueSet.contains(0.0)) {
-            assert(catMap(0.0) === 0) // value 0 gets index 0
-          }
-          // Check transformed data
-          assert(indexedRDD.map(_(feature)).collect().toSet === 
targetValueIndexSet)
-          // Check metadata
-          val featureAttr = featureAttrs(feature)
-          assert(featureAttr.index.get === feature)
-          featureAttr match {
-            case attr: BinaryAttribute =>
-              assert(attr.values.get === 
origValueSet.toArray.sorted.map(_.toString))
-            case attr: NominalAttribute =>
-              assert(attr.values.get === 
origValueSet.toArray.sorted.map(_.toString))
-              assert(attr.isOrdinal.get === false)
-            case _ =>
-              throw new RuntimeException(errMsg + s". Categorical feature 
$feature failed" +
-                s" metadata check. Found feature attribute: $featureAttr.")
+        testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") 
{ rows =>
+          val transformed = rows.map { r => Tuple1(r.getAs[Vector](0)) 
}.toDF("indexed")
+          val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0))
+          val featureAttrs = 
AttributeGroup.fromStructField(rows.head.schema("indexed"))
+          assert(featureAttrs.name === "indexed")
+          assert(featureAttrs.attributes.get.length === model.numFeatures)
+          categoricalFeatures.foreach { feature: Int =>
+            val origValueSet = collectedData.map(_(feature)).toSet
+            val targetValueIndexSet = Range(0, origValueSet.size).toSet
+            val catMap = categoryMaps(feature)
+            assert(catMap.keys.toSet === origValueSet) // Correct categories
+            assert(catMap.values.toSet === targetValueIndexSet) // Correct 
category indices
+            if (origValueSet.contains(0.0)) {
+              assert(catMap(0.0) === 0) // value 0 gets index 0
+            }
+            // Check transformed data
+            assert(indexedRDD.map(_(feature)).collect().toSet === 
targetValueIndexSet)
+            // Check metadata
+            val featureAttr = featureAttrs(feature)
+            assert(featureAttr.index.get === feature)
+            featureAttr match {
+              case attr: BinaryAttribute =>
+                assert(attr.values.get === 
origValueSet.toArray.sorted.map(_.toString))
+              case attr: NominalAttribute =>
+                assert(attr.values.get === 
origValueSet.toArray.sorted.map(_.toString))
+                assert(attr.isOrdinal.get === false)
+              case _ =>
+                throw new RuntimeException(errMsg + s". Categorical feature 
$feature failed" +
+                  s" metadata check. Found feature attribute: $featureAttr.")
+            }
           }
-        }
-        // Check numerical feature metadata.
-        Range(0, model.numFeatures).filter(feature => 
!categoricalFeatures.contains(feature))
-          .foreach { feature: Int =>
-          val featureAttr = featureAttrs(feature)
-          featureAttr match {
-            case attr: NumericAttribute =>
-              assert(featureAttr.index.get === feature)
-            case _ =>
-              throw new RuntimeException(errMsg + s". Numerical feature 
$feature failed" +
-                s" metadata check. Found feature attribute: $featureAttr.")
+          // Check numerical feature metadata.
+          Range(0, model.numFeatures).filter(feature => 
!categoricalFeatures.contains(feature))
+            .foreach { feature: Int =>
+            val featureAttr = featureAttrs(feature)
+            featureAttr match {
+              case attr: NumericAttribute =>
+                assert(featureAttr.index.get === feature)
+              case _ =>
+                throw new RuntimeException(errMsg + s". Numerical feature 
$feature failed" +
+                  s" metadata check. Found feature attribute: $featureAttr.")
+            }
           }
         }
       } catch {
@@ -236,25 +245,32 @@ class VectorIndexerSuite extends SparkFunSuite with 
MLlibTestSparkContext
       (sparsePoints1, sparsePoints1TestInvalid))) {
       val vectorIndexer = 
getIndexer.setMaxCategories(4).setHandleInvalid("error")
       val model = vectorIndexer.fit(points)
-      intercept[SparkException] {
-        model.transform(pointsTestInvalid).collect()
-      }
+      testTransformerByInterceptingException[FeatureData](
+        pointsTestInvalid,
+        model,
+        "VectorIndexer encountered invalid value",
+        "indexed")
       val vectorIndexer1 = 
getIndexer.setMaxCategories(4).setHandleInvalid("skip")
       val model1 = vectorIndexer1.fit(points)
-      val invalidTransformed1 = 
model1.transform(pointsTestInvalid).select("indexed")
-        .collect().map(_(0))
-      val transformed1 = 
model1.transform(points).select("indexed").collect().map(_(0))
-      assert(transformed1 === invalidTransformed1)
-
+      val expected = Seq(
+        Vectors.dense(1.0, 2.0, 0.0),
+        Vectors.dense(0.0, 1.0, 2.0),
+        Vectors.dense(0.0, 0.0, 1.0),
+        Vectors.dense(1.0, 3.0, 2.0))
+      testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model1, 
"indexed") { rows =>
+        assert(rows.map(_(0)) == expected)
+      }
+      testTransformerByGlobalCheckFunc[FeatureData](points, model1, "indexed") 
{ rows =>
+        assert(rows.map(_(0)) == expected)
+      }
       val vectorIndexer2 = 
getIndexer.setMaxCategories(4).setHandleInvalid("keep")
       val model2 = vectorIndexer2.fit(points)
-      val invalidTransformed2 = 
model2.transform(pointsTestInvalid).select("indexed")
-        .collect().map(_(0))
-      assert(invalidTransformed2 === transformed1 ++ Array(
-        Vectors.dense(2.0, 2.0, 0.0),
-        Vectors.dense(0.0, 4.0, 2.0),
-        Vectors.dense(1.0, 3.0, 3.0))
-      )
+      testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model2, 
"indexed") { rows =>
+        assert(rows.map(_(0)) == expected ++ Array(
+          Vectors.dense(2.0, 2.0, 0.0),
+          Vectors dense(0.0, 4.0, 2.0),
+          Vectors.dense(1.0, 3.0, 3.0)))
+      }
     }
   }
 
@@ -263,12 +279,12 @@ class VectorIndexerSuite extends SparkFunSuite with 
MLlibTestSparkContext
       val points = data.collect().map(_.getAs[Vector](0))
       val vectorIndexer = getIndexer.setMaxCategories(maxCategories)
       val model = vectorIndexer.fit(data)
-      val indexedPoints =
-        
model.transform(data).select("indexed").rdd.map(_.getAs[Vector](0)).collect()
-      points.zip(indexedPoints).foreach {
-        case (orig: SparseVector, indexed: SparseVector) =>
-          assert(orig.indices.length == indexed.indices.length)
-        case _ => throw new UnknownError("Unit test has a bug in it.") // 
should never happen
+      testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { 
rows =>
+        points.zip(rows.map(_(0))).foreach {
+          case (orig: SparseVector, indexed: SparseVector) =>
+            assert(orig.indices.length == indexed.indices.length)
+          case _ => throw new UnknownError("Unit test has a bug in it.") // 
should never happen
+        }
       }
     }
     checkSparsity(sparsePoints1, maxCategories = 2)
@@ -286,17 +302,18 @@ class VectorIndexerSuite extends SparkFunSuite with 
MLlibTestSparkContext
     val vectorIndexer = getIndexer.setMaxCategories(2)
     val model = vectorIndexer.fit(densePoints1WithMeta)
     // Check that ML metadata are preserved.
-    val indexedPoints = model.transform(densePoints1WithMeta)
-    val transAttributes: Array[Attribute] =
-      
AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get
-    featureAttributes.zip(transAttributes).foreach { case (orig, trans) =>
-      assert(orig.name === trans.name)
-      (orig, trans) match {
-        case (orig: NumericAttribute, trans: NumericAttribute) =>
-          assert(orig.max.nonEmpty && orig.max === trans.max)
-        case _ =>
+    testTransformerByGlobalCheckFunc[FeatureData](densePoints1WithMeta, model, 
"indexed") { rows =>
+      val transAttributes: Array[Attribute] =
+        
AttributeGroup.fromStructField(rows.head.schema("indexed")).attributes.get
+      featureAttributes.zip(transAttributes).foreach { case (orig, trans) =>
+        assert(orig.name === trans.name)
+        (orig, trans) match {
+          case (orig: NumericAttribute, trans: NumericAttribute) =>
+            assert(orig.max.nonEmpty && orig.max === trans.max)
+          case _ =>
           // do nothing
           // TODO: Once input features marked as categorical are handled 
correctly, check that here.
+        }
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0663b611/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala
index f6c9a76..d89d10b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala
@@ -17,17 +17,15 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.Pipeline
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import org.apache.spark.sql.streaming.StreamTest
 
 class VectorSizeHintSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+  extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -40,16 +38,23 @@ class VectorSizeHintSuite
     val data = Seq((Vectors.dense(1, 2), 0)).toDF("vector", "intValue")
 
     val noSizeTransformer = new VectorSizeHint().setInputCol("vector")
-    intercept[NoSuchElementException] (noSizeTransformer.transform(data))
+    testTransformerByInterceptingException[(Vector, Int)](
+      data,
+      noSizeTransformer,
+      "Failed to find a default value for size",
+      "vector")
     intercept[NoSuchElementException] 
(noSizeTransformer.transformSchema(data.schema))
 
     val noInputColTransformer = new VectorSizeHint().setSize(2)
-    intercept[NoSuchElementException] (noInputColTransformer.transform(data))
+    testTransformerByInterceptingException[(Vector, Int)](
+      data,
+      noInputColTransformer,
+      "Failed to find a default value for inputCol",
+      "vector")
     intercept[NoSuchElementException] 
(noInputColTransformer.transformSchema(data.schema))
   }
 
   test("Adding size to column of vectors.") {
-
     val size = 3
     val vectorColName = "vector"
     val denseVector = Vectors.dense(1, 2, 3)
@@ -66,12 +71,15 @@ class VectorSizeHintSuite
         .setInputCol(vectorColName)
         .setSize(size)
         .setHandleInvalid(handleInvalid)
-      val withSize = transformer.transform(dataFrame)
-      assert(
-        AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == 
size,
-        "Transformer did not add expected size data.")
-      val numRows = withSize.collect().length
-      assert(numRows === data.length, s"Expecting ${data.length} rows, got 
$numRows.")
+      testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataFrame, transformer, 
vectorColName) {
+        rows => {
+          assert(
+            
AttributeGroup.fromStructField(rows.head.schema(vectorColName)).size == size,
+            "Transformer did not add expected size data.")
+          val numRows = rows.length
+          assert(numRows === data.length, s"Expecting ${data.length} rows, got 
$numRows.")
+        }
+      }
     }
   }
 
@@ -93,14 +101,16 @@ class VectorSizeHintSuite
         .setInputCol(vectorColName)
         .setSize(size)
         .setHandleInvalid(handleInvalid)
-      val withSize = transformer.transform(dataFrameWithMetadata)
-
-      val newGroup = 
AttributeGroup.fromStructField(withSize.schema(vectorColName))
-      assert(newGroup.size === size, "Column has incorrect size metadata.")
-      assert(
-        newGroup.attributes.get === group.attributes.get,
-        "VectorSizeHint did not preserve attributes.")
-      withSize.collect
+      testTransformerByGlobalCheckFunc[(Int, Int, Int, Vector)](
+        dataFrameWithMetadata,
+        transformer,
+        vectorColName) { rows =>
+          val newGroup = 
AttributeGroup.fromStructField(rows.head.schema(vectorColName))
+          assert(newGroup.size === size, "Column has incorrect size metadata.")
+          assert(
+            newGroup.attributes.get === group.attributes.get,
+            "VectorSizeHint did not preserve attributes.")
+      }
     }
   }
 
@@ -120,7 +130,11 @@ class VectorSizeHintSuite
         .setInputCol(vectorColName)
         .setSize(size)
         .setHandleInvalid(handleInvalid)
-      
intercept[IllegalArgumentException](transformer.transform(dataFrameWithMetadata))
+      testTransformerByInterceptingException[(Int, Int, Int, Vector)](
+        dataFrameWithMetadata,
+        transformer,
+        "Trying to set size of vectors in `vector` to 4 but size already set 
to 3.",
+        vectorColName)
     }
   }
 
@@ -136,18 +150,36 @@ class VectorSizeHintSuite
       .setHandleInvalid("error")
       .setSize(3)
 
-    intercept[SparkException](sizeHint.transform(dataWithNull).collect())
-    intercept[SparkException](sizeHint.transform(dataWithShort).collect())
+    testTransformerByInterceptingException[Tuple1[Vector]](
+      dataWithNull,
+      sizeHint,
+      "Got null vector in VectorSizeHint",
+      "vector")
+
+    testTransformerByInterceptingException[Tuple1[Vector]](
+      dataWithShort,
+      sizeHint,
+      "VectorSizeHint Expecting a vector of size 3 but got 1",
+      "vector")
 
     sizeHint.setHandleInvalid("skip")
-    assert(sizeHint.transform(dataWithNull).count() === 1)
-    assert(sizeHint.transform(dataWithShort).count() === 1)
+    testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, 
"vector") { rows =>
+      assert(rows.length === 1)
+    }
+    testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, 
"vector") { rows =>
+      assert(rows.length === 1)
+    }
 
     sizeHint.setHandleInvalid("optimistic")
-    assert(sizeHint.transform(dataWithNull).count() === 2)
-    assert(sizeHint.transform(dataWithShort).count() === 2)
+    testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, 
"vector") { rows =>
+      assert(rows.length === 2)
+    }
+    testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, 
"vector") { rows =>
+      assert(rows.length === 2)
+    }
   }
 
+
   test("read/write") {
     val sizeHint = new VectorSizeHint()
       .setInputCol("myInputCol")

http://git-wip-us.apache.org/repos/asf/spark/blob/0663b611/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
index 1746ce5..3d90f9d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -17,16 +17,16 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, 
NumericAttribute}
 import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.types.{StructField, StructType}
 
-class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with 
DefaultReadWriteTest {
+class VectorSlicerSuite extends MLTest with DefaultReadWriteTest {
+
+  import testImplicits._
 
   test("params") {
     val slicer = new VectorSlicer().setInputCol("feature")
@@ -84,12 +84,12 @@ class VectorSlicerSuite extends SparkFunSuite with 
MLlibTestSparkContext with De
 
     val vectorSlicer = new 
VectorSlicer().setInputCol("features").setOutputCol("result")
 
-    def validateResults(df: DataFrame): Unit = {
-      df.select("result", "expected").collect().foreach { case Row(vec1: 
Vector, vec2: Vector) =>
+    def validateResults(rows: Seq[Row]): Unit = {
+      rows.foreach { case Row(vec1: Vector, vec2: Vector) =>
         assert(vec1 === vec2)
       }
-      val resultMetadata = AttributeGroup.fromStructField(df.schema("result"))
-      val expectedMetadata = 
AttributeGroup.fromStructField(df.schema("expected"))
+      val resultMetadata = 
AttributeGroup.fromStructField(rows.head.schema("result"))
+      val expectedMetadata = 
AttributeGroup.fromStructField(rows.head.schema("expected"))
       assert(resultMetadata.numAttributes === expectedMetadata.numAttributes)
       
resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { 
case (a, b) =>
         assert(a === b)
@@ -97,13 +97,16 @@ class VectorSlicerSuite extends SparkFunSuite with 
MLlibTestSparkContext with De
     }
 
     vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty)
-    validateResults(vectorSlicer.transform(df))
+    testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, 
"result", "expected")(
+      validateResults)
 
     vectorSlicer.setIndices(Array(1)).setNames(Array("f4"))
-    validateResults(vectorSlicer.transform(df))
+    testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, 
"result", "expected")(
+      validateResults)
 
     vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4"))
-    validateResults(vectorSlicer.transform(df))
+    testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, 
"result", "expected")(
+      validateResults)
   }
 
   test("read/write") {

http://git-wip-us.apache.org/repos/asf/spark/blob/0663b611/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 10682ba..b59c4e7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -17,17 +17,17 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Row
 import org.apache.spark.util.Utils
 
-class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with 
DefaultReadWriteTest {
+class Word2VecSuite extends MLTest with DefaultReadWriteTest {
+
+  import testImplicits._
 
   test("params") {
     ParamsSuite.checkParams(new Word2Vec)
@@ -36,10 +36,6 @@ class Word2VecSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
   }
 
   test("Word2Vec") {
-
-    val spark = this.spark
-    import spark.implicits._
-
     val sentence = "a b " * 100 + "a c " * 10
     val numOfWords = sentence.split(" ").size
     val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" 
"))
@@ -70,17 +66,13 @@ class Word2VecSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
     // These expectations are just magic values, characterizing the current
     // behavior.  The test needs to be updated to be more general, see 
SPARK-11502
     val magicExp = Vectors.dense(0.30153007534417237, -0.6833061711354689, 
0.5116530778733167)
-    model.transform(docDF).select("result", "expected").collect().foreach {
+    testTransformer[(Seq[String], Vector)](docDF, model, "result", "expected") 
{
       case Row(vector1: Vector, vector2: Vector) =>
         assert(vector1 ~== magicExp absTol 1E-5, "Transformed vector is 
different with expected.")
     }
   }
 
   test("getVectors") {
-
-    val spark = this.spark
-    import spark.implicits._
-
     val sentence = "a b " * 100 + "a c " * 10
     val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" 
"))
 
@@ -119,9 +111,6 @@ class Word2VecSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
 
   test("findSynonyms") {
 
-    val spark = this.spark
-    import spark.implicits._
-
     val sentence = "a b " * 100 + "a c " * 10
     val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" 
"))
     val docDF = doc.zip(doc).toDF("text", "alsotext")
@@ -154,9 +143,6 @@ class Word2VecSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
 
   test("window size") {
 
-    val spark = this.spark
-    import spark.implicits._
-
     val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
     val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" 
"))
     val docDF = doc.zip(doc).toDF("text", "alsotext")
@@ -227,8 +213,6 @@ class Word2VecSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
   }
 
   test("Word2Vec works with input that is non-nullable (NGram)") {
-    val spark = this.spark
-    import spark.implicits._
 
     val sentence = "a q s t q s t b b b s t m s t m q "
     val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" 
")).toDF("text")
@@ -243,7 +227,7 @@ class Word2VecSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
       .fit(ngramDF)
 
     // Just test that this transformation succeeds
-    model.transform(ngramDF).collect()
+    testTransformerByGlobalCheckFunc[(Seq[String], Seq[String])](ngramDF, 
model, "result") { _ => }
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0663b611/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala 
b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index 17678aa..795fd0e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -22,9 +22,10 @@ import java.io.File
 import org.scalatest.Suite
 
 import org.apache.spark.SparkContext
-import org.apache.spark.ml.{PipelineModel, Transformer}
+import org.apache.spark.ml.Transformer
 import org.apache.spark.sql.{DataFrame, Encoder, Row}
 import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.streaming.StreamTest
 import org.apache.spark.sql.test.TestSparkSession
 import org.apache.spark.util.Utils
@@ -62,8 +63,10 @@ trait MLTest extends StreamTest with TempDirectory { self: 
Suite =>
 
     val columnNames = dataframe.schema.fieldNames
     val stream = MemoryStream[A]
-    val streamDF = stream.toDS().toDF(columnNames: _*)
-
+    val columnsWithMetadata = dataframe.schema.map { structField =>
+      col(structField.name).as(structField.name, structField.metadata)
+    }
+    val streamDF = stream.toDS().toDF(columnNames: 
_*).select(columnsWithMetadata: _*)
     val data = dataframe.as[A].collect()
 
     val streamOutput = transformer.transform(streamDF)
@@ -108,5 +111,29 @@ trait MLTest extends StreamTest with TempDirectory { self: 
Suite =>
       otherResultCols: _*)(globalCheckFunction)
     testTransformerOnDF(dataframe, transformer, firstResultCol,
       otherResultCols: _*)(globalCheckFunction)
+    }
+
+  def testTransformerByInterceptingException[A : Encoder](
+    dataframe: DataFrame,
+    transformer: Transformer,
+    expectedMessagePart : String,
+    firstResultCol: String) {
+
+    def hasExpectedMessage(exception: Throwable): Boolean =
+      exception.getMessage.contains(expectedMessagePart) ||
+        (exception.getCause != null && 
exception.getCause.getMessage.contains(expectedMessagePart))
+
+    withClue(s"""Expected message part "${expectedMessagePart}" is not found 
in DF test.""") {
+      val exceptionOnDf = intercept[Throwable] {
+        testTransformerOnDF(dataframe, transformer, firstResultCol)(_ => Unit)
+      }
+      assert(hasExpectedMessage(exceptionOnDf))
+    }
+    withClue(s"""Expected message part "${expectedMessagePart}" is not found 
in stream test.""") {
+      val exceptionOnStreamData = intercept[Throwable] {
+        testTransformerOnStreamData(dataframe, transformer, firstResultCol)(_ 
=> Unit)
+      }
+      assert(hasExpectedMessage(exceptionOnStreamData))
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to