Repository: spark
Updated Branches:
  refs/heads/branch-2.0 11854e5a1 -> 182991edd


[SPARK-16802] [SQL] fix overflow in LongToUnsafeRowMap

## What changes were proposed in this pull request?

This patch fix the overflow in LongToUnsafeRowMap when the range of key is very 
wide (the key is much much smaller then minKey, for example, key is 
Long.MinValue, minKey is > 0).

## How was this patch tested?

Added regression test (also for SPARK-16740)

Author: Davies Liu <dav...@databricks.com>

Closes #14464 from davies/fix_overflow.

(cherry picked from commit 9d4e6212fa8d434089d32bff1217f39919abe44d)
Signed-off-by: Davies Liu <davies....@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/182991ed
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/182991ed
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/182991ed

Branch: refs/heads/branch-2.0
Commit: 182991eddebfaeb60fc44ecb8fc457ea6dd6f56a
Parents: 11854e5
Author: Davies Liu <dav...@databricks.com>
Authored: Thu Aug 4 11:20:17 2016 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Thu Aug 4 11:20:29 2016 -0700

----------------------------------------------------------------------
 .../sql/execution/joins/HashedRelation.scala    | 16 ++++---
 .../execution/joins/HashedRelationSuite.scala   | 45 ++++++++++++++++++++
 2 files changed, 55 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/182991ed/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index cf4454c..0897573 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -459,9 +459,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
    */
   def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
     if (isDense) {
-      val idx = (key - minKey).toInt
-      if (idx >= 0 && key <= maxKey && array(idx) > 0) {
-        return getRow(array(idx), resultRow)
+      if (key >= minKey && key <= maxKey) {
+        val value = array((key - minKey).toInt)
+        if (value > 0) {
+          return getRow(value, resultRow)
+        }
       }
     } else {
       var pos = firstSlot(key)
@@ -497,9 +499,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
    */
   def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
     if (isDense) {
-      val idx = (key - minKey).toInt
-      if (idx >=0 && key <= maxKey && array(idx) > 0) {
-        return valueIter(array(idx), resultRow)
+      if (key >= minKey && key <= maxKey) {
+        val value = array((key - minKey).toInt)
+        if (value > 0) {
+          return valueIter(value, resultRow)
+        }
       }
     } else {
       var pos = firstSlot(key)

http://git-wip-us.apache.org/repos/asf/spark/blob/182991ed/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 40864c8..1196f5e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -152,6 +152,51 @@ class HashedRelationSuite extends SparkFunSuite with 
SharedSQLContext {
     }
   }
 
+  test("LongToUnsafeRowMap with very wide range") {
+    val taskMemoryManager = new TaskMemoryManager(
+      new StaticMemoryManager(
+        new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+        Long.MaxValue,
+        Long.MaxValue,
+        1),
+      0)
+    val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, 
false)))
+
+    {
+      // SPARK-16740
+      val keys = Seq(0L, Long.MaxValue, Long.MaxValue)
+      val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+      keys.foreach { k =>
+        map.append(k, unsafeProj(InternalRow(k)))
+      }
+      map.optimize()
+      val row = unsafeProj(InternalRow(0L)).copy()
+      keys.foreach { k =>
+        assert(map.getValue(k, row) eq row)
+        assert(row.getLong(0) === k)
+      }
+      map.free()
+    }
+
+
+    {
+      // SPARK-16802
+      val keys = Seq(Long.MaxValue, Long.MaxValue - 10)
+      val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+      keys.foreach { k =>
+        map.append(k, unsafeProj(InternalRow(k)))
+      }
+      map.optimize()
+      val row = unsafeProj(InternalRow(0L)).copy()
+      keys.foreach { k =>
+        assert(map.getValue(k, row) eq row)
+        assert(row.getLong(0) === k)
+      }
+      assert(map.getValue(Long.MinValue, row) eq null)
+      map.free()
+    }
+  }
+
   test("Spark-14521") {
     val ser = new KryoSerializer(
       (new SparkConf).set("spark.kryo.referenceTracking", 
"false")).newInstance()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to