Repository: spark
Updated Branches:
  refs/heads/master 0d74bd7fd -> 80f3bcb58


[SPARK-5652][Mllib] Use broadcasted weights in LogisticRegressionModel

`LogisticRegressionModel`'s `predictPoint` should directly use broadcasted 
weights. This pr also fixes the compilation errors of two unit test suite: 
`JavaLogisticRegressionSuite ` and `JavaLinearRegressionSuite`.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #4429 from viirya/use_bcvalue and squashes the following commits:

5a797e5 [Liang-Chi Hsieh] Use broadcasted weights. Fix compilation error.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/80f3bcb5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/80f3bcb5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/80f3bcb5

Branch: refs/heads/master
Commit: 80f3bcb58f836cfe1829c85bdd349c10525c8a5e
Parents: 0d74bd7
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Fri Feb 6 11:22:11 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Fri Feb 6 11:22:11 2015 -0800

----------------------------------------------------------------------
 .../spark/mllib/classification/LogisticRegression.scala      | 8 ++++----
 .../spark/ml/classification/JavaLogisticRegressionSuite.java | 4 ++--
 .../spark/ml/regression/JavaLinearRegressionSuite.java       | 4 ++--
 3 files changed, 8 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/80f3bcb5/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index a668e7a..9a391bf 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
  *
  * @param weights Weights computed for every feature.
  * @param intercept Intercept computed for this model. (Only used in Binary 
Logistic Regression.
- *                  In Multinomial Logistic Regression, the intercepts will 
not be a single values,
+ *                  In Multinomial Logistic Regression, the intercepts will 
not be a single value,
  *                  so the intercepts will be part of the weights.)
  * @param numFeatures the dimension of the features.
  * @param numClasses the number of possible outcomes for k classes 
classification problem in
@@ -107,7 +107,7 @@ class LogisticRegressionModel (
     // If dataMatrix and weightMatrix have the same dimension, it's binary 
logistic regression.
     if (numClasses == 2) {
       require(numFeatures == weightMatrix.size)
-      val margin = dot(weights, dataMatrix) + intercept
+      val margin = dot(weightMatrix, dataMatrix) + intercept
       val score = 1.0 / (1.0 + math.exp(-margin))
       threshold match {
         case Some(t) => if (score > t) 1.0 else 0.0
@@ -116,11 +116,11 @@ class LogisticRegressionModel (
     } else {
       val dataWithBiasSize = weightMatrix.size / (numClasses - 1)
 
-      val weightsArray = weights match {
+      val weightsArray = weightMatrix match {
         case dv: DenseVector => dv.values
         case _ =>
           throw new IllegalArgumentException(
-            s"weights only supports dense vector but got type 
${weights.getClass}.")
+            s"weights only supports dense vector but got type 
${weightMatrix.getClass}.")
       }
 
       val margins = (0 until numClasses - 1).map { i =>

http://git-wip-us.apache.org/repos/asf/spark/blob/80f3bcb5/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 2628402..d4b6644 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -84,7 +84,7 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
       .setThreshold(0.6)
       .setProbabilityCol("myProbability");
     LogisticRegressionModel model = lr.fit(dataset);
-    assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
+    assert(model.fittingParamMap().apply(lr.maxIter()).equals(10));
     assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
     assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6));
     assert(model.getThreshold() == 0.6);
@@ -109,7 +109,7 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
     // Call fit() with new params, and check as many params as we can.
     LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), 
lr.regParam().w(0.1),
         lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
-    assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
+    assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5));
     assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
     assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4));
     assert(model2.getThreshold() == 0.4);

http://git-wip-us.apache.org/repos/asf/spark/blob/80f3bcb5/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 5bd616e..40d5a92 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -76,13 +76,13 @@ public class JavaLinearRegressionSuite implements 
Serializable {
         .setMaxIter(10)
         .setRegParam(1.0);
     LinearRegressionModel model = lr.fit(dataset);
-    assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
+    assert(model.fittingParamMap().apply(lr.maxIter()).equals(10));
     assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
 
     // Call fit() with new params, and check as many params as we can.
     LinearRegressionModel model2 =
         lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), 
lr.predictionCol().w("thePred"));
-    assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
+    assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5));
     assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
     assert(model2.getPredictionCol().equals("thePred"));
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to