Repository: spark
Updated Branches:
  refs/heads/master e4065376d -> 1f86e795b


[SPARK-19616][SPARKR] weightCol and aggregationDepth should be improved for 
some SparkR APIs

## What changes were proposed in this pull request?

This is a follow-up PR of #16800

When doing SPARK-19456, we found that "" should be consider a NULL column name 
and should not be set. aggregationDepth should be exposed as an expert 
parameter.

## How was this patch tested?
Existing tests.

Author: wm...@hotmail.com <wm...@hotmail.com>

Closes #16945 from wangmiao1981/svc.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1f86e795
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1f86e795
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1f86e795

Branch: refs/heads/master
Commit: 1f86e795b87ba93640062f29e87a032924d94b2a
Parents: e406537
Author: wm...@hotmail.com <wm...@hotmail.com>
Authored: Wed Feb 22 11:50:24 2017 -0800
Committer: Felix Cheung <felixche...@apache.org>
Committed: Wed Feb 22 11:50:24 2017 -0800

----------------------------------------------------------------------
 R/pkg/R/generics.R                              |  2 +-
 R/pkg/R/mllib_classification.R                  | 13 +++++++----
 R/pkg/R/mllib_regression.R                      | 24 +++++++++++++-------
 .../tests/testthat/test_mllib_classification.R  | 10 +++++++-
 .../ml/r/AFTSurvivalRegressionWrapper.scala     |  6 ++++-
 .../r/GeneralizedLinearRegressionWrapper.scala  |  4 +++-
 .../spark/ml/r/IsotonicRegressionWrapper.scala  |  3 ++-
 .../spark/ml/r/LogisticRegressionWrapper.scala  |  7 ++++--
 8 files changed, 50 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1f86e795/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 11940d3..647cbbd 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1406,7 +1406,7 @@ setGeneric("spark.randomForest",
 
 #' @rdname spark.survreg
 #' @export
-setGeneric("spark.survreg", function(data, formula) { 
standardGeneric("spark.survreg") })
+setGeneric("spark.survreg", function(data, formula, ...) { 
standardGeneric("spark.survreg") })
 
 #' @rdname spark.svmLinear
 #' @export

http://git-wip-us.apache.org/repos/asf/spark/blob/1f86e795/R/pkg/R/mllib_classification.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R
index fa0d795..05bb952 100644
--- a/R/pkg/R/mllib_classification.R
+++ b/R/pkg/R/mllib_classification.R
@@ -207,6 +207,9 @@ function(object, path, overwrite = FALSE) {
 #'                  excepting that at most one value may be 0. The class with 
largest value p/t is predicted, where p
 #'                  is the original probability of that class and t is the 
class's threshold.
 #' @param weightCol The weight column name.
+#' @param aggregationDepth The depth for treeAggregate (greater than or equal 
to 2). If the dimensions of features
+#'                         or the number of partitions are large, this param 
could be adjusted to a larger size.
+#'                         This is an expert parameter. Default value should 
be good for most cases.
 #' @param ... additional arguments passed to the method.
 #' @return \code{spark.logit} returns a fitted logistic regression model.
 #' @rdname spark.logit
@@ -245,11 +248,13 @@ function(object, path, overwrite = FALSE) {
 setMethod("spark.logit", signature(data = "SparkDataFrame", formula = 
"formula"),
           function(data, formula, regParam = 0.0, elasticNetParam = 0.0, 
maxIter = 100,
                    tol = 1E-6, family = "auto", standardization = TRUE,
-                   thresholds = 0.5, weightCol = NULL) {
+                   thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) {
             formula <- paste(deparse(formula), collapse = "")
 
-            if (is.null(weightCol)) {
-              weightCol <- ""
+            if (!is.null(weightCol) && weightCol == "") {
+              weightCol <- NULL
+            } else if (!is.null(weightCol)) {
+              weightCol <- as.character(weightCol)
             }
 
             jobj <- 
callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
@@ -257,7 +262,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", 
formula = "formula")
                                 as.numeric(elasticNetParam), 
as.integer(maxIter),
                                 as.numeric(tol), as.character(family),
                                 as.logical(standardization), 
as.array(thresholds),
-                                as.character(weightCol))
+                                weightCol, as.integer(aggregationDepth))
             new("LogisticRegressionModel", jobj = jobj)
           })
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1f86e795/R/pkg/R/mllib_regression.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R
index 96ee220..ac0578c 100644
--- a/R/pkg/R/mllib_regression.R
+++ b/R/pkg/R/mllib_regression.R
@@ -102,14 +102,16 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", 
formula = "formula"),
             }
 
             formula <- paste(deparse(formula), collapse = "")
-            if (is.null(weightCol)) {
-              weightCol <- ""
+            if (!is.null(weightCol) && weightCol == "") {
+              weightCol <- NULL
+            } else if (!is.null(weightCol)) {
+              weightCol <- as.character(weightCol)
             }
 
             # For known families, Gamma is upper-cased
             jobj <- 
callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
                                 "fit", formula, data@sdf, 
tolower(family$family), family$link,
-                                tol, as.integer(maxIter), 
as.character(weightCol), regParam)
+                                tol, as.integer(maxIter), weightCol, regParam)
             new("GeneralizedLinearRegressionModel", jobj = jobj)
           })
 
@@ -305,13 +307,15 @@ setMethod("spark.isoreg", signature(data = 
"SparkDataFrame", formula = "formula"
           function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol 
= NULL) {
             formula <- paste(deparse(formula), collapse = "")
 
-            if (is.null(weightCol)) {
-              weightCol <- ""
+            if (!is.null(weightCol) && weightCol == "") {
+              weightCol <- NULL
+            } else if (!is.null(weightCol)) {
+              weightCol <- as.character(weightCol)
             }
 
             jobj <- 
callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit",
                                 data@sdf, formula, as.logical(isotonic), 
as.integer(featureIndex),
-                                as.character(weightCol))
+                                weightCol)
             new("IsotonicRegressionModel", jobj = jobj)
           })
 
@@ -372,6 +376,10 @@ setMethod("write.ml", signature(object = 
"IsotonicRegressionModel", path = "char
 #' @param formula a symbolic description of the model to be fitted. Currently 
only a few formula
 #'                operators are supported, including '~', ':', '+', and '-'.
 #'                Note that operator '.' is not supported currently.
+#' @param aggregationDepth The depth for treeAggregate (greater than or equal 
to 2). If the dimensions of features
+#'                         or the number of partitions are large, this param 
could be adjusted to a larger size.
+#'                         This is an expert parameter. Default value should 
be good for most cases.
+#' @param ... additional arguments passed to the method.
 #' @return \code{spark.survreg} returns a fitted AFT survival regression model.
 #' @rdname spark.survreg
 #' @seealso survival: \url{https://cran.r-project.org/package=survival}
@@ -396,10 +404,10 @@ setMethod("write.ml", signature(object = 
"IsotonicRegressionModel", path = "char
 #' }
 #' @note spark.survreg since 2.0.0
 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = 
"formula"),
-          function(data, formula) {
+          function(data, formula, aggregationDepth = 2) {
             formula <- paste(deparse(formula), collapse = "")
             jobj <- 
callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
-                                "fit", formula, data@sdf)
+                                "fit", formula, data@sdf, 
as.integer(aggregationDepth))
             new("AFTSurvivalRegressionModel", jobj = jobj)
           })
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1f86e795/R/pkg/inst/tests/testthat/test_mllib_classification.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R 
b/R/pkg/inst/tests/testthat/test_mllib_classification.R
index 620f528..459254d 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_classification.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R
@@ -211,7 +211,15 @@ test_that("spark.logit", {
   df <- createDataFrame(data)
   model <- spark.logit(df, label ~ feature)
   prediction <- collect(select(predict(model, df), "prediction"))
-  expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0"))
+  expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", 
"1.0"))
+
+  # Test prediction with weightCol
+  weight <- c(2.0, 2.0, 2.0, 1.0, 1.0)
+  data2 <- as.data.frame(cbind(label, feature, weight))
+  df2 <- createDataFrame(data2)
+  model2 <- spark.logit(df2, label ~ feature, weightCol = "weight")
+  prediction2 <- collect(select(predict(model2, df2), "prediction"))
+  expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", 
"0.0"))
 })
 
 test_that("spark.mlp", {

http://git-wip-us.apache.org/repos/asf/spark/blob/1f86e795/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index bd965ac..0bf543d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -82,7 +82,10 @@ private[r] object AFTSurvivalRegressionWrapper extends 
MLReadable[AFTSurvivalReg
   }
 
 
-  def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = {
+  def fit(
+      formula: String,
+      data: DataFrame,
+      aggregationDepth: Int): AFTSurvivalRegressionWrapper = {
 
     val (rewritedFormula, censorCol) = formulaRewrite(formula)
 
@@ -100,6 +103,7 @@ private[r] object AFTSurvivalRegressionWrapper extends 
MLReadable[AFTSurvivalReg
       .setCensorCol(censorCol)
       .setFitIntercept(rFormula.hasIntercept)
       .setFeaturesCol(rFormula.getFeaturesCol)
+      .setAggregationDepth(aggregationDepth)
 
     val pipeline = new Pipeline()
       .setStages(Array(rFormulaModel, aft))

http://git-wip-us.apache.org/repos/asf/spark/blob/1f86e795/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 78f401f..cbd6cd1 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -87,9 +87,11 @@ private[r] object GeneralizedLinearRegressionWrapper
       .setFitIntercept(rFormula.hasIntercept)
       .setTol(tol)
       .setMaxIter(maxIter)
-      .setWeightCol(weightCol)
       .setRegParam(regParam)
       .setFeaturesCol(rFormula.getFeaturesCol)
+
+    if (weightCol != null) glr.setWeightCol(weightCol)
+
     val pipeline = new Pipeline()
       .setStages(Array(rFormulaModel, glr))
       .fit(data)

http://git-wip-us.apache.org/repos/asf/spark/blob/1f86e795/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
index 4863231..d31ebb4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
@@ -74,9 +74,10 @@ private[r] object IsotonicRegressionWrapper
     val isotonicRegression = new IsotonicRegression()
       .setIsotonic(isotonic)
       .setFeatureIndex(featureIndex)
-      .setWeightCol(weightCol)
       .setFeaturesCol(rFormula.getFeaturesCol)
 
+    if (weightCol != null) isotonicRegression.setWeightCol(weightCol)
+
     val pipeline = new Pipeline()
       .setStages(Array(rFormulaModel, isotonicRegression))
       .fit(data)

http://git-wip-us.apache.org/repos/asf/spark/blob/1f86e795/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
index 645bc72..c96f99c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
@@ -96,7 +96,8 @@ private[r] object LogisticRegressionWrapper
       family: String,
       standardization: Boolean,
       thresholds: Array[Double],
-      weightCol: String
+      weightCol: String,
+      aggregationDepth: Int
       ): LogisticRegressionWrapper = {
 
     val rFormula = new RFormula()
@@ -119,10 +120,10 @@ private[r] object LogisticRegressionWrapper
       .setFitIntercept(fitIntercept)
       .setFamily(family)
       .setStandardization(standardization)
-      .setWeightCol(weightCol)
       .setFeaturesCol(rFormula.getFeaturesCol)
       .setLabelCol(rFormula.getLabelCol)
       .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
+      .setAggregationDepth(aggregationDepth)
 
     if (thresholds.length > 1) {
       lr.setThresholds(thresholds)
@@ -130,6 +131,8 @@ private[r] object LogisticRegressionWrapper
       lr.setThreshold(thresholds(0))
     }
 
+    if (weightCol != null) lr.setWeightCol(weightCol)
+
     val idxToStr = new IndexToString()
       .setInputCol(PREDICTED_LABEL_INDEX_COL)
       .setOutputCol(PREDICTED_LABEL_COL)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to