Github user jkbradley commented on a diff in the pull request: https://github.com/apache/spark/pull/20829#discussion_r177559587 --- Diff: mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala --- @@ -147,4 +159,88 @@ class VectorAssemblerSuite .filter(vectorUDF($"features") > 1) .count() == 1) } + + test("assemble should keep nulls when keepInvalid is true") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + assert(assemble(Array(1, 1), true)(1.0, null) === Vectors.dense(1.0, Double.NaN)) + assert(assemble(Array(1, 2), true)(1.0, null) === Vectors.dense(1.0, Double.NaN, Double.NaN)) + assert(assemble(Array(1), true)(null) === Vectors.dense(Double.NaN)) + assert(assemble(Array(2), true)(null) === Vectors.dense(Double.NaN, Double.NaN)) + } + + test("assemble should throw errors when keepInvalid is false") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + intercept[SparkException](assemble(Array(1, 1), false)(1.0, null)) + intercept[SparkException](assemble(Array(1, 2), false)(1.0, null)) + intercept[SparkException](assemble(Array(1), false)(null)) + intercept[SparkException](assemble(Array(2), false)(null)) + } + + test("get lengths functions") { + import org.apache.spark.ml.feature.VectorAssembler._ + val df = dfWithNulls + assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2)) + assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y"))) + .getMessage.contains("VectorSizeHint")) + assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 4"), + Seq("y"))).getMessage.contains("VectorSizeHint")) + + assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2)) + assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID)) + .getMessage.contains("VectorSizeHint")) + assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID)) + .getMessage.contains("VectorSizeHint")) + } + + test("Handle Invalid should behave properly") { + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z", "n")) + .setOutputCol("features") + + def run_with_metadata(mode: String, additional_filter: String = "true"): Dataset[_] = { + val attributeY = new AttributeGroup("y", 2) + val subAttributesOfZ = Array(NumericAttribute.defaultAttr, NumericAttribute.defaultAttr) --- End diff -- unused
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org