Github user WeichenXu123 commented on a diff in the pull request: https://github.com/apache/spark/pull/20686#discussion_r172408009 --- Diff: mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala --- @@ -324,19 +352,46 @@ class QuantileDiscretizerSuite .setStages(Array(discretizerForCol1, discretizerForCol2, discretizerForCol3)) .fit(df) - val resultForMultiCols = plForMultiCols.transform(df) - .select("result1", "result2", "result3") - .collect() - - val resultForSingleCol = plForSingleCol.transform(df) - .select("result1", "result2", "result3") - .collect() + val expected = Seq( + (0.0, 0.0, 0.0), + (0.0, 0.0, 1.0), + (0.0, 0.0, 1.0), + (0.0, 1.0, 2.0), + (0.0, 1.0, 2.0), + (0.0, 1.0, 2.0), + (0.0, 1.0, 3.0), + (0.0, 2.0, 4.0), + (0.0, 2.0, 4.0), + (1.0, 2.0, 5.0), + (1.0, 2.0, 5.0), + (1.0, 2.0, 5.0), + (1.0, 3.0, 6.0), + (1.0, 3.0, 6.0), + (1.0, 3.0, 7.0), + (1.0, 4.0, 8.0), + (1.0, 4.0, 8.0), + (1.0, 4.0, 9.0), + (1.0, 4.0, 9.0), + (1.0, 4.0, 9.0) + ).toDF("result1", "result2", "result3") + .collect().toSeq --- End diff -- But I prefer to avoid hardcoding big literal array so that the code is easier for maintenance. and following code is enough I think: ``` val expected = plForSingleCol.transform(df).select("result1", "result2", "result3").collect() testTransformerByGlobalCheckFunc[(Double, Double, Double)]( df,plForSingleCol, "result1", "result2","result3") { rows =>assert(rows == expected) } ``` There is a similar case here https://github.com/apache/spark/pull/20121#discussion_r172288890
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org