Github user yogeshg commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20829#discussion_r176266756
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala ---
    @@ -147,4 +149,72 @@ class VectorAssemblerSuite
           .filter(vectorUDF($"features") > 1)
           .count() == 1)
       }
    +
    +  test("assemble should keep nulls") {
    +    import org.apache.spark.ml.feature.VectorAssembler.assemble
    +    assert(assemble(Seq(1, 1), true)(1.0, null) === Vectors.dense(1.0, 
Double.NaN))
    +    assert(assemble(Seq(1, 2), true)(1.0, null) === Vectors.dense(1.0, 
Double.NaN, Double.NaN))
    +    assert(assemble(Seq(1), true)(null) === Vectors.dense(Double.NaN))
    +    assert(assemble(Seq(2), true)(null) === Vectors.dense(Double.NaN, 
Double.NaN))
    +  }
    +
    +  test("assemble should throw errors") {
    +    import org.apache.spark.ml.feature.VectorAssembler.assemble
    +    intercept[SparkException](assemble(Seq(1, 1), false)(1.0, null) ===
    +      Vectors.dense(1.0, Double.NaN))
    +    intercept[SparkException](assemble(Seq(1, 2), false)(1.0, null) ===
    +      Vectors.dense(1.0, Double.NaN, Double.NaN))
    +    intercept[SparkException](assemble(Seq(1), false)(null) === 
Vectors.dense(Double.NaN))
    +    intercept[SparkException](assemble(Seq(2), false)(null) ===
    +      Vectors.dense(Double.NaN, Double.NaN))
    +  }
    +
    +  test("get lengths function") {
    +    val df = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, 
Long)](
    +      (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, 
Array(1), Array(3.0)), 7L),
    +      (2, 1, 0.0, null, "a", Vectors.sparse(2, Array(1), Array(3.0)), 6L),
    +      (3, 3, null, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, 
Array(1), Array(3.0)), 8L),
    +      (4, 4, null, null, "a", Vectors.sparse(2, Array(1), Array(3.0)), 9L)
    +    ).toDF("id1", "id2", "x", "y", "name", "z", "n")
    +    assert(VectorAssembler.getLengthsFromFirst(df, Seq("y")).exists(_ == 
"y" -> 2))
    +    
intercept[NullPointerException](VectorAssembler.getLengthsFromFirst(df.sort("id2"),
 Seq("y")))
    +    intercept[NoSuchElementException](
    +      VectorAssembler.getLengthsFromFirst(df.filter("id1 > 4"), Seq("y")))
    +
    +    assert(VectorAssembler.getLengths(
    +      df.sort("id2"), Seq("y"), VectorAssembler.SKIP_INVALID).exists(_ == 
"y" -> 2))
    +    intercept[NullPointerException](VectorAssembler.getLengths(
    +      df.sort("id2"), Seq("y"), VectorAssembler.ERROR_INVALID))
    +    intercept[RuntimeException](VectorAssembler.getLengths(
    +      df.sort("id2"), Seq("y"), VectorAssembler.KEEP_INVALID))
    +  }
    +
    +  test("Handle Invalid should behave properly") {
    +    val df = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, 
Long)](
    --- End diff --
    
    thanks, good idea! this helped me in catching the `drop.na()` bug that 
might drop everything


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to