This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new e38310c74e6 Revert "[SPARK-43393][SQL] Address sequence expression overflow bug" e38310c74e6 is described below commit e38310c74e6cae8c8c8489ffcbceb80ed37a7cae Author: Dongjoon Hyun <dh...@apple.com> AuthorDate: Wed Nov 15 09:12:42 2023 -0800 Revert "[SPARK-43393][SQL] Address sequence expression overflow bug" This reverts commit 41a7a4a3233772003aef380428acd9eaf39b9a93. --- .../expressions/collectionOperations.scala | 48 ++++++------------- .../expressions/CollectionExpressionsSuite.scala | 56 ++-------------------- 2 files changed, 20 insertions(+), 84 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c3c235fba67..ade4a6c5be7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -22,8 +22,6 @@ import java.util.Comparator import scala.collection.mutable import scala.reflect.ClassTag -import org.apache.spark.QueryContext -import org.apache.spark.SparkException.internalError import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -42,6 +40,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SQLOpenHashSet import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String} /** @@ -3081,34 +3080,6 @@ case class Sequence( } object Sequence { - private def prettyName: String = "sequence" - - def sequenceLength(start: Long, stop: Long, step: Long): Int = { - try { - val delta = Math.subtractExact(stop, start) - if (delta == Long.MinValue && step == -1L) { - // We must special-case division of Long.MinValue by -1 to catch potential unchecked - // overflow in next operation. Division does not have a builtin overflow check. We - // previously special-case div-by-zero. - throw new ArithmeticException("Long overflow (Long.MinValue / -1)") - } - val len = if (stop == start) 1L else Math.addExact(1L, (delta / step)) - if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName, len) - } - len.toInt - } catch { - // We handle overflows in the previous try block by raising an appropriate exception. - case _: ArithmeticException => - val safeLen = - BigInt(1) + (BigInt(stop) - BigInt(start)) / BigInt(step) - if (safeLen > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName, safeLen) - } - throw internalError("Unreachable code reached.") - case e: Exception => throw e - } - } private type LessThanOrEqualFn = (Any, Any) => Boolean @@ -3480,7 +3451,13 @@ object Sequence { || (estimatedStep == num.zero && start == stop), s"Illegal sequence boundaries: $start to $stop by $step") - sequenceLength(start.toLong, stop.toLong, estimatedStep.toLong) + val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong + + require( + len <= MAX_ROUNDED_ARRAY_LENGTH, + s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + + len.toInt } private def genSequenceLengthCode( @@ -3490,7 +3467,7 @@ object Sequence { step: String, estimatedStep: String, len: String): String = { - val calcFn = classOf[Sequence].getName + ".sequenceLength" + val longLen = ctx.freshName("longLen") s""" |if (!(($estimatedStep > 0 && $start <= $stop) || | ($estimatedStep < 0 && $start >= $stop) || @@ -3498,7 +3475,12 @@ object Sequence { | throw new IllegalArgumentException( | "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step); |} - |int $len = $calcFn((long) $start, (long) $stop, (long) $estimatedStep); + |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $estimatedStep; + |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) { + | throw new IllegalArgumentException( + | "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH"); + |} + |int $len = (int) $longLen; """.stripMargin } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d001006c58c..1787f6ac72d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{outstandingZoneIds, import org.apache.spark.sql.catalyst.util.IntervalUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.UTF8String class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -769,6 +769,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // test sequence boundaries checking + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)), + EmptyRow, s"Too long sequence: 4294967296. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + checkExceptionInExpression[IllegalArgumentException]( new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 1 to 2 by 0") checkExceptionInExpression[IllegalArgumentException]( @@ -778,56 +782,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkExceptionInExpression[IllegalArgumentException]( new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, "boundaries: 1 to 2 by -1") - // SPARK-43393: test Sequence overflow checking - checkErrorInExpression[SparkRuntimeException]( - new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)), - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", - parameters = Map( - "numberOfElements" -> (BigInt(Int.MaxValue) - BigInt { Int.MinValue } + 1).toString, - "functionName" -> toSQLId("sequence"), - "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), - "parameter" -> toSQLId("count"))) - checkErrorInExpression[SparkRuntimeException]( - new Sequence(Literal(0L), Literal(Long.MaxValue), Literal(1L)), - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", - parameters = Map( - "numberOfElements" -> (BigInt(Long.MaxValue) + 1).toString, - "functionName" -> toSQLId("sequence"), - "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), - "parameter" -> toSQLId("count"))) - checkErrorInExpression[SparkRuntimeException]( - new Sequence(Literal(0L), Literal(Long.MinValue), Literal(-1L)), - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", - parameters = Map( - "numberOfElements" -> ((0 - BigInt(Long.MinValue)) + 1).toString(), - "functionName" -> toSQLId("sequence"), - "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), - "parameter" -> toSQLId("count"))) - checkErrorInExpression[SparkRuntimeException]( - new Sequence(Literal(Long.MinValue), Literal(Long.MaxValue), Literal(1L)), - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", - parameters = Map( - "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString, - "functionName" -> toSQLId("sequence"), - "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), - "parameter" -> toSQLId("count"))) - checkErrorInExpression[SparkRuntimeException]( - new Sequence(Literal(Long.MaxValue), Literal(Long.MinValue), Literal(-1L)), - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", - parameters = Map( - "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } + 1).toString, - "functionName" -> toSQLId("sequence"), - "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), - "parameter" -> toSQLId("count"))) - checkErrorInExpression[SparkRuntimeException]( - new Sequence(Literal(Long.MaxValue), Literal(-1L), Literal(-1L)), - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", - parameters = Map( - "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { -1L } + 1).toString, - "functionName" -> toSQLId("sequence"), - "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), - "parameter" -> toSQLId("count"))) - // test sequence with one element (zero step or equal start and stop) checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org