Repository: spark
Updated Branches:
  refs/heads/master 673a80d22 -> 92c0eaf34


[SPARK-17086][ML] Fix InvalidArgumentException issue in QuantileDiscretizer 
when some quantiles are duplicated

## What changes were proposed in this pull request?

In cases when QuantileDiscretizerSuite is called upon a numeric array with 
duplicated elements,  we will  take the unique elements generated from 
approxQuantiles as input for Bucketizer.

## How was this patch tested?

An unit test is added in QuantileDiscretizerSuite

QuantileDiscretizer.fit will throw an illegal exception when calling setSplits 
on a list of splits
with duplicated elements. Bucketizer.setSplits should only accept either a 
numeric vector of two
or more unique cut points, although that may produce less number of buckets 
than requested.

Signed-off-by: VinceShieh <vincent.xieintel.com>

Author: VinceShieh <vincent....@intel.com>

Closes #14747 from VinceShieh/SPARK-17086.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/92c0eaf3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/92c0eaf3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/92c0eaf3

Branch: refs/heads/master
Commit: 92c0eaf348b42b3479610da0be761013f9d81c54
Parents: 673a80d
Author: VinceShieh <vincent....@intel.com>
Authored: Wed Aug 24 10:16:58 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Wed Aug 24 10:16:58 2016 +0100

----------------------------------------------------------------------
 .../spark/ml/feature/QuantileDiscretizer.scala   |  7 ++++++-
 .../ml/feature/QuantileDiscretizerSuite.scala    | 19 +++++++++++++++++++
 2 files changed, 25 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/92c0eaf3/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 558a7bb..e098008 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -114,7 +114,12 @@ final class QuantileDiscretizer @Since("1.6.0") 
(@Since("1.6.0") override val ui
     splits(0) = Double.NegativeInfinity
     splits(splits.length - 1) = Double.PositiveInfinity
 
-    val bucketizer = new Bucketizer(uid).setSplits(splits)
+    val distinctSplits = splits.distinct
+    if (splits.length != distinctSplits.length) {
+      log.warn(s"Some quantiles were identical. Bucketing to 
${distinctSplits.length - 1}" +
+        s" buckets as a result.")
+    }
+    val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted)
     copyValues(bucketizer.setParent(this))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/92c0eaf3/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index b73dbd6..18f1e89 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -52,6 +52,25 @@ class QuantileDiscretizerSuite
       "Bucket sizes are not within expected relative error tolerance.")
   }
 
+  test("Test Bucketizer on duplicated splits") {
+    val spark = this.spark
+    import spark.implicits._
+
+    val datasetSize = 12
+    val numBuckets = 5
+    val df = sc.parallelize(Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 
2.0, 1.0, 3.0))
+      .map(Tuple1.apply).toDF("input")
+    val discretizer = new QuantileDiscretizer()
+      .setInputCol("input")
+      .setOutputCol("result")
+      .setNumBuckets(numBuckets)
+    val result = discretizer.fit(df).transform(df)
+
+    val observedNumBuckets = result.select("result").distinct.count
+    assert(2 <= observedNumBuckets && observedNumBuckets <= numBuckets,
+      "Observed number of buckets are not within expected range.")
+  }
+
   test("Test transform method on unseen data") {
     val spark = this.spark
     import spark.implicits._


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to