Repository: spark
Updated Branches:
  refs/heads/master 19af298bb -> d4a912243


[SPARK-16710][SPARKR][ML] spark.glm should support weightCol

## What changes were proposed in this pull request?
Training GLMs on weighted dataset is very important use cases, but it is not 
supported by SparkR currently. Users can pass argument ```weights``` to specify 
the weights vector in native R. For ```spark.glm```, we can pass in the 
```weightCol``` which is consistent with MLlib.

## How was this patch tested?
Unit test.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #14346 from yanboliang/spark-16710.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d4a91224
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d4a91224
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d4a91224

Branch: refs/heads/master
Commit: d4a9122430d6c3aeaaee32aa09d314016ff6ddc7
Parents: 19af298
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Wed Aug 10 10:53:48 2016 -0700
Committer: Shivaram Venkataraman <shiva...@cs.berkeley.edu>
Committed: Wed Aug 10 10:53:48 2016 -0700

----------------------------------------------------------------------
 R/pkg/R/mllib.R                                 | 15 +++++++++----
 R/pkg/inst/tests/testthat/test_mllib.R          | 22 ++++++++++++++++++++
 .../r/GeneralizedLinearRegressionWrapper.scala  |  4 +++-
 3 files changed, 36 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d4a91224/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 50c601f..25d9f07 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -91,6 +91,8 @@ NULL
 #'               
\url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
 #' @param tol Positive convergence tolerance of iterations.
 #' @param maxIter Integer giving the maximal number of IRLS iterations.
+#' @param weightCol The weight column name. If this is not set or NULL, we 
treat all instance
+#'                  weights as 1.0.
 #' @aliases spark.glm,SparkDataFrame,formula-method
 #' @return \code{spark.glm} returns a fitted generalized linear model
 #' @rdname spark.glm
@@ -119,7 +121,7 @@ NULL
 #' @note spark.glm since 2.0.0
 #' @seealso \link{glm}, \link{read.ml}
 setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
-          function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25) 
{
+          function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, 
weightCol = NULL) {
             if (is.character(family)) {
               family <- get(family, mode = "function", envir = parent.frame())
             }
@@ -132,10 +134,13 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", 
formula = "formula"),
             }
 
             formula <- paste(deparse(formula), collapse = "")
+            if (is.null(weightCol)) {
+              weightCol <- ""
+            }
 
             jobj <- 
callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
                                 "fit", formula, data@sdf, family$family, 
family$link,
-                                tol, as.integer(maxIter))
+                                tol, as.integer(maxIter), weightCol)
             return(new("GeneralizedLinearRegressionModel", jobj = jobj))
           })
 
@@ -151,6 +156,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", 
formula = "formula"),
 #'               
\url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
 #' @param epsilon Positive convergence tolerance of iterations.
 #' @param maxit Integer giving the maximal number of IRLS iterations.
+#' @param weightCol The weight column name. If this is not set or NULL, we 
treat all instance
+#'                  weights as 1.0.
 #' @return \code{glm} returns a fitted generalized linear model.
 #' @rdname glm
 #' @export
@@ -165,8 +172,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", 
formula = "formula"),
 #' @note glm since 1.5.0
 #' @seealso \link{spark.glm}
 setMethod("glm", signature(formula = "formula", family = "ANY", data = 
"SparkDataFrame"),
-          function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 
25) {
-            spark.glm(data, formula, family, tol = epsilon, maxIter = maxit)
+          function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 
25, weightCol = NULL) {
+            spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, 
weightCol = weightCol)
           })
 
 #  Returns the summary of a model produced by glm() or spark.glm(), similarly 
to R's summary().

http://git-wip-us.apache.org/repos/asf/spark/blob/d4a91224/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R 
b/R/pkg/inst/tests/testthat/test_mllib.R
index ab390a8..bc18224 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -118,6 +118,28 @@ test_that("spark.glm summary", {
   expect_equal(stats$df.residual, rStats$df.residual)
   expect_equal(stats$aic, rStats$aic)
 
+  # Test spark.glm works with weighted dataset
+  a1 <- c(0, 1, 2, 3)
+  a2 <- c(5, 2, 1, 3)
+  w <- c(1, 2, 3, 4)
+  b <- c(1, 0, 1, 0)
+  data <- as.data.frame(cbind(a1, a2, w, b))
+  df <- suppressWarnings(createDataFrame(data))
+
+  stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = 
"w"))
+  rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights 
= w))
+
+  coefs <- unlist(stats$coefficients)
+  rCoefs <- unlist(rStats$coefficients)
+  expect_true(all(abs(rCoefs - coefs) < 1e-3))
+  expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", 
"a2")))
+  expect_equal(stats$dispersion, rStats$dispersion)
+  expect_equal(stats$null.deviance, rStats$null.deviance)
+  expect_equal(stats$deviance, rStats$deviance)
+  expect_equal(stats$df.null, rStats$df.null)
+  expect_equal(stats$df.residual, rStats$df.residual)
+  expect_equal(stats$aic, rStats$aic)
+
   # Test summary works on base GLM models
   baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
   baseSummary <- summary(baseModel)

http://git-wip-us.apache.org/repos/asf/spark/blob/d4a91224/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 5642abc..0d3181d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -68,7 +68,8 @@ private[r] object GeneralizedLinearRegressionWrapper
       family: String,
       link: String,
       tol: Double,
-      maxIter: Int): GeneralizedLinearRegressionWrapper = {
+      maxIter: Int,
+      weightCol: String): GeneralizedLinearRegressionWrapper = {
     val rFormula = new RFormula()
       .setFormula(formula)
     val rFormulaModel = rFormula.fit(data)
@@ -84,6 +85,7 @@ private[r] object GeneralizedLinearRegressionWrapper
       .setFitIntercept(rFormula.hasIntercept)
       .setTol(tol)
       .setMaxIter(maxIter)
+      .setWeightCol(weightCol)
     val pipeline = new Pipeline()
       .setStages(Array(rFormulaModel, glr))
       .fit(data)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to