This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 0060279f733 [SPARK-44871][SQL][3.4] Fix percentile_disc behaviour 0060279f733 is described below commit 0060279f733989b03aca2bbb0624dfc0c3193aae Author: Peter Toth <peter.t...@gmail.com> AuthorDate: Tue Aug 22 19:27:15 2023 +0300 [SPARK-44871][SQL][3.4] Fix percentile_disc behaviour ### What changes were proposed in this pull request? This PR fixes `percentile_disc()` function as currently it returns inforrect results in some cases. E.g.: ``` SELECT percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 FROM VALUES (0), (1), (2), (3), (4) AS v(a) ``` currently returns: ``` +---+---+---+---+---+---+---+---+---+---+---+ | p0| p1| p2| p3| p4| p5| p6| p7| p8| p9|p10| +---+---+---+---+---+---+---+---+---+---+---+ |0.0|0.0|0.0|1.0|1.0|2.0|2.0|2.0|3.0|3.0|4.0| +---+---+---+---+---+---+---+---+---+---+---+ ``` but after this PR it returns the correct: ``` +---+---+---+---+---+---+---+---+---+---+---+ | p0| p1| p2| p3| p4| p5| p6| p7| p8| p9|p10| +---+---+---+---+---+---+---+---+---+---+---+ |0.0|0.0|0.0|1.0|1.0|2.0|2.0|3.0|3.0|4.0|4.0| +---+---+---+---+---+---+---+---+---+---+---+ ``` ### Why are the changes needed? Bugfix. ### Does this PR introduce _any_ user-facing change? Yes, fixes a correctness bug, but the old behaviour can be restored with `spark.sql.legacy.percentileDiscCalculation=true`. ### How was this patch tested? Added new UTs. Closes #42610 from peter-toth/SPARK-44871-fix-percentile-disc-behaviour-3.4. Authored-by: Peter Toth <peter.t...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../expressions/aggregate/percentiles.scala | 39 +++++-- .../org/apache/spark/sql/internal/SQLConf.scala | 10 ++ .../resources/sql-tests/inputs/percentiles.sql | 77 +++++++++++++- .../sql-tests/results/percentiles.sql.out | 116 +++++++++++++++++++++ 4 files changed, 234 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala index 8447a5f9b51..da04c5a1c8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/percentiles.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.types.TypeCollection.NumericAndAnsiInterval import org.apache.spark.util.collection.OpenHashMap @@ -168,11 +169,8 @@ abstract class PercentileBase val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) }.tail - val maxPosition = accumulatedCounts.last._2 - 1 - percentages.map { percentile => - getPercentile(accumulatedCounts, maxPosition * percentile) - } + percentages.map(getPercentile(accumulatedCounts, _)) } private def generateOutput(percentiles: Seq[Double]): Any = { @@ -195,8 +193,11 @@ abstract class PercentileBase * This function has been based upon similar function from HIVE * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. */ - private def getPercentile( - accumulatedCounts: Seq[(AnyRef, Long)], position: Double): Double = { + protected def getPercentile( + accumulatedCounts: Seq[(AnyRef, Long)], + percentile: Double): Double = { + val position = (accumulatedCounts.last._2 - 1) * percentile + // We may need to do linear interpolation to get the exact percentile val lower = position.floor.toLong val higher = position.ceil.toLong @@ -219,6 +220,7 @@ abstract class PercentileBase } if (discrete) { + // We end up here only if spark.sql.legacy.percentileDiscCalculation=true toDoubleValue(lowerKey) } else { // Linear interpolation to get the exact percentile @@ -388,7 +390,9 @@ case class PercentileDisc( percentageExpression: Expression, reverse: Boolean = false, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends PercentileBase with BinaryLike[Expression] { + inputAggBufferOffset: Int = 0, + legacyCalculation: Boolean = SQLConf.get.getConf(SQLConf.LEGACY_PERCENTILE_DISC_CALCULATION)) + extends PercentileBase with BinaryLike[Expression] { val frequencyExpression: Expression = Literal(1L) @@ -416,4 +420,25 @@ case class PercentileDisc( child = newLeft, percentageExpression = newRight ) + + override protected def getPercentile( + accumulatedCounts: Seq[(AnyRef, Long)], + percentile: Double): Double = { + if (legacyCalculation) { + super.getPercentile(accumulatedCounts, percentile) + } else { + // `percentile_disc(p)` returns the value with the smallest `cume_dist()` value given that is + // greater than or equal to `p` so `position` here is `p` adjusted by max position. + val position = accumulatedCounts.last._2 * percentile + + val higher = position.ceil.toLong + + // Use binary search to find the higher position. + val countsArray = accumulatedCounts.map(_._2).toArray[Long] + val higherIndex = binarySearchCount(countsArray, 0, accumulatedCounts.size, higher) + val higherKey = accumulatedCounts(higherIndex)._1 + + toDoubleValue(higherKey) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3b18cfee2a0..45d04c9720e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4175,6 +4175,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val LEGACY_PERCENTILE_DISC_CALCULATION = buildConf("spark.sql.legacy.percentileDiscCalculation") + .internal() + .doc("If true, the old bogus percentile_disc calculation is used. The old calculation " + + "incorrectly mapped the requested percentile to the sorted range of values in some cases " + + "and so returned incorrect results. Also, the new implementation is faster as it doesn't " + + "contain the interpolation logic that the old percentile_cont based one did.") + .version("3.3.4") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * diff --git a/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql b/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql index c55c300b5e8..87c5d4be90c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql @@ -299,4 +299,79 @@ SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY dt2) FROM intervals GROUP BY k -ORDER BY k; \ No newline at end of file +ORDER BY k; + +-- SPARK-44871: Fix percentile_disc behaviour +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0) AS v(a); + +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1) AS v(a); + +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1), (2) AS v(a); + +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1), (2), (3), (4) AS v(a); + +SET spark.sql.legacy.percentileDiscCalculation = true; + +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1), (2), (3), (4) AS v(a); + +SET spark.sql.legacy.percentileDiscCalculation = false; diff --git a/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out b/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out index cd99ded56bf..54d2164621f 100644 --- a/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out @@ -730,3 +730,119 @@ struct<k:int,median(dt2):interval day to second,percentile(dt2, 0.5, 1):interval 2 0 00:22:30.000000000 0 00:22:30.000000000 0 00:22:30.000000000 3 0 01:00:00.000000000 0 01:00:00.000000000 0 01:00:00.000000000 4 NULL NULL NULL + + +-- !query +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0) AS v(a) +-- !query schema +struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double> +-- !query output +0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + + +-- !query +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1) AS v(a) +-- !query schema +struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double> +-- !query output +0.0 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0 + + +-- !query +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1), (2) AS v(a) +-- !query schema +struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double> +-- !query output +0.0 0.0 0.0 0.0 1.0 1.0 1.0 2.0 2.0 2.0 2.0 + + +-- !query +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1), (2), (3), (4) AS v(a) +-- !query schema +struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double> +-- !query output +0.0 0.0 0.0 1.0 1.0 2.0 2.0 3.0 3.0 4.0 4.0 + + +-- !query +SET spark.sql.legacy.percentileDiscCalculation = true +-- !query schema +struct<key:string,value:string> +-- !query output +spark.sql.legacy.percentileDiscCalculation true + + +-- !query +SELECT + percentile_disc(0.0) WITHIN GROUP (ORDER BY a) as p0, + percentile_disc(0.1) WITHIN GROUP (ORDER BY a) as p1, + percentile_disc(0.2) WITHIN GROUP (ORDER BY a) as p2, + percentile_disc(0.3) WITHIN GROUP (ORDER BY a) as p3, + percentile_disc(0.4) WITHIN GROUP (ORDER BY a) as p4, + percentile_disc(0.5) WITHIN GROUP (ORDER BY a) as p5, + percentile_disc(0.6) WITHIN GROUP (ORDER BY a) as p6, + percentile_disc(0.7) WITHIN GROUP (ORDER BY a) as p7, + percentile_disc(0.8) WITHIN GROUP (ORDER BY a) as p8, + percentile_disc(0.9) WITHIN GROUP (ORDER BY a) as p9, + percentile_disc(1.0) WITHIN GROUP (ORDER BY a) as p10 +FROM VALUES (0), (1), (2), (3), (4) AS v(a) +-- !query schema +struct<p0:double,p1:double,p2:double,p3:double,p4:double,p5:double,p6:double,p7:double,p8:double,p9:double,p10:double> +-- !query output +0.0 0.0 0.0 1.0 1.0 2.0 2.0 2.0 3.0 3.0 4.0 + + +-- !query +SET spark.sql.legacy.percentileDiscCalculation = false +-- !query schema +struct<key:string,value:string> +-- !query output +spark.sql.legacy.percentileDiscCalculation false --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org