Repository: spark
Updated Branches:
  refs/heads/master c83e03948 -> d138aa8ee


[SPARK-6705][MLLIB] Add fit intercept api to ml logisticregression

I have the fit intercept enabled by default for logistic regression, I
wonder what others think here. I understand that it enables allocation
by default which is undesirable, but one needs to have a very strong
reason for not having an intercept term enabled so it is the safer
default from a statistical sense.

Explicitly modeling the intercept by adding a column of all 1s does not
work. I believe the reason is that since the API for
LogisticRegressionWithLBFGS forces column normalization, and a column of all
1s has 0 variance so dividing by 0 kills it.

Author: Omede Firouz <ofir...@palantir.com>

Closes #5301 from oefirouz/addIntercept and squashes the following commits:

9f1286b [Omede Firouz] [SPARK-6705][MLLIB] Add fitInterceptTerm to 
LogisticRegression
1d6bd6f [Omede Firouz] [SPARK-6705][MLLIB] Add a fit intercept term to ML 
LogisticRegression
9963509 [Omede Firouz] [MLLIB] Add fitIntercept to LogisticRegression
2257fca [Omede Firouz] [MLLIB] Add fitIntercept param to logistic regression
329c1e2 [Omede Firouz] [MLLIB] Add fit intercept term
bd9663c [Omede Firouz] [MLLIB] Add fit intercept api to ml logisticregression


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d138aa8e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d138aa8e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d138aa8e

Branch: refs/heads/master
Commit: d138aa8ee23f4450242da3ac70a493229a90c76b
Parents: c83e039
Author: Omede Firouz <ofir...@palantir.com>
Authored: Tue Apr 7 23:36:31 2015 -0400
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Apr 7 23:36:31 2015 -0400

----------------------------------------------------------------------
 .../spark/ml/classification/LogisticRegression.scala    |  8 ++++++--
 .../scala/org/apache/spark/ml/param/sharedParams.scala  | 12 ++++++++++++
 .../ml/classification/LogisticRegressionSuite.scala     |  9 +++++++++
 3 files changed, 27 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d138aa8e/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 49c00f7..3462574 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel
  * Params for logistic regression.
  */
 private[classification] trait LogisticRegressionParams extends 
ProbabilisticClassifierParams
-  with HasRegParam with HasMaxIter with HasThreshold
+  with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold
 
 
 /**
@@ -56,6 +56,9 @@ class LogisticRegression
   def setMaxIter(value: Int): this.type = set(maxIter, value)
 
   /** @group setParam */
+  def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+
+  /** @group setParam */
   def setThreshold(value: Double): this.type = set(threshold, value)
 
   override protected def train(dataset: DataFrame, paramMap: ParamMap): 
LogisticRegressionModel = {
@@ -67,7 +70,8 @@ class LogisticRegression
     }
 
     // Train model
-    val lr = new LogisticRegressionWithLBFGS
+    val lr = new LogisticRegressionWithLBFGS()
+      .setIntercept(paramMap(fitIntercept))
     lr.optimizer
       .setRegParam(paramMap(regParam))
       .setNumIterations(paramMap(maxIter))

http://git-wip-us.apache.org/repos/asf/spark/blob/d138aa8e/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index 5d660d1..0739fdb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -106,6 +106,18 @@ private[ml] trait HasProbabilityCol extends Params {
   def getProbabilityCol: String = get(probabilityCol)
 }
 
+private[ml] trait HasFitIntercept extends Params {
+  /**
+   * param for fitting the intercept term, defaults to true
+   * @group param
+   */
+  val fitIntercept: BooleanParam =
+    new BooleanParam(this, "fitIntercept", "indicates whether to fit an 
intercept term", Some(true))
+
+  /** @group getParam */
+  def getFitIntercept: Boolean = get(fitIntercept)
+}
+
 private[ml] trait HasThreshold extends Params {
   /**
    * param for threshold in (binary) prediction

http://git-wip-us.apache.org/repos/asf/spark/blob/d138aa8e/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index b3d1bfc..35d8c2e 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with 
MLlibTestSparkContext {
     assert(lr.getPredictionCol == "prediction")
     assert(lr.getRawPredictionCol == "rawPrediction")
     assert(lr.getProbabilityCol == "probability")
+    assert(lr.getFitIntercept == true)
     val model = lr.fit(dataset)
     model.transform(dataset)
       .select("label", "probability", "prediction", "rawPrediction")
@@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with 
MLlibTestSparkContext {
     assert(model.getPredictionCol == "prediction")
     assert(model.getRawPredictionCol == "rawPrediction")
     assert(model.getProbabilityCol == "probability")
+    assert(model.intercept !== 0.0)
+  }
+
+  test("logistic regression doesn't fit intercept when fitIntercept is off") {
+    val lr = new LogisticRegression
+    lr.setFitIntercept(false)
+    val model = lr.fit(dataset)
+    assert(model.intercept === 0.0)
   }
 
   test("logistic regression with setters") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to