Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/19715#discussion_r150401807 --- Diff: mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala --- @@ -146,4 +146,172 @@ class QuantileDiscretizerSuite val model = discretizer.fit(df) assert(model.hasParent) } + + test("Multiple Columns: Test observed number of buckets and their sizes match expected values") { + val spark = this.spark + import spark.implicits._ + + val datasetSize = 100000 + val numBuckets = 5 + val data1 = Array.range(1, 100001, 1).map(_.toDouble) + val data2 = Array.range(1, 200000, 2).map(_.toDouble) + val data = (0 until 100000).map { idx => + (data1(idx), data2(idx)) + } + val df: DataFrame = data.toSeq.toDF("input1", "input2") + + val discretizer = new QuantileDiscretizer() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("result1", "result2")) + .setNumBuckets(numBuckets) + assert(discretizer.isQuantileDiscretizeMultipleColumns()) + val result = discretizer.fit(df).transform(df) + + val relativeError = discretizer.getRelativeError + val isGoodBucket = udf { + (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) + } + + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + + val numGoodBuckets = result.groupBy("result" + i).count.filter(isGoodBucket($"count")).count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") + } + } + + test("Multiple Columns: Test on data with high proportion of duplicated values") { + val spark = this.spark + import spark.implicits._ + + val numBuckets = 5 + val expectedNumBucket = 3 + val data1 = Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0) + val data2 = Array(1.0, 2.0, 3.0, 1.0, 1.0, 1.0, 1.0, 3.0, 2.0, 3.0, 1.0, 2.0) + val data = (0 until data1.length).map { idx => + (data1(idx), data2(idx)) + } + val df: DataFrame = data.toSeq.toDF("input1", "input2") --- End diff -- nit: Remove `DataFrame`.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org