Repository: spark
Updated Branches:
  refs/heads/master 79ff85363 -> 9cff67f34


[MINOR][ML] Correct test cases of LoR raw2prediction & probability2prediction.

## What changes were proposed in this pull request?
Correct test cases of ```LogisticRegression``` raw2prediction & 
probability2prediction.

## How was this patch tested?
Changed unit tests.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #16407 from yanboliang/raw-probability.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9cff67f3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9cff67f3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9cff67f3

Branch: refs/heads/master
Commit: 9cff67f3465bc6ffe1b5abee9501e3c17f8fd194
Parents: 79ff853
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Wed Dec 28 01:24:18 2016 -0800
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Wed Dec 28 01:24:18 2016 -0800

----------------------------------------------------------------------
 .../LogisticRegressionSuite.scala               | 20 ++++++++++++++++++--
 1 file changed, 18 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9cff67f3/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 9c4c59a..f8bcbee 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -359,8 +359,16 @@ class LogisticRegressionSuite
         assert(pred == predFromProb)
     }
 
-    // force it to use probability2prediction
+    // force it to use raw2prediction
     model.setProbabilityCol("")
+    val resultsUsingRaw2Predict =
+      
model.transform(smallMultinomialDataset).select("prediction").as[Double].collect()
+    
resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach
 {
+      case (pred1, pred2) => assert(pred1 === pred2)
+    }
+
+    // force it to use probability2prediction
+    model.setRawPredictionCol("")
     val resultsUsingProb2Predict =
       
model.transform(smallMultinomialDataset).select("prediction").as[Double].collect()
     
resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach
 {
@@ -405,8 +413,16 @@ class LogisticRegressionSuite
         assert(pred == predFromProb)
     }
 
-    // force it to use probability2prediction
+    // force it to use raw2prediction
     model.setProbabilityCol("")
+    val resultsUsingRaw2Predict =
+      
model.transform(smallBinaryDataset).select("prediction").as[Double].collect()
+    
resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach
 {
+      case (pred1, pred2) => assert(pred1 === pred2)
+    }
+
+    // force it to use probability2prediction
+    model.setRawPredictionCol("")
     val resultsUsingProb2Predict =
       
model.transform(smallBinaryDataset).select("prediction").as[Double].collect()
     
resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach
 {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to