Repository: spark
Updated Branches:
  refs/heads/master 39eb3bb1e -> 2f7395670


[SPARK-17697][ML] Fixed bug in summary calculations that pattern match against 
label without casting

## What changes were proposed in this pull request?
In calling LogisticRegression.evaluate and GeneralizedLinearRegression.evaluate 
using a Dataset where the Label is not of a double type, calculations pattern 
match against a double and throw a MatchError.  This fix casts the Label column 
to a DoubleType to ensure there is no MatchError.

## How was this patch tested?
Added unit tests to call evaluate with a dataset that has Label as other 
numeric types.

Author: Bryan Cutler <cutl...@gmail.com>

Closes #15288 from BryanCutler/binaryLOR-numericCheck-SPARK-17697.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2f739567
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2f739567
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2f739567

Branch: refs/heads/master
Commit: 2f739567080d804a942cfcca0e22f91ab7cbea36
Parents: 39eb3bb
Author: Bryan Cutler <cutl...@gmail.com>
Authored: Thu Sep 29 16:31:30 2016 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Thu Sep 29 16:31:30 2016 -0700

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  |  2 +-
 .../GeneralizedLinearRegression.scala           | 11 +++++----
 .../LogisticRegressionSuite.scala               | 18 +++++++++++++-
 .../GeneralizedLinearRegressionSuite.scala      | 25 ++++++++++++++++++++
 4 files changed, 49 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2f739567/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 5ab63d1..329961a 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1169,7 +1169,7 @@ class BinaryLogisticRegressionSummary 
private[classification] (
   // TODO: Allow the user to vary the number of bins using a setBins method in
   // BinaryClassificationMetrics. For now the default is set to 100.
   @transient private val binaryMetrics = new BinaryClassificationMetrics(
-    predictions.select(probabilityCol, labelCol).rdd.map {
+    predictions.select(col(probabilityCol), 
col(labelCol).cast(DoubleType)).rdd.map {
       case Row(score: Vector, label: Double) => (score(1), label)
     }, 100
   )

http://git-wip-us.apache.org/repos/asf/spark/blob/2f739567/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 02b27fb..bb9e150 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -992,7 +992,7 @@ class GeneralizedLinearRegressionSummary 
private[regression] (
     } else {
       link.unlink(0.0)
     }
-    predictions.select(col(model.getLabelCol), w).rdd.map {
+    predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map {
       case Row(y: Double, weight: Double) =>
         family.deviance(y, wtdmu, weight)
     }.sum()
@@ -1004,7 +1004,7 @@ class GeneralizedLinearRegressionSummary 
private[regression] (
   @Since("2.0.0")
   lazy val deviance: Double = {
     val w = weightCol
-    predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
+    predictions.select(col(model.getLabelCol).cast(DoubleType), 
col(predictionCol), w).rdd.map {
       case Row(label: Double, pred: Double, weight: Double) =>
         family.deviance(label, pred, weight)
     }.sum()
@@ -1030,9 +1030,10 @@ class GeneralizedLinearRegressionSummary 
private[regression] (
   lazy val aic: Double = {
     val w = weightCol
     val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0)
-    val t = predictions.select(col(model.getLabelCol), col(predictionCol), 
w).rdd.map {
-      case Row(label: Double, pred: Double, weight: Double) =>
-        (label, pred, weight)
+    val t = predictions.select(
+      col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
+        case Row(label: Double, pred: Double, weight: Double) =>
+          (label, pred, weight)
     }
     family.aic(t, deviance, numInstances, weightSum) + 2 * rank
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2f739567/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 8451e60..42b5675 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -32,7 +32,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, 
MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{Dataset, Row}
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.LongType
 
 class LogisticRegressionSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -1776,6 +1777,21 @@ class LogisticRegressionSuite
       summary.precisionByThreshold.collect() === 
sameSummary.precisionByThreshold.collect())
   }
 
+  test("evaluate with labels that are not doubles") {
+    // Evaluate a test set with Label that is a numeric type other than Double
+    val lr = new LogisticRegression()
+      .setMaxIter(1)
+      .setRegParam(1.0)
+    val model = lr.fit(smallBinaryDataset)
+    val summary = 
model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
+
+    val longLabelData = 
smallBinaryDataset.select(col(model.getLabelCol).cast(LongType),
+      col(model.getFeaturesCol))
+    val longSummary = 
model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary]
+
+    assert(summary.areaUnderROC === longSummary.areaUnderROC)
+  }
+
   test("statistics on training data") {
     // Test that loss is monotonically decreasing.
     val lr = new LogisticRegression()

http://git-wip-us.apache.org/repos/asf/spark/blob/2f739567/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 937aa7d..ac1ef5f 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.mllib.random._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.FloatType
 
 class GeneralizedLinearRegressionSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -1067,6 +1068,30 @@ class GeneralizedLinearRegressionSuite
       idx += 1
     }
   }
+
+  test("evaluate with labels that are not doubles") {
+    // Evaulate with a dataset that contains Labels not as doubles to verify 
correct casting
+    val dataset = Seq(
+      Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+      Instance(19.0, 1.0, Vectors.dense(1.0, 7.0)),
+      Instance(23.0, 1.0, Vectors.dense(2.0, 11.0)),
+      Instance(29.0, 1.0, Vectors.dense(3.0, 13.0))
+    ).toDF()
+
+    val trainer = new GeneralizedLinearRegression()
+      .setMaxIter(1)
+    val model = trainer.fit(dataset)
+    assert(model.hasSummary)
+    val summary = model.summary
+
+    val longLabelDataset = 
dataset.select(col(model.getLabelCol).cast(FloatType),
+      col(model.getFeaturesCol))
+    val evalSummary = model.evaluate(longLabelDataset)
+    // The calculations below involve pattern matching with Label as a double
+    assert(evalSummary.nullDeviance === summary.nullDeviance)
+    assert(evalSummary.deviance === summary.deviance)
+    assert(evalSummary.aic === summary.aic)
+  }
 }
 
 object GeneralizedLinearRegressionSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to