Repository: spark
Updated Branches:
  refs/heads/master 104232580 -> 0b076d4cb


[SPARK-17219][ML] enhanced NaN value handling in Bucketizer

## What changes were proposed in this pull request?

This PR is an enhancement of PR with commit 
ID:57dc326bd00cf0a49da971e9c573c48ae28acaa2.
NaN is a special type of value which is commonly seen as invalid. But We find 
that there are certain cases where NaN are also valuable, thus need special 
handling. We provided user when dealing NaN values with 3 options, to either 
reserve an extra bucket for NaN values, or remove the NaN values, or report an 
error, by setting handleNaN "keep", "skip", or "error"(default) respectively.

'''Before:
val bucketizer: Bucketizer = new Bucketizer()
          .setInputCol("feature")
          .setOutputCol("result")
          .setSplits(splits)
'''After:
val bucketizer: Bucketizer = new Bucketizer()
          .setInputCol("feature")
          .setOutputCol("result")
          .setSplits(splits)
          .setHandleNaN("keep")

## How was this patch tested?
Tests added in QuantileDiscretizerSuite, BucketizerSuite and DataFrameStatSuite

Signed-off-by: VinceShieh <vincent.xieintel.com>

Author: VinceShieh <vincent....@intel.com>
Author: Vincent Xie <vincent....@intel.com>
Author: Joseph K. Bradley <jos...@databricks.com>

Closes #15428 from VinceShieh/spark-17219_followup.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0b076d4c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0b076d4c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0b076d4c

Branch: refs/heads/master
Commit: 0b076d4cb6afde2946124e6411ed6a6ce7b8b1a7
Parents: 1042325
Author: VinceShieh <vincent....@intel.com>
Authored: Thu Oct 27 11:52:15 2016 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Thu Oct 27 11:52:15 2016 -0700

----------------------------------------------------------------------
 docs/ml-features.md                             | 15 +++--
 .../apache/spark/ml/feature/Bucketizer.scala    | 71 ++++++++++++++++++--
 .../spark/ml/feature/QuantileDiscretizer.scala  | 47 +++++++++++--
 .../spark/ml/feature/BucketizerSuite.scala      | 26 +++++--
 .../ml/feature/QuantileDiscretizerSuite.scala   | 35 +++++++---
 python/pyspark/ml/feature.py                    |  5 --
 .../apache/spark/sql/DataFrameStatSuite.scala   |  4 ++
 7 files changed, 161 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0b076d4c/docs/ml-features.md
----------------------------------------------------------------------
diff --git a/docs/ml-features.md b/docs/ml-features.md
index a7f710f..64c6a16 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -1103,11 +1103,16 @@ for more details on the API.
 
 `QuantileDiscretizer` takes a column with continuous features and outputs a 
column with binned
 categorical features. The number of bins is set by the `numBuckets` parameter. 
It is possible
-that the number of buckets used will be less than this value, for example, if 
there are too few
-distinct values of the input to create enough distinct quantiles. Note also 
that NaN values are
-handled specially and placed into their own bucket. For example, if 4 buckets 
are used, then
-non-NaN data will be put into buckets[0-3], but NaNs will be counted in a 
special bucket[4].
-The bin ranges are chosen using an approximate algorithm (see the 
documentation for
+that the number of buckets used will be smaller than this value, for example, 
if there are too few
+distinct values of the input to create enough distinct quantiles.
+
+NaN values: Note also that QuantileDiscretizer
+will raise an error when it finds NaN values in the dataset, but the user can 
also choose to either
+keep or remove NaN values within the dataset by setting `handleInvalid`. If 
the user chooses to keep
+NaN values, they will be handled specially and placed into their own bucket, 
for example, if 4 buckets
+are used, then non-NaN data will be put into buckets[0-3], but NaNs will be 
counted in a special bucket[4].
+
+Algorithm: The bin ranges are chosen using an approximate algorithm (see the 
documentation for
 
[approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions)
 for a
 detailed description). The precision of the approximation can be controlled 
with the
 `relativeError` parameter. When set to zero, exact quantiles are calculated

http://git-wip-us.apache.org/repos/asf/spark/blob/0b076d4c/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index ec0ea05..1143f0f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -27,6 +27,7 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
 import org.apache.spark.ml.util._
 import org.apache.spark.sql._
+import org.apache.spark.sql.expressions.UserDefinedFunction
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
 
@@ -46,6 +47,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") 
override val uid: String
    * also includes y. Splits should be of length >= 3 and strictly increasing.
    * Values at -inf, inf must be explicitly provided to cover all Double 
values;
    * otherwise, values outside the splits specified will be treated as errors.
+   *
+   * See also [[handleInvalid]], which can optionally create an additional 
bucket for NaN values.
+   *
    * @group param
    */
   @Since("1.4.0")
@@ -73,15 +77,47 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") 
override val uid: String
   @Since("1.4.0")
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
+  /**
+   * Param for how to handle invalid entries. Options are skip (filter out 
rows with
+   * invalid values), error (throw an error), or keep (keep invalid values in 
a special additional
+   * bucket).
+   * Default: "error"
+   * @group param
+   */
+  @Since("2.1.0")
+  val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", 
"how to handle" +
+    "invalid entries. Options are skip (filter out rows with invalid values), 
" +
+    "error (throw an error), or keep (keep invalid values in a special 
additional bucket).",
+    ParamValidators.inArray(Bucketizer.supportedHandleInvalid))
+
+  /** @group getParam */
+  @Since("2.1.0")
+  def getHandleInvalid: String = $(handleInvalid)
+
+  /** @group setParam */
+  @Since("2.1.0")
+  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+  setDefault(handleInvalid, Bucketizer.ERROR_INVALID)
+
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema)
-    val bucketizer = udf { feature: Double =>
-      Bucketizer.binarySearchForBuckets($(splits), feature)
+    val (filteredDataset, keepInvalid) = {
+      if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
+        // "skip" NaN option is set, will filter out NaN values in the dataset
+        (dataset.na.drop().toDF(), false)
+      } else {
+        (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID)
+      }
+    }
+
+    val bucketizer: UserDefinedFunction = udf { (feature: Double) =>
+      Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid)
     }
-    val newCol = bucketizer(dataset($(inputCol)))
-    val newField = prepOutputField(dataset.schema)
-    dataset.withColumn($(outputCol), newCol, newField.metadata)
+
+    val newCol = bucketizer(filteredDataset($(inputCol)))
+    val newField = prepOutputField(filteredDataset.schema)
+    filteredDataset.withColumn($(outputCol), newCol, newField.metadata)
   }
 
   private def prepOutputField(schema: StructType): StructField = {
@@ -106,6 +142,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") 
override val uid: String
 @Since("1.6.0")
 object Bucketizer extends DefaultParamsReadable[Bucketizer] {
 
+  private[feature] val SKIP_INVALID: String = "skip"
+  private[feature] val ERROR_INVALID: String = "error"
+  private[feature] val KEEP_INVALID: String = "keep"
+  private[feature] val supportedHandleInvalid: Array[String] =
+    Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
+
   /**
    * We require splits to be of length >= 3 and to be in strictly increasing 
order.
    * No NaN split should be accepted.
@@ -126,11 +168,26 @@ object Bucketizer extends 
DefaultParamsReadable[Bucketizer] {
 
   /**
    * Binary searching in several buckets to place each data point.
+   * @param splits array of split points
+   * @param feature data point
+   * @param keepInvalid NaN flag.
+   *                    Set "true" to make an extra bucket for NaN values;
+   *                    Set "false" to report an error for NaN values
+   * @return bucket for each data point
    * @throws SparkException if a feature is < splits.head or > splits.last
    */
-  private[feature] def binarySearchForBuckets(splits: Array[Double], feature: 
Double): Double = {
+
+  private[feature] def binarySearchForBuckets(
+      splits: Array[Double],
+      feature: Double,
+      keepInvalid: Boolean): Double = {
     if (feature.isNaN) {
-      splits.length - 1
+      if (keepInvalid) {
+        splits.length - 1
+      } else {
+        throw new SparkException("Bucketizer encountered NaN value. To handle 
or skip NaNs," +
+          " try setting Bucketizer.handleInvalid.")
+      }
     } else if (feature == splits.last) {
       splits.length - 2
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/0b076d4c/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 05e034d..b9e01dd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -36,6 +36,9 @@ private[feature] trait QuantileDiscretizerBase extends Params
   /**
    * Number of buckets (quantiles, or categories) into which data points are 
grouped. Must
    * be >= 2.
+   *
+   * See also [[handleInvalid]], which can optionally create an additional 
bucket for NaN values.
+   *
    * default: 2
    * @group param
    */
@@ -61,17 +64,41 @@ private[feature] trait QuantileDiscretizerBase extends 
Params
 
   /** @group getParam */
   def getRelativeError: Double = getOrDefault(relativeError)
+
+  /**
+   * Param for how to handle invalid entries. Options are skip (filter out 
rows with
+   * invalid values), error (throw an error), or keep (keep invalid values in 
a special additional
+   * bucket).
+   * Default: "error"
+   * @group param
+   */
+  @Since("2.1.0")
+  val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", 
"how to handle" +
+    "invalid entries. Options are skip (filter out rows with invalid values), 
" +
+    "error (throw an error), or keep (keep invalid values in a special 
additional bucket).",
+    ParamValidators.inArray(Bucketizer.supportedHandleInvalid))
+  setDefault(handleInvalid, Bucketizer.ERROR_INVALID)
+
+  /** @group getParam */
+  @Since("2.1.0")
+  def getHandleInvalid: String = $(handleInvalid)
+
 }
 
 /**
  * `QuantileDiscretizer` takes a column with continuous features and outputs a 
column with binned
  * categorical features. The number of bins can be set using the `numBuckets` 
parameter. It is
- * possible that the number of buckets used will be less than this value, for 
example, if there
- * are too few distinct values of the input to create enough distinct 
quantiles. Note also that
- * NaN values are handled specially and placed into their own bucket. For 
example, if 4 buckets
- * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be 
counted in a special
- * bucket(4).
- * The bin ranges are chosen using an approximate algorithm (see the 
documentation for
+ * possible that the number of buckets used will be smaller than this value, 
for example, if there
+ * are too few distinct values of the input to create enough distinct 
quantiles.
+ *
+ * NaN handling: Note also that
+ * QuantileDiscretizer will raise an error when it finds NaN values in the 
dataset, but the user can
+ * also choose to either keep or remove NaN values within the dataset by 
setting `handleInvalid`.
+ * If the user chooses to keep NaN values, they will be handled specially and 
placed into their own
+ * bucket, for example, if 4 buckets are used, then non-NaN data will be put 
into buckets[0-3],
+ * but NaNs will be counted in a special bucket[4].
+ *
+ * Algorithm: The bin ranges are chosen using an approximate algorithm (see 
the documentation for
  * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile 
approxQuantile]]
  * for a detailed description). The precision of the approximation can be 
controlled with the
  * `relativeError` parameter. The lower and upper bin bounds will be 
`-Infinity` and `+Infinity`,
@@ -100,6 +127,10 @@ final class QuantileDiscretizer @Since("1.6.0") 
(@Since("1.6.0") override val ui
   @Since("1.6.0")
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
+  /** @group setParam */
+  @Since("2.1.0")
+  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
   @Since("1.6.0")
   override def transformSchema(schema: StructType): StructType = {
     SchemaUtils.checkNumericType(schema, $(inputCol))
@@ -124,7 +155,9 @@ final class QuantileDiscretizer @Since("1.6.0") 
(@Since("1.6.0") override val ui
       log.warn(s"Some quantiles were identical. Bucketing to 
${distinctSplits.length - 1}" +
         s" buckets as a result.")
     }
-    val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted)
+    val bucketizer = new Bucketizer(uid)
+      .setSplits(distinctSplits.sorted)
+      .setHandleInvalid($(handleInvalid))
     copyValues(bucketizer.setParent(this))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0b076d4c/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 87cdceb..aac2913 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -99,21 +99,32 @@ class BucketizerSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defa
       .setOutputCol("result")
       .setSplits(splits)
 
+    bucketizer.setHandleInvalid("keep")
     bucketizer.transform(dataFrame).select("result", 
"expected").collect().foreach {
       case Row(x: Double, y: Double) =>
         assert(x === y,
           s"The feature value is not correct after bucketing.  Expected $y but 
found $x")
     }
+
+    bucketizer.setHandleInvalid("skip")
+    val skipResults: Array[Double] = bucketizer.transform(dataFrame)
+      .select("result").as[Double].collect()
+    assert(skipResults.length === 7)
+    assert(skipResults.forall(_ !== 4.0))
+
+    bucketizer.setHandleInvalid("error")
+    withClue("Bucketizer should throw error when setHandleInvalid=error and 
given NaN values") {
+      intercept[SparkException] {
+        bucketizer.transform(dataFrame).collect()
+      }
+    }
   }
 
   test("Bucket continuous features, with NaN splits") {
     val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, 
Double.PositiveInfinity, Double.NaN)
-    withClue("Invalid NaN split was not caught as an invalid split!") {
+    withClue("Invalid NaN split was not caught during Bucketizer 
initialization") {
       intercept[IllegalArgumentException] {
-        val bucketizer: Bucketizer = new Bucketizer()
-          .setInputCol("feature")
-          .setOutputCol("result")
-          .setSplits(splits)
+        new Bucketizer().setSplits(splits)
       }
     }
   }
@@ -138,7 +149,8 @@ class BucketizerSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defa
     val data = Array.fill(100)(Random.nextDouble())
     val splits: Array[Double] = Double.NegativeInfinity +:
       Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity
-    val bsResult = Vectors.dense(data.map(x => 
Bucketizer.binarySearchForBuckets(splits, x)))
+    val bsResult = Vectors.dense(data.map(x =>
+      Bucketizer.binarySearchForBuckets(splits, x, false)))
     val lsResult = Vectors.dense(data.map(x => 
BucketizerSuite.linearSearchForBuckets(splits, x)))
     assert(bsResult ~== lsResult absTol 1e-5)
   }
@@ -169,7 +181,7 @@ private object BucketizerSuite extends SparkFunSuite {
   /** Check all values in splits, plus values between all splits. */
   def checkBinarySearch(splits: Array[Double]): Unit = {
     def testFeature(feature: Double, expectedBucket: Double): Unit = {
-      assert(Bucketizer.binarySearchForBuckets(splits, feature) === 
expectedBucket,
+      assert(Bucketizer.binarySearchForBuckets(splits, feature, false) === 
expectedBucket,
         s"Expected feature value $feature to be in bucket $expectedBucket with 
splits:" +
           s" ${splits.mkString(", ")}")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/0b076d4c/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 6822594..f219f77 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -17,10 +17,10 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql._
 import org.apache.spark.sql.functions.udf
 
 class QuantileDiscretizerSuite
@@ -76,20 +76,33 @@ class QuantileDiscretizerSuite
     import spark.implicits._
 
     val numBuckets = 3
-    val df = sc.parallelize(Array(1.0, 1.0, 1.0, Double.NaN))
-      .map(Tuple1.apply).toDF("input")
+    val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, 
Double.NaN, Double.NaN)
+    val expectedKeep = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0)
+    val expectedSkip = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0)
+
     val discretizer = new QuantileDiscretizer()
       .setInputCol("input")
       .setOutputCol("result")
       .setNumBuckets(numBuckets)
 
-    // Reserve extra one bucket for NaN
-    val expectedNumBuckets = discretizer.fit(df).getSplits.length - 1
-    val result = discretizer.fit(df).transform(df)
-    val observedNumBuckets = result.select("result").distinct.count
-    assert(observedNumBuckets == expectedNumBuckets,
-      s"Observed number of buckets are not correct." +
-        s" Expected $expectedNumBuckets but found $observedNumBuckets")
+    withClue("QuantileDiscretizer with handleInvalid=error should throw 
exception for NaN values") {
+      val dataFrame: DataFrame = validData.toSeq.toDF("input")
+      intercept[SparkException] {
+        discretizer.fit(dataFrame).transform(dataFrame).collect()
+      }
+    }
+
+    List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{
+      case(u, v) =>
+        discretizer.setHandleInvalid(u)
+        val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", 
"expected")
+        val result = discretizer.fit(dataFrame).transform(dataFrame)
+        result.select("result", "expected").collect().foreach {
+          case Row(x: Double, y: Double) =>
+            assert(x === y,
+              s"The feature value is not correct after bucketing.  Expected $y 
but found $x")
+        }
+    }
   }
 
   test("Test transform method on unseen data") {

http://git-wip-us.apache.org/repos/asf/spark/blob/0b076d4c/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 7683360..94afe82 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -1155,11 +1155,6 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadab
 
     `QuantileDiscretizer` takes a column with continuous features and outputs 
a column with binned
     categorical features. The number of bins can be set using the 
:py:attr:`numBuckets` parameter.
-    It is possible that the number of buckets used will be less than this 
value, for example, if
-    there are too few distinct values of the input to create enough distinct 
quantiles. Note also
-    that NaN values are handled specially and placed into their own bucket. 
For example, if 4
-    buckets are used, then non-NaN data will be put into buckets(0-3), but 
NaNs will be counted in
-    a special bucket(4).
     The bin ranges are chosen using an approximate algorithm (see the 
documentation for
     :py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed 
description).
     The precision of the approximation can be controlled with the

http://git-wip-us.apache.org/repos/asf/spark/blob/0b076d4c/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 73026c7..1383208 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -150,6 +150,10 @@ class DataFrameStatSuite extends QueryTest with 
SharedSQLContext {
       assert(math.abs(d1 - 2 * q1 * n) < error_double)
       assert(math.abs(d2 - 2 * q2 * n) < error_double)
     }
+    // test approxQuantile on NaN values
+    val dfNaN = Seq(Double.NaN, 1.0, Double.NaN, Double.NaN).toDF("input")
+    val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), 
epsilons.head)
+    assert(resNaN.count(_.isNaN) === 0)
   }
 
   test("crosstab") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to