This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 41a7a4a32337 [SPARK-43393][SQL] Address sequence expression overflow 
bug
41a7a4a32337 is described below

commit 41a7a4a3233772003aef380428acd9eaf39b9a93
Author: Deepayan Patra <deepayan.pa...@databricks.com>
AuthorDate: Wed Nov 15 14:27:34 2023 +0800

    [SPARK-43393][SQL] Address sequence expression overflow bug
    
    Spark has a (long-standing) overflow bug in the `sequence` expression.
    
    Consider the following operations:
    ```
    spark.sql("CREATE TABLE foo (l LONG);")
    spark.sql(s"INSERT INTO foo VALUES (${Long.MaxValue});")
    spark.sql("SELECT sequence(0, l) FROM foo;").collect()
    ```
    
    The result of these operations will be:
    ```
    Array[org.apache.spark.sql.Row] = Array([WrappedArray()])
    ```
    an unintended consequence of overflow.
    
    The sequence is applied to values `0` and `Long.MaxValue` with a step size 
of `1` which uses a length computation defined 
[here](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L3451).
 In this calculation, with `start = 0`, `stop = Long.MaxValue`, and `step = 1`, 
the calculated `len` overflows to `Long.MinValue`. The computation, in binary 
looks like:
    
    ```
      0111111111111111111111111111111111111111111111111111111111111111
    - 0000000000000000000000000000000000000000000000000000000000000000
    ------------------------------------------------------------------
      0111111111111111111111111111111111111111111111111111111111111111
    / 0000000000000000000000000000000000000000000000000000000000000001
    ------------------------------------------------------------------
      0111111111111111111111111111111111111111111111111111111111111111
    + 0000000000000000000000000000000000000000000000000000000000000001
    ------------------------------------------------------------------
      1000000000000000000000000000000000000000000000000000000000000000
    ```
    
    The following 
[check](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L3454)
 passes as the negative `Long.MinValue` is still `<= MAX_ROUNDED_ARRAY_LENGTH`. 
The following cast to `toInt` uses this representation and [truncates the upper 
bits](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spa
 [...]
    
    Other overflows are similarly problematic.
    
    This PR addresses the issue by checking numeric operations in the length 
computation for overflow.
    
    There is a correctness bug from overflow in the `sequence` expression.
    
    No.
    
    Tests added in `CollectionExpressionsSuite.scala`.
    
    Closes #41072 from thepinetree/spark-sequence-overflow.
    
    Authored-by: Deepayan Patra <deepayan.pa...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit afc4c49927cb7f0f2a7f24a42c4fe497796dd9e3)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../expressions/collectionOperations.scala         | 48 +++++++++++++------
 .../expressions/CollectionExpressionsSuite.scala   | 56 ++++++++++++++++++++--
 2 files changed, 84 insertions(+), 20 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index ade4a6c5be72..c3c235fba677 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -22,6 +22,8 @@ import java.util.Comparator
 import scala.collection.mutable
 import scala.reflect.ClassTag
 
+import org.apache.spark.QueryContext
+import org.apache.spark.SparkException.internalError
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, 
UnresolvedAttribute, UnresolvedSeed}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -40,7 +42,6 @@ import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SQLOpenHashSet
 import org.apache.spark.unsafe.UTF8StringBuilder
 import org.apache.spark.unsafe.array.ByteArrayMethods
-import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
 import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}
 
 /**
@@ -3080,6 +3081,34 @@ case class Sequence(
 }
 
 object Sequence {
+  private def prettyName: String = "sequence"
+
+  def sequenceLength(start: Long, stop: Long, step: Long): Int = {
+    try {
+      val delta = Math.subtractExact(stop, start)
+      if (delta == Long.MinValue && step == -1L) {
+        // We must special-case division of Long.MinValue by -1 to catch 
potential unchecked
+        // overflow in next operation. Division does not have a builtin 
overflow check. We
+        // previously special-case div-by-zero.
+        throw new ArithmeticException("Long overflow (Long.MinValue / -1)")
+      }
+      val len = if (stop == start) 1L else Math.addExact(1L, (delta / step))
+      if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+        throw 
QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName, len)
+      }
+      len.toInt
+    } catch {
+      // We handle overflows in the previous try block by raising an 
appropriate exception.
+      case _: ArithmeticException =>
+        val safeLen =
+          BigInt(1) + (BigInt(stop) - BigInt(start)) / BigInt(step)
+        if (safeLen > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+          throw 
QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName, 
safeLen)
+        }
+        throw internalError("Unreachable code reached.")
+      case e: Exception => throw e
+    }
+  }
 
   private type LessThanOrEqualFn = (Any, Any) => Boolean
 
@@ -3451,13 +3480,7 @@ object Sequence {
         || (estimatedStep == num.zero && start == stop),
       s"Illegal sequence boundaries: $start to $stop by $step")
 
-    val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / 
estimatedStep.toLong
-
-    require(
-      len <= MAX_ROUNDED_ARRAY_LENGTH,
-      s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
-
-    len.toInt
+    sequenceLength(start.toLong, stop.toLong, estimatedStep.toLong)
   }
 
   private def genSequenceLengthCode(
@@ -3467,7 +3490,7 @@ object Sequence {
       step: String,
       estimatedStep: String,
       len: String): String = {
-    val longLen = ctx.freshName("longLen")
+    val calcFn = classOf[Sequence].getName + ".sequenceLength"
     s"""
        |if (!(($estimatedStep > 0 && $start <= $stop) ||
        |  ($estimatedStep < 0 && $start >= $stop) ||
@@ -3475,12 +3498,7 @@ object Sequence {
        |  throw new IllegalArgumentException(
        |    "Illegal sequence boundaries: " + $start + " to " + $stop + " by " 
+ $step);
        |}
-       |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / 
$estimatedStep;
-       |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
-       |  throw new IllegalArgumentException(
-       |    "Too long sequence: " + $longLen + ". Should be <= 
$MAX_ROUNDED_ARRAY_LENGTH");
-       |}
-       |int $len = (int) $longLen;
+       |int $len = $calcFn((long) $start, (long) $stop, (long) $estimatedStep);
        """.stripMargin
   }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 1787f6ac72dd..d001006c58cf 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -34,7 +34,7 @@ import 
org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{outstandingZoneIds,
 import org.apache.spark.sql.catalyst.util.IntervalUtils._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.types.UTF8String
 
 class CollectionExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
@@ -769,10 +769,6 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
 
     // test sequence boundaries checking
 
-    checkExceptionInExpression[IllegalArgumentException](
-      new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
-      EmptyRow, s"Too long sequence: 4294967296. Should be <= 
$MAX_ROUNDED_ARRAY_LENGTH")
-
     checkExceptionInExpression[IllegalArgumentException](
       new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 
1 to 2 by 0")
     checkExceptionInExpression[IllegalArgumentException](
@@ -782,6 +778,56 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     checkExceptionInExpression[IllegalArgumentException](
       new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, 
"boundaries: 1 to 2 by -1")
 
+    // SPARK-43393: test Sequence overflow checking
+    checkErrorInExpression[SparkRuntimeException](
+      new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
+      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
+      parameters = Map(
+        "numberOfElements" -> (BigInt(Int.MaxValue) - BigInt { Int.MinValue } 
+ 1).toString,
+        "functionName" -> toSQLId("sequence"),
+        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+        "parameter" -> toSQLId("count")))
+    checkErrorInExpression[SparkRuntimeException](
+      new Sequence(Literal(0L), Literal(Long.MaxValue), Literal(1L)),
+      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
+      parameters = Map(
+        "numberOfElements" -> (BigInt(Long.MaxValue) + 1).toString,
+        "functionName" -> toSQLId("sequence"),
+        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+        "parameter" -> toSQLId("count")))
+    checkErrorInExpression[SparkRuntimeException](
+      new Sequence(Literal(0L), Literal(Long.MinValue), Literal(-1L)),
+      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
+      parameters = Map(
+        "numberOfElements" -> ((0 - BigInt(Long.MinValue)) + 1).toString(),
+        "functionName" -> toSQLId("sequence"),
+        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+        "parameter" -> toSQLId("count")))
+    checkErrorInExpression[SparkRuntimeException](
+      new Sequence(Literal(Long.MinValue), Literal(Long.MaxValue), 
Literal(1L)),
+      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
+      parameters = Map(
+        "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue 
} + 1).toString,
+        "functionName" -> toSQLId("sequence"),
+        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+        "parameter" -> toSQLId("count")))
+    checkErrorInExpression[SparkRuntimeException](
+      new Sequence(Literal(Long.MaxValue), Literal(Long.MinValue), 
Literal(-1L)),
+      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
+      parameters = Map(
+        "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue 
} + 1).toString,
+        "functionName" -> toSQLId("sequence"),
+        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+        "parameter" -> toSQLId("count")))
+    checkErrorInExpression[SparkRuntimeException](
+      new Sequence(Literal(Long.MaxValue), Literal(-1L), Literal(-1L)),
+      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
+      parameters = Map(
+        "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { -1L } + 
1).toString,
+        "functionName" -> toSQLId("sequence"),
+        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+        "parameter" -> toSQLId("count")))
+
     // test sequence with one element (zero step or equal start and stop)
 
     checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to