Repository: spark Updated Branches: refs/heads/branch-1.5 e7329ab31 -> 29756ff11
http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 9201d1e..450ab7b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -57,8 +57,9 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val task = new ResultTask[String, String](0, 0, - sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) + val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array) + val task = new ResultTask[String, String]( + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) intercept[RuntimeException] { task.run(0, 0, null) } @@ -66,7 +67,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 3abb99c..f7cc4bb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer) http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala index db718ec..05b3afe 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -138,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { shuffleHandle, reduceId, reduceId + 1, - new TaskContextImpl(0, 0, 0, 0, null, null), + TaskContext.empty(), blockManager, mapOutputTracker) http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index cf8bd8a..828153b 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -29,7 +29,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkFunSuite, TaskContextImpl} +import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener @@ -95,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0, null, null), + TaskContext.empty(), transfer, blockManager, blocksByAddress, @@ -165,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala new file mode 100644 index 0000000..98f9314 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite, Success} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} +import org.apache.spark.ui.scope.RDDOperationGraphListener + +class StagePageSuite extends SparkFunSuite with LocalSparkContext { + + test("peak execution memory only displayed if unsafe is enabled") { + val unsafeConf = "spark.sql.unsafe.enabled" + val conf = new SparkConf().set(unsafeConf, "true") + val html = renderStagePage(conf).toString().toLowerCase + val targetString = "peak execution memory" + assert(html.contains(targetString)) + // Disable unsafe and make sure it's not there + val conf2 = new SparkConf().set(unsafeConf, "false") + val html2 = renderStagePage(conf2).toString().toLowerCase + assert(!html2.contains(targetString)) + } + + /** + * Render a stage page started with the given conf and return the HTML. + * This also runs a dummy stage to populate the page with useful content. + */ + private def renderStagePage(conf: SparkConf): Seq[Node] = { + val jobListener = new JobProgressListener(conf) + val graphListener = new RDDOperationGraphListener(conf) + val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) + val request = mock(classOf[HttpServletRequest]) + when(tab.conf).thenReturn(conf) + when(tab.progressListener).thenReturn(jobListener) + when(tab.operationGraphListener).thenReturn(graphListener) + when(tab.appName).thenReturn("testing") + when(tab.headerTabs).thenReturn(Seq.empty) + when(request.getParameter("id")).thenReturn("0") + when(request.getParameter("attempt")).thenReturn("0") + val page = new StagePage(tab) + + // Simulate a stage in job progress listener + val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") + val taskInfo = new TaskInfo(0, 0, 0, 0, "0", "localhost", TaskLocality.ANY, false) + jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) + jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) + taskInfo.markSuccessful() + jobListener.onTaskEnd( + SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) + page.render(request) + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 9c362f0..12e9baf 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -399,4 +399,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } + test("external aggregation updates peak execution memory") { + val conf = createSparkConf(loadDefaults = false) + .set("spark.shuffle.memoryFraction", "0.001") + .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter + sc = new SparkContext("local", "test", conf) + // No spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") { + sc.parallelize(1 to 10, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + // With spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with spilling") { + sc.parallelize(1 to 1000 * 1000, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + } + } http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 986cd86..bdb0f4d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -692,7 +692,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { sortWithoutBreakingSortingContracts(createSparkConf(true, false)) } - def sortWithoutBreakingSortingContracts(conf: SparkConf) { + private def sortWithoutBreakingSortingContracts(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.01") conf.set("spark.shuffle.manager", "sort") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) @@ -743,5 +743,15 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } sorter2.stop() - } + } + + test("sorting updates peak execution memory") { + val conf = createSparkConf(loadDefaults = false, kryo = false) + .set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + // Avoid aggregating here to make sure we're not also using ExternalAppendOnlyMap + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter") { + sc.parallelize(1 to 1000, 2).repartition(100).count() + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 5e4c623..193906d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -106,6 +106,13 @@ final class UnsafeExternalRowSorter { sorter.spill(); } + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsage() { + return sorter.getPeakMemoryUsedBytes(); + } + private void cleanupResources() { sorter.freeMemory(); } http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 9e2c933..43d06ce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -209,6 +209,14 @@ public final class UnsafeFixedWidthAggregationMap { } /** + * The memory used by this map's managed structures, in bytes. + * Note that this is also the peak memory used by this map, since the map is append-only. + */ + public long getMemoryUsage() { + return map.getTotalMemoryConsumption(); + } + + /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ public void free() { http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index cd87b8d..bf4905d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import java.io.IOException -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -263,11 +263,12 @@ case class GeneratedAggregate( assert(iter.hasNext, "There should be at least one row for this path") log.info("Using Unsafe-based aggregator") val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") + val taskContext = TaskContext.get() val aggregationMap = new UnsafeFixedWidthAggregationMap( newAggregationBuffer(EmptyRow), aggregationBufferSchema, groupKeySchema, - TaskContext.get.taskMemoryManager(), + taskContext.taskMemoryManager(), SparkEnv.get.shuffleMemoryManager, 1024 * 16, // initial capacity pageSizeBytes, @@ -284,6 +285,10 @@ case class GeneratedAggregate( updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) } + // Record memory used in the process + taskContext.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage) + new Iterator[InternalRow] { private[this] val mapIterator = aggregationMap.iterator() private[this] val resultProjection = resultProjectionBuilder() @@ -300,7 +305,7 @@ case class GeneratedAggregate( } else { // This is the last element in the iterator, so let's free the buffer. Before we do, // though, we need to make a defensive copy of the result so that we don't return an - // object that might contain dangling pointers to the freed memory + // object that might contain dangling pointers to the freed memory. val resultCopy = result.copy() aggregationMap.free() resultCopy http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 624efc1..e73e252 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -70,7 +71,14 @@ case class BroadcastHashJoin( val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => - hashJoin(streamedIter, broadcastRelation.value) + val hashedRelation = broadcastRelation.value + hashedRelation match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + hashJoin(streamedIter, hashedRelation) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 309716a..c35e439 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -75,6 +76,13 @@ case class BroadcastHashOuterJoin( val hashTable = broadcastRelation.value val keyGenerator = streamedKeyGenerator + hashTable match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index a605939..5bd06fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -51,7 +52,14 @@ case class BroadcastLeftSemiJoinHash( val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => - hashSemiJoin(streamIter, broadcastedRelation.value) + val hashedRelation = broadcastedRelation.value + hashedRelation match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + hashSemiJoin(streamIter, hashedRelation) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cc8bbfd..58b4236 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -183,8 +183,27 @@ private[joins] final class UnsafeHashedRelation( private[joins] def this() = this(null) // Needed for serialization // Use BytesToBytesMap in executor for better performance (it's created when deserialization) + // This is used in broadcast joins and distributed mode only @transient private[this] var binaryMap: BytesToBytesMap = _ + /** + * Return the size of the unsafe map on the executors. + * + * For broadcast joins, this hashed relation is bigger on the driver because it is + * represented as a Java hash map there. While serializing the map to the executors, + * however, we rehash the contents in a binary map to reduce the memory footprint on + * the executors. + * + * For non-broadcast joins or in local mode, return 0. + */ + def getUnsafeSize: Long = { + if (binaryMap != null) { + binaryMap.getTotalMemoryConsumption + } else { + 0 + } + } + override def get(key: InternalRow): Seq[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] @@ -214,7 +233,7 @@ private[joins] final class UnsafeHashedRelation( } } else { - // Use the JavaHashMap in Local mode or ShuffleHashJoin + // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin) hashTable.get(unsafeKey) } } @@ -316,6 +335,7 @@ private[joins] object UnsafeHashedRelation { keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { + // Use a Java hash table here because unsafe maps expect fixed size records val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) // Create a mapping of buildKeys -> rows http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 92cf328..3192b6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -76,6 +77,11 @@ case class ExternalSort( val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) sorter.insertAll(iterator.map(r => (r.copy(), null))) val baseIterator = sorter.iterator.map(_._1) + val context = TaskContext.get() + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) // TODO(marmbrus): The complex type signature below thwarts inference for no reason. CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) }, preservesPartitioning = true) @@ -137,7 +143,11 @@ case class TungstenSort( if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } - sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + val taskContext = TaskContext.get() + taskContext.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) + sortedIterator }, preservesPartitioning = true) } http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f1abae0..29dfcf2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,6 +21,7 @@ import java.sql.Timestamp import org.scalatest.BeforeAndAfterAll +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException @@ -258,6 +259,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } } + private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + val hasGeneratedAgg = df.queryExecution.executedPlan + .collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true } + .nonEmpty + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + test("aggregation with codegen") { val originalValue = sqlContext.conf.codegenEnabled sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) @@ -267,26 +285,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { .unionAll(sqlContext.table("testData")) .registerTempTable("testData3x") - def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { - val df = sql(sqlText) - // First, check if we have GeneratedAggregate. - var hasGeneratedAgg = false - df.queryExecution.executedPlan.foreach { - case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true - case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true - case _ => - } - if (!hasGeneratedAgg) { - fail( - s""" - |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. - |${df.queryExecution.simpleString} - """.stripMargin) - } - // Then, check results. - checkAnswer(df, expectedResults) - } - try { // Just to group rows. testCodeGen( @@ -1605,6 +1603,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } + test("aggregation with codegen updates peak execution memory") { + withSQLConf( + (SQLConf.CODEGEN_ENABLED.key, "true"), + (SQLConf.USE_SQL_AGGREGATE2.key, "false")) { + val sc = sqlContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") { + testCodeGen( + "SELECT key, count(value) FROM testData GROUP BY key", + (1 to 100).map(i => Row(i, 1))) + } + } + } + + test("external sorting updates peak execution memory") { + withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { + val sc = sqlContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") { + sortTest() + } + } + } + test("SPARK-9511: error with table starting with number") { val df = sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) .toDF("num", "str") http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index c794984..88bce0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.BeforeAndAfterAll +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext @@ -59,6 +60,17 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { ) } + test("sorting updates peak execution memory") { + val sc = TestSQLContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child), + (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child), + sortAnswers = false) + } + } + // Test sorting on different data types for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 7c591f6..ef827b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -69,7 +69,8 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { taskAttemptId = Random.nextInt(10000), attemptNumber = 0, taskMemoryManager = taskMemoryManager, - metricsSystem = null)) + metricsSystem = null, + internalAccumulators = Seq.empty)) try { f http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 0282b25..601a5a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -76,7 +76,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { taskAttemptId = 98456, attemptNumber = 0, taskMemoryManager = taskMemMgr, - metricsSystem = null)) + metricsSystem = null, + internalAccumulators = Seq.empty)) // Create the data converters val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) http://git-wip-us.apache.org/repos/asf/spark/blob/29756ff1/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala new file mode 100644 index 0000000..0554e11 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -0,0 +1,94 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +// TODO: uncomment the test here! It is currently failing due to +// bad interaction with org.apache.spark.sql.test.TestSQLContext. + +// scalastyle:off +//package org.apache.spark.sql.execution.joins +// +//import scala.reflect.ClassTag +// +//import org.scalatest.BeforeAndAfterAll +// +//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +//import org.apache.spark.sql.functions._ +//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} +// +///** +// * Test various broadcast join operators with unsafe enabled. +// * +// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs +// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode. +// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in +// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without +// * serializing the hashed relation, which does not happen in local mode. +// */ +//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { +// private var sc: SparkContext = null +// private var sqlContext: SQLContext = null +// +// /** +// * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. +// */ +// override def beforeAll(): Unit = { +// super.beforeAll() +// val conf = new SparkConf() +// .setMaster("local-cluster[2,1,1024]") +// .setAppName("testing") +// sc = new SparkContext(conf) +// sqlContext = new SQLContext(sc) +// sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) +// sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) +// } +// +// override def afterAll(): Unit = { +// sc.stop() +// sc = null +// sqlContext = null +// } +// +// /** +// * Test whether the specified broadcast join updates the peak execution memory accumulator. +// */ +// private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { +// AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { +// val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") +// val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") +// // Comparison at the end is for broadcast left semi join +// val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") +// val df3 = df1.join(broadcast(df2), joinExpression, joinType) +// val plan = df3.queryExecution.executedPlan +// assert(plan.collect { case p: T => p }.size === 1) +// plan.executeCollect() +// } +// } +// +// test("unsafe broadcast hash join updates peak execution memory") { +// testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") +// } +// +// test("unsafe broadcast hash outer join updates peak execution memory") { +// testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") +// } +// +// test("unsafe broadcast left semi join updates peak execution memory") { +// testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") +// } +// +//} +// scalastyle:on --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org