Repository: spark
Updated Branches:
  refs/heads/master d4a637cd4 -> 85941ecf2


[SPARK-11569][ML] Fix StringIndexer to handle null value properly

## What changes were proposed in this pull request?

This PR is to enhance StringIndexer with NULL values handling.

Before the PR, StringIndexer will throw an exception when encounters NULL 
values.
With this PR:
- handleInvalid=error: Throw an exception as before
- handleInvalid=skip: Skip null values as well as unseen labels
- handleInvalid=keep: Give null values an additional index as well as unseen 
labels

BTW, I noticed someone was trying to solve the same problem ( #9920 ) but seems 
getting no progress or response for a long time. Would you mind to give me a 
chance to solve it ? I'm eager to help. :-)

## How was this patch tested?

new unit tests

Author: Menglong TAN <tanmengl...@renrenche.com>
Author: Menglong TAN <tanmengl...@gmail.com>

Closes #17233 from crackcell/11569_StringIndexer_NULL.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/85941ecf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/85941ecf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/85941ecf

Branch: refs/heads/master
Commit: 85941ecf28362f35718ebcd3a22dbb17adb49154
Parents: d4a637c
Author: Menglong TAN <tanmengl...@renrenche.com>
Authored: Tue Mar 14 07:45:42 2017 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Mar 14 07:45:42 2017 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/StringIndexer.scala | 54 ++++++++++++--------
 .../spark/ml/feature/StringIndexerSuite.scala   | 45 ++++++++++++++++
 2 files changed, 77 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/85941ecf/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 810b02f..99321bc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -39,20 +39,21 @@ import org.apache.spark.util.collection.OpenHashMap
 private[feature] trait StringIndexerBase extends Params with HasInputCol with 
HasOutputCol {
 
   /**
-   * Param for how to handle unseen labels. Options are 'skip' (filter out 
rows with
-   * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in 
a special additional
-   * bucket, at index numLabels.
+   * Param for how to handle invalid data (unseen labels or NULL values).
+   * Options are 'skip' (filter out rows with invalid data),
+   * 'error' (throw an error), or 'keep' (put invalid data in a special 
additional
+   * bucket, at index numLabels).
    * Default: "error"
    * @group param
    */
   @Since("1.6.0")
   val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", 
"how to handle " +
-    "unseen labels. Options are 'skip' (filter out rows with unseen labels), " 
+
-    "error (throw an error), or 'keep' (put unseen labels in a special 
additional bucket, " +
-    "at index numLabels).",
+    "invalid data (unseen labels or NULL values). " +
+    "Options are 'skip' (filter out rows with invalid data), error (throw an 
error), " +
+    "or 'keep' (put invalid data in a special additional bucket, at index 
numLabels).",
     ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
 
-  setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)
+  setDefault(handleInvalid, StringIndexer.ERROR_INVALID)
 
   /** @group getParam */
   @Since("1.6.0")
@@ -106,7 +107,7 @@ class StringIndexer @Since("1.4.0") (
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): StringIndexerModel = {
     transformSchema(dataset.schema, logging = true)
-    val counts = dataset.select(col($(inputCol)).cast(StringType))
+    val counts = 
dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
       .rdd
       .map(_.getString(0))
       .countByValue()
@@ -125,11 +126,11 @@ class StringIndexer @Since("1.4.0") (
 
 @Since("1.6.0")
 object StringIndexer extends DefaultParamsReadable[StringIndexer] {
-  private[feature] val SKIP_UNSEEN_LABEL: String = "skip"
-  private[feature] val ERROR_UNSEEN_LABEL: String = "error"
-  private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
+  private[feature] val SKIP_INVALID: String = "skip"
+  private[feature] val ERROR_INVALID: String = "error"
+  private[feature] val KEEP_INVALID: String = "keep"
   private[feature] val supportedHandleInvalids: Array[String] =
-    Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)
+    Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
 
   @Since("1.6.0")
   override def load(path: String): StringIndexer = super.load(path)
@@ -188,7 +189,7 @@ class StringIndexerModel (
     transformSchema(dataset.schema, logging = true)
 
     val filteredLabels = getHandleInvalid match {
-      case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
+      case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
       case _ => labels
     }
 
@@ -196,22 +197,31 @@ class StringIndexerModel (
       .withName($(outputCol)).withValues(filteredLabels).toMetadata()
     // If we are skipping invalid records, filter them out.
     val (filteredDataset, keepInvalid) = getHandleInvalid match {
-      case StringIndexer.SKIP_UNSEEN_LABEL =>
+      case StringIndexer.SKIP_INVALID =>
         val filterer = udf { label: String =>
           labelToIndex.contains(label)
         }
-        (dataset.where(filterer(dataset($(inputCol)))), false)
-      case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL)
+        
(dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), 
false)
+      case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID)
     }
 
     val indexer = udf { label: String =>
-      if (labelToIndex.contains(label)) {
-        labelToIndex(label)
-      } else if (keepInvalid) {
-        labels.length
+      if (label == null) {
+        if (keepInvalid) {
+          labels.length
+        } else {
+          throw new SparkException("StringIndexer encountered NULL value. To 
handle or skip " +
+            "NULLS, try setting StringIndexer.handleInvalid.")
+        }
       } else {
-        throw new SparkException(s"Unseen label: $label.  To handle unseen 
labels, " +
-          s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")
+        if (labelToIndex.contains(label)) {
+          labelToIndex(label)
+        } else if (keepInvalid) {
+          labels.length
+        } else {
+          throw new SparkException(s"Unseen label: $label.  To handle unseen 
labels, " +
+            s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.")
+        }
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/85941ecf/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 188dffb..8d9042b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -122,6 +122,51 @@ class StringIndexerSuite
     assert(output === expected)
   }
 
+  test("StringIndexer with NULLs") {
+    val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null))
+    val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null))
+    val df = data.toDF("id", "label")
+    val df2 = data2.toDF("id", "label")
+
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+
+    withClue("StringIndexer should throw error when setHandleInvalid=error " +
+      "when given NULL values") {
+      intercept[SparkException] {
+        indexer.setHandleInvalid("error")
+        indexer.fit(df).transform(df2).collect()
+      }
+    }
+
+    indexer.setHandleInvalid("skip")
+    val transformedSkip = indexer.fit(df).transform(df2)
+    val attrSkip = Attribute
+      .fromStructField(transformedSkip.schema("labelIndex"))
+      .asInstanceOf[NominalAttribute]
+    assert(attrSkip.values.get === Array("b", "a"))
+    val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
+      (r.getInt(0), r.getDouble(1))
+    }.collect().toSet
+    // a -> 1, b -> 0
+    val expectedSkip = Set((0, 1.0), (1, 0.0))
+    assert(outputSkip === expectedSkip)
+
+    indexer.setHandleInvalid("keep")
+    val transformedKeep = indexer.fit(df).transform(df2)
+    val attrKeep = Attribute
+      .fromStructField(transformedKeep.schema("labelIndex"))
+      .asInstanceOf[NominalAttribute]
+    assert(attrKeep.values.get === Array("b", "a", "__unknown"))
+    val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
+      (r.getInt(0), r.getDouble(1))
+    }.collect().toSet
+    // a -> 1, b -> 0, null -> 2
+    val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0))
+    assert(outputKeep === expectedKeep)
+  }
+
   test("StringIndexerModel should keep silent if the input column does not 
exist.") {
     val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
       .setInputCol("label")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to