Repository: spark
Updated Branches:
  refs/heads/branch-2.0 bd609b0b7 -> 2604eadcf


[SPARK-15390] fix broadcast with 100 millions rows

## What changes were proposed in this pull request?

When broadcast a table with more than 100 millions rows (should not ideally), 
the size of needed memory will overflow.

This PR fix the overflow by converting it to Long when calculating the size of 
memory.

Also add more checking in broadcast to show reasonable messages.

## How was this patch tested?

Add test.

Author: Davies Liu <dav...@databricks.com>

Closes #13182 from davies/fix_broadcast.

(cherry picked from commit 9308bf119204015c8733fab0c2aef70ff2e41d74)
Signed-off-by: Davies Liu <davies....@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2604eadc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2604eadc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2604eadc

Branch: refs/heads/branch-2.0
Commit: 2604eadcfad20bebe6bd73fa8da36cd631e92e55
Parents: bd609b0
Author: Davies Liu <dav...@databricks.com>
Authored: Thu May 19 11:45:18 2016 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Thu May 19 11:45:29 2016 -0700

----------------------------------------------------------------------
 .../execution/exchange/BroadcastExchangeExec.scala   | 13 +++++++++++--
 .../spark/sql/execution/joins/HashedRelation.scala   |  5 +++--
 .../sql/execution/joins/HashedRelationSuite.scala    | 15 +++++++++++++++
 3 files changed, 29 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2604eadc/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index b6ecd3c..d3081ba 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.exchange
 import scala.concurrent.{ExecutionContext, Future}
 import scala.concurrent.duration._
 
-import org.apache.spark.broadcast
+import org.apache.spark.{broadcast, SparkException}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
@@ -72,9 +72,18 @@ case class BroadcastExchangeExec(
         val beforeCollect = System.nanoTime()
         // Note that we use .executeCollect() because we don't want to convert 
data to Scala types
         val input: Array[InternalRow] = child.executeCollect()
+        if (input.length >= 512000000) {
+          throw new SparkException(
+            s"Cannot broadcast the table with more than 512 millions rows: 
${input.length} rows")
+        }
         val beforeBuild = System.nanoTime()
         longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
-        longMetric("dataSize") += 
input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
+        val dataSize = 
input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
+        longMetric("dataSize") += dataSize
+        if (dataSize >= (8L << 30)) {
+          throw new SparkException(
+            s"Cannot broadcast the table that is larger than 8GB: ${dataSize 
>> 30} GB")
+        }
 
         // Construct and broadcast the relation.
         val relation = mode.transform(input)

http://git-wip-us.apache.org/repos/asf/spark/blob/2604eadc/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index cb41457..cd6b97a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -410,9 +410,10 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
 
   private def init(): Unit = {
     if (mm != null) {
+      require(capacity < 512000000, "Cannot broadcast more than 512 millions 
rows")
       var n = 1
       while (n < capacity) n *= 2
-      ensureAcquireMemory(n * 2 * 8 + (1 << 20))
+      ensureAcquireMemory(n * 2L * 8 + (1 << 20))
       array = new Array[Long](n * 2)
       mask = n * 2 - 2
       page = new Array[Long](1 << 17)  // 1M bytes
@@ -788,7 +789,7 @@ private[joins] object LongHashedRelation {
       sizeEstimate: Int,
       taskMemoryManager: TaskMemoryManager): LongHashedRelation = {
 
-    val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, 
sizeEstimate)
+    val map = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate)
     val keyGenerator = UnsafeProjection.create(key)
 
     // Create a mapping of key -> rows

http://git-wip-us.apache.org/repos/asf/spark/blob/2604eadc/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index b7b08dc..a5b5654 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -212,4 +212,19 @@ class HashedRelationSuite extends SparkFunSuite with 
SharedSQLContext {
     assert(longRelation.estimatedSize > (2L << 30))
     longRelation.close()
   }
+
+  test("build HashedRelation with more than 100 millions rows") {
+    val unsafeProj = UnsafeProjection.create(
+      Seq(BoundReference(0, IntegerType, false),
+        BoundReference(1, StringType, true)))
+    val unsafeRow = unsafeProj(InternalRow(0, UTF8String.fromString(" " * 
100)))
+    val key = Seq(BoundReference(0, IntegerType, false))
+    val rows = (0 until (1 << 10)).iterator.map { i =>
+      unsafeRow.setInt(0, i % 1000000)
+      unsafeRow.setInt(1, i)
+      unsafeRow
+    }
+    val m = LongHashedRelation(rows, key, 100 << 20, mm)
+    m.close()
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to