This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new a697725d99a [SPARK-46274][SQL] Fix Range operator computeStats() to check long validity before converting a697725d99a is described below commit a697725d99a0177a2b1fbb0607e859ac10af1c4e Author: Nick Young <nick.yo...@databricks.com> AuthorDate: Wed Dec 6 15:20:19 2023 -0800 [SPARK-46274][SQL] Fix Range operator computeStats() to check long validity before converting ### What changes were proposed in this pull request? Range operator's `computeStats()` function unsafely casts from `BigInt` to `Long` and causes issues downstream with statistics estimation. Adds bounds checking to avoid crashing. ### Why are the changes needed? Downstream statistics estimation will crash and fail loudly; to avoid this and help maintain clean code we should fix this. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? UT ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44191 from n-young-db/range-compute-stats. Authored-by: Nick Young <nick.yo...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 9fd575ae46f8a4dbd7da18887a44c693d8788332) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../catalyst/plans/logical/basicLogicalOperators.scala | 12 +++++++----- .../statsEstimation/BasicStatsEstimationSuite.scala | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index b4d7716a566..58c03ee72d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1063,10 +1063,12 @@ case class Range( if (numElements == 0) { Statistics(sizeInBytes = 0, rowCount = Some(0)) } else { - val (minVal, maxVal) = if (step > 0) { - (start, start + (numElements - 1) * step) + val (minVal, maxVal) = if (!numElements.isValidLong) { + (None, None) + } else if (step > 0) { + (Some(start), Some(start + (numElements.toLong - 1) * step)) } else { - (start + (numElements - 1) * step, start) + (Some(start + (numElements.toLong - 1) * step), Some(start)) } val histogram = if (conf.histogramEnabled) { @@ -1077,8 +1079,8 @@ case class Range( val colStat = ColumnStat( distinctCount = Some(numElements), - max = Some(maxVal), - min = Some(minVal), + max = maxVal, + min = minVal, nullCount = Some(0), avgLen = Some(LongType.defaultSize), maxLen = Some(LongType.defaultSize), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 33e521eb65a..d1276615c5f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -176,6 +176,22 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { expectedStatsCboOff = rangeStats, extraConfig) } +test("range with invalid long value") { + val numElements = BigInt(Long.MaxValue) - BigInt(Long.MinValue) + val range = Range(Long.MinValue, Long.MaxValue, 1, None) + val rangeAttrs = AttributeMap(range.output.map(attr => + (attr, ColumnStat( + distinctCount = Some(numElements), + nullCount = Some(0), + maxLen = Some(LongType.defaultSize), + avgLen = Some(LongType.defaultSize))))) + val rangeStats = Statistics( + sizeInBytes = numElements * 8, + rowCount = Some(numElements), + attributeStats = rangeAttrs) + checkStats(range, rangeStats, rangeStats) +} + test("windows") { val windows = plan.window(Seq(min(attribute).as("sum_attr")), Seq(attribute), Nil) val windowsStats = Statistics(sizeInBytes = plan.size.get * (4 + 4 + 8) / (4 + 8)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org