Repository: spark Updated Branches: refs/heads/master 79ff85363 -> 9cff67f34
[MINOR][ML] Correct test cases of LoR raw2prediction & probability2prediction. ## What changes were proposed in this pull request? Correct test cases of ```LogisticRegression``` raw2prediction & probability2prediction. ## How was this patch tested? Changed unit tests. Author: Yanbo Liang <yblia...@gmail.com> Closes #16407 from yanboliang/raw-probability. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9cff67f3 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9cff67f3 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9cff67f3 Branch: refs/heads/master Commit: 9cff67f3465bc6ffe1b5abee9501e3c17f8fd194 Parents: 79ff853 Author: Yanbo Liang <yblia...@gmail.com> Authored: Wed Dec 28 01:24:18 2016 -0800 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Wed Dec 28 01:24:18 2016 -0800 ---------------------------------------------------------------------- .../LogisticRegressionSuite.scala | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9cff67f3/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 9c4c59a..f8bcbee 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -359,8 +359,16 @@ class LogisticRegressionSuite assert(pred == predFromProb) } - // force it to use probability2prediction + // force it to use raw2prediction model.setProbabilityCol("") + val resultsUsingRaw2Predict = + model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() + resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use probability2prediction + model.setRawPredictionCol("") val resultsUsingProb2Predict = model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { @@ -405,8 +413,16 @@ class LogisticRegressionSuite assert(pred == predFromProb) } - // force it to use probability2prediction + // force it to use raw2prediction model.setProbabilityCol("") + val resultsUsingRaw2Predict = + model.transform(smallBinaryDataset).select("prediction").as[Double].collect() + resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use probability2prediction + model.setRawPredictionCol("") val resultsUsingProb2Predict = model.transform(smallBinaryDataset).select("prediction").as[Double].collect() resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org