Repository: spark Updated Branches: refs/heads/branch-2.0 bd609b0b7 -> 2604eadcf
[SPARK-15390] fix broadcast with 100 millions rows ## What changes were proposed in this pull request? When broadcast a table with more than 100 millions rows (should not ideally), the size of needed memory will overflow. This PR fix the overflow by converting it to Long when calculating the size of memory. Also add more checking in broadcast to show reasonable messages. ## How was this patch tested? Add test. Author: Davies Liu <dav...@databricks.com> Closes #13182 from davies/fix_broadcast. (cherry picked from commit 9308bf119204015c8733fab0c2aef70ff2e41d74) Signed-off-by: Davies Liu <davies....@gmail.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2604eadc Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2604eadc Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2604eadc Branch: refs/heads/branch-2.0 Commit: 2604eadcfad20bebe6bd73fa8da36cd631e92e55 Parents: bd609b0 Author: Davies Liu <dav...@databricks.com> Authored: Thu May 19 11:45:18 2016 -0700 Committer: Davies Liu <davies....@gmail.com> Committed: Thu May 19 11:45:29 2016 -0700 ---------------------------------------------------------------------- .../execution/exchange/BroadcastExchangeExec.scala | 13 +++++++++++-- .../spark/sql/execution/joins/HashedRelation.scala | 5 +++-- .../sql/execution/joins/HashedRelationSuite.scala | 15 +++++++++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2604eadc/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index b6ecd3c..d3081ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.exchange import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ -import org.apache.spark.broadcast +import org.apache.spark.{broadcast, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -72,9 +72,18 @@ case class BroadcastExchangeExec( val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val input: Array[InternalRow] = child.executeCollect() + if (input.length >= 512000000) { + throw new SparkException( + s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows") + } val beforeBuild = System.nanoTime() longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 - longMetric("dataSize") += input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + longMetric("dataSize") += dataSize + if (dataSize >= (8L << 30)) { + throw new SparkException( + s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") + } // Construct and broadcast the relation. val relation = mode.transform(input) http://git-wip-us.apache.org/repos/asf/spark/blob/2604eadc/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cb41457..cd6b97a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -410,9 +410,10 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap private def init(): Unit = { if (mm != null) { + require(capacity < 512000000, "Cannot broadcast more than 512 millions rows") var n = 1 while (n < capacity) n *= 2 - ensureAcquireMemory(n * 2 * 8 + (1 << 20)) + ensureAcquireMemory(n * 2L * 8 + (1 << 20)) array = new Array[Long](n * 2) mask = n * 2 - 2 page = new Array[Long](1 << 17) // 1M bytes @@ -788,7 +789,7 @@ private[joins] object LongHashedRelation { sizeEstimate: Int, taskMemoryManager: TaskMemoryManager): LongHashedRelation = { - val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) + val map = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) val keyGenerator = UnsafeProjection.create(key) // Create a mapping of key -> rows http://git-wip-us.apache.org/repos/asf/spark/blob/2604eadc/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index b7b08dc..a5b5654 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -212,4 +212,19 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(longRelation.estimatedSize > (2L << 30)) longRelation.close() } + + test("build HashedRelation with more than 100 millions rows") { + val unsafeProj = UnsafeProjection.create( + Seq(BoundReference(0, IntegerType, false), + BoundReference(1, StringType, true))) + val unsafeRow = unsafeProj(InternalRow(0, UTF8String.fromString(" " * 100))) + val key = Seq(BoundReference(0, IntegerType, false)) + val rows = (0 until (1 << 10)).iterator.map { i => + unsafeRow.setInt(0, i % 1000000) + unsafeRow.setInt(1, i) + unsafeRow + } + val m = LongHashedRelation(rows, key, 100 << 20, mm) + m.close() + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org