Repository: spark
Updated Branches:
  refs/heads/branch-2.1 ab865cfd9 -> 1c3f1da82


[SPARK-18326][SPARKR][ML] Review SparkR ML wrappers API for 2.1

## What changes were proposed in this pull request?
Reviewing SparkR ML wrappers API for 2.1 release, mainly two issues:
* Remove ```probabilityCol``` from the argument list of ```spark.logit``` and 
```spark.randomForest```. Since it was used when making prediction and should 
be an argument of ```predict```, and we will work on this at 
[SPARK-18618](https://issues.apache.org/jira/browse/SPARK-18618) in the next 
release cycle.
* Fix ```spark.als``` params to make it consistent with MLlib.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #16169 from yanboliang/spark-18326.

(cherry picked from commit 97255497d885f0f8ccfc808e868bc8aa5e4d1063)
Signed-off-by: Yanbo Liang <yblia...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1c3f1da8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1c3f1da8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1c3f1da8

Branch: refs/heads/branch-2.1
Commit: 1c3f1da82356426b6b550fee67e66dc82eaf1c85
Parents: ab865cf
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Wed Dec 7 20:23:28 2016 -0800
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Wed Dec 7 20:23:45 2016 -0800

----------------------------------------------------------------------
 R/pkg/R/mllib.R                                 | 23 +++++++++-----------
 R/pkg/inst/tests/testthat/test_mllib.R          |  4 ++--
 .../spark/ml/r/LogisticRegressionWrapper.scala  |  4 +---
 .../r/RandomForestClassificationWrapper.scala   |  2 --
 4 files changed, 13 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1c3f1da8/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 074e9cb..632e4ad 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -733,7 +733,6 @@ setMethod("predict", signature(object = "KMeansModel"),
 #'                  excepting that at most one value may be 0. The class with 
largest value p/t is predicted, where p
 #'                  is the original probability of that class and t is the 
class's threshold.
 #' @param weightCol The weight column name.
-#' @param probabilityCol column name for predicted class conditional 
probabilities.
 #' @param ... additional arguments passed to the method.
 #' @return \code{spark.logit} returns a fitted logistic regression model
 #' @rdname spark.logit
@@ -772,7 +771,7 @@ setMethod("predict", signature(object = "KMeansModel"),
 setMethod("spark.logit", signature(data = "SparkDataFrame", formula = 
"formula"),
           function(data, formula, regParam = 0.0, elasticNetParam = 0.0, 
maxIter = 100,
                    tol = 1E-6, family = "auto", standardization = TRUE,
-                   thresholds = 0.5, weightCol = NULL, probabilityCol = 
"probability") {
+                   thresholds = 0.5, weightCol = NULL) {
             formula <- paste(deparse(formula), collapse = "")
 
             if (is.null(weightCol)) {
@@ -784,7 +783,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", 
formula = "formula")
                                 as.numeric(elasticNetParam), 
as.integer(maxIter),
                                 as.numeric(tol), as.character(family),
                                 as.logical(standardization), 
as.array(thresholds),
-                                as.character(weightCol), 
as.character(probabilityCol))
+                                as.character(weightCol))
             new("LogisticRegressionModel", jobj = jobj)
           })
 
@@ -1425,7 +1424,7 @@ setMethod("predict", signature(object = 
"GaussianMixtureModel"),
 #' @param userCol column name for user ids. Ids must be (or can be coerced 
into) integers.
 #' @param itemCol column name for item ids. Ids must be (or can be coerced 
into) integers.
 #' @param rank rank of the matrix factorization (> 0).
-#' @param reg regularization parameter (>= 0).
+#' @param regParam regularization parameter (>= 0).
 #' @param maxIter maximum number of iterations (>= 0).
 #' @param nonnegative logical value indicating whether to apply nonnegativity 
constraints.
 #' @param implicitPrefs logical value indicating whether to use implicit 
preference.
@@ -1464,21 +1463,21 @@ setMethod("predict", signature(object = 
"GaussianMixtureModel"),
 #'
 #' # set other arguments
 #' modelS <- spark.als(df, "rating", "user", "item", rank = 20,
-#'                     reg = 0.1, nonnegative = TRUE)
+#'                     regParam = 0.1, nonnegative = TRUE)
 #' statsS <- summary(modelS)
 #' }
 #' @note spark.als since 2.1.0
 setMethod("spark.als", signature(data = "SparkDataFrame"),
           function(data, ratingCol = "rating", userCol = "user", itemCol = 
"item",
-                   rank = 10, reg = 0.1, maxIter = 10, nonnegative = FALSE,
+                   rank = 10, regParam = 0.1, maxIter = 10, nonnegative = 
FALSE,
                    implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, 
numItemBlocks = 10,
                    checkpointInterval = 10, seed = 0) {
 
             if (!is.numeric(rank) || rank <= 0) {
               stop("rank should be a positive number.")
             }
-            if (!is.numeric(reg) || reg < 0) {
-              stop("reg should be a nonnegative number.")
+            if (!is.numeric(regParam) || regParam < 0) {
+              stop("regParam should be a nonnegative number.")
             }
             if (!is.numeric(maxIter) || maxIter <= 0) {
               stop("maxIter should be a positive number.")
@@ -1486,7 +1485,7 @@ setMethod("spark.als", signature(data = "SparkDataFrame"),
 
             jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper",
                                 "fit", data@sdf, ratingCol, userCol, itemCol, 
as.integer(rank),
-                                reg, as.integer(maxIter), implicitPrefs, 
alpha, nonnegative,
+                                regParam, as.integer(maxIter), implicitPrefs, 
alpha, nonnegative,
                                 as.integer(numUserBlocks), 
as.integer(numItemBlocks),
                                 as.integer(checkpointInterval), 
as.integer(seed))
             new("ALSModel", jobj = jobj)
@@ -1684,8 +1683,6 @@ print.summary.KSTest <- function(x, ...) {
 #'                     nodes. If TRUE, the algorithm will cache node IDs for 
each instance. Caching
 #'                     can speed up training of deeper trees. Users can set 
how often should the
 #'                     cache be checkpointed or disable it by setting 
checkpointInterval.
-#' @param probabilityCol column name for predicted class conditional 
probabilities, only for
-#'                       classification.
 #' @param ... additional arguments passed to the method.
 #' @aliases spark.randomForest,SparkDataFrame,formula-method
 #' @return \code{spark.randomForest} returns a fitted Random Forest model.
@@ -1720,7 +1717,7 @@ setMethod("spark.randomForest", signature(data = 
"SparkDataFrame", formula = "fo
                    maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
                    featureSubsetStrategy = "auto", seed = NULL, 
subsamplingRate = 1.0,
                    minInstancesPerNode = 1, minInfoGain = 0.0, 
checkpointInterval = 10,
-                   maxMemoryInMB = 256, cacheNodeIds = FALSE, probabilityCol = 
"probability") {
+                   maxMemoryInMB = 256, cacheNodeIds = FALSE) {
             type <- match.arg(type)
             formula <- paste(deparse(formula), collapse = "")
             if (!is.null(seed)) {
@@ -1749,7 +1746,7 @@ setMethod("spark.randomForest", signature(data = 
"SparkDataFrame", formula = "fo
                                          impurity, 
as.integer(minInstancesPerNode),
                                          as.numeric(minInfoGain), 
as.integer(checkpointInterval),
                                          as.character(featureSubsetStrategy), 
seed,
-                                         as.numeric(subsamplingRate), 
as.character(probabilityCol),
+                                         as.numeric(subsamplingRate),
                                          as.integer(maxMemoryInMB), 
as.logical(cacheNodeIds))
                      new("RandomForestClassificationModel", jobj = jobj)
                    }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c3f1da8/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R 
b/R/pkg/inst/tests/testthat/test_mllib.R
index 9f810be..db1e4dc 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -926,10 +926,10 @@ test_that("spark.posterior and spark.perplexity", {
 
 test_that("spark.als", {
   data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 
4.0),
-  list(2, 1, 1.0), list(2, 2, 5.0))
+               list(2, 1, 1.0), list(2, 2, 5.0))
   df <- createDataFrame(data, c("user", "item", "score"))
   model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = 
"item",
-  rank = 10, maxIter = 5, seed = 0, reg = 0.1)
+                     rank = 10, maxIter = 5, seed = 0, regParam = 0.1)
   stats <- summary(model)
   expect_equal(stats$rank, 10)
   test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", 
"item"))

http://git-wip-us.apache.org/repos/asf/spark/blob/1c3f1da8/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
index 7f0f3ce..645bc72 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
@@ -96,8 +96,7 @@ private[r] object LogisticRegressionWrapper
       family: String,
       standardization: Boolean,
       thresholds: Array[Double],
-      weightCol: String,
-      probabilityCol: String
+      weightCol: String
       ): LogisticRegressionWrapper = {
 
     val rFormula = new RFormula()
@@ -123,7 +122,6 @@ private[r] object LogisticRegressionWrapper
       .setWeightCol(weightCol)
       .setFeaturesCol(rFormula.getFeaturesCol)
       .setLabelCol(rFormula.getLabelCol)
-      .setProbabilityCol(probabilityCol)
       .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
 
     if (thresholds.length > 1) {

http://git-wip-us.apache.org/repos/asf/spark/blob/1c3f1da8/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index 0b860e5..366f375 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -76,7 +76,6 @@ private[r] object RandomForestClassifierWrapper extends 
MLReadable[RandomForestC
       featureSubsetStrategy: String,
       seed: String,
       subsamplingRate: Double,
-      probabilityCol: String,
       maxMemoryInMB: Int,
       cacheNodeIds: Boolean): RandomForestClassifierWrapper = {
 
@@ -102,7 +101,6 @@ private[r] object RandomForestClassifierWrapper extends 
MLReadable[RandomForestC
       .setSubsamplingRate(subsamplingRate)
       .setMaxMemoryInMB(maxMemoryInMB)
       .setCacheNodeIds(cacheNodeIds)
-      .setProbabilityCol(probabilityCol)
       .setFeaturesCol(rFormula.getFeaturesCol)
       .setLabelCol(rFormula.getLabelCol)
       .setPredictionCol(PREDICTED_LABEL_INDEX_COL)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to