Repository: spark
Updated Branches:
  refs/heads/master 363091100 -> 4be337583


[SPARK-15767][ML][SPARKR] Decision Tree wrapper in SparkR

## What changes were proposed in this pull request?
support decision tree in R

## How was this patch tested?
added tests

Author: Zheng RuiFeng <ruife...@foxmail.com>

Closes #17981 from zhengruifeng/dt_r.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4be33758
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4be33758
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4be33758

Branch: refs/heads/master
Commit: 4be33758354e1f95fd1d82a5482f3f00218e8c91
Parents: 3630911
Author: Zheng RuiFeng <ruife...@foxmail.com>
Authored: Mon May 22 10:40:49 2017 -0700
Committer: Felix Cheung <felixche...@apache.org>
Committed: Mon May 22 10:40:49 2017 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                                 |   5 +
 R/pkg/R/generics.R                              |   5 +
 R/pkg/R/mllib_tree.R                            | 240 +++++++++++++++++++
 R/pkg/R/mllib_utils.R                           |  14 +-
 R/pkg/inst/tests/testthat/test_mllib_tree.R     |  86 +++++++
 .../r/DecisionTreeClassificationWrapper.scala   | 152 ++++++++++++
 .../ml/r/DecisionTreeRegressionWrapper.scala    | 137 +++++++++++
 .../scala/org/apache/spark/ml/r/RWrappers.scala |   4 +
 8 files changed, 639 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4be33758/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 5c074d3..4e3fe00 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -63,6 +63,7 @@ exportMethods("glm",
               "spark.als",
               "spark.kstest",
               "spark.logit",
+              "spark.decisionTree",
               "spark.randomForest",
               "spark.gbt",
               "spark.bisectingKmeans",
@@ -414,6 +415,8 @@ export("as.DataFrame",
        "print.summary.GeneralizedLinearRegressionModel",
        "read.ml",
        "print.summary.KSTest",
+       "print.summary.DecisionTreeRegressionModel",
+       "print.summary.DecisionTreeClassificationModel",
        "print.summary.RandomForestRegressionModel",
        "print.summary.RandomForestClassificationModel",
        "print.summary.GBTRegressionModel",
@@ -452,6 +455,8 @@ S3method(print, structField)
 S3method(print, structType)
 S3method(print, summary.GeneralizedLinearRegressionModel)
 S3method(print, summary.KSTest)
+S3method(print, summary.DecisionTreeRegressionModel)
+S3method(print, summary.DecisionTreeClassificationModel)
 S3method(print, summary.RandomForestRegressionModel)
 S3method(print, summary.RandomForestClassificationModel)
 S3method(print, summary.GBTRegressionModel)

http://git-wip-us.apache.org/repos/asf/spark/blob/4be33758/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 514ca99..5630d0c 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1506,6 +1506,11 @@ setGeneric("spark.mlp", function(data, formula, ...) { 
standardGeneric("spark.ml
 #' @export
 setGeneric("spark.naiveBayes", function(data, formula, ...) { 
standardGeneric("spark.naiveBayes") })
 
+#' @rdname spark.decisionTree
+#' @export
+setGeneric("spark.decisionTree",
+           function(data, formula, ...) { 
standardGeneric("spark.decisionTree") })
+
 #' @rdname spark.randomForest
 #' @export
 setGeneric("spark.randomForest",

http://git-wip-us.apache.org/repos/asf/spark/blob/4be33758/R/pkg/R/mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R
index 82279be..2f1220a 100644
--- a/R/pkg/R/mllib_tree.R
+++ b/R/pkg/R/mllib_tree.R
@@ -45,6 +45,20 @@ setClass("RandomForestRegressionModel", representation(jobj 
= "jobj"))
 #' @note RandomForestClassificationModel since 2.1.0
 setClass("RandomForestClassificationModel", representation(jobj = "jobj"))
 
+#' S4 class that represents a DecisionTreeRegressionModel
+#'
+#' @param jobj a Java object reference to the backing Scala 
DecisionTreeRegressionModel
+#' @export
+#' @note DecisionTreeRegressionModel since 2.3.0
+setClass("DecisionTreeRegressionModel", representation(jobj = "jobj"))
+
+#' S4 class that represents a DecisionTreeClassificationModel
+#'
+#' @param jobj a Java object reference to the backing Scala 
DecisionTreeClassificationModel
+#' @export
+#' @note DecisionTreeClassificationModel since 2.3.0
+setClass("DecisionTreeClassificationModel", representation(jobj = "jobj"))
+
 # Create the summary of a tree ensemble model (eg. Random Forest, GBT)
 summary.treeEnsemble <- function(model) {
   jobj <- model@jobj
@@ -81,6 +95,36 @@ print.summary.treeEnsemble <- function(x) {
   invisible(x)
 }
 
+# Create the summary of a decision tree model
+summary.decisionTree <- function(model) {
+  jobj <- model@jobj
+  formula <- callJMethod(jobj, "formula")
+  numFeatures <- callJMethod(jobj, "numFeatures")
+  features <-  callJMethod(jobj, "features")
+  featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), 
"toString")
+  maxDepth <- callJMethod(jobj, "maxDepth")
+  list(formula = formula,
+       numFeatures = numFeatures,
+       features = features,
+       featureImportances = featureImportances,
+       maxDepth = maxDepth,
+       jobj = jobj)
+}
+
+# Prints the summary of decision tree models
+print.summary.decisionTree <- function(x) {
+  jobj <- x$jobj
+  cat("Formula: ", x$formula)
+  cat("\nNumber of features: ", x$numFeatures)
+  cat("\nFeatures: ", unlist(x$features))
+  cat("\nFeature importances: ", x$featureImportances)
+  cat("\nMax Depth: ", x$maxDepth)
+
+  summaryStr <- callJMethod(jobj, "summary")
+  cat("\n", summaryStr, "\n")
+  invisible(x)
+}
+
 #' Gradient Boosted Tree Model for Regression and Classification
 #'
 #' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or 
Classification model on a
@@ -499,3 +543,199 @@ setMethod("write.ml", signature(object = 
"RandomForestClassificationModel", path
           function(object, path, overwrite = FALSE) {
             write_internal(object, path, overwrite)
           })
+
+#' Decision Tree Model for Regression and Classification
+#'
+#' \code{spark.decisionTree} fits a Decision Tree Regression model or 
Classification model on
+#' a SparkDataFrame. Users can call \code{summary} to get a summary of the 
fitted Decision Tree
+#' model, \code{predict} to make predictions on new data, and 
\code{write.ml}/\code{read.ml} to
+#' save/load fitted models.
+#' For more details, see
+#' 
\href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{
+#' Decision Tree Regression} and
+#' 
\href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{
+#' Decision Tree Classification}
+#'
+#' @param data a SparkDataFrame for training.
+#' @param formula a symbolic description of the model to be fitted. Currently 
only a few formula
+#'                operators are supported, including '~', ':', '+', and '-'.
+#' @param type type of model, one of "regression" or "classification", to fit
+#' @param maxDepth Maximum depth of the tree (>= 0).
+#' @param maxBins Maximum number of bins used for discretizing continuous 
features and for choosing
+#'                how to split on features at each node. More bins give higher 
granularity. Must be
+#'                >= 2 and >= number of categories in any categorical feature.
+#' @param impurity Criterion used for information gain calculation.
+#'                 For regression, must be "variance". For classification, 
must be one of
+#'                 "entropy" and "gini", default is "gini".
+#' @param seed integer seed for random number generation.
+#' @param minInstancesPerNode Minimum number of instances each child must have 
after split.
+#' @param minInfoGain Minimum information gain for a split to be considered at 
a tree node.
+#' @param checkpointInterval Param for set checkpoint interval (>= 1) or 
disable checkpoint (-1).
+#' @param maxMemoryInMB Maximum memory in MB allocated to histogram 
aggregation.
+#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to 
match instances with
+#'                     nodes. If TRUE, the algorithm will cache node IDs for 
each instance. Caching
+#'                     can speed up training of deeper trees. Users can set 
how often should the
+#'                     cache be checkpointed or disable it by setting 
checkpointInterval.
+#' @param ... additional arguments passed to the method.
+#' @aliases spark.decisionTree,SparkDataFrame,formula-method
+#' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
+#' @rdname spark.decisionTree
+#' @name spark.decisionTree
+#' @export
+#' @examples
+#' \dontrun{
+#' # fit a Decision Tree Regression Model
+#' df <- createDataFrame(longley)
+#' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth 
= 5, maxBins = 16)
+#'
+#' # get the summary of the model
+#' summary(model)
+#'
+#' # make predictions
+#' predictions <- predict(model, df)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#'
+#' # fit a Decision Tree Classification Model
+#' t <- as.data.frame(Titanic)
+#' df <- createDataFrame(t)
+#' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification")
+#' }
+#' @note spark.decisionTree since 2.3.0
+setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = 
"formula"),
+          function(data, formula, type = c("regression", "classification"),
+                   maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL,
+                   minInstancesPerNode = 1, minInfoGain = 0.0, 
checkpointInterval = 10,
+                   maxMemoryInMB = 256, cacheNodeIds = FALSE) {
+            type <- match.arg(type)
+            formula <- paste(deparse(formula), collapse = "")
+            if (!is.null(seed)) {
+              seed <- as.character(as.integer(seed))
+            }
+            switch(type,
+                   regression = {
+                     if (is.null(impurity)) impurity <- "variance"
+                     impurity <- match.arg(impurity, "variance")
+                     jobj <- 
callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper",
+                                         "fit", data@sdf, formula, 
as.integer(maxDepth),
+                                         as.integer(maxBins), impurity,
+                                         as.integer(minInstancesPerNode), 
as.numeric(minInfoGain),
+                                         as.integer(checkpointInterval), seed,
+                                         as.integer(maxMemoryInMB), 
as.logical(cacheNodeIds))
+                     new("DecisionTreeRegressionModel", jobj = jobj)
+                   },
+                   classification = {
+                     if (is.null(impurity)) impurity <- "gini"
+                     impurity <- match.arg(impurity, c("gini", "entropy"))
+                     jobj <- 
callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper",
+                                         "fit", data@sdf, formula, 
as.integer(maxDepth),
+                                         as.integer(maxBins), impurity,
+                                         as.integer(minInstancesPerNode), 
as.numeric(minInfoGain),
+                                         as.integer(checkpointInterval), seed,
+                                         as.integer(maxMemoryInMB), 
as.logical(cacheNodeIds))
+                     new("DecisionTreeClassificationModel", jobj = jobj)
+                   }
+            )
+          })
+
+#  Get the summary of a Decision Tree Regression Model
+
+#' @return \code{summary} returns summary information of the fitted model, 
which is a list.
+#'         The list of components includes \code{formula} (formula),
+#'         \code{numFeatures} (number of features), \code{features} (list of 
features),
+#'         \code{featureImportances} (feature importances), and 
\code{maxDepth} (max depth of trees).
+#' @rdname spark.decisionTree
+#' @aliases summary,DecisionTreeRegressionModel-method
+#' @export
+#' @note summary(DecisionTreeRegressionModel) since 2.3.0
+setMethod("summary", signature(object = "DecisionTreeRegressionModel"),
+          function(object) {
+            ans <- summary.decisionTree(object)
+            class(ans) <- "summary.DecisionTreeRegressionModel"
+            ans
+          })
+
+#  Prints the summary of Decision Tree Regression Model
+
+#' @param x summary object of Decision Tree regression model or classification 
model
+#'          returned by \code{summary}.
+#' @rdname spark.decisionTree
+#' @export
+#' @note print.summary.DecisionTreeRegressionModel since 2.3.0
+print.summary.DecisionTreeRegressionModel <- function(x, ...) {
+  print.summary.decisionTree(x)
+}
+
+#  Get the summary of a Decision Tree Classification Model
+
+#' @rdname spark.decisionTree
+#' @aliases summary,DecisionTreeClassificationModel-method
+#' @export
+#' @note summary(DecisionTreeClassificationModel) since 2.3.0
+setMethod("summary", signature(object = "DecisionTreeClassificationModel"),
+          function(object) {
+            ans <- summary.decisionTree(object)
+            class(ans) <- "summary.DecisionTreeClassificationModel"
+            ans
+          })
+
+#  Prints the summary of Decision Tree Classification Model
+
+#' @rdname spark.decisionTree
+#' @export
+#' @note print.summary.DecisionTreeClassificationModel since 2.3.0
+print.summary.DecisionTreeClassificationModel <- function(x, ...) {
+  print.summary.decisionTree(x)
+}
+
+#  Makes predictions from a Decision Tree Regression model or Classification 
model
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted 
labeled in a column named
+#'         "prediction".
+#' @rdname spark.decisionTree
+#' @aliases predict,DecisionTreeRegressionModel-method
+#' @export
+#' @note predict(DecisionTreeRegressionModel) since 2.3.0
+setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
+          function(object, newData) {
+            predict_internal(object, newData)
+          })
+
+#' @rdname spark.decisionTree
+#' @aliases predict,DecisionTreeClassificationModel-method
+#' @export
+#' @note predict(DecisionTreeClassificationModel) since 2.3.0
+setMethod("predict", signature(object = "DecisionTreeClassificationModel"),
+          function(object, newData) {
+            predict_internal(object, newData)
+          })
+
+#  Save the Decision Tree Regression or Classification model to the input path.
+
+#' @param object A fitted Decision Tree regression model or classification 
model.
+#' @param path The directory where the model is saved.
+#' @param overwrite Overwrites or not if the output path already exists. 
Default is FALSE
+#'                  which means throw exception if the output path exists.
+#'
+#' @aliases write.ml,DecisionTreeRegressionModel,character-method
+#' @rdname spark.decisionTree
+#' @export
+#' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0
+setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = 
"character"),
+          function(object, path, overwrite = FALSE) {
+            write_internal(object, path, overwrite)
+          })
+
+#' @aliases write.ml,DecisionTreeClassificationModel,character-method
+#' @rdname spark.decisionTree
+#' @export
+#' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0
+setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", 
path = "character"),
+          function(object, path, overwrite = FALSE) {
+            write_internal(object, path, overwrite)
+          })

http://git-wip-us.apache.org/repos/asf/spark/blob/4be33758/R/pkg/R/mllib_utils.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R
index 5dfef86..a53c92c 100644
--- a/R/pkg/R/mllib_utils.R
+++ b/R/pkg/R/mllib_utils.R
@@ -32,8 +32,9 @@
 #' @rdname write.ml
 #' @name write.ml
 #' @export
-#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, 
\link{spark.gaussianMixture},
-#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, 
\link{spark.isoreg},
+#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, 
\link{spark.decisionTree},
+#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
+#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
 #' @seealso \link{spark.kmeans},
 #' @seealso \link{spark.lda}, \link{spark.logit},
 #' @seealso \link{spark.mlp}, \link{spark.naiveBayes},
@@ -48,8 +49,9 @@ NULL
 #' @rdname predict
 #' @name predict
 #' @export
-#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, 
\link{spark.gaussianMixture},
-#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, 
\link{spark.isoreg},
+#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, 
\link{spark.decisionTree},
+#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
+#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
 #' @seealso \link{spark.kmeans},
 #' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
 #' @seealso \link{spark.randomForest}, \link{spark.survreg}, 
\link{spark.svmLinear}
@@ -110,6 +112,10 @@ read.ml <- function(path) {
     new("RandomForestRegressionModel", jobj = jobj)
   } else if (isInstanceOf(jobj, 
"org.apache.spark.ml.r.RandomForestClassifierWrapper")) {
     new("RandomForestClassificationModel", jobj = jobj)
+  } else if (isInstanceOf(jobj, 
"org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) {
+    new("DecisionTreeRegressionModel", jobj = jobj)
+  } else if (isInstanceOf(jobj, 
"org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) {
+    new("DecisionTreeClassificationModel", jobj = jobj)
   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) {
     new("GBTRegressionModel", jobj = jobj)
   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) 
{

http://git-wip-us.apache.org/repos/asf/spark/blob/4be33758/R/pkg/inst/tests/testthat/test_mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R 
b/R/pkg/inst/tests/testthat/test_mllib_tree.R
index 146bc28..b283e73 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_tree.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R
@@ -209,4 +209,90 @@ test_that("spark.randomForest", {
   expect_equal(summary(model)$numFeatures, 4)
 })
 
+test_that("spark.decisionTree", {
+  # regression
+  data <- suppressWarnings(createDataFrame(longley))
+  model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, 
maxBins = 16)
+
+  predictions <- collect(predict(model, data))
+  expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
+                                         63.221, 63.639, 64.989, 63.761,
+                                         66.019, 67.857, 68.169, 66.513,
+                                         68.655, 69.564, 69.331, 70.551),
+               tolerance = 1e-4)
+
+  stats <- summary(model)
+  expect_equal(stats$maxDepth, 5)
+  expect_error(capture.output(stats), NA)
+  expect_true(length(capture.output(stats)) > 6)
+
+  modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = 
".tmp")
+  write.ml(model, modelPath)
+  expect_error(write.ml(model, modelPath))
+  write.ml(model, modelPath, overwrite = TRUE)
+  model2 <- read.ml(modelPath)
+  stats2 <- summary(model2)
+  expect_equal(stats$formula, stats2$formula)
+  expect_equal(stats$numFeatures, stats2$numFeatures)
+  expect_equal(stats$features, stats2$features)
+  expect_equal(stats$featureImportances, stats2$featureImportances)
+  expect_equal(stats$maxDepth, stats2$maxDepth)
+
+  unlink(modelPath)
+
+  # classification
+  data <- suppressWarnings(createDataFrame(iris))
+  model <- spark.decisionTree(data, Species ~ Petal_Length + Petal_Width, 
"classification",
+                              maxDepth = 5, maxBins = 16)
+
+  stats <- summary(model)
+  expect_equal(stats$numFeatures, 2)
+  expect_equal(stats$maxDepth, 5)
+  expect_error(capture.output(stats), NA)
+  expect_true(length(capture.output(stats)) > 6)
+  # Test string prediction values
+  predictions <- collect(predict(model, data))$prediction
+  expect_equal(length(grep("setosa", predictions)), 50)
+  expect_equal(length(grep("versicolor", predictions)), 50)
+
+  modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext 
= ".tmp")
+  write.ml(model, modelPath)
+  expect_error(write.ml(model, modelPath))
+  write.ml(model, modelPath, overwrite = TRUE)
+  model2 <- read.ml(modelPath)
+  stats2 <- summary(model2)
+  expect_equal(stats$depth, stats2$depth)
+  expect_equal(stats$numNodes, stats2$numNodes)
+  expect_equal(stats$numClasses, stats2$numClasses)
+
+  unlink(modelPath)
+
+  # Test numeric response variable
+  labelToIndex <- function(species) {
+    switch(as.character(species),
+      setosa = 0.0,
+      versicolor = 1.0,
+      virginica = 2.0
+    )
+  }
+  iris$NumericSpecies <- lapply(iris$Species, labelToIndex)
+  data <- suppressWarnings(createDataFrame(iris[-5]))
+  model <- spark.decisionTree(data, NumericSpecies ~ Petal_Length + 
Petal_Width, "classification",
+                              maxDepth = 5, maxBins = 16)
+  stats <- summary(model)
+  expect_equal(stats$numFeatures, 2)
+  expect_equal(stats$maxDepth, 5)
+
+  # Test numeric prediction values
+  predictions <- collect(predict(model, data))$prediction
+  expect_equal(length(grep("1.0", predictions)), 50)
+  expect_equal(length(grep("2.0", predictions)), 50)
+
+  # spark.decisionTree classification can work on libsvm data
+  data <- 
read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
+                source = "libsvm")
+  model <- spark.decisionTree(data, label ~ features, "classification")
+  expect_equal(summary(model)$numFeatures, 4)
+})
+
 sparkR.session.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/4be33758/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
new file mode 100644
index 0000000..7f59825
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, 
DecisionTreeClassifier}
+import org.apache.spark.ml.feature.{IndexToString, RFormula}
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.r.RWrapperUtils._
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class DecisionTreeClassifierWrapper private (
+  val pipeline: PipelineModel,
+  val formula: String,
+  val features: Array[String]) extends MLWritable {
+
+  import DecisionTreeClassifierWrapper._
+
+  private val dtcModel: DecisionTreeClassificationModel =
+    pipeline.stages(1).asInstanceOf[DecisionTreeClassificationModel]
+
+  lazy val numFeatures: Int = dtcModel.numFeatures
+  lazy val featureImportances: Vector = dtcModel.featureImportances
+  lazy val maxDepth: Int = dtcModel.getMaxDepth
+
+  def summary: String = dtcModel.toDebugString
+
+  def transform(dataset: Dataset[_]): DataFrame = {
+    pipeline.transform(dataset)
+      .drop(PREDICTED_LABEL_INDEX_COL)
+      .drop(dtcModel.getFeaturesCol)
+      .drop(dtcModel.getLabelCol)
+  }
+
+  override def write: MLWriter = new
+      DecisionTreeClassifierWrapper.DecisionTreeClassifierWrapperWriter(this)
+}
+
+private[r] object DecisionTreeClassifierWrapper extends 
MLReadable[DecisionTreeClassifierWrapper] {
+
+  val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
+  val PREDICTED_LABEL_COL = "prediction"
+
+  def fit(  // scalastyle:ignore
+      data: DataFrame,
+      formula: String,
+      maxDepth: Int,
+      maxBins: Int,
+      impurity: String,
+      minInstancesPerNode: Int,
+      minInfoGain: Double,
+      checkpointInterval: Int,
+      seed: String,
+      maxMemoryInMB: Int,
+      cacheNodeIds: Boolean): DecisionTreeClassifierWrapper = {
+
+    val rFormula = new RFormula()
+      .setFormula(formula)
+      .setForceIndexLabel(true)
+    checkDataColumns(rFormula, data)
+    val rFormulaModel = rFormula.fit(data)
+
+    // get labels and feature names from output schema
+    val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
+
+    // assemble and fit the pipeline
+    val dtc = new DecisionTreeClassifier()
+      .setMaxDepth(maxDepth)
+      .setMaxBins(maxBins)
+      .setImpurity(impurity)
+      .setMinInstancesPerNode(minInstancesPerNode)
+      .setMinInfoGain(minInfoGain)
+      .setCheckpointInterval(checkpointInterval)
+      .setMaxMemoryInMB(maxMemoryInMB)
+      .setCacheNodeIds(cacheNodeIds)
+      .setFeaturesCol(rFormula.getFeaturesCol)
+      .setLabelCol(rFormula.getLabelCol)
+      .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
+    if (seed != null && seed.length > 0) dtc.setSeed(seed.toLong)
+
+    val idxToStr = new IndexToString()
+      .setInputCol(PREDICTED_LABEL_INDEX_COL)
+      .setOutputCol(PREDICTED_LABEL_COL)
+      .setLabels(labels)
+
+    val pipeline = new Pipeline()
+      .setStages(Array(rFormulaModel, dtc, idxToStr))
+      .fit(data)
+
+    new DecisionTreeClassifierWrapper(pipeline, formula, features)
+  }
+
+  override def read: MLReader[DecisionTreeClassifierWrapper] =
+    new DecisionTreeClassifierWrapperReader
+
+  override def load(path: String): DecisionTreeClassifierWrapper = 
super.load(path)
+
+  class DecisionTreeClassifierWrapperWriter(instance: 
DecisionTreeClassifierWrapper)
+    extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("formula" -> instance.formula) ~
+        ("features" -> instance.features.toSeq)
+      val rMetadataJson: String = compact(render(rMetadata))
+
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+      instance.pipeline.save(pipelinePath)
+    }
+  }
+
+  class DecisionTreeClassifierWrapperReader extends 
MLReader[DecisionTreeClassifierWrapper] {
+
+    override def load(path: String): DecisionTreeClassifierWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+      val pipeline = PipelineModel.load(pipelinePath)
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val formula = (rMetadata \ "formula").extract[String]
+      val features = (rMetadata \ "features").extract[Array[String]]
+
+      new DecisionTreeClassifierWrapper(pipeline, formula, features)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4be33758/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala
new file mode 100644
index 0000000..de712d6
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, 
DecisionTreeRegressor}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class DecisionTreeRegressorWrapper private (
+  val pipeline: PipelineModel,
+  val formula: String,
+  val features: Array[String]) extends MLWritable {
+
+  private val dtrModel: DecisionTreeRegressionModel =
+    pipeline.stages(1).asInstanceOf[DecisionTreeRegressionModel]
+
+  lazy val numFeatures: Int = dtrModel.numFeatures
+  lazy val featureImportances: Vector = dtrModel.featureImportances
+  lazy val maxDepth: Int = dtrModel.getMaxDepth
+
+  def summary: String = dtrModel.toDebugString
+
+  def transform(dataset: Dataset[_]): DataFrame = {
+    pipeline.transform(dataset).drop(dtrModel.getFeaturesCol)
+  }
+
+  override def write: MLWriter = new
+      DecisionTreeRegressorWrapper.DecisionTreeRegressorWrapperWriter(this)
+}
+
+private[r] object DecisionTreeRegressorWrapper extends 
MLReadable[DecisionTreeRegressorWrapper] {
+  def fit(  // scalastyle:ignore
+      data: DataFrame,
+      formula: String,
+      maxDepth: Int,
+      maxBins: Int,
+      impurity: String,
+      minInstancesPerNode: Int,
+      minInfoGain: Double,
+      checkpointInterval: Int,
+      seed: String,
+      maxMemoryInMB: Int,
+      cacheNodeIds: Boolean): DecisionTreeRegressorWrapper = {
+
+    val rFormula = new RFormula()
+      .setFormula(formula)
+    RWrapperUtils.checkDataColumns(rFormula, data)
+    val rFormulaModel = rFormula.fit(data)
+
+    // get feature names from output schema
+    val schema = rFormulaModel.transform(data).schema
+    val featureAttrs = 
AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+      .attributes.get
+    val features = featureAttrs.map(_.name.get)
+
+    // assemble and fit the pipeline
+    val dtr = new DecisionTreeRegressor()
+      .setMaxDepth(maxDepth)
+      .setMaxBins(maxBins)
+      .setImpurity(impurity)
+      .setMinInstancesPerNode(minInstancesPerNode)
+      .setMinInfoGain(minInfoGain)
+      .setCheckpointInterval(checkpointInterval)
+      .setMaxMemoryInMB(maxMemoryInMB)
+      .setCacheNodeIds(cacheNodeIds)
+      .setFeaturesCol(rFormula.getFeaturesCol)
+    if (seed != null && seed.length > 0) dtr.setSeed(seed.toLong)
+
+    val pipeline = new Pipeline()
+      .setStages(Array(rFormulaModel, dtr))
+      .fit(data)
+
+    new DecisionTreeRegressorWrapper(pipeline, formula, features)
+  }
+
+  override def read: MLReader[DecisionTreeRegressorWrapper] = new 
DecisionTreeRegressorWrapperReader
+
+  override def load(path: String): DecisionTreeRegressorWrapper = 
super.load(path)
+
+  class DecisionTreeRegressorWrapperWriter(instance: 
DecisionTreeRegressorWrapper)
+    extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("formula" -> instance.formula) ~
+        ("features" -> instance.features.toSeq)
+      val rMetadataJson: String = compact(render(rMetadata))
+
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+      instance.pipeline.save(pipelinePath)
+    }
+  }
+
+  class DecisionTreeRegressorWrapperReader extends 
MLReader[DecisionTreeRegressorWrapper] {
+
+    override def load(path: String): DecisionTreeRegressorWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+      val pipeline = PipelineModel.load(pipelinePath)
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val formula = (rMetadata \ "formula").extract[String]
+      val features = (rMetadata \ "features").extract[Array[String]]
+
+      new DecisionTreeRegressorWrapper(pipeline, formula, features)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4be33758/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index b30ce12..ba6445a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -60,6 +60,10 @@ private[r] object RWrappers extends MLReader[Object] {
         RandomForestRegressorWrapper.load(path)
       case "org.apache.spark.ml.r.RandomForestClassifierWrapper" =>
         RandomForestClassifierWrapper.load(path)
+      case "org.apache.spark.ml.r.DecisionTreeRegressorWrapper" =>
+        DecisionTreeRegressorWrapper.load(path)
+      case "org.apache.spark.ml.r.DecisionTreeClassifierWrapper" =>
+        DecisionTreeClassifierWrapper.load(path)
       case "org.apache.spark.ml.r.GBTRegressorWrapper" =>
         GBTRegressorWrapper.load(path)
       case "org.apache.spark.ml.r.GBTClassifierWrapper" =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to