Repository: spark Updated Branches: refs/heads/master 673a80d22 -> 92c0eaf34
[SPARK-17086][ML] Fix InvalidArgumentException issue in QuantileDiscretizer when some quantiles are duplicated ## What changes were proposed in this pull request? In cases when QuantileDiscretizerSuite is called upon a numeric array with duplicated elements, we will take the unique elements generated from approxQuantiles as input for Bucketizer. ## How was this patch tested? An unit test is added in QuantileDiscretizerSuite QuantileDiscretizer.fit will throw an illegal exception when calling setSplits on a list of splits with duplicated elements. Bucketizer.setSplits should only accept either a numeric vector of two or more unique cut points, although that may produce less number of buckets than requested. Signed-off-by: VinceShieh <vincent.xieintel.com> Author: VinceShieh <vincent....@intel.com> Closes #14747 from VinceShieh/SPARK-17086. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/92c0eaf3 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/92c0eaf3 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/92c0eaf3 Branch: refs/heads/master Commit: 92c0eaf348b42b3479610da0be761013f9d81c54 Parents: 673a80d Author: VinceShieh <vincent....@intel.com> Authored: Wed Aug 24 10:16:58 2016 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Wed Aug 24 10:16:58 2016 +0100 ---------------------------------------------------------------------- .../spark/ml/feature/QuantileDiscretizer.scala | 7 ++++++- .../ml/feature/QuantileDiscretizerSuite.scala | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/92c0eaf3/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 558a7bb..e098008 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -114,7 +114,12 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui splits(0) = Double.NegativeInfinity splits(splits.length - 1) = Double.PositiveInfinity - val bucketizer = new Bucketizer(uid).setSplits(splits) + val distinctSplits = splits.distinct + if (splits.length != distinctSplits.length) { + log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + + s" buckets as a result.") + } + val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted) copyValues(bucketizer.setParent(this)) } http://git-wip-us.apache.org/repos/asf/spark/blob/92c0eaf3/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index b73dbd6..18f1e89 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -52,6 +52,25 @@ class QuantileDiscretizerSuite "Bucket sizes are not within expected relative error tolerance.") } + test("Test Bucketizer on duplicated splits") { + val spark = this.spark + import spark.implicits._ + + val datasetSize = 12 + val numBuckets = 5 + val df = sc.parallelize(Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0)) + .map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) + + val observedNumBuckets = result.select("result").distinct.count + assert(2 <= observedNumBuckets && observedNumBuckets <= numBuckets, + "Observed number of buckets are not within expected range.") + } + test("Test transform method on unseen data") { val spark = this.spark import spark.implicits._ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org