This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new e38310c74e6 Revert "[SPARK-43393][SQL] Address sequence expression 
overflow bug"
e38310c74e6 is described below

commit e38310c74e6cae8c8c8489ffcbceb80ed37a7cae
Author: Dongjoon Hyun <dh...@apple.com>
AuthorDate: Wed Nov 15 09:12:42 2023 -0800

    Revert "[SPARK-43393][SQL] Address sequence expression overflow bug"
    
    This reverts commit 41a7a4a3233772003aef380428acd9eaf39b9a93.
---
 .../expressions/collectionOperations.scala         | 48 ++++++-------------
 .../expressions/CollectionExpressionsSuite.scala   | 56 ++--------------------
 2 files changed, 20 insertions(+), 84 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index c3c235fba67..ade4a6c5be7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -22,8 +22,6 @@ import java.util.Comparator
 import scala.collection.mutable
 import scala.reflect.ClassTag
 
-import org.apache.spark.QueryContext
-import org.apache.spark.SparkException.internalError
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, 
UnresolvedAttribute, UnresolvedSeed}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -42,6 +40,7 @@ import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SQLOpenHashSet
 import org.apache.spark.unsafe.UTF8StringBuilder
 import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
 import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}
 
 /**
@@ -3081,34 +3080,6 @@ case class Sequence(
 }
 
 object Sequence {
-  private def prettyName: String = "sequence"
-
-  def sequenceLength(start: Long, stop: Long, step: Long): Int = {
-    try {
-      val delta = Math.subtractExact(stop, start)
-      if (delta == Long.MinValue && step == -1L) {
-        // We must special-case division of Long.MinValue by -1 to catch 
potential unchecked
-        // overflow in next operation. Division does not have a builtin 
overflow check. We
-        // previously special-case div-by-zero.
-        throw new ArithmeticException("Long overflow (Long.MinValue / -1)")
-      }
-      val len = if (stop == start) 1L else Math.addExact(1L, (delta / step))
-      if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-        throw 
QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName, len)
-      }
-      len.toInt
-    } catch {
-      // We handle overflows in the previous try block by raising an 
appropriate exception.
-      case _: ArithmeticException =>
-        val safeLen =
-          BigInt(1) + (BigInt(stop) - BigInt(start)) / BigInt(step)
-        if (safeLen > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-          throw 
QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName, 
safeLen)
-        }
-        throw internalError("Unreachable code reached.")
-      case e: Exception => throw e
-    }
-  }
 
   private type LessThanOrEqualFn = (Any, Any) => Boolean
 
@@ -3480,7 +3451,13 @@ object Sequence {
         || (estimatedStep == num.zero && start == stop),
       s"Illegal sequence boundaries: $start to $stop by $step")
 
-    sequenceLength(start.toLong, stop.toLong, estimatedStep.toLong)
+    val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / 
estimatedStep.toLong
+
+    require(
+      len <= MAX_ROUNDED_ARRAY_LENGTH,
+      s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
+
+    len.toInt
   }
 
   private def genSequenceLengthCode(
@@ -3490,7 +3467,7 @@ object Sequence {
       step: String,
       estimatedStep: String,
       len: String): String = {
-    val calcFn = classOf[Sequence].getName + ".sequenceLength"
+    val longLen = ctx.freshName("longLen")
     s"""
        |if (!(($estimatedStep > 0 && $start <= $stop) ||
        |  ($estimatedStep < 0 && $start >= $stop) ||
@@ -3498,7 +3475,12 @@ object Sequence {
        |  throw new IllegalArgumentException(
        |    "Illegal sequence boundaries: " + $start + " to " + $stop + " by " 
+ $step);
        |}
-       |int $len = $calcFn((long) $start, (long) $stop, (long) $estimatedStep);
+       |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / 
$estimatedStep;
+       |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
+       |  throw new IllegalArgumentException(
+       |    "Too long sequence: " + $longLen + ". Should be <= 
$MAX_ROUNDED_ARRAY_LENGTH");
+       |}
+       |int $len = (int) $longLen;
        """.stripMargin
   }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index d001006c58c..1787f6ac72d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -34,7 +34,7 @@ import 
org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{outstandingZoneIds,
 import org.apache.spark.sql.catalyst.util.IntervalUtils._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
 import org.apache.spark.unsafe.types.UTF8String
 
 class CollectionExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
@@ -769,6 +769,10 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
 
     // test sequence boundaries checking
 
+    checkExceptionInExpression[IllegalArgumentException](
+      new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
+      EmptyRow, s"Too long sequence: 4294967296. Should be <= 
$MAX_ROUNDED_ARRAY_LENGTH")
+
     checkExceptionInExpression[IllegalArgumentException](
       new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 
1 to 2 by 0")
     checkExceptionInExpression[IllegalArgumentException](
@@ -778,56 +782,6 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     checkExceptionInExpression[IllegalArgumentException](
       new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, 
"boundaries: 1 to 2 by -1")
 
-    // SPARK-43393: test Sequence overflow checking
-    checkErrorInExpression[SparkRuntimeException](
-      new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
-      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
-      parameters = Map(
-        "numberOfElements" -> (BigInt(Int.MaxValue) - BigInt { Int.MinValue } 
+ 1).toString,
-        "functionName" -> toSQLId("sequence"),
-        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
-        "parameter" -> toSQLId("count")))
-    checkErrorInExpression[SparkRuntimeException](
-      new Sequence(Literal(0L), Literal(Long.MaxValue), Literal(1L)),
-      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
-      parameters = Map(
-        "numberOfElements" -> (BigInt(Long.MaxValue) + 1).toString,
-        "functionName" -> toSQLId("sequence"),
-        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
-        "parameter" -> toSQLId("count")))
-    checkErrorInExpression[SparkRuntimeException](
-      new Sequence(Literal(0L), Literal(Long.MinValue), Literal(-1L)),
-      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
-      parameters = Map(
-        "numberOfElements" -> ((0 - BigInt(Long.MinValue)) + 1).toString(),
-        "functionName" -> toSQLId("sequence"),
-        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
-        "parameter" -> toSQLId("count")))
-    checkErrorInExpression[SparkRuntimeException](
-      new Sequence(Literal(Long.MinValue), Literal(Long.MaxValue), 
Literal(1L)),
-      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
-      parameters = Map(
-        "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue 
} + 1).toString,
-        "functionName" -> toSQLId("sequence"),
-        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
-        "parameter" -> toSQLId("count")))
-    checkErrorInExpression[SparkRuntimeException](
-      new Sequence(Literal(Long.MaxValue), Literal(Long.MinValue), 
Literal(-1L)),
-      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
-      parameters = Map(
-        "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue 
} + 1).toString,
-        "functionName" -> toSQLId("sequence"),
-        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
-        "parameter" -> toSQLId("count")))
-    checkErrorInExpression[SparkRuntimeException](
-      new Sequence(Literal(Long.MaxValue), Literal(-1L), Literal(-1L)),
-      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
-      parameters = Map(
-        "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { -1L } + 
1).toString,
-        "functionName" -> toSQLId("sequence"),
-        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
-        "parameter" -> toSQLId("count")))
-
     // test sequence with one element (zero step or equal start and stop)
 
     checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to