Repository: spark
Updated Branches:
  refs/heads/branch-1.5 e7329ab31 -> 29756ff11


http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 9201d1e..450ab7b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -57,8 +57,9 @@ class TaskContextSuite extends SparkFunSuite with 
BeforeAndAfter with LocalSpark
     }
     val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
     val func = (c: TaskContext, i: Iterator[String]) => i.next()
-    val task = new ResultTask[String, String](0, 0,
-      sc.broadcast(closureSerializer.serialize((rdd, func)).array), 
rdd.partitions(0), Seq(), 0)
+    val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, 
func)).array)
+    val task = new ResultTask[String, String](
+      0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty)
     intercept[RuntimeException] {
       task.run(0, 0, null)
     }
@@ -66,7 +67,7 @@ class TaskContextSuite extends SparkFunSuite with 
BeforeAndAfter with LocalSpark
   }
 
   test("all TaskCompletionListeners should be called even if some fail") {
-    val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+    val context = TaskContext.empty()
     val listener = mock(classOf[TaskCompletionListener])
     context.addTaskCompletionListener(_ => throw new Exception("blah"))
     context.addTaskCompletionListener(listener)

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 3abb99c..f7cc4bb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: 
(String, String)* /* ex
 /**
  * A Task implementation that results in a large serialized task.
  */
-class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) {
+class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, 
Seq.empty) {
   val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 
1024)
   val random = new Random(0)
   random.nextBytes(randomBuffer)

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
 
b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
index db718ec..05b3afe 100644
--- 
a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
@@ -138,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with 
LocalSparkContext {
       shuffleHandle,
       reduceId,
       reduceId + 1,
-      new TaskContextImpl(0, 0, 0, 0, null, null),
+      TaskContext.empty(),
       blockManager,
       mapOutputTracker)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
 
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index cf8bd8a..828153b 100644
--- 
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -29,7 +29,7 @@ import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
 import org.scalatest.PrivateMethodTester
 
-import org.apache.spark.{SparkFunSuite, TaskContextImpl}
+import org.apache.spark.{SparkFunSuite, TaskContext}
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.shuffle.BlockFetchingListener
@@ -95,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite 
with PrivateMethodT
     )
 
     val iterator = new ShuffleBlockFetcherIterator(
-      new TaskContextImpl(0, 0, 0, 0, null, null),
+      TaskContext.empty(),
       transfer,
       blockManager,
       blocksByAddress,
@@ -165,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
       (remoteBmId, blocks.keys.map(blockId => (blockId, 
1.asInstanceOf[Long])).toSeq))
 
-    val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null)
+    val taskContext = TaskContext.empty()
     val iterator = new ShuffleBlockFetcherIterator(
       taskContext,
       transfer,
@@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
       (remoteBmId, blocks.keys.map(blockId => (blockId, 
1.asInstanceOf[Long])).toSeq))
 
-    val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null)
+    val taskContext = TaskContext.empty()
     val iterator = new ShuffleBlockFetcherIterator(
       taskContext,
       transfer,

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala 
b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
new file mode 100644
index 0000000..98f9314
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}
+
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite, Success}
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler._
+import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
+import org.apache.spark.ui.scope.RDDOperationGraphListener
+
+class StagePageSuite extends SparkFunSuite with LocalSparkContext {
+
+  test("peak execution memory only displayed if unsafe is enabled") {
+    val unsafeConf = "spark.sql.unsafe.enabled"
+    val conf = new SparkConf().set(unsafeConf, "true")
+    val html = renderStagePage(conf).toString().toLowerCase
+    val targetString = "peak execution memory"
+    assert(html.contains(targetString))
+    // Disable unsafe and make sure it's not there
+    val conf2 = new SparkConf().set(unsafeConf, "false")
+    val html2 = renderStagePage(conf2).toString().toLowerCase
+    assert(!html2.contains(targetString))
+  }
+
+  /**
+   * Render a stage page started with the given conf and return the HTML.
+   * This also runs a dummy stage to populate the page with useful content.
+   */
+  private def renderStagePage(conf: SparkConf): Seq[Node] = {
+    val jobListener = new JobProgressListener(conf)
+    val graphListener = new RDDOperationGraphListener(conf)
+    val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
+    val request = mock(classOf[HttpServletRequest])
+    when(tab.conf).thenReturn(conf)
+    when(tab.progressListener).thenReturn(jobListener)
+    when(tab.operationGraphListener).thenReturn(graphListener)
+    when(tab.appName).thenReturn("testing")
+    when(tab.headerTabs).thenReturn(Seq.empty)
+    when(request.getParameter("id")).thenReturn("0")
+    when(request.getParameter("attempt")).thenReturn("0")
+    val page = new StagePage(tab)
+
+    // Simulate a stage in job progress listener
+    val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, 
"details")
+    val taskInfo = new TaskInfo(0, 0, 0, 0, "0", "localhost", 
TaskLocality.ANY, false)
+    jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
+    jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
+    taskInfo.markSuccessful()
+    jobListener.onTaskEnd(
+      SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, 
TaskMetrics.empty))
+    jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
+    page.render(request)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
 
b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 9c362f0..12e9baf 100644
--- 
a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -399,4 +399,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite 
with LocalSparkContext {
     sc.stop()
   }
 
+  test("external aggregation updates peak execution memory") {
+    val conf = createSparkConf(loadDefaults = false)
+      .set("spark.shuffle.memoryFraction", "0.001")
+      .set("spark.shuffle.manager", "hash") // make sure we're not also using 
ExternalSorter
+    sc = new SparkContext("local", "test", conf)
+    // No spilling
+    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without 
spilling") {
+      sc.parallelize(1 to 10, 2).map { i => (i, i) }.reduceByKey(_ + _).count()
+    }
+    // With spilling
+    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with 
spilling") {
+      sc.parallelize(1 to 1000 * 1000, 2).map { i => (i, i) }.reduceByKey(_ + 
_).count()
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
 
b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 986cd86..bdb0f4d 100644
--- 
a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -692,7 +692,7 @@ class ExternalSorterSuite extends SparkFunSuite with 
LocalSparkContext {
     sortWithoutBreakingSortingContracts(createSparkConf(true, false))
   }
 
-  def sortWithoutBreakingSortingContracts(conf: SparkConf) {
+  private def sortWithoutBreakingSortingContracts(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.01")
     conf.set("spark.shuffle.manager", "sort")
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
@@ -743,5 +743,15 @@ class ExternalSorterSuite extends SparkFunSuite with 
LocalSparkContext {
     }
 
     sorter2.stop()
- }
+  }
+
+  test("sorting updates peak execution memory") {
+    val conf = createSparkConf(loadDefaults = false, kryo = false)
+      .set("spark.shuffle.manager", "sort")
+    sc = new SparkContext("local", "test", conf)
+    // Avoid aggregating here to make sure we're not also using 
ExternalAppendOnlyMap
+    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter") {
+      sc.parallelize(1 to 1000, 2).repartition(100).count()
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 5e4c623..193906d 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -106,6 +106,13 @@ final class UnsafeExternalRowSorter {
     sorter.spill();
   }
 
+  /**
+   * Return the peak memory used so far, in bytes.
+   */
+  public long getPeakMemoryUsage() {
+    return sorter.getPeakMemoryUsedBytes();
+  }
+
   private void cleanupResources() {
     sorter.freeMemory();
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 9e2c933..43d06ce 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -209,6 +209,14 @@ public final class UnsafeFixedWidthAggregationMap {
   }
 
   /**
+   * The memory used by this map's managed structures, in bytes.
+   * Note that this is also the peak memory used by this map, since the map is 
append-only.
+   */
+  public long getMemoryUsage() {
+    return map.getTotalMemoryConsumption();
+  }
+
+  /**
    * Free the memory associated with this map. This is idempotent and can be 
called multiple times.
    */
   public void free() {

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index cd87b8d..bf4905d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
 
 import java.io.IOException
 
-import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -263,11 +263,12 @@ case class GeneratedAggregate(
         assert(iter.hasNext, "There should be at least one row for this path")
         log.info("Using Unsafe-based aggregator")
         val pageSizeBytes = 
SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
+        val taskContext = TaskContext.get()
         val aggregationMap = new UnsafeFixedWidthAggregationMap(
           newAggregationBuffer(EmptyRow),
           aggregationBufferSchema,
           groupKeySchema,
-          TaskContext.get.taskMemoryManager(),
+          taskContext.taskMemoryManager(),
           SparkEnv.get.shuffleMemoryManager,
           1024 * 16, // initial capacity
           pageSizeBytes,
@@ -284,6 +285,10 @@ case class GeneratedAggregate(
           
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, 
currentRow))
         }
 
+        // Record memory used in the process
+        taskContext.internalMetricsToAccumulators(
+          
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage)
+
         new Iterator[InternalRow] {
           private[this] val mapIterator = aggregationMap.iterator()
           private[this] val resultProjection = resultProjectionBuilder()
@@ -300,7 +305,7 @@ case class GeneratedAggregate(
               } else {
                 // This is the last element in the iterator, so let's free the 
buffer. Before we do,
                 // though, we need to make a defensive copy of the result so 
that we don't return an
-                // object that might contain dangling pointers to the freed 
memory
+                // object that might contain dangling pointers to the freed 
memory.
                 val resultCopy = result.copy()
                 aggregationMap.free()
                 resultCopy

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 624efc1..e73e252 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
 import scala.concurrent._
 import scala.concurrent.duration._
 
+import org.apache.spark.{InternalAccumulator, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -70,7 +71,14 @@ case class BroadcastHashJoin(
     val broadcastRelation = Await.result(broadcastFuture, timeout)
 
     streamedPlan.execute().mapPartitions { streamedIter =>
-      hashJoin(streamedIter, broadcastRelation.value)
+      val hashedRelation = broadcastRelation.value
+      hashedRelation match {
+        case unsafe: UnsafeHashedRelation =>
+          TaskContext.get().internalMetricsToAccumulators(
+            
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+        case _ =>
+      }
+      hashJoin(streamedIter, hashedRelation)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index 309716a..c35e439 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
 import scala.concurrent._
 import scala.concurrent.duration._
 
+import org.apache.spark.{InternalAccumulator, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -75,6 +76,13 @@ case class BroadcastHashOuterJoin(
       val hashTable = broadcastRelation.value
       val keyGenerator = streamedKeyGenerator
 
+      hashTable match {
+        case unsafe: UnsafeHashedRelation =>
+          TaskContext.get().internalMetricsToAccumulators(
+            
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+        case _ =>
+      }
+
       joinType match {
         case LeftOuter =>
           streamedIter.flatMap(currentRow => {

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index a605939..5bd06fb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.joins
 
+import org.apache.spark.{InternalAccumulator, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -51,7 +52,14 @@ case class BroadcastLeftSemiJoinHash(
       val broadcastedRelation = sparkContext.broadcast(hashRelation)
 
       left.execute().mapPartitions { streamIter =>
-        hashSemiJoin(streamIter, broadcastedRelation.value)
+        val hashedRelation = broadcastedRelation.value
+        hashedRelation match {
+          case unsafe: UnsafeHashedRelation =>
+            TaskContext.get().internalMetricsToAccumulators(
+              
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+          case _ =>
+        }
+        hashSemiJoin(streamIter, hashedRelation)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index cc8bbfd..58b4236 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -183,8 +183,27 @@ private[joins] final class UnsafeHashedRelation(
   private[joins] def this() = this(null)  // Needed for serialization
 
   // Use BytesToBytesMap in executor for better performance (it's created when 
deserialization)
+  // This is used in broadcast joins and distributed mode only
   @transient private[this] var binaryMap: BytesToBytesMap = _
 
+  /**
+   * Return the size of the unsafe map on the executors.
+   *
+   * For broadcast joins, this hashed relation is bigger on the driver because 
it is
+   * represented as a Java hash map there. While serializing the map to the 
executors,
+   * however, we rehash the contents in a binary map to reduce the memory 
footprint on
+   * the executors.
+   *
+   * For non-broadcast joins or in local mode, return 0.
+   */
+  def getUnsafeSize: Long = {
+    if (binaryMap != null) {
+      binaryMap.getTotalMemoryConsumption
+    } else {
+      0
+    }
+  }
+
   override def get(key: InternalRow): Seq[InternalRow] = {
     val unsafeKey = key.asInstanceOf[UnsafeRow]
 
@@ -214,7 +233,7 @@ private[joins] final class UnsafeHashedRelation(
       }
 
     } else {
-      // Use the JavaHashMap in Local mode or ShuffleHashJoin
+      // Use the Java HashMap in local mode or for non-broadcast joins (e.g. 
ShuffleHashJoin)
       hashTable.get(unsafeKey)
     }
   }
@@ -316,6 +335,7 @@ private[joins] object UnsafeHashedRelation {
       keyGenerator: UnsafeProjection,
       sizeEstimate: Int): HashedRelation = {
 
+    // Use a Java hash table here because unsafe maps expect fixed size records
     val hashTable = new JavaHashMap[UnsafeRow, 
CompactBuffer[UnsafeRow]](sizeEstimate)
 
     // Create a mapping of buildKeys -> rows

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index 92cf328..3192b6e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution
 
+import org.apache.spark.{InternalAccumulator, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
@@ -76,6 +77,11 @@ case class ExternalSort(
       val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering 
= Some(ordering))
       sorter.insertAll(iterator.map(r => (r.copy(), null)))
       val baseIterator = sorter.iterator.map(_._1)
+      val context = TaskContext.get()
+      context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+      context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+      context.internalMetricsToAccumulators(
+        
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
       // TODO(marmbrus): The complex type signature below thwarts inference 
for no reason.
       CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, 
sorter.stop())
     }, preservesPartitioning = true)
@@ -137,7 +143,11 @@ case class TungstenSort(
       if (testSpillFrequency > 0) {
         sorter.setTestSpillFrequency(testSpillFrequency)
       }
-      sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+      val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+      val taskContext = TaskContext.get()
+      taskContext.internalMetricsToAccumulators(
+        
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
+      sortedIterator
     }, preservesPartitioning = true)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index f1abae0..29dfcf2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -21,6 +21,7 @@ import java.sql.Timestamp
 
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.DefaultParserDialect
 import org.apache.spark.sql.catalyst.errors.DialectException
@@ -258,6 +259,23 @@ class SQLQuerySuite extends QueryTest with 
BeforeAndAfterAll with SQLTestUtils {
     }
   }
 
+  private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
+    val df = sql(sqlText)
+    // First, check if we have GeneratedAggregate.
+    val hasGeneratedAgg = df.queryExecution.executedPlan
+      .collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true }
+      .nonEmpty
+    if (!hasGeneratedAgg) {
+      fail(
+        s"""
+           |Codegen is enabled, but query $sqlText does not have 
GeneratedAggregate in the plan.
+           |${df.queryExecution.simpleString}
+         """.stripMargin)
+    }
+    // Then, check results.
+    checkAnswer(df, expectedResults)
+  }
+
   test("aggregation with codegen") {
     val originalValue = sqlContext.conf.codegenEnabled
     sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
@@ -267,26 +285,6 @@ class SQLQuerySuite extends QueryTest with 
BeforeAndAfterAll with SQLTestUtils {
       .unionAll(sqlContext.table("testData"))
       .registerTempTable("testData3x")
 
-    def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
-      val df = sql(sqlText)
-      // First, check if we have GeneratedAggregate.
-      var hasGeneratedAgg = false
-      df.queryExecution.executedPlan.foreach {
-        case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
-        case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true
-        case _ =>
-      }
-      if (!hasGeneratedAgg) {
-        fail(
-          s"""
-             |Codegen is enabled, but query $sqlText does not have 
GeneratedAggregate in the plan.
-             |${df.queryExecution.simpleString}
-           """.stripMargin)
-      }
-      // Then, check results.
-      checkAnswer(df, expectedResults)
-    }
-
     try {
       // Just to group rows.
       testCodeGen(
@@ -1605,6 +1603,28 @@ class SQLQuerySuite extends QueryTest with 
BeforeAndAfterAll with SQLTestUtils {
       Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
   }
 
+  test("aggregation with codegen updates peak execution memory") {
+    withSQLConf(
+        (SQLConf.CODEGEN_ENABLED.key, "true"),
+        (SQLConf.USE_SQL_AGGREGATE2.key, "false")) {
+      val sc = sqlContext.sparkContext
+      AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with 
codegen") {
+        testCodeGen(
+          "SELECT key, count(value) FROM testData GROUP BY key",
+          (1 to 100).map(i => Row(i, 1)))
+      }
+    }
+  }
+
+  test("external sorting updates peak execution memory") {
+    withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) {
+      val sc = sqlContext.sparkContext
+      AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") {
+        sortTest()
+      }
+    }
+  }
+
   test("SPARK-9511: error with table starting with number") {
     val df = sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, 
i.toString))
       .toDF("num", "str")

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index c794984..88bce0e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
 
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.test.TestSQLContext
@@ -59,6 +60,17 @@ class TungstenSortSuite extends SparkPlanTest with 
BeforeAndAfterAll {
     )
   }
 
+  test("sorting updates peak execution memory") {
+    val sc = TestSQLContext.sparkContext
+    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") {
+      checkThatPlansAgree(
+        (1 to 100).map(v => Tuple1(v)).toDF("a"),
+        (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child),
+        (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child),
+        sortAnswers = false)
+    }
+  }
+
   // Test sorting on different data types
   for (
     dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 7c591f6..ef827b0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -69,7 +69,8 @@ class UnsafeFixedWidthAggregationMapSuite extends 
SparkFunSuite with Matchers {
         taskAttemptId = Random.nextInt(10000),
         attemptNumber = 0,
         taskMemoryManager = taskMemoryManager,
-        metricsSystem = null))
+        metricsSystem = null,
+        internalAccumulators = Seq.empty))
 
       try {
         f

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 0282b25..601a5a0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -76,7 +76,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
         taskAttemptId = 98456,
         attemptNumber = 0,
         taskMemoryManager = taskMemMgr,
-        metricsSystem = null))
+        metricsSystem = null,
+        internalAccumulators = Seq.empty))
 
       // Create the data converters
       val kExternalConverter = 
CatalystTypeConverters.createToCatalystConverter(keySchema)

http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
new file mode 100644
index 0000000..0554e11
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -0,0 +1,94 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+// TODO: uncomment the test here! It is currently failing due to
+// bad interaction with org.apache.spark.sql.test.TestSQLContext.
+
+// scalastyle:off
+//package org.apache.spark.sql.execution.joins
+//
+//import scala.reflect.ClassTag
+//
+//import org.scalatest.BeforeAndAfterAll
+//
+//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
+//import org.apache.spark.sql.functions._
+//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest}
+//
+///**
+// * Test various broadcast join operators with unsafe enabled.
+// *
+// * This needs to be its own suite because 
[[org.apache.spark.sql.test.TestSQLContext]] runs
+// * in local mode, but for tests in this suite we need to run Spark in 
local-cluster mode.
+// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] 
in
+// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not 
triggered without
+// * serializing the hashed relation, which does not happen in local mode.
+// */
+//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
+//  private var sc: SparkContext = null
+//  private var sqlContext: SQLContext = null
+//
+//  /**
+//   * Create a new [[SQLContext]] running in local-cluster mode with unsafe 
and codegen enabled.
+//   */
+//  override def beforeAll(): Unit = {
+//    super.beforeAll()
+//    val conf = new SparkConf()
+//      .setMaster("local-cluster[2,1,1024]")
+//      .setAppName("testing")
+//    sc = new SparkContext(conf)
+//    sqlContext = new SQLContext(sc)
+//    sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true)
+//    sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
+//  }
+//
+//  override def afterAll(): Unit = {
+//    sc.stop()
+//    sc = null
+//    sqlContext = null
+//  }
+//
+//  /**
+//   * Test whether the specified broadcast join updates the peak execution 
memory accumulator.
+//   */
+//  private def testBroadcastJoin[T: ClassTag](name: String, joinType: 
String): Unit = {
+//    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) {
+//      val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, 
"2"))).toDF("key", "value")
+//      val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, 
"2"))).toDF("key", "value")
+//      // Comparison at the end is for broadcast left semi join
+//      val joinExpression = df1("key") === df2("key") && df1("value") > 
df2("value")
+//      val df3 = df1.join(broadcast(df2), joinExpression, joinType)
+//      val plan = df3.queryExecution.executedPlan
+//      assert(plan.collect { case p: T => p }.size === 1)
+//      plan.executeCollect()
+//    }
+//  }
+//
+//  test("unsafe broadcast hash join updates peak execution memory") {
+//    testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", 
"inner")
+//  }
+//
+//  test("unsafe broadcast hash outer join updates peak execution memory") {
+//    testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer 
join", "left_outer")
+//  }
+//
+//  test("unsafe broadcast left semi join updates peak execution memory") {
+//    testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi 
join", "leftsemi")
+//  }
+//
+//}
+// scalastyle:on


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to