Repository: spark
Updated Branches:
  refs/heads/master 2881a2d1d -> b6879b8b3


[SPARK-16137][SPARKR] randomForest for R

## What changes were proposed in this pull request?

Random Forest Regression and Classification for R
Clean-up/reordering generics.R

## How was this patch tested?

manual tests, unit tests

Author: Felix Cheung <felixcheun...@hotmail.com>

Closes #15607 from felixcheung/rrandomforest.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b6879b8b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b6879b8b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b6879b8b

Branch: refs/heads/master
Commit: b6879b8b3518c71c23262554fcb0fdad60287011
Parents: 2881a2d
Author: Felix Cheung <felixcheun...@hotmail.com>
Authored: Sun Oct 30 16:19:19 2016 -0700
Committer: Felix Cheung <felixche...@apache.org>
Committed: Sun Oct 30 16:19:19 2016 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                                 |   9 +-
 R/pkg/R/generics.R                              |  66 ++---
 R/pkg/R/mllib.R                                 | 252 ++++++++++++++++++-
 R/pkg/inst/tests/testthat/test_mllib.R          |  68 +++++
 .../scala/org/apache/spark/ml/r/RWrappers.scala |   4 +
 .../r/RandomForestClassificationWrapper.scala   | 147 +++++++++++
 .../ml/r/RandomForestRegressionWrapper.scala    | 144 +++++++++++
 7 files changed, 656 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b6879b8b/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 7a89c01..9cd6269 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -44,7 +44,8 @@ exportMethods("glm",
               "spark.gaussianMixture",
               "spark.als",
               "spark.kstest",
-              "spark.logit")
+              "spark.logit",
+              "spark.randomForest")
 
 # Job group lifecycle management methods
 export("setJobGroup",
@@ -350,7 +351,9 @@ export("as.DataFrame",
        "uncacheTable",
        "print.summary.GeneralizedLinearRegressionModel",
        "read.ml",
-       "print.summary.KSTest")
+       "print.summary.KSTest",
+       "print.summary.RandomForestRegressionModel",
+       "print.summary.RandomForestClassificationModel")
 
 export("structField",
        "structField.jobj",
@@ -375,6 +378,8 @@ S3method(print, structField)
 S3method(print, structType)
 S3method(print, summary.GeneralizedLinearRegressionModel)
 S3method(print, summary.KSTest)
+S3method(print, summary.RandomForestRegressionModel)
+S3method(print, summary.RandomForestClassificationModel)
 S3method(structField, character)
 S3method(structField, jobj)
 S3method(structType, jobj)

http://git-wip-us.apache.org/repos/asf/spark/blob/b6879b8b/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 107e1c6..0271b26 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1310,9 +1310,11 @@ setGeneric("window", function(x, ...) { 
standardGeneric("window") })
 #' @export
 setGeneric("year", function(x) { standardGeneric("year") })
 
-#' @rdname spark.glm
+###################### Spark.ML Methods ##########################
+
+#' @rdname fitted
 #' @export
-setGeneric("spark.glm", function(data, formula, ...) { 
standardGeneric("spark.glm") })
+setGeneric("fitted")
 
 #' @param x,y For \code{glm}: logical values indicating whether the response 
vector
 #'          and model matrix used in the fitting process should be returned as
@@ -1332,13 +1334,38 @@ setGeneric("predict", function(object, ...) { 
standardGeneric("predict") })
 #' @export
 setGeneric("rbind", signature = "...")
 
+#' @rdname spark.als
+#' @export
+setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
+
+#' @rdname spark.gaussianMixture
+#' @export
+setGeneric("spark.gaussianMixture",
+           function(data, formula, ...) { 
standardGeneric("spark.gaussianMixture") })
+
+#' @rdname spark.glm
+#' @export
+setGeneric("spark.glm", function(data, formula, ...) { 
standardGeneric("spark.glm") })
+
+#' @rdname spark.isoreg
+#' @export
+setGeneric("spark.isoreg", function(data, formula, ...) { 
standardGeneric("spark.isoreg") })
+
 #' @rdname spark.kmeans
 #' @export
 setGeneric("spark.kmeans", function(data, formula, ...) { 
standardGeneric("spark.kmeans") })
 
-#' @rdname fitted
+#' @rdname spark.kstest
 #' @export
-setGeneric("fitted")
+setGeneric("spark.kstest", function(data, ...) { 
standardGeneric("spark.kstest") })
+
+#' @rdname spark.lda
+#' @export
+setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") })
+
+#' @rdname spark.logit
+#' @export
+setGeneric("spark.logit", function(data, formula, ...) { 
standardGeneric("spark.logit") })
 
 #' @rdname spark.mlp
 #' @export
@@ -1348,13 +1375,14 @@ setGeneric("spark.mlp", function(data, ...) { 
standardGeneric("spark.mlp") })
 #' @export
 setGeneric("spark.naiveBayes", function(data, formula, ...) { 
standardGeneric("spark.naiveBayes") })
 
-#' @rdname spark.survreg
+#' @rdname spark.randomForest
 #' @export
-setGeneric("spark.survreg", function(data, formula) { 
standardGeneric("spark.survreg") })
+setGeneric("spark.randomForest",
+           function(data, formula, ...) { 
standardGeneric("spark.randomForest") })
 
-#' @rdname spark.lda
+#' @rdname spark.survreg
 #' @export
-setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") })
+setGeneric("spark.survreg", function(data, formula) { 
standardGeneric("spark.survreg") })
 
 #' @rdname spark.lda
 #' @export
@@ -1364,20 +1392,6 @@ setGeneric("spark.posterior", function(object, newData) 
{ standardGeneric("spark
 #' @export
 setGeneric("spark.perplexity", function(object, data) { 
standardGeneric("spark.perplexity") })
 
-#' @rdname spark.isoreg
-#' @export
-setGeneric("spark.isoreg", function(data, formula, ...) { 
standardGeneric("spark.isoreg") })
-
-#' @rdname spark.gaussianMixture
-#' @export
-setGeneric("spark.gaussianMixture",
-           function(data, formula, ...) {
-             standardGeneric("spark.gaussianMixture")
-           })
-
-#' @rdname spark.logit
-#' @export
-setGeneric("spark.logit", function(data, formula, ...) { 
standardGeneric("spark.logit") })
 
 #' @param object a fitted ML model object.
 #' @param path the directory where the model is saved.
@@ -1385,11 +1399,3 @@ setGeneric("spark.logit", function(data, formula, ...) { 
standardGeneric("spark.
 #' @rdname write.ml
 #' @export
 setGeneric("write.ml", function(object, path, ...) { 
standardGeneric("write.ml") })
-
-#' @rdname spark.als
-#' @export
-setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
-
-#' @rdname spark.kstest
-#' @export
-setGeneric("spark.kstest", function(data, ...) { 
standardGeneric("spark.kstest") })

http://git-wip-us.apache.org/repos/asf/spark/blob/b6879b8b/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 629f284..7a220b8 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -102,6 +102,20 @@ setClass("KSTest", representation(jobj = "jobj"))
 #' @note LogisticRegressionModel since 2.1.0
 setClass("LogisticRegressionModel", representation(jobj = "jobj"))
 
+#' S4 class that represents a RandomForestRegressionModel
+#'
+#' @param jobj a Java object reference to the backing Scala 
RandomForestRegressionModel
+#' @export
+#' @note RandomForestRegressionModel since 2.1.0
+setClass("RandomForestRegressionModel", representation(jobj = "jobj"))
+
+#' S4 class that represents a RandomForestClassificationModel
+#'
+#' @param jobj a Java object reference to the backing Scala 
RandomForestClassificationModel
+#' @export
+#' @note RandomForestClassificationModel since 2.1.0
+setClass("RandomForestClassificationModel", representation(jobj = "jobj"))
+
 #' Saves the MLlib model to the input path
 #'
 #' Saves the MLlib model to the input path. For more information, see the 
specific
@@ -112,7 +126,7 @@ setClass("LogisticRegressionModel", representation(jobj = 
"jobj"))
 #' @seealso \link{spark.glm}, \link{glm},
 #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, 
\link{spark.isoreg}, \link{spark.kmeans},
 #' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, 
\link{spark.naiveBayes},
-#' @seealso \link{spark.survreg}
+#' @seealso \link{spark.randomForest}, \link{spark.survreg},
 #' @seealso \link{read.ml}
 NULL
 
@@ -125,7 +139,8 @@ NULL
 #' @export
 #' @seealso \link{spark.glm}, \link{glm},
 #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, 
\link{spark.isoreg}, \link{spark.kmeans},
-#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, 
\link{spark.survreg}
+#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
+#' @seealso \link{spark.randomForest}, \link{spark.survreg}
 NULL
 
 write_internal <- function(object, path, overwrite = FALSE) {
@@ -1122,6 +1137,10 @@ read.ml <- function(path) {
     new("ALSModel", jobj = jobj)
   } else if (isInstanceOf(jobj, 
"org.apache.spark.ml.r.LogisticRegressionWrapper")) {
     new("LogisticRegressionModel", jobj = jobj)
+  } else if (isInstanceOf(jobj, 
"org.apache.spark.ml.r.RandomForestRegressorWrapper")) {
+    new("RandomForestRegressionModel", jobj = jobj)
+  } else if (isInstanceOf(jobj, 
"org.apache.spark.ml.r.RandomForestClassifierWrapper")) {
+    new("RandomForestClassificationModel", jobj = jobj)
   } else {
     stop("Unsupported model: ", jobj)
   }
@@ -1617,3 +1636,232 @@ print.summary.KSTest <- function(x, ...) {
   cat(summaryStr, "\n")
   invisible(x)
 }
+
+#' Random Forest Model for Regression and Classification
+#'
+#' \code{spark.randomForest} fits a Random Forest Regression model or 
Classification model on
+#' a SparkDataFrame. Users can call \code{summary} to get a summary of the 
fitted Random Forest
+#' model, \code{predict} to make predictions on new data, and 
\code{write.ml}/\code{read.ml} to
+#' save/load fitted models.
+#' For more details, see
+#' 
\href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{Random
 Forest}
+#'
+#' @param data a SparkDataFrame for training.
+#' @param formula a symbolic description of the model to be fitted. Currently 
only a few formula
+#'                operators are supported, including '~', ':', '+', and '-'.
+#' @param type type of model, one of "regression" or "classification", to fit
+#' @param maxDepth Maximum depth of the tree (>= 0). (default = 5)
+#' @param maxBins Maximum number of bins used for discretizing continuous 
features and for choosing
+#'                how to split on features at each node. More bins give higher 
granularity. Must be
+#'                >= 2 and >= number of categories in any categorical feature. 
(default = 32)
+#' @param numTrees Number of trees to train (>= 1).
+#' @param impurity Criterion used for information gain calculation.
+#'                 For regression, must be "variance". For classification, 
must be one of
+#'                 "entropy" and "gini". (default = gini)
+#' @param minInstancesPerNode Minimum number of instances each child must have 
after split.
+#' @param minInfoGain Minimum information gain for a split to be considered at 
a tree node.
+#' @param checkpointInterval Param for set checkpoint interval (>= 1) or 
disable checkpoint (-1).
+#' @param featureSubsetStrategy The number of features to consider for splits 
at each tree node.
+#'        Supported options: "auto", "all", "onethird", "sqrt", "log2", 
(0.0-1.0], [1-n].
+#' @param seed integer seed for random number generation.
+#' @param subsamplingRate Fraction of the training data used for learning each 
decision tree, in
+#'                        range (0, 1]. (default = 1.0)
+#' @param probabilityCol column name for predicted class conditional 
probabilities, only for
+#'                       classification. (default = "probability")
+#' @param maxMemoryInMB Maximum memory in MB allocated to histogram 
aggregation.
+#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to 
match instances with
+#'                     nodes.
+#' @param ... additional arguments passed to the method.
+#' @aliases spark.randomForest,SparkDataFrame,formula-method
+#' @return \code{spark.randomForest} returns a fitted Random Forest model.
+#' @rdname spark.randomForest
+#' @name spark.randomForest
+#' @export
+#' @examples
+#' \dontrun{
+#' # fit a Random Forest Regression Model
+#' df <- createDataFrame(longley)
+#' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth 
= 5, maxBins = 16)
+#'
+#' # get the summary of the model
+#' summary(model)
+#'
+#' # make predictions
+#' predictions <- predict(model, df)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#'
+#' # fit a Random Forest Classification Model
+#' df <- createDataFrame(iris)
+#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, 
"classification")
+#' }
+#' @note spark.randomForest since 2.1.0
+setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = 
"formula"),
+          function(data, formula, type = c("regression", "classification"),
+                   maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
+                   minInstancesPerNode = 1, minInfoGain = 0.0, 
checkpointInterval = 10,
+                   featureSubsetStrategy = "auto", seed = NULL, 
subsamplingRate = 1.0,
+                   probabilityCol = "probability", maxMemoryInMB = 256, 
cacheNodeIds = FALSE) {
+            type <- match.arg(type)
+            formula <- paste(deparse(formula), collapse = "")
+            if (!is.null(seed)) {
+              seed <- as.character(as.integer(seed))
+            }
+            switch(type,
+                   regression = {
+                     if (is.null(impurity)) impurity <- "variance"
+                     impurity <- match.arg(impurity, "variance")
+                     jobj <- 
callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper",
+                                         "fit", data@sdf, formula, 
as.integer(maxDepth),
+                                         as.integer(maxBins), 
as.integer(numTrees),
+                                         impurity, 
as.integer(minInstancesPerNode),
+                                         as.numeric(minInfoGain), 
as.integer(checkpointInterval),
+                                         as.character(featureSubsetStrategy), 
seed,
+                                         as.numeric(subsamplingRate),
+                                         as.integer(maxMemoryInMB), 
as.logical(cacheNodeIds))
+                     new("RandomForestRegressionModel", jobj = jobj)
+                   },
+                   classification = {
+                     if (is.null(impurity)) impurity <- "gini"
+                     impurity <- match.arg(impurity, c("gini", "entropy"))
+                     jobj <- 
callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper",
+                                         "fit", data@sdf, formula, 
as.integer(maxDepth),
+                                         as.integer(maxBins), 
as.integer(numTrees),
+                                         impurity, 
as.integer(minInstancesPerNode),
+                                         as.numeric(minInfoGain), 
as.integer(checkpointInterval),
+                                         as.character(featureSubsetStrategy), 
seed,
+                                         as.numeric(subsamplingRate), 
as.character(probabilityCol),
+                                         as.integer(maxMemoryInMB), 
as.logical(cacheNodeIds))
+                     new("RandomForestClassificationModel", jobj = jobj)
+                   }
+            )
+          })
+
+# Makes predictions from a Random Forest Regression model or Classification 
model
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted 
labeled in a column named
+#' "prediction"
+#' @rdname spark.randomForest
+#' @aliases predict,RandomForestRegressionModel-method
+#' @export
+#' @note predict(randomForestRegressionModel) since 2.1.0
+setMethod("predict", signature(object = "RandomForestRegressionModel"),
+          function(object, newData) {
+            predict_internal(object, newData)
+          })
+
+#' @rdname spark.randomForest
+#' @aliases predict,RandomForestClassificationModel-method
+#' @export
+#' @note predict(randomForestClassificationModel) since 2.1.0
+setMethod("predict", signature(object = "RandomForestClassificationModel"),
+          function(object, newData) {
+            predict_internal(object, newData)
+          })
+
+# Save the Random Forest Regression or Classification model to the input path.
+
+#' @param object A fitted Random Forest regression model or classification 
model
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. 
Default is FALSE
+#'                  which means throw exception if the output path exists.
+#'
+#' @aliases write.ml,RandomForestRegressionModel,character-method
+#' @rdname spark.randomForest
+#' @export
+#' @note write.ml(RandomForestRegressionModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = 
"character"),
+          function(object, path, overwrite = FALSE) {
+            write_internal(object, path, overwrite)
+          })
+
+#' @aliases write.ml,RandomForestClassificationModel,character-method
+#' @rdname spark.randomForest
+#' @export
+#' @note write.ml(RandomForestClassificationModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "RandomForestClassificationModel", 
path = "character"),
+          function(object, path, overwrite = FALSE) {
+            write_internal(object, path, overwrite)
+          })
+
+#  Get the summary of an RandomForestRegressionModel model
+summary.randomForest <- function(model) {
+  jobj <- model@jobj
+  formula <- callJMethod(jobj, "formula")
+  numFeatures <- callJMethod(jobj, "numFeatures")
+  features <-  callJMethod(jobj, "features")
+  featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), 
"toString")
+  numTrees <- callJMethod(jobj, "numTrees")
+  treeWeights <- callJMethod(jobj, "treeWeights")
+  list(formula = formula,
+       numFeatures = numFeatures,
+       features = features,
+       featureImportances = featureImportances,
+       numTrees = numTrees,
+       treeWeights = treeWeights,
+       jobj = jobj)
+}
+
+#' @return \code{summary} returns the model's features as lists, depth and 
number of nodes
+#'                        or number of classes.
+#' @rdname spark.randomForest
+#' @aliases summary,RandomForestRegressionModel-method
+#' @export
+#' @note summary(RandomForestRegressionModel) since 2.1.0
+setMethod("summary", signature(object = "RandomForestRegressionModel"),
+          function(object) {
+            ans <- summary.randomForest(object)
+            class(ans) <- "summary.RandomForestRegressionModel"
+            ans
+          })
+
+#  Get the summary of an RandomForestClassificationModel model
+
+#' @rdname spark.randomForest
+#' @aliases summary,RandomForestClassificationModel-method
+#' @export
+#' @note summary(RandomForestClassificationModel) since 2.1.0
+setMethod("summary", signature(object = "RandomForestClassificationModel"),
+          function(object) {
+            ans <- summary.randomForest(object)
+            class(ans) <- "summary.RandomForestClassificationModel"
+            ans
+          })
+
+#  Prints the summary of Random Forest Regression Model
+print.summary.randomForest <- function(x) {
+  jobj <- x$jobj
+  cat("Formula: ", x$formula)
+  cat("\nNumber of features: ", x$numFeatures)
+  cat("\nFeatures: ", unlist(x$features))
+  cat("\nFeature importances: ", x$featureImportances)
+  cat("\nNumber of trees: ", x$numTrees)
+  cat("\nTree weights: ", unlist(x$treeWeights))
+
+  summaryStr <- callJMethod(jobj, "summary")
+  cat("\n", summaryStr, "\n")
+  invisible(x)
+}
+
+#' @param x summary object of Random Forest regression model or classification 
model
+#'          returned by \code{summary}.
+#' @rdname spark.randomForest
+#' @export
+#' @note print.summary.RandomForestRegressionModel since 2.1.0
+print.summary.RandomForestRegressionModel <- function(x, ...) {
+  print.summary.randomForest(x)
+}
+
+#  Prints the summary of Random Forest Classification Model
+
+#' @rdname spark.randomForest
+#' @export
+#' @note print.summary.RandomForestClassificationModel since 2.1.0
+print.summary.RandomForestClassificationModel <- function(x, ...) {
+  print.summary.randomForest(x)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b6879b8b/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R 
b/R/pkg/inst/tests/testthat/test_mllib.R
index 6d1fccc..db98d0e 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -871,4 +871,72 @@ test_that("spark.kstest", {
   expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:")
 })
 
+test_that("spark.randomForest Regression", {
+  data <- suppressWarnings(createDataFrame(longley))
+  model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, 
maxBins = 16,
+                              numTrees = 1)
+
+  predictions <- collect(predict(model, data))
+  expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
+                                         63.221, 63.639, 64.989, 63.761,
+                                         66.019, 67.857, 68.169, 66.513,
+                                         68.655, 69.564, 69.331, 70.551),
+               tolerance = 1e-4)
+
+  stats <- summary(model)
+  expect_equal(stats$numTrees, 1)
+  expect_error(capture.output(stats), NA)
+  expect_true(length(capture.output(stats)) > 6)
+
+  model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, 
maxBins = 16,
+                              numTrees = 20, seed = 123)
+  predictions <- collect(predict(model, data))
+  expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258,
+                                         63.736, 64.296, 64.868, 64.300,
+                                         66.709, 67.697, 67.966, 67.252,
+                                         68.866, 69.593, 69.195, 69.658),
+               tolerance = 1e-4)
+  stats <- summary(model)
+  expect_equal(stats$numTrees, 20)
+
+  modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = 
".tmp")
+  write.ml(model, modelPath)
+  expect_error(write.ml(model, modelPath))
+  write.ml(model, modelPath, overwrite = TRUE)
+  model2 <- read.ml(modelPath)
+  stats2 <- summary(model2)
+  expect_equal(stats$formula, stats2$formula)
+  expect_equal(stats$numFeatures, stats2$numFeatures)
+  expect_equal(stats$features, stats2$features)
+  expect_equal(stats$featureImportances, stats2$featureImportances)
+  expect_equal(stats$numTrees, stats2$numTrees)
+  expect_equal(stats$treeWeights, stats2$treeWeights)
+
+  unlink(modelPath)
+})
+
+test_that("spark.randomForest Classification", {
+  data <- suppressWarnings(createDataFrame(iris))
+  model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, 
"classification",
+                              maxDepth = 5, maxBins = 16)
+
+  stats <- summary(model)
+  expect_equal(stats$numFeatures, 2)
+  expect_equal(stats$numTrees, 20)
+  expect_error(capture.output(stats), NA)
+  expect_true(length(capture.output(stats)) > 6)
+
+  modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext 
= ".tmp")
+  write.ml(model, modelPath)
+  expect_error(write.ml(model, modelPath))
+  write.ml(model, modelPath, overwrite = TRUE)
+  model2 <- read.ml(modelPath)
+  stats2 <- summary(model2)
+  expect_equal(stats$depth, stats2$depth)
+  expect_equal(stats$numNodes, stats2$numNodes)
+  expect_equal(stats$numClasses, stats2$numClasses)
+
+  unlink(modelPath)
+})
+
 sparkR.session.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/b6879b8b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 1df3662..0e09e18 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -56,6 +56,10 @@ private[r] object RWrappers extends MLReader[Object] {
         ALSWrapper.load(path)
       case "org.apache.spark.ml.r.LogisticRegressionWrapper" =>
         LogisticRegressionWrapper.load(path)
+      case "org.apache.spark.ml.r.RandomForestRegressorWrapper" =>
+        RandomForestRegressorWrapper.load(path)
+      case "org.apache.spark.ml.r.RandomForestClassifierWrapper" =>
+        RandomForestClassifierWrapper.load(path)
       case _ =>
         throw new SparkException(s"SparkR read.ml does not support load 
$className")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/b6879b8b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
new file mode 100644
index 0000000..b0088dd
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.classification.{RandomForestClassificationModel, 
RandomForestClassifier}
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class RandomForestClassifierWrapper private (
+  val pipeline: PipelineModel,
+  val formula: String,
+  val features: Array[String]) extends MLWritable {
+
+  private val DTModel: RandomForestClassificationModel =
+    pipeline.stages(1).asInstanceOf[RandomForestClassificationModel]
+
+  lazy val numFeatures: Int = DTModel.numFeatures
+  lazy val featureImportances: Vector = DTModel.featureImportances
+  lazy val numTrees: Int = DTModel.getNumTrees
+  lazy val treeWeights: Array[Double] = DTModel.treeWeights
+
+  def summary: String = DTModel.toDebugString
+
+  def transform(dataset: Dataset[_]): DataFrame = {
+    pipeline.transform(dataset).drop(DTModel.getFeaturesCol)
+  }
+
+  override def write: MLWriter = new
+      RandomForestClassifierWrapper.RandomForestClassifierWrapperWriter(this)
+}
+
+private[r] object RandomForestClassifierWrapper extends 
MLReadable[RandomForestClassifierWrapper] {
+  def fit(  // scalastyle:ignore
+      data: DataFrame,
+      formula: String,
+      maxDepth: Int,
+      maxBins: Int,
+      numTrees: Int,
+      impurity: String,
+      minInstancesPerNode: Int,
+      minInfoGain: Double,
+      checkpointInterval: Int,
+      featureSubsetStrategy: String,
+      seed: String,
+      subsamplingRate: Double,
+      probabilityCol: String,
+      maxMemoryInMB: Int,
+      cacheNodeIds: Boolean): RandomForestClassifierWrapper = {
+
+    val rFormula = new RFormula()
+      .setFormula(formula)
+    RWrapperUtils.checkDataColumns(rFormula, data)
+    val rFormulaModel = rFormula.fit(data)
+
+    // get feature names from output schema
+    val schema = rFormulaModel.transform(data).schema
+    val featureAttrs = 
AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+      .attributes.get
+    val features = featureAttrs.map(_.name.get)
+
+    // assemble and fit the pipeline
+    val rfc = new RandomForestClassifier()
+      .setMaxDepth(maxDepth)
+      .setMaxBins(maxBins)
+      .setNumTrees(numTrees)
+      .setImpurity(impurity)
+      .setMinInstancesPerNode(minInstancesPerNode)
+      .setMinInfoGain(minInfoGain)
+      .setCheckpointInterval(checkpointInterval)
+      .setFeatureSubsetStrategy(featureSubsetStrategy)
+      .setSubsamplingRate(subsamplingRate)
+      .setMaxMemoryInMB(maxMemoryInMB)
+      .setCacheNodeIds(cacheNodeIds)
+      .setProbabilityCol(probabilityCol)
+      .setFeaturesCol(rFormula.getFeaturesCol)
+    if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
+
+    val pipeline = new Pipeline()
+      .setStages(Array(rFormulaModel, rfc))
+      .fit(data)
+
+    new RandomForestClassifierWrapper(pipeline, formula, features)
+  }
+
+  override def read: MLReader[RandomForestClassifierWrapper] =
+    new RandomForestClassifierWrapperReader
+
+  override def load(path: String): RandomForestClassifierWrapper = 
super.load(path)
+
+  class RandomForestClassifierWrapperWriter(instance: 
RandomForestClassifierWrapper)
+    extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("formula" -> instance.formula) ~
+        ("features" -> instance.features.toSeq)
+      val rMetadataJson: String = compact(render(rMetadata))
+
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+      instance.pipeline.save(pipelinePath)
+    }
+  }
+
+  class RandomForestClassifierWrapperReader extends 
MLReader[RandomForestClassifierWrapper] {
+
+    override def load(path: String): RandomForestClassifierWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+      val pipeline = PipelineModel.load(pipelinePath)
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val formula = (rMetadata \ "formula").extract[String]
+      val features = (rMetadata \ "features").extract[Array[String]]
+
+      new RandomForestClassifierWrapper(pipeline, formula, features)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b6879b8b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
new file mode 100644
index 0000000..c887440
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.regression.{RandomForestRegressionModel, 
RandomForestRegressor}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class RandomForestRegressorWrapper private (
+  val pipeline: PipelineModel,
+  val formula: String,
+  val features: Array[String]) extends MLWritable {
+
+  private val DTModel: RandomForestRegressionModel =
+    pipeline.stages(1).asInstanceOf[RandomForestRegressionModel]
+
+  lazy val numFeatures: Int = DTModel.numFeatures
+  lazy val featureImportances: Vector = DTModel.featureImportances
+  lazy val numTrees: Int = DTModel.getNumTrees
+  lazy val treeWeights: Array[Double] = DTModel.treeWeights
+
+  def summary: String = DTModel.toDebugString
+
+  def transform(dataset: Dataset[_]): DataFrame = {
+    pipeline.transform(dataset).drop(DTModel.getFeaturesCol)
+  }
+
+  override def write: MLWriter = new
+      RandomForestRegressorWrapper.RandomForestRegressorWrapperWriter(this)
+}
+
+private[r] object RandomForestRegressorWrapper extends 
MLReadable[RandomForestRegressorWrapper] {
+  def fit(  // scalastyle:ignore
+      data: DataFrame,
+      formula: String,
+      maxDepth: Int,
+      maxBins: Int,
+      numTrees: Int,
+      impurity: String,
+      minInstancesPerNode: Int,
+      minInfoGain: Double,
+      checkpointInterval: Int,
+      featureSubsetStrategy: String,
+      seed: String,
+      subsamplingRate: Double,
+      maxMemoryInMB: Int,
+      cacheNodeIds: Boolean): RandomForestRegressorWrapper = {
+
+    val rFormula = new RFormula()
+      .setFormula(formula)
+    RWrapperUtils.checkDataColumns(rFormula, data)
+    val rFormulaModel = rFormula.fit(data)
+
+    // get feature names from output schema
+    val schema = rFormulaModel.transform(data).schema
+    val featureAttrs = 
AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+      .attributes.get
+    val features = featureAttrs.map(_.name.get)
+
+    // assemble and fit the pipeline
+    val rfr = new RandomForestRegressor()
+      .setMaxDepth(maxDepth)
+      .setMaxBins(maxBins)
+      .setNumTrees(numTrees)
+      .setImpurity(impurity)
+      .setMinInstancesPerNode(minInstancesPerNode)
+      .setMinInfoGain(minInfoGain)
+      .setCheckpointInterval(checkpointInterval)
+      .setFeatureSubsetStrategy(featureSubsetStrategy)
+      .setSubsamplingRate(subsamplingRate)
+      .setMaxMemoryInMB(maxMemoryInMB)
+      .setCacheNodeIds(cacheNodeIds)
+      .setFeaturesCol(rFormula.getFeaturesCol)
+    if (seed != null && seed.length > 0) rfr.setSeed(seed.toLong)
+
+    val pipeline = new Pipeline()
+      .setStages(Array(rFormulaModel, rfr))
+      .fit(data)
+
+    new RandomForestRegressorWrapper(pipeline, formula, features)
+  }
+
+  override def read: MLReader[RandomForestRegressorWrapper] = new 
RandomForestRegressorWrapperReader
+
+  override def load(path: String): RandomForestRegressorWrapper = 
super.load(path)
+
+  class RandomForestRegressorWrapperWriter(instance: 
RandomForestRegressorWrapper)
+    extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("formula" -> instance.formula) ~
+        ("features" -> instance.features.toSeq)
+      val rMetadataJson: String = compact(render(rMetadata))
+
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+      instance.pipeline.save(pipelinePath)
+    }
+  }
+
+  class RandomForestRegressorWrapperReader extends 
MLReader[RandomForestRegressorWrapper] {
+
+    override def load(path: String): RandomForestRegressorWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+      val pipeline = PipelineModel.load(pipelinePath)
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val formula = (rMetadata \ "formula").extract[String]
+      val features = (rMetadata \ "features").extract[Array[String]]
+
+      new RandomForestRegressorWrapper(pipeline, formula, features)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to