Repository: spark
Updated Branches:
  refs/heads/master d0bfc6733 -> a7b46c627


[SPARK-20307][SPARKR] SparkR: pass on setHandleInvalid to spark.mllib functions 
that use StringIndexer

## What changes were proposed in this pull request?

For randomForest classifier, if test data contains unseen labels, it will throw 
an error. The StringIndexer already has the handleInvalid logic. The patch add 
a new method to set the underlying StringIndexer handleInvalid logic.

This patch should also apply to other classifiers. This PR focuses on the main 
logic and randomForest classifier. I will do follow-up PR for other classifiers.

## How was this patch tested?

Add a new unit test based on the error case in the JIRA.

Author: wangmiao1981 <wm...@hotmail.com>

Closes #18496 from wangmiao1981/handle.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a7b46c62
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a7b46c62
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a7b46c62

Branch: refs/heads/master
Commit: a7b46c627b5d2461257f337139a29f23350e0c77
Parents: d0bfc67
Author: wangmiao1981 <wm...@hotmail.com>
Authored: Fri Jul 7 23:51:32 2017 -0700
Committer: Felix Cheung <felixche...@apache.org>
Committed: Fri Jul 7 23:51:32 2017 -0700

----------------------------------------------------------------------
 R/pkg/R/mllib_tree.R                            | 11 +++++++--
 R/pkg/tests/fulltests/test_mllib_tree.R         | 17 +++++++++++++
 .../org/apache/spark/ml/feature/RFormula.scala  | 25 ++++++++++++++++++++
 .../r/RandomForestClassificationWrapper.scala   |  4 +++-
 .../spark/ml/feature/StringIndexerSuite.scala   |  2 +-
 5 files changed, 55 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a7b46c62/R/pkg/R/mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R
index 2f1220a..75b1a74 100644
--- a/R/pkg/R/mllib_tree.R
+++ b/R/pkg/R/mllib_tree.R
@@ -374,6 +374,10 @@ setMethod("write.ml", signature(object = 
"GBTClassificationModel", path = "chara
 #'                     nodes. If TRUE, the algorithm will cache node IDs for 
each instance. Caching
 #'                     can speed up training of deeper trees. Users can set 
how often should the
 #'                     cache be checkpointed or disable it by setting 
checkpointInterval.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL 
values) in classification model.
+#'        Supported options: "skip" (filter out rows with invalid data),
+#'                           "error" (throw an error), "keep" (put invalid 
data in a special additional
+#'                           bucket, at index numLabels). Default is "error".
 #' @param ... additional arguments passed to the method.
 #' @aliases spark.randomForest,SparkDataFrame,formula-method
 #' @return \code{spark.randomForest} returns a fitted Random Forest model.
@@ -409,7 +413,8 @@ setMethod("spark.randomForest", signature(data = 
"SparkDataFrame", formula = "fo
                    maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
                    featureSubsetStrategy = "auto", seed = NULL, 
subsamplingRate = 1.0,
                    minInstancesPerNode = 1, minInfoGain = 0.0, 
checkpointInterval = 10,
-                   maxMemoryInMB = 256, cacheNodeIds = FALSE) {
+                   maxMemoryInMB = 256, cacheNodeIds = FALSE,
+                   handleInvalid = c("error", "keep", "skip")) {
             type <- match.arg(type)
             formula <- paste(deparse(formula), collapse = "")
             if (!is.null(seed)) {
@@ -430,6 +435,7 @@ setMethod("spark.randomForest", signature(data = 
"SparkDataFrame", formula = "fo
                      new("RandomForestRegressionModel", jobj = jobj)
                    },
                    classification = {
+                     handleInvalid <- match.arg(handleInvalid)
                      if (is.null(impurity)) impurity <- "gini"
                      impurity <- match.arg(impurity, c("gini", "entropy"))
                      jobj <- 
callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper",
@@ -439,7 +445,8 @@ setMethod("spark.randomForest", signature(data = 
"SparkDataFrame", formula = "fo
                                          as.numeric(minInfoGain), 
as.integer(checkpointInterval),
                                          as.character(featureSubsetStrategy), 
seed,
                                          as.numeric(subsamplingRate),
-                                         as.integer(maxMemoryInMB), 
as.logical(cacheNodeIds))
+                                         as.integer(maxMemoryInMB), 
as.logical(cacheNodeIds),
+                                         handleInvalid)
                      new("RandomForestClassificationModel", jobj = jobj)
                    }
             )

http://git-wip-us.apache.org/repos/asf/spark/blob/a7b46c62/R/pkg/tests/fulltests/test_mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R 
b/R/pkg/tests/fulltests/test_mllib_tree.R
index 9b3fc8d..66a0693 100644
--- a/R/pkg/tests/fulltests/test_mllib_tree.R
+++ b/R/pkg/tests/fulltests/test_mllib_tree.R
@@ -212,6 +212,23 @@ test_that("spark.randomForest", {
   expect_equal(length(grep("1.0", predictions)), 50)
   expect_equal(length(grep("2.0", predictions)), 50)
 
+  # Test unseen labels
+  data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+                    someString = base::sample(c("this", "that"), 10, replace = 
TRUE),
+                    stringsAsFactors = FALSE)
+  trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+  traindf <- as.DataFrame(data[trainidxs, ])
+  testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+  model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
+                          maxDepth = 10, maxBins = 10, numTrees = 10)
+  predictions <- predict(model, testdf)
+  expect_error(collect(predictions))
+  model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
+                             maxDepth = 10, maxBins = 10, numTrees = 10,
+                             handleInvalid = "skip")
+  predictions <- predict(model, testdf)
+  expect_equal(class(collect(predictions)$clicked[1]), "character")
+
   # spark.randomForest classification can work on libsvm data
   if (windows_with_hadoop()) {
     data <- 
read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),

http://git-wip-us.apache.org/repos/asf/spark/blob/a7b46c62/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 4b44878..61aa646 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -132,6 +132,30 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override 
val uid: String)
   @Since("1.5.0")
   def getFormula: String = $(formula)
 
+  /**
+   * Param for how to handle invalid data (unseen labels or NULL values).
+   * Options are 'skip' (filter out rows with invalid data),
+   * 'error' (throw an error), or 'keep' (put invalid data in a special 
additional
+   * bucket, at index numLabels).
+   * Default: "error"
+   * @group param
+   */
+  @Since("2.3.0")
+  val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", 
"How to handle " +
+    "invalid data (unseen labels or NULL values). " +
+    "Options are 'skip' (filter out rows with invalid data), error (throw an 
error), " +
+    "or 'keep' (put invalid data in a special additional bucket, at index 
numLabels).",
+    ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
+  setDefault(handleInvalid, StringIndexer.ERROR_INVALID)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
+  /** @group getParam */
+  @Since("2.3.0")
+  def getHandleInvalid: String = $(handleInvalid)
+
   /** @group setParam */
   @Since("1.5.0")
   def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@@ -197,6 +221,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override 
val uid: String)
             .setInputCol(term)
             .setOutputCol(indexCol)
             .setStringOrderType($(stringIndexerOrderType))
+            .setHandleInvalid($(handleInvalid))
           prefixesToRewrite(indexCol + "_") = term + "_"
           (term, indexCol)
         case _ =>

http://git-wip-us.apache.org/repos/asf/spark/blob/a7b46c62/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index 8a83d4e..132345f 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -78,11 +78,13 @@ private[r] object RandomForestClassifierWrapper extends 
MLReadable[RandomForestC
       seed: String,
       subsamplingRate: Double,
       maxMemoryInMB: Int,
-      cacheNodeIds: Boolean): RandomForestClassifierWrapper = {
+      cacheNodeIds: Boolean,
+      handleInvalid: String): RandomForestClassifierWrapper = {
 
     val rFormula = new RFormula()
       .setFormula(formula)
       .setForceIndexLabel(true)
+      .setHandleInvalid(handleInvalid)
     checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a7b46c62/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 806a927..027b1fb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
 import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, 
MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.functions.col


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to