Repository: spark Updated Branches: refs/heads/master 0d74bd7fd -> 80f3bcb58
[SPARK-5652][Mllib] Use broadcasted weights in LogisticRegressionModel `LogisticRegressionModel`'s `predictPoint` should directly use broadcasted weights. This pr also fixes the compilation errors of two unit test suite: `JavaLogisticRegressionSuite ` and `JavaLinearRegressionSuite`. Author: Liang-Chi Hsieh <vii...@gmail.com> Closes #4429 from viirya/use_bcvalue and squashes the following commits: 5a797e5 [Liang-Chi Hsieh] Use broadcasted weights. Fix compilation error. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/80f3bcb5 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/80f3bcb5 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/80f3bcb5 Branch: refs/heads/master Commit: 80f3bcb58f836cfe1829c85bdd349c10525c8a5e Parents: 0d74bd7 Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Fri Feb 6 11:22:11 2015 -0800 Committer: Xiangrui Meng <m...@databricks.com> Committed: Fri Feb 6 11:22:11 2015 -0800 ---------------------------------------------------------------------- .../spark/mllib/classification/LogisticRegression.scala | 8 ++++---- .../spark/ml/classification/JavaLogisticRegressionSuite.java | 4 ++-- .../spark/ml/regression/JavaLinearRegressionSuite.java | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/80f3bcb5/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index a668e7a..9a391bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. (Only used in Binary Logistic Regression. - * In Multinomial Logistic Regression, the intercepts will not be a single values, + * In Multinomial Logistic Regression, the intercepts will not be a single value, * so the intercepts will be part of the weights.) * @param numFeatures the dimension of the features. * @param numClasses the number of possible outcomes for k classes classification problem in @@ -107,7 +107,7 @@ class LogisticRegressionModel ( // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. if (numClasses == 2) { require(numFeatures == weightMatrix.size) - val margin = dot(weights, dataMatrix) + intercept + val margin = dot(weightMatrix, dataMatrix) + intercept val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { case Some(t) => if (score > t) 1.0 else 0.0 @@ -116,11 +116,11 @@ class LogisticRegressionModel ( } else { val dataWithBiasSize = weightMatrix.size / (numClasses - 1) - val weightsArray = weights match { + val weightsArray = weightMatrix match { case dv: DenseVector => dv.values case _ => throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weights.getClass}.") + s"weights only supports dense vector but got type ${weightMatrix.getClass}.") } val margins = (0 until numClasses - 1).map { i => http://git-wip-us.apache.org/repos/asf/spark/blob/80f3bcb5/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 2628402..d4b6644 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -84,7 +84,7 @@ public class JavaLogisticRegressionSuite implements Serializable { .setThreshold(0.6) .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); - assert(model.fittingParamMap().apply(lr.maxIter()) == 10); + assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6)); assert(model.getThreshold() == 0.6); @@ -109,7 +109,7 @@ public class JavaLogisticRegressionSuite implements Serializable { // Call fit() with new params, and check as many params as we can. LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); - assert(model2.fittingParamMap().apply(lr.maxIter()) == 5); + assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4)); assert(model2.getThreshold() == 0.4); http://git-wip-us.apache.org/repos/asf/spark/blob/80f3bcb5/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 5bd616e..40d5a92 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -76,13 +76,13 @@ public class JavaLinearRegressionSuite implements Serializable { .setMaxIter(10) .setRegParam(1.0); LinearRegressionModel model = lr.fit(dataset); - assert(model.fittingParamMap().apply(lr.maxIter()) == 10); + assert(model.fittingParamMap().apply(lr.maxIter()).equals(10)); assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0)); // Call fit() with new params, and check as many params as we can. LinearRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); - assert(model2.fittingParamMap().apply(lr.maxIter()) == 5); + assert(model2.fittingParamMap().apply(lr.maxIter()).equals(5)); assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1)); assert(model2.getPredictionCol().equals("thePred")); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org