Github user jkbradley commented on a diff in the pull request: https://github.com/apache/spark/pull/20686#discussion_r173592635 --- Diff: mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala --- @@ -70,36 +71,37 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) + // Verify we throw by default with unseen values - intercept[SparkException] { - indexer.transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String)]( + df2, + indexer, + "Unseen label:", + "labelIndex") indexer.setHandleInvalid("skip") - // Verify that we skip the c record - val transformedSkip = indexer.transform(df2) - val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrSkip.values.get === Array("b", "a")) - val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet - // a -> 1, b -> 0 - val expectedSkip = Set((0, 1.0), (1, 0.0)) - assert(outputSkip === expectedSkip) + + testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows => + val attrSkip = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + // Verify that we skip the c record + // a -> 1, b -> 0 + val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF() --- End diff -- This can be moved outside of the testTransformerByGlobalCheckFunc method.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org