Repository: spark
Updated Branches:
  refs/heads/master 762bacc16 -> 551def5d6


[SPARK-9789] [ML] Added logreg threshold param back

Reinstated LogisticRegression.threshold Param for binary compatibility.  Param 
thresholds overrides threshold, if set.

CC: mengxr dbtsai feynmanliang

Author: Joseph K. Bradley <jos...@databricks.com>

Closes #8079 from jkbradley/logreg-reinstate-threshold.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/551def5d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/551def5d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/551def5d

Branch: refs/heads/master
Commit: 551def5d6972440365bd7436d484a67138d9a8f3
Parents: 762bacc
Author: Joseph K. Bradley <jos...@databricks.com>
Authored: Wed Aug 12 14:27:13 2015 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Wed Aug 12 14:27:13 2015 -0700

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  | 127 +++++++++++++++----
 .../ml/param/shared/SharedParamsCodeGen.scala   |   4 +-
 .../spark/ml/param/shared/sharedParams.scala    |   6 +-
 .../JavaLogisticRegressionSuite.java            |   7 +-
 .../LogisticRegressionSuite.scala               |  33 +++--
 python/pyspark/ml/classification.py             |  98 ++++++++------
 6 files changed, 199 insertions(+), 76 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/551def5d/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index f55134d..5bcd711 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -34,8 +34,7 @@ import 
org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.storage.StorageLevel
 
 /**
@@ -43,44 +42,115 @@ import org.apache.spark.storage.StorageLevel
  */
 private[classification] trait LogisticRegressionParams extends 
ProbabilisticClassifierParams
   with HasRegParam with HasElasticNetParam with HasMaxIter with 
HasFitIntercept with HasTol
-  with HasStandardization {
+  with HasStandardization with HasThreshold {
 
   /**
-   * Version of setThresholds() for binary classification, available for 
backwards
-   * compatibility.
+   * Set threshold in binary classification, in range [0, 1].
    *
-   * Calling this with threshold p will effectively call 
`setThresholds(Array(1-p, p))`.
+   * If the estimated probability of class label 1 is > threshold, then 
predict 1, else 0.
+   * A high threshold encourages the model to predict 0 more often;
+   * a low threshold encourages the model to predict 1 more often.
+   *
+   * Note: Calling this with threshold p is equivalent to calling 
`setThresholds(Array(1-p, p))`.
+   *       When [[setThreshold()]] is called, any user-set value for 
[[thresholds]] will be cleared.
+   *       If both [[threshold]] and [[thresholds]] are set in a ParamMap, 
then they must be
+   *       equivalent.
+   *
+   * Default is 0.5.
+   * @group setParam
+   */
+  def setThreshold(value: Double): this.type = {
+    if (isSet(thresholds)) clear(thresholds)
+    set(threshold, value)
+  }
+
+  /**
+   * Get threshold for binary classification.
+   *
+   * If [[threshold]] is set, returns that value.
+   * Otherwise, if [[thresholds]] is set with length 2 (i.e., binary 
classification),
+   * this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / 
thresholds(1))}}}.
+   * Otherwise, returns [[threshold]] default value.
+   *
+   * @group getParam
+   * @throws IllegalArgumentException if [[thresholds]] is set to an array of 
length other than 2.
+   */
+  override def getThreshold: Double = {
+    checkThresholdConsistency()
+    if (isSet(thresholds)) {
+      val ts = $(thresholds)
+      require(ts.length == 2, "Logistic Regression getThreshold only applies 
to" +
+        " binary classification, but thresholds has length != 2.  thresholds: 
" + ts.mkString(","))
+      1.0 / (1.0 + ts(0) / ts(1))
+    } else {
+      $(threshold)
+    }
+  }
+
+  /**
+   * Set thresholds in multiclass (or binary) classification to adjust the 
probability of
+   * predicting each class. Array must have length equal to the number of 
classes, with values >= 0.
+   * The class with largest value p/t is predicted, where p is the original 
probability of that
+   * class and t is the class' threshold.
+   *
+   * Note: When [[setThresholds()]] is called, any user-set value for 
[[threshold]] will be cleared.
+   *       If both [[threshold]] and [[thresholds]] are set in a ParamMap, 
then they must be
+   *       equivalent.
    *
-   * Default is effectively 0.5.
    * @group setParam
    */
-  def setThreshold(value: Double): this.type = set(thresholds, Array(1.0 - 
value, value))
+  def setThresholds(value: Array[Double]): this.type = {
+    if (isSet(threshold)) clear(threshold)
+    set(thresholds, value)
+  }
 
   /**
-   * Version of [[getThresholds()]] for binary classification, available for 
backwards
-   * compatibility.
+   * Get thresholds for binary or multiclass classification.
+   *
+   * If [[thresholds]] is set, return its value.
+   * Otherwise, if [[threshold]] is set, return the equivalent thresholds for 
binary
+   * classification: (1-threshold, threshold).
+   * If neither are set, throw an exception.
    *
-   * Param thresholds must have length 2 (or not be specified).
-   * This returns {{{1 / (1 + thresholds(0) / thresholds(1))}}}.
    * @group getParam
    */
-  def getThreshold: Double = {
-    if (isDefined(thresholds)) {
-      val thresholdValues = $(thresholds)
-      assert(thresholdValues.length == 2, "Logistic Regression getThreshold 
only applies to" +
-        " binary classification, but thresholds has length != 2." +
-        s"  thresholds: ${thresholdValues.mkString(",")}")
-      1.0 / (1.0 + thresholdValues(0) / thresholdValues(1))
+  override def getThresholds: Array[Double] = {
+    checkThresholdConsistency()
+    if (!isSet(thresholds) && isSet(threshold)) {
+      val t = $(threshold)
+      Array(1-t, t)
     } else {
-      0.5
+      $(thresholds)
+    }
+  }
+
+  /**
+   * If [[threshold]] and [[thresholds]] are both set, ensures they are 
consistent.
+   * @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are 
not equivalent
+   */
+  protected def checkThresholdConsistency(): Unit = {
+    if (isSet(threshold) && isSet(thresholds)) {
+      val ts = $(thresholds)
+      require(ts.length == 2, "Logistic Regression found inconsistent values 
for threshold and" +
+        s" thresholds.  Param threshold is set (${$(threshold)}), indicating 
binary" +
+        s" classification, but Param thresholds is set with length 
${ts.length}." +
+        " Clear one Param value to fix this problem.")
+      val t = 1.0 / (1.0 + ts(0) / ts(1))
+      require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression 
getThreshold found" +
+        s" inconsistent values for threshold (${$(threshold)}) and thresholds 
(equivalent to $t)")
     }
   }
+
+  override def validateParams(): Unit = {
+    checkThresholdConsistency()
+  }
 }
 
 /**
  * :: Experimental ::
  * Logistic regression.
- * Currently, this class only supports binary classification.
+ * Currently, this class only supports binary classification.  It will support 
multiclass
+ * in the future.
  */
 @Experimental
 class LogisticRegression(override val uid: String)
@@ -128,7 +198,7 @@ class LogisticRegression(override val uid: String)
    * Whether to fit an intercept term.
    * Default is true.
    * @group setParam
-   * */
+   */
   def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
   setDefault(fitIntercept -> true)
 
@@ -140,7 +210,7 @@ class LogisticRegression(override val uid: String)
    * is applied. In R's GLMNET package, the default behavior is true as well.
    * Default is true.
    * @group setParam
-   * */
+   */
   def setStandardization(value: Boolean): this.type = set(standardization, 
value)
   setDefault(standardization -> true)
 
@@ -148,6 +218,10 @@ class LogisticRegression(override val uid: String)
 
   override def getThreshold: Double = super.getThreshold
 
+  override def setThresholds(value: Array[Double]): this.type = 
super.setThresholds(value)
+
+  override def getThresholds: Array[Double] = super.getThresholds
+
   override protected def train(dataset: DataFrame): LogisticRegressionModel = {
     // Extract columns from data.  If dataset is persisted, do not persist 
oldDataset.
     val instances = extractLabeledPoints(dataset).map {
@@ -314,6 +388,10 @@ class LogisticRegressionModel private[ml] (
 
   override def getThreshold: Double = super.getThreshold
 
+  override def setThresholds(value: Array[Double]): this.type = 
super.setThresholds(value)
+
+  override def getThresholds: Array[Double] = super.getThresholds
+
   /** Margin (rawPrediction) for class label 1.  For binary classification 
only. */
   private val margin: Vector => Double = (features) => {
     BLAS.dot(features, weights) + intercept
@@ -364,6 +442,7 @@ class LogisticRegressionModel private[ml] (
    * The behavior of this can be adjusted using [[thresholds]].
    */
   override protected def predict(features: Vector): Double = {
+    // Note: We should use getThreshold instead of $(threshold) since 
getThreshold is overridden.
     if (score(features) > getThreshold) 1 else 0
   }
 
@@ -393,6 +472,7 @@ class LogisticRegressionModel private[ml] (
   }
 
   override protected def raw2prediction(rawPrediction: Vector): Double = {
+    // Note: We should use getThreshold instead of $(threshold) since 
getThreshold is overridden.
     val t = getThreshold
     val rawThreshold = if (t == 0.0) {
       Double.NegativeInfinity
@@ -405,6 +485,7 @@ class LogisticRegressionModel private[ml] (
   }
 
   override protected def probability2prediction(probability: Vector): Double = 
{
+    // Note: We should use getThreshold instead of $(threshold) since 
getThreshold is overridden.
     if (probability(1) > getThreshold) 1 else 0
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/551def5d/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index da4c076..9e12f18 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -45,14 +45,14 @@ private[shared] object SharedParamsCodeGen {
         " These probabilities should be treated as confidences, not precise 
probabilities.",
         Some("\"probability\"")),
       ParamDesc[Double]("threshold",
-        "threshold in binary classification prediction, in range [0, 1]",
+        "threshold in binary classification prediction, in range [0, 1]", 
Some("0.5"),
         isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
       ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class 
classification" +
         " to adjust the probability of predicting each class." +
         " Array must have length equal to the number of classes, with values 
>= 0." +
         " The class with largest value p/t is predicted, where p is the 
original probability" +
         " of that class and t is the class' threshold.",
-        isValid = "(t: Array[Double]) => t.forall(_ >= 0)"),
+        isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = 
false),
       ParamDesc[String]("inputCol", "input column name"),
       ParamDesc[Array[String]]("inputCols", "input column names"),
       ParamDesc[String]("outputCol", "output column name", Some("uid + 
\"__output\"")),

http://git-wip-us.apache.org/repos/asf/spark/blob/551def5d/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 23e2b6c..a17d4ea 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -139,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params {
 }
 
 /**
- * Trait for shared param threshold.
+ * Trait for shared param threshold (default: 0.5).
  */
 private[ml] trait HasThreshold extends Params {
 
@@ -149,6 +149,8 @@ private[ml] trait HasThreshold extends Params {
    */
   final val threshold: DoubleParam = new DoubleParam(this, "threshold", 
"threshold in binary classification prediction, in range [0, 1]", 
ParamValidators.inRange(0, 1))
 
+  setDefault(threshold, 0.5)
+
   /** @group getParam */
   def getThreshold: Double = $(threshold)
 }
@@ -165,7 +167,7 @@ private[ml] trait HasThresholds extends Params {
   final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, 
"thresholds", "Thresholds in multi-class classification to adjust the 
probability of predicting each class. Array must have length equal to the 
number of classes, with values >= 0. The class with largest value p/t is 
predicted, where p is the original probability of that class and t is the 
class' threshold.", (t: Array[Double]) => t.forall(_ >= 0))
 
   /** @group getParam */
-  final def getThresholds: Array[Double] = $(thresholds)
+  def getThresholds: Array[Double] = $(thresholds)
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/551def5d/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 7e9aa38..618b95b 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -100,9 +100,7 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
       assert(r.getDouble(0) == 0.0);
     }
     // Call transform with params, and check that the params worked.
-    double[] thresholds = {1.0, 0.0};
-    model.transform(
-      dataset, model.thresholds().w(thresholds), 
model.probabilityCol().w("myProb"))
+    model.transform(dataset, model.threshold().w(0.0), 
model.probabilityCol().w("myProb"))
       .registerTempTable("predNotAllZero");
     DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM 
predNotAllZero");
     boolean foundNonZero = false;
@@ -112,9 +110,8 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
     assert(foundNonZero);
 
     // Call fit() with new params, and check as many params as we can.
-    double[] thresholds2 = {0.6, 0.4};
     LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), 
lr.regParam().w(0.1),
-        lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb"));
+        lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
     LogisticRegression parent2 = (LogisticRegression) model2.parent();
     assert(parent2.getMaxIter() == 5);
     assert(parent2.getRegParam() == 0.1);

http://git-wip-us.apache.org/repos/asf/spark/blob/551def5d/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 8c3d459..e354e16 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -94,12 +94,13 @@ class LogisticRegressionSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   test("setThreshold, getThreshold") {
     val lr = new LogisticRegression
     // default
-    withClue("LogisticRegression should not have thresholds set by default") {
-      intercept[java.util.NoSuchElementException] {
+    assert(lr.getThreshold === 0.5, "LogisticRegression.threshold should 
default to 0.5")
+    withClue("LogisticRegression should not have thresholds set by default.") {
+      intercept[java.util.NoSuchElementException] { // Note: The exception 
type may change in future
         lr.getThresholds
       }
     }
-    // Set via thresholds.
+    // Set via threshold.
     // Intuition: Large threshold or large thresholds(1) makes class 0 more 
likely.
     lr.setThreshold(1.0)
     assert(lr.getThresholds === Array(0.0, 1.0))
@@ -107,10 +108,26 @@ class LogisticRegressionSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(lr.getThresholds === Array(1.0, 0.0))
     lr.setThreshold(0.5)
     assert(lr.getThresholds === Array(0.5, 0.5))
-    // Test getThreshold
-    lr.setThresholds(Array(0.3, 0.7))
+    // Set via thresholds
+    val lr2 = new LogisticRegression
+    lr2.setThresholds(Array(0.3, 0.7))
     val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7)
-    assert(lr.getThreshold ~== expectedThreshold relTol 1E-7)
+    assert(lr2.getThreshold ~== expectedThreshold relTol 1E-7)
+    // thresholds and threshold must be consistent
+    lr2.setThresholds(Array(0.1, 0.2, 0.3))
+    withClue("getThreshold should throw error if thresholds has length != 2.") 
{
+      intercept[IllegalArgumentException] {
+        lr2.getThreshold
+      }
+    }
+    // thresholds and threshold must be consistent: values
+    withClue("fit with ParamMap should throw error if threshold, thresholds do 
not match.") {
+      intercept[IllegalArgumentException] {
+        val lr2model = lr2.fit(dataset,
+          lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> 
(expectedThreshold / 2.0))
+        lr2model.getThreshold
+      }
+    }
   }
 
   test("logistic regression doesn't fit intercept when fitIntercept is off") {
@@ -145,7 +162,7 @@ class LogisticRegressionSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
     // Call transform with params, and check that the params worked.
     val predNotAllZero =
-      model.transform(dataset, model.thresholds -> Array(1.0, 0.0),
+      model.transform(dataset, model.threshold -> 0.0,
         model.probabilityCol -> "myProb")
         .select("prediction", "myProb")
         .collect()
@@ -153,8 +170,8 @@ class LogisticRegressionSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(predNotAllZero.exists(_ !== 0.0))
 
     // Call fit() with new params, and check as many params as we can.
+    lr.setThresholds(Array(0.6, 0.4))
     val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1,
-      lr.thresholds -> Array(0.6, 0.4),
       lr.probabilityCol -> "theProb")
     val parent2 = model2.parent.asInstanceOf[LogisticRegression]
     assert(parent2.getMaxIter === 5)

http://git-wip-us.apache.org/repos/asf/spark/blob/551def5d/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index 6702dce..83f808e 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -76,19 +76,21 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredicti
                        " Array must have length equal to the number of 
classes, with values >= 0." +
                        " The class with largest value p/t is predicted, where 
p is the original" +
                        " probability of that class and t is the class' 
threshold.")
+    threshold = Param(Params._dummy(), "threshold",
+                      "Threshold in binary classification prediction, in range 
[0, 1]." +
+                      " If threshold and thresholds are both set, they must 
match.")
 
     @keyword_only
     def __init__(self, featuresCol="features", labelCol="label", 
predictionCol="prediction",
                  maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, 
fitIntercept=True,
-                 threshold=None, thresholds=None,
+                 threshold=0.5, thresholds=None,
                  probabilityCol="probability", 
rawPredictionCol="rawPrediction"):
         """
         __init__(self, featuresCol="features", labelCol="label", 
predictionCol="prediction", \
                  maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, 
fitIntercept=True, \
-                 threshold=None, thresholds=None, \
+                 threshold=0.5, thresholds=None, \
                  probabilityCol="probability", 
rawPredictionCol="rawPrediction")
-        Param thresholds overrides Param threshold; threshold is provided
-        for backwards compatibility and only applies to binary classification.
+        If the threshold and thresholds Params are both set, they must be 
equivalent.
         """
         super(LogisticRegression, self).__init__()
         self._java_obj = self._new_java_obj(
@@ -101,7 +103,11 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredicti
                   "the penalty is an L2 penalty. For alpha = 1, it is an L1 
penalty.")
         #: param for whether to fit an intercept term.
         self.fitIntercept = Param(self, "fitIntercept", "whether to fit an 
intercept term.")
-        #: param for threshold in binary classification prediction, in range 
[0, 1].
+        #: param for threshold in binary classification, in range [0, 1].
+        self.threshold = Param(self, "threshold",
+                               "Threshold in binary classification prediction, 
in range [0, 1]." +
+                               " If threshold and thresholds are both set, 
they must match.")
+        #: param for thresholds or cutoffs in binary or multiclass 
classification
         self.thresholds = \
             Param(self, "thresholds",
                   "Thresholds in multi-class classification" +
@@ -110,29 +116,28 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredicti
                   " The class with largest value p/t is predicted, where p is 
the original" +
                   " probability of that class and t is the class' threshold.")
         self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, 
tol=1E-6,
-                         fitIntercept=True)
+                         fitIntercept=True, threshold=0.5)
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
+        self._checkThresholdConsistency()
 
     @keyword_only
     def setParams(self, featuresCol="features", labelCol="label", 
predictionCol="prediction",
                   maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, 
fitIntercept=True,
-                  threshold=None, thresholds=None,
+                  threshold=0.5, thresholds=None,
                   probabilityCol="probability", 
rawPredictionCol="rawPrediction"):
         """
         setParams(self, featuresCol="features", labelCol="label", 
predictionCol="prediction", \
                   maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, 
fitIntercept=True, \
-                  threshold=None, thresholds=None, \
+                  threshold=0.5, thresholds=None, \
                   probabilityCol="probability", 
rawPredictionCol="rawPrediction")
         Sets params for logistic regression.
-        Param thresholds overrides Param threshold; threshold is provided
-        for backwards compatibility and only applies to binary classification.
+        If the threshold and thresholds Params are both set, they must be 
equivalent.
         """
-        # Under the hood we use thresholds so translate threshold to 
thresholds if applicable
-        if thresholds is None and threshold is not None:
-            kwargs[thresholds] = [1-threshold, threshold]
         kwargs = self.setParams._input_kwargs
-        return self._set(**kwargs)
+        self._set(**kwargs)
+        self._checkThresholdConsistency()
+        return self
 
     def _create_model(self, java_model):
         return LogisticRegressionModel(java_model)
@@ -165,44 +170,65 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredicti
 
     def setThreshold(self, value):
         """
-        Sets the value of :py:attr:`thresholds` using [1-value, value].
+        Sets the value of :py:attr:`threshold`.
+        Clears value of :py:attr:`thresholds` if it has been set.
+        """
+        self._paramMap[self.threshold] = value
+        if self.isSet(self.thresholds):
+            del self._paramMap[self.thresholds]
+        return self
 
-        >>> lr = LogisticRegression()
-        >>> lr.getThreshold()
-        0.5
-        >>> lr.setThreshold(0.6)
-        LogisticRegression_...
-        >>> abs(lr.getThreshold() - 0.6) < 1e-5
-        True
+    def getThreshold(self):
+        """
+        Gets the value of threshold or its default value.
         """
-        return self.setThresholds([1-value, value])
+        self._checkThresholdConsistency()
+        if self.isSet(self.thresholds):
+            ts = self.getOrDefault(self.thresholds)
+            if len(ts) != 2:
+                raise ValueError("Logistic Regression getThreshold only 
applies to" +
+                                 " binary classification, but thresholds has 
length != 2." +
+                                 "  thresholds: " + ",".join(ts))
+            return 1.0/(1.0 + ts[0]/ts[1])
+        else:
+            return self.getOrDefault(self.threshold)
 
     def setThresholds(self, value):
         """
         Sets the value of :py:attr:`thresholds`.
+        Clears value of :py:attr:`threshold` if it has been set.
         """
         self._paramMap[self.thresholds] = value
+        if self.isSet(self.threshold):
+            del self._paramMap[self.threshold]
         return self
 
     def getThresholds(self):
         """
-        Gets the value of thresholds or its default value.
+        If :py:attr:`thresholds` is set, return its value.
+        Otherwise, if :py:attr:`threshold` is set, return the equivalent 
thresholds for binary
+        classification: (1-threshold, threshold).
+        If neither are set, throw an error.
         """
-        return self.getOrDefault(self.thresholds)
+        self._checkThresholdConsistency()
+        if not self.isSet(self.thresholds) and self.isSet(self.threshold):
+            t = self.getOrDefault(self.threshold)
+            return [1.0-t, t]
+        else:
+            return self.getOrDefault(self.thresholds)
 
-    def getThreshold(self):
-        """
-        Gets the value of threshold or its default value.
-        """
-        if self.isDefined(self.thresholds):
-            thresholds = self.getOrDefault(self.thresholds)
-            if len(thresholds) != 2:
+    def _checkThresholdConsistency(self):
+        if self.isSet(self.threshold) and self.isSet(self.thresholds):
+            ts = self.getParam(self.thresholds)
+            if len(ts) != 2:
                 raise ValueError("Logistic Regression getThreshold only 
applies to" +
                                  " binary classification, but thresholds has 
length != 2." +
-                                 "  thresholds: " + ",".join(thresholds))
-            return 1.0/(1.0+thresholds[0]/thresholds[1])
-        else:
-            return 0.5
+                                 " thresholds: " + ",".join(ts))
+            t = 1.0/(1.0 + ts[0]/ts[1])
+            t2 = self.getParam(self.threshold)
+            if abs(t2 - t) >= 1E-5:
+                raise ValueError("Logistic Regression getThreshold found 
inconsistent values for" +
+                                 " threshold (%g) and thresholds (equivalent 
to %g)" % (t2, t))
 
 
 class LogisticRegressionModel(JavaModel):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to