Github user ymazari commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20367#discussion_r163465302
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala ---
    @@ -119,6 +119,41 @@ class CountVectorizerSuite extends SparkFunSuite with 
MLlibTestSparkContext
         }
       }
     
    +  test("CountVectorizer maxDF") {
    +    val df = Seq(
    +      (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0), (2, 
1.0)))),
    +      (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
    +      (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0)))),
    +      (3, split("a"), Vectors.sparse(3, Seq()))
    +    ).toDF("id", "words", "expected")
    +
    +    // maxDF: ignore terms with count more than 3
    +    val cvModel = new CountVectorizer()
    +      .setInputCol("words")
    +      .setOutputCol("features")
    +      .setMaxDF(3)
    +      .fit(df)
    +    assert(cvModel.vocabulary === Array("b", "c", "d"))
    +
    +    cvModel.transform(df).select("features", "expected").collect().foreach 
{
    +      case Row(features: Vector, expected: Vector) =>
    +        assert(features ~== expected absTol 1e-14)
    +    }
    +
    +    // maxDF: ignore terms with freq > 0.75
    +    val cvModel2 = new CountVectorizer()
    +      .setInputCol("words")
    +      .setOutputCol("features")
    +      .setMaxDF(0.75)
    +      .fit(df)
    +    assert(cvModel2.vocabulary === Array("b", "c", "d"))
    +
    +    cvModel2.transform(df).select("features", 
"expected").collect().foreach {
    +      case Row(features: Vector, expected: Vector) =>
    +        assert(features ~== expected absTol 1e-14)
    --- End diff --
    
    Done.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to