http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
 
b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
deleted file mode 100644
index 87a786b..0000000
--- 
a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ /dev/null
@@ -1,273 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.InputStream
-import java.nio.IntBuffer
-import java.util.Comparator
-
-import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
-import org.apache.spark.storage.DiskBlockObjectWriter
-import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
-
-/**
- * Append-only buffer of key-value pairs, each with a corresponding partition 
ID, that serializes
- * its records upon insert and stores them as raw bytes.
- *
- * We use two data-structures to store the contents. The serialized records 
are stored in a
- * ChainedBuffer that can expand gracefully as records are added. This buffer 
is accompanied by a
- * metadata buffer that stores pointers into the data buffer as well as the 
partition ID of each
- * record. Each entry in the metadata buffer takes up a fixed amount of space.
- *
- * Sorting the collection means swapping entries in the metadata buffer - the 
record buffer need not
- * be modified at all. Storing the partition IDs in the metadata buffer means 
that comparisons can
- * happen without following any pointers, which should minimize cache misses.
- *
- * Currently, only sorting by partition is supported.
- *
- * Each record is laid out inside the the metaBuffer as follows. keyStart, a 
long, is split across
- * two integers:
- *
- *   +-------------+------------+------------+-------------+
- *   |         keyStart         | keyValLen  | partitionId |
- *   +-------------+------------+------------+-------------+
- *
- * The buffer can support up to `536870911 (2 ^ 29 - 1)` records.
- *
- * @param metaInitialRecords The initial number of entries in the metadata 
buffer.
- * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used 
to store the records.
- * @param serializerInstance the serializer used for serializing inserted 
records.
- */
-private[spark] class PartitionedSerializedPairBuffer[K, V](
-    metaInitialRecords: Int,
-    kvBlockSize: Int,
-    serializerInstance: SerializerInstance)
-  extends WritablePartitionedPairCollection[K, V] with SizeTracker {
-
-  if (serializerInstance.isInstanceOf[JavaSerializerInstance]) {
-    throw new IllegalArgumentException("PartitionedSerializedPairBuffer does 
not support" +
-      " Java-serialized objects.")
-  }
-
-  require(metaInitialRecords <= MAXIMUM_RECORDS,
-    s"Can't make capacity bigger than ${MAXIMUM_RECORDS} records")
-  private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE)
-
-  private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize)
-  private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer)
-  private val kvSerializationStream = 
serializerInstance.serializeStream(kvOutputStream)
-
-  def insert(partition: Int, key: K, value: V): Unit = {
-    if (metaBuffer.position == metaBuffer.capacity) {
-      growMetaBuffer()
-    }
-
-    val keyStart = kvBuffer.size
-    kvSerializationStream.writeKey[Any](key)
-    kvSerializationStream.writeValue[Any](value)
-    kvSerializationStream.flush()
-    val keyValLen = (kvBuffer.size - keyStart).toInt
-
-    // keyStart, a long, gets split across two ints
-    metaBuffer.put(keyStart.toInt)
-    metaBuffer.put((keyStart >> 32).toInt)
-    metaBuffer.put(keyValLen)
-    metaBuffer.put(partition)
-  }
-
-  /** Double the size of the array because we've reached capacity */
-  private def growMetaBuffer(): Unit = {
-    if (metaBuffer.capacity >= MAXIMUM_META_BUFFER_CAPACITY) {
-      throw new IllegalStateException(s"Can't insert more than 
${MAXIMUM_RECORDS} records")
-    }
-    val newCapacity =
-      if (metaBuffer.capacity * 2 < 0 || metaBuffer.capacity * 2 > 
MAXIMUM_META_BUFFER_CAPACITY) {
-        // Overflow
-        MAXIMUM_META_BUFFER_CAPACITY
-      } else {
-        metaBuffer.capacity * 2
-      }
-    val newMetaBuffer = IntBuffer.allocate(newCapacity)
-    newMetaBuffer.put(metaBuffer.array)
-    metaBuffer = newMetaBuffer
-  }
-
-  /** Iterate through the data in a given order. For this class this is not 
really destructive. */
-  override def partitionedDestructiveSortedIterator(keyComparator: 
Option[Comparator[K]])
-    : Iterator[((Int, K), V)] = {
-    sort(keyComparator)
-    val is = orderedInputStream
-    val deserStream = serializerInstance.deserializeStream(is)
-    new Iterator[((Int, K), V)] {
-      var metaBufferPos = 0
-      def hasNext: Boolean = metaBufferPos < metaBuffer.position
-      def next(): ((Int, K), V) = {
-        val key = deserStream.readKey[Any]().asInstanceOf[K]
-        val value = deserStream.readValue[Any]().asInstanceOf[V]
-        val partition = metaBuffer.get(metaBufferPos + PARTITION)
-        metaBufferPos += RECORD_SIZE
-        ((partition, key), value)
-      }
-    }
-  }
-
-  override def estimateSize: Long = metaBuffer.capacity * 4L + 
kvBuffer.capacity
-
-  override def destructiveSortedWritablePartitionedIterator(keyComparator: 
Option[Comparator[K]])
-    : WritablePartitionedIterator = {
-    sort(keyComparator)
-    new WritablePartitionedIterator {
-      // current position in the meta buffer in ints
-      var pos = 0
-
-      def writeNext(writer: DiskBlockObjectWriter): Unit = {
-        val keyStart = getKeyStartPos(metaBuffer, pos)
-        val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN)
-        pos += RECORD_SIZE
-        kvBuffer.read(keyStart, writer, keyValLen)
-        writer.recordWritten()
-      }
-      def nextPartition(): Int = metaBuffer.get(pos + PARTITION)
-      def hasNext(): Boolean = pos < metaBuffer.position
-    }
-  }
-
-  // Visible for testing
-  def orderedInputStream: OrderedInputStream = {
-    new OrderedInputStream(metaBuffer, kvBuffer)
-  }
-
-  private def sort(keyComparator: Option[Comparator[K]]): Unit = {
-    val comparator = if (keyComparator.isEmpty) {
-      new Comparator[Int]() {
-        def compare(partition1: Int, partition2: Int): Int = {
-          partition1 - partition2
-        }
-      }
-    } else {
-      throw new UnsupportedOperationException()
-    }
-
-    val sorter = new Sorter(new SerializedSortDataFormat)
-    sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator)
-  }
-}
-
-private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: 
ChainedBuffer)
-    extends InputStream {
-
-  import PartitionedSerializedPairBuffer._
-
-  private var metaBufferPos = 0
-  private var kvBufferPos =
-    if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) 
else 0
-
-  override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length)
-
-  override def read(bytes: Array[Byte], offs: Int, len: Int): Int = {
-    if (metaBufferPos >= metaBuffer.position) {
-      return -1
-    }
-    val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) -
-      (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt
-    val toRead = math.min(bytesRemainingInRecord, len)
-    kvBuffer.read(kvBufferPos, bytes, offs, toRead)
-    if (toRead == bytesRemainingInRecord) {
-      metaBufferPos += RECORD_SIZE
-      if (metaBufferPos < metaBuffer.position) {
-        kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos)
-      }
-    } else {
-      kvBufferPos += toRead
-    }
-    toRead
-  }
-
-  override def read(): Int = {
-    throw new UnsupportedOperationException()
-  }
-}
-
-private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, 
IntBuffer] {
-
-  private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE)
-
-  /** Return the sort key for the element at the given index. */
-  override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = {
-    metaBuffer.get(pos * RECORD_SIZE + PARTITION)
-  }
-
-  /** Swap two elements. */
-  override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = {
-    val iOff = pos0 * RECORD_SIZE
-    val jOff = pos1 * RECORD_SIZE
-    System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE)
-    System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, 
RECORD_SIZE)
-    System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE)
-  }
-
-  /** Copy a single element from src(srcPos) to dst(dstPos). */
-  override def copyElement(
-      src: IntBuffer,
-      srcPos: Int,
-      dst: IntBuffer,
-      dstPos: Int): Unit = {
-    val srcOff = srcPos * RECORD_SIZE
-    val dstOff = dstPos * RECORD_SIZE
-    System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE)
-  }
-
-  /**
-   * Copy a range of elements starting at src(srcPos) to dst, starting at 
dstPos.
-   * Overlapping ranges are allowed.
-   */
-  override def copyRange(
-      src: IntBuffer,
-      srcPos: Int,
-      dst: IntBuffer,
-      dstPos: Int,
-      length: Int): Unit = {
-    val srcOff = srcPos * RECORD_SIZE
-    val dstOff = dstPos * RECORD_SIZE
-    System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * 
length)
-  }
-
-  /**
-   * Allocates a Buffer that can hold up to 'length' elements.
-   * All elements of the buffer should be considered invalid until data is 
explicitly copied in.
-   */
-  override def allocate(length: Int): IntBuffer = {
-    IntBuffer.allocate(length * RECORD_SIZE)
-  }
-}
-
-private object PartitionedSerializedPairBuffer {
-  val KEY_START = 0 // keyStart, a long, gets split across two ints
-  val KEY_VAL_LEN = 2
-  val PARTITION = 3
-  val RECORD_SIZE = PARTITION + 1 // num ints of metadata
-
-  val MAXIMUM_RECORDS = Int.MaxValue / RECORD_SIZE // (2 ^ 29) - 1
-  val MAXIMUM_META_BUFFER_CAPACITY = MAXIMUM_RECORDS * RECORD_SIZE // (2 ^ 31) 
- 4
-
-  def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = {
-    val lower32 = metaBuffer.get(metaBufferPos + KEY_START)
-    val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1)
-    (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
 
b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
new file mode 100644
index 0000000..232ae4d
--- /dev/null
+++ 
b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import org.apache.spark.shuffle.sort.PackedRecordPointer;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
+
+public class PackedRecordPointerSuite {
+
+  @Test
+  public void heap() {
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock page0 = memoryManager.allocatePage(128);
+    final MemoryBlock page1 = memoryManager.allocatePage(128);
+    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+      page1.getBaseOffset() + 42);
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+    assertEquals(360, packedPointer.getPartitionId());
+    final long recordPointer = packedPointer.getRecordPointer();
+    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+    assertEquals(page1.getBaseOffset() + 42, 
memoryManager.getOffsetInPage(recordPointer));
+    assertEquals(addressInPage1, recordPointer);
+    memoryManager.cleanUpAllAllocatedMemory();
+  }
+
+  @Test
+  public void offHeap() {
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+    final MemoryBlock page0 = memoryManager.allocatePage(128);
+    final MemoryBlock page1 = memoryManager.allocatePage(128);
+    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+      page1.getBaseOffset() + 42);
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+    assertEquals(360, packedPointer.getPartitionId());
+    final long recordPointer = packedPointer.getRecordPointer();
+    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+    assertEquals(page1.getBaseOffset() + 42, 
memoryManager.getOffsetInPage(recordPointer));
+    assertEquals(addressInPage1, recordPointer);
+    memoryManager.cleanUpAllAllocatedMemory();
+  }
+
+  @Test
+  public void maximumPartitionIdCanBeEncoded() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(0, 
MAXIMUM_PARTITION_ID));
+    assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId());
+  }
+
+  @Test
+  public void 
partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    try {
+      // Pointers greater than the maximum partition ID will overflow or 
trigger an assertion error
+      packedPointer.set(PackedRecordPointer.packPointer(0, 
MAXIMUM_PARTITION_ID + 1));
+      assertFalse(MAXIMUM_PARTITION_ID  + 1 == packedPointer.getPartitionId());
+    } catch (AssertionError e ) {
+      // pass
+    }
+  }
+
+  @Test
+  public void maximumOffsetInPageCanBeEncoded() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    long address = TaskMemoryManager.encodePageNumberAndOffset(0, 
MAXIMUM_PAGE_SIZE_BYTES - 1);
+    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+    assertEquals(address, packedPointer.getRecordPointer());
+  }
+
+  @Test
+  public void offsetsPastMaxOffsetInPageWillOverflow() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    long address = TaskMemoryManager.encodePageNumberAndOffset(0, 
MAXIMUM_PAGE_SIZE_BYTES);
+    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+    assertEquals(0, packedPointer.getRecordPointer());
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
 
b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
new file mode 100644
index 0000000..1ef3c5f
--- /dev/null
+++ 
b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.util.Arrays;
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class ShuffleInMemorySorterSuite {
+
+  private static String getStringFromDataPage(Object baseObject, long 
baseOffset, int strLength) {
+    final byte[] strBytes = new byte[strLength];
+    Platform.copyMemory(baseObject, baseOffset, strBytes, 
Platform.BYTE_ARRAY_OFFSET, strLength);
+    return new String(strBytes);
+  }
+
+  @Test
+  public void testSortingEmptyInput() {
+    final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100);
+    final ShuffleInMemorySorter.ShuffleSorterIterator iter = 
sorter.getSortedIterator();
+    assert(!iter.hasNext());
+  }
+
+  @Test
+  public void testBasicSorting() throws Exception {
+    final String[] dataToSort = new String[] {
+      "Boba",
+      "Pearls",
+      "Tapioca",
+      "Taho",
+      "Condensed Milk",
+      "Jasmine",
+      "Milk Tea",
+      "Lychee",
+      "Mango"
+    };
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+    final Object baseObject = dataPage.getBaseObject();
+    final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
+    final HashPartitioner hashPartitioner = new HashPartitioner(4);
+
+    // Write the records into the data page and store pointers into the sorter
+    long position = dataPage.getBaseOffset();
+    for (String str : dataToSort) {
+      final long recordAddress = 
memoryManager.encodePageNumberAndOffset(dataPage, position);
+      final byte[] strBytes = str.getBytes("utf-8");
+      Platform.putInt(baseObject, position, strBytes.length);
+      position += 4;
+      Platform.copyMemory(
+        strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, 
strBytes.length);
+      position += strBytes.length;
+      sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
+    }
+
+    // Sort the records
+    final ShuffleInMemorySorter.ShuffleSorterIterator iter = 
sorter.getSortedIterator();
+    int prevPartitionId = -1;
+    Arrays.sort(dataToSort);
+    for (int i = 0; i < dataToSort.length; i++) {
+      Assert.assertTrue(iter.hasNext());
+      iter.loadNext();
+      final int partitionId = iter.packedRecordPointer.getPartitionId();
+      Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
+      Assert.assertTrue("Partition id " + partitionId + " should be >= prev id 
" + prevPartitionId,
+        partitionId >= prevPartitionId);
+      final long recordAddress = iter.packedRecordPointer.getRecordPointer();
+      final int recordLength = Platform.getInt(
+        memoryManager.getPage(recordAddress), 
memoryManager.getOffsetInPage(recordAddress));
+      final String str = getStringFromDataPage(
+        memoryManager.getPage(recordAddress),
+        memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record 
length
+        recordLength);
+      Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
+    }
+    Assert.assertFalse(iter.hasNext());
+  }
+
+  @Test
+  public void testSortingManyNumbers() throws Exception {
+    ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
+    int[] numbersToSort = new int[128000];
+    Random random = new Random(16);
+    for (int i = 0; i < numbersToSort.length; i++) {
+      numbersToSort[i] = 
random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
+      sorter.insertRecord(0, numbersToSort[i]);
+    }
+    Arrays.sort(numbersToSort);
+    int[] sorterResult = new int[numbersToSort.length];
+    ShuffleInMemorySorter.ShuffleSorterIterator iter = 
sorter.getSortedIterator();
+    int j = 0;
+    while (iter.hasNext()) {
+      iter.loadNext();
+      sorterResult[j] = iter.packedRecordPointer.getPartitionId();
+      j += 1;
+    }
+    Assert.assertArrayEquals(numbersToSort, sorterResult);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
 
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
new file mode 100644
index 0000000..29d9823
--- /dev/null
+++ 
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -0,0 +1,560 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.*;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import scala.*;
+import scala.collection.Iterator;
+import scala.runtime.AbstractFunction1;
+
+import com.google.common.collect.Iterators;
+import com.google.common.collect.HashMultiset;
+import com.google.common.io.ByteStreams;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.*;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZ4CompressionCodec;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.io.SnappyCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.serializer.*;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.sort.SerializedShuffleHandle;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeShuffleWriterSuite {
+
+  static final int NUM_PARTITITONS = 4;
+  final TaskMemoryManager taskMemoryManager =
+    new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+  final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
+  File mergedOutputFile;
+  File tempDir;
+  long[] partitionSizesInMergedFile;
+  final LinkedList<File> spillFilesCreated = new LinkedList<File>();
+  SparkConf conf;
+  final Serializer serializer = new KryoSerializer(new SparkConf());
+  TaskMetrics taskMetrics;
+
+  @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager 
shuffleMemoryManager;
+  @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+  @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver 
shuffleBlockResolver;
+  @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+  @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+  @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, 
Object> shuffleDep;
+
+  private final class CompressStream extends AbstractFunction1<OutputStream, 
OutputStream> {
+    @Override
+    public OutputStream apply(OutputStream stream) {
+      if (conf.getBoolean("spark.shuffle.compress", true)) {
+        return 
CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
+      } else {
+        return stream;
+      }
+    }
+  }
+
+  @After
+  public void tearDown() {
+    Utils.deleteRecursively(tempDir);
+    final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
+    if (leakedMemory != 0) {
+      fail("Test leaked " + leakedMemory + " bytes of managed memory");
+    }
+  }
+
+  @Before
+  @SuppressWarnings("unchecked")
+  public void setUp() throws IOException {
+    MockitoAnnotations.initMocks(this);
+    tempDir = Utils.createTempDir("test", "test");
+    mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
+    partitionSizesInMergedFile = null;
+    spillFilesCreated.clear();
+    conf = new SparkConf().set("spark.buffer.pageSize", "128m");
+    taskMetrics = new TaskMetrics();
+
+    when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+    when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024);
+
+    when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+    when(blockManager.getDiskWriter(
+      any(BlockId.class),
+      any(File.class),
+      any(SerializerInstance.class),
+      anyInt(),
+      any(ShuffleWriteMetrics.class))).thenAnswer(new 
Answer<DiskBlockObjectWriter>() {
+      @Override
+      public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) 
throws Throwable {
+        Object[] args = invocationOnMock.getArguments();
+
+        return new DiskBlockObjectWriter(
+          (File) args[1],
+          (SerializerInstance) args[2],
+          (Integer) args[3],
+          new CompressStream(),
+          false,
+          (ShuffleWriteMetrics) args[4]
+        );
+      }
+    });
+    when(blockManager.wrapForCompression(any(BlockId.class), 
any(InputStream.class))).thenAnswer(
+      new Answer<InputStream>() {
+        @Override
+        public InputStream answer(InvocationOnMock invocation) throws 
Throwable {
+          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+          InputStream is = (InputStream) invocation.getArguments()[1];
+          if (conf.getBoolean("spark.shuffle.compress", true)) {
+            return 
CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
+          } else {
+            return is;
+          }
+        }
+      }
+    );
+
+    when(blockManager.wrapForCompression(any(BlockId.class), 
any(OutputStream.class))).thenAnswer(
+      new Answer<OutputStream>() {
+        @Override
+        public OutputStream answer(InvocationOnMock invocation) throws 
Throwable {
+          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+          OutputStream os = (OutputStream) invocation.getArguments()[1];
+          if (conf.getBoolean("spark.shuffle.compress", true)) {
+            return 
CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
+          } else {
+            return os;
+          }
+        }
+      }
+    );
+
+    when(shuffleBlockResolver.getDataFile(anyInt(), 
anyInt())).thenReturn(mergedOutputFile);
+    doAnswer(new Answer<Void>() {
+      @Override
+      public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+        partitionSizesInMergedFile = (long[]) 
invocationOnMock.getArguments()[2];
+        return null;
+      }
+    }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), 
any(long[].class));
+
+    when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
+      new Answer<Tuple2<TempShuffleBlockId, File>>() {
+        @Override
+        public Tuple2<TempShuffleBlockId, File> answer(
+          InvocationOnMock invocationOnMock) throws Throwable {
+          TempShuffleBlockId blockId = new 
TempShuffleBlockId(UUID.randomUUID());
+          File file = File.createTempFile("spillFile", ".spill", tempDir);
+          spillFilesCreated.add(file);
+          return Tuple2$.MODULE$.apply(blockId, file);
+        }
+      });
+
+    when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+    when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
+
+    
when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
+    when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
+  }
+
+  private UnsafeShuffleWriter<Object, Object> createWriter(
+      boolean transferToEnabled) throws IOException {
+    conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
+    return new UnsafeShuffleWriter<Object, Object>(
+      blockManager,
+      shuffleBlockResolver,
+      taskMemoryManager,
+      shuffleMemoryManager,
+      new SerializedShuffleHandle<Object, Object>(0, 1, shuffleDep),
+      0, // map id
+      taskContext,
+      conf
+    );
+  }
+
+  private void assertSpillFilesWereCleanedUp() {
+    for (File spillFile : spillFilesCreated) {
+      assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+        spillFile.exists());
+    }
+  }
+
+  private List<Tuple2<Object, Object>> readRecordsFromFile() throws 
IOException {
+    final ArrayList<Tuple2<Object, Object>> recordsList = new 
ArrayList<Tuple2<Object, Object>>();
+    long startOffset = 0;
+    for (int i = 0; i < NUM_PARTITITONS; i++) {
+      final long partitionSize = partitionSizesInMergedFile[i];
+      if (partitionSize > 0) {
+        InputStream in = new FileInputStream(mergedOutputFile);
+        ByteStreams.skipFully(in, startOffset);
+        in = new LimitedInputStream(in, partitionSize);
+        if (conf.getBoolean("spark.shuffle.compress", true)) {
+          in = 
CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
+        }
+        DeserializationStream recordsStream = 
serializer.newInstance().deserializeStream(in);
+        Iterator<Tuple2<Object, Object>> records = 
recordsStream.asKeyValueIterator();
+        while (records.hasNext()) {
+          Tuple2<Object, Object> record = records.next();
+          assertEquals(i, hashPartitioner.getPartition(record._1()));
+          recordsList.add(record);
+        }
+        recordsStream.close();
+        startOffset += partitionSize;
+      }
+    }
+    return recordsList;
+  }
+
+  @Test(expected=IllegalStateException.class)
+  public void mustCallWriteBeforeSuccessfulStop() throws IOException {
+    createWriter(false).stop(true);
+  }
+
+  @Test
+  public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
+    createWriter(false).stop(false);
+  }
+
+  class PandaException extends RuntimeException {
+  }
+
+  @Test(expected=PandaException.class)
+  public void writeFailurePropagates() throws Exception {
+    class BadRecords extends 
scala.collection.AbstractIterator<Product2<Object, Object>> {
+      @Override public boolean hasNext() {
+        throw new PandaException();
+      }
+      @Override public Product2<Object, Object> next() {
+        return null;
+      }
+    }
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+    writer.write(new BadRecords());
+  }
+
+  @Test
+  public void writeEmptyIterator() throws Exception {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+    writer.write(Iterators.<Product2<Object, Object>>emptyIterator());
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+    assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
+    assertEquals(0, 
taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten());
+    assertEquals(0, 
taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten());
+    assertEquals(0, taskMetrics.diskBytesSpilled());
+    assertEquals(0, taskMetrics.memoryBytesSpilled());
+  }
+
+  @Test
+  public void writeWithoutSpilling() throws Exception {
+    // In this example, each partition should have exactly one record:
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    for (int i = 0; i < NUM_PARTITITONS; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+    writer.write(dataToWrite.iterator());
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+
+    long sumOfPartitionSizes = 0;
+    for (long size: partitionSizesInMergedFile) {
+      // All partitions should be the same size:
+      assertEquals(partitionSizesInMergedFile[0], size);
+      sumOfPartitionSizes += size;
+    }
+    assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), 
shuffleWriteMetrics.shuffleRecordsWritten());
+    assertEquals(0, taskMetrics.diskBytesSpilled());
+    assertEquals(0, taskMetrics.memoryBytesSpilled());
+    assertEquals(mergedOutputFile.length(), 
shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  private void testMergingSpills(
+      boolean transferToEnabled,
+      String compressionCodecName) throws IOException {
+    if (compressionCodecName != null) {
+      conf.set("spark.shuffle.compress", "true");
+      conf.set("spark.io.compression.codec", compressionCodecName);
+    } else {
+      conf.set("spark.shuffle.compress", "false");
+    }
+    final UnsafeShuffleWriter<Object, Object> writer = 
createWriter(transferToEnabled);
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    writer.insertRecordIntoSorter(dataToWrite.get(0));
+    writer.insertRecordIntoSorter(dataToWrite.get(1));
+    writer.insertRecordIntoSorter(dataToWrite.get(2));
+    writer.insertRecordIntoSorter(dataToWrite.get(3));
+    writer.forceSorterToSpill();
+    writer.insertRecordIntoSorter(dataToWrite.get(4));
+    writer.insertRecordIntoSorter(dataToWrite.get(5));
+    writer.closeAndWriteOutput();
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+    assertEquals(2, spillFilesCreated.size());
+
+    long sumOfPartitionSizes = 0;
+    for (long size: partitionSizesInMergedFile) {
+      sumOfPartitionSizes += size;
+    }
+    assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
+
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), 
shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), 
lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), 
shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndLZF() throws Exception {
+    testMergingSpills(true, LZFCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndLZF() throws Exception {
+    testMergingSpills(false, LZFCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndLZ4() throws Exception {
+    testMergingSpills(true, LZ4CompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
+    testMergingSpills(false, LZ4CompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndSnappy() throws Exception {
+    testMergingSpills(true, SnappyCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
+    testMergingSpills(false, SnappyCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
+    testMergingSpills(true, null);
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
+    testMergingSpills(false, null);
+  }
+
+  @Test
+  public void writeEnoughDataToTriggerSpill() throws Exception {
+    when(shuffleMemoryManager.tryToAcquire(anyLong()))
+      .then(returnsFirstArg()) // Allocate initial sort buffer
+      .then(returnsFirstArg()) // Allocate initial data page
+      .thenReturn(0L) // Deny request to allocate new data page
+      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite = new 
ArrayList<Product2<Object, Object>>();
+    final byte[] bigByteArray = new 
byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
+    for (int i = 0; i < 128 + 1; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
+    }
+    writer.write(dataToWrite.iterator());
+    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    assertEquals(2, spillFilesCreated.size());
+    writer.stop(true);
+    readRecordsFromFile();
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), 
shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), 
lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), 
shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws 
Exception {
+    when(shuffleMemoryManager.tryToAcquire(anyLong()))
+      .then(returnsFirstArg()) // Allocate initial sort buffer
+      .then(returnsFirstArg()) // Allocate initial data page
+      .thenReturn(0L) // Deny request to grow sort buffer
+      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite = new 
ArrayList<Product2<Object, Object>>();
+    for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    writer.write(dataToWrite.iterator());
+    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    assertEquals(2, spillFilesCreated.size());
+    writer.stop(true);
+    readRecordsFromFile();
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), 
shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), 
lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), 
shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws 
Exception {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    final byte[] bytes = new byte[(int) 
(ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
+    new Random(42).nextBytes(bytes);
+    dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
+    writer.write(dataToWrite.iterator());
+    writer.stop(true);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite = new 
ArrayList<Product2<Object, Object>>();
+    dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(new 
byte[1])));
+    // We should be able to write a record that's right _at_ the max record 
size
+    final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()];
+    new Random(42).nextBytes(atMaxRecordSize);
+    dataToWrite.add(new Tuple2<Object, Object>(2, 
ByteBuffer.wrap(atMaxRecordSize)));
+    // Inserting a record that's larger than the max record size
+    final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 
1];
+    new Random(42).nextBytes(exceedsMaxRecordSize);
+    dataToWrite.add(new Tuple2<Object, Object>(3, 
ByteBuffer.wrap(exceedsMaxRecordSize)));
+    writer.write(dataToWrite.iterator());
+    writer.stop(true);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+    writer.forceSorterToSpill();
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+    writer.stop(false);
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void testPeakMemoryUsed() throws Exception {
+    final long recordLengthBytes = 8;
+    final long pageSizeBytes = 256;
+    final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
+    when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
+    final UnsafeShuffleWriter<Object, Object> writer =
+      new UnsafeShuffleWriter<Object, Object>(
+        blockManager,
+        shuffleBlockResolver,
+        taskMemoryManager,
+        shuffleMemoryManager,
+        new SerializedShuffleHandle<>(0, 1, shuffleDep),
+        0, // map id
+        taskContext,
+        conf);
+
+    // Peak memory should be monotonically increasing. More specifically, 
every time
+    // we allocate a new page it should increase by exactly the size of the 
page.
+    long previousPeakMemory = writer.getPeakMemoryUsedBytes();
+    long newPeakMemory;
+    try {
+      for (int i = 0; i < numRecordsPerPage * 10; i++) {
+        writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+        newPeakMemory = writer.getPeakMemoryUsedBytes();
+        if (i % numRecordsPerPage == 0 && i != 0) {
+          // The first page is allocated in constructor, another page will be 
allocated after
+          // every numRecordsPerPage records (peak memory should change).
+          assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
+        } else {
+          assertEquals(previousPeakMemory, newPeakMemory);
+        }
+        previousPeakMemory = newPeakMemory;
+      }
+
+      // Spilling should not change peak memory
+      writer.forceSorterToSpill();
+      newPeakMemory = writer.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+      for (int i = 0; i < numRecordsPerPage; i++) {
+        writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+      }
+      newPeakMemory = writer.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+
+      // Closing the writer should not change peak memory
+      writer.closeAndWriteOutput();
+      newPeakMemory = writer.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+    } finally {
+      writer.stop(false);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
 
b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
deleted file mode 100644
index 934b7e0..0000000
--- 
a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import org.junit.Test;
-import static org.junit.Assert.*;
-
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*;
-
-public class PackedRecordPointerSuite {
-
-  @Test
-  public void heap() {
-    final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
-    final MemoryBlock page0 = memoryManager.allocatePage(128);
-    final MemoryBlock page1 = memoryManager.allocatePage(128);
-    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
-      page1.getBaseOffset() + 42);
-    PackedRecordPointer packedPointer = new PackedRecordPointer();
-    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
-    assertEquals(360, packedPointer.getPartitionId());
-    final long recordPointer = packedPointer.getRecordPointer();
-    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
-    assertEquals(page1.getBaseOffset() + 42, 
memoryManager.getOffsetInPage(recordPointer));
-    assertEquals(addressInPage1, recordPointer);
-    memoryManager.cleanUpAllAllocatedMemory();
-  }
-
-  @Test
-  public void offHeap() {
-    final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
-    final MemoryBlock page0 = memoryManager.allocatePage(128);
-    final MemoryBlock page1 = memoryManager.allocatePage(128);
-    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
-      page1.getBaseOffset() + 42);
-    PackedRecordPointer packedPointer = new PackedRecordPointer();
-    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
-    assertEquals(360, packedPointer.getPartitionId());
-    final long recordPointer = packedPointer.getRecordPointer();
-    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
-    assertEquals(page1.getBaseOffset() + 42, 
memoryManager.getOffsetInPage(recordPointer));
-    assertEquals(addressInPage1, recordPointer);
-    memoryManager.cleanUpAllAllocatedMemory();
-  }
-
-  @Test
-  public void maximumPartitionIdCanBeEncoded() {
-    PackedRecordPointer packedPointer = new PackedRecordPointer();
-    packedPointer.set(PackedRecordPointer.packPointer(0, 
MAXIMUM_PARTITION_ID));
-    assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId());
-  }
-
-  @Test
-  public void 
partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
-    PackedRecordPointer packedPointer = new PackedRecordPointer();
-    try {
-      // Pointers greater than the maximum partition ID will overflow or 
trigger an assertion error
-      packedPointer.set(PackedRecordPointer.packPointer(0, 
MAXIMUM_PARTITION_ID + 1));
-      assertFalse(MAXIMUM_PARTITION_ID  + 1 == packedPointer.getPartitionId());
-    } catch (AssertionError e ) {
-      // pass
-    }
-  }
-
-  @Test
-  public void maximumOffsetInPageCanBeEncoded() {
-    PackedRecordPointer packedPointer = new PackedRecordPointer();
-    long address = TaskMemoryManager.encodePageNumberAndOffset(0, 
MAXIMUM_PAGE_SIZE_BYTES - 1);
-    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
-    assertEquals(address, packedPointer.getRecordPointer());
-  }
-
-  @Test
-  public void offsetsPastMaxOffsetInPageWillOverflow() {
-    PackedRecordPointer packedPointer = new PackedRecordPointer();
-    long address = TaskMemoryManager.encodePageNumberAndOffset(0, 
MAXIMUM_PAGE_SIZE_BYTES);
-    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
-    assertEquals(0, packedPointer.getRecordPointer());
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
 
b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
deleted file mode 100644
index 40fefe2..0000000
--- 
a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import java.util.Arrays;
-import java.util.Random;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-import org.apache.spark.HashPartitioner;
-import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-
-public class UnsafeShuffleInMemorySorterSuite {
-
-  private static String getStringFromDataPage(Object baseObject, long 
baseOffset, int strLength) {
-    final byte[] strBytes = new byte[strLength];
-    Platform.copyMemory(baseObject, baseOffset, strBytes, 
Platform.BYTE_ARRAY_OFFSET, strLength);
-    return new String(strBytes);
-  }
-
-  @Test
-  public void testSortingEmptyInput() {
-    final UnsafeShuffleInMemorySorter sorter = new 
UnsafeShuffleInMemorySorter(100);
-    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = 
sorter.getSortedIterator();
-    assert(!iter.hasNext());
-  }
-
-  @Test
-  public void testBasicSorting() throws Exception {
-    final String[] dataToSort = new String[] {
-      "Boba",
-      "Pearls",
-      "Tapioca",
-      "Taho",
-      "Condensed Milk",
-      "Jasmine",
-      "Milk Tea",
-      "Lychee",
-      "Mango"
-    };
-    final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
-    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
-    final Object baseObject = dataPage.getBaseObject();
-    final UnsafeShuffleInMemorySorter sorter = new 
UnsafeShuffleInMemorySorter(4);
-    final HashPartitioner hashPartitioner = new HashPartitioner(4);
-
-    // Write the records into the data page and store pointers into the sorter
-    long position = dataPage.getBaseOffset();
-    for (String str : dataToSort) {
-      final long recordAddress = 
memoryManager.encodePageNumberAndOffset(dataPage, position);
-      final byte[] strBytes = str.getBytes("utf-8");
-      Platform.putInt(baseObject, position, strBytes.length);
-      position += 4;
-      Platform.copyMemory(
-        strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, 
strBytes.length);
-      position += strBytes.length;
-      sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
-    }
-
-    // Sort the records
-    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = 
sorter.getSortedIterator();
-    int prevPartitionId = -1;
-    Arrays.sort(dataToSort);
-    for (int i = 0; i < dataToSort.length; i++) {
-      Assert.assertTrue(iter.hasNext());
-      iter.loadNext();
-      final int partitionId = iter.packedRecordPointer.getPartitionId();
-      Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
-      Assert.assertTrue("Partition id " + partitionId + " should be >= prev id 
" + prevPartitionId,
-        partitionId >= prevPartitionId);
-      final long recordAddress = iter.packedRecordPointer.getRecordPointer();
-      final int recordLength = Platform.getInt(
-        memoryManager.getPage(recordAddress), 
memoryManager.getOffsetInPage(recordAddress));
-      final String str = getStringFromDataPage(
-        memoryManager.getPage(recordAddress),
-        memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record 
length
-        recordLength);
-      Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
-    }
-    Assert.assertFalse(iter.hasNext());
-  }
-
-  @Test
-  public void testSortingManyNumbers() throws Exception {
-    UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
-    int[] numbersToSort = new int[128000];
-    Random random = new Random(16);
-    for (int i = 0; i < numbersToSort.length; i++) {
-      numbersToSort[i] = 
random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
-      sorter.insertRecord(0, numbersToSort[i]);
-    }
-    Arrays.sort(numbersToSort);
-    int[] sorterResult = new int[numbersToSort.length];
-    UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = 
sorter.getSortedIterator();
-    int j = 0;
-    while (iter.hasNext()) {
-      iter.loadNext();
-      sorterResult[j] = iter.packedRecordPointer.getPartitionId();
-      j += 1;
-    }
-    Assert.assertArrayEquals(numbersToSort, sorterResult);
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
 
b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
deleted file mode 100644
index d218344..0000000
--- 
a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ /dev/null
@@ -1,560 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import java.io.*;
-import java.nio.ByteBuffer;
-import java.util.*;
-
-import scala.*;
-import scala.collection.Iterator;
-import scala.reflect.ClassTag;
-import scala.runtime.AbstractFunction1;
-
-import com.google.common.collect.Iterators;
-import com.google.common.collect.HashMultiset;
-import com.google.common.io.ByteStreams;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.greaterThan;
-import static org.hamcrest.Matchers.lessThan;
-import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
-import static org.mockito.Answers.RETURNS_SMART_NULLS;
-import static org.mockito.Mockito.*;
-
-import org.apache.spark.*;
-import org.apache.spark.io.CompressionCodec$;
-import org.apache.spark.io.LZ4CompressionCodec;
-import org.apache.spark.io.LZFCompressionCodec;
-import org.apache.spark.io.SnappyCompressionCodec;
-import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.executor.TaskMetrics;
-import org.apache.spark.network.util.LimitedInputStream;
-import org.apache.spark.serializer.*;
-import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.storage.*;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import org.apache.spark.util.Utils;
-
-public class UnsafeShuffleWriterSuite {
-
-  static final int NUM_PARTITITONS = 4;
-  final TaskMemoryManager taskMemoryManager =
-    new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
-  final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
-  File mergedOutputFile;
-  File tempDir;
-  long[] partitionSizesInMergedFile;
-  final LinkedList<File> spillFilesCreated = new LinkedList<File>();
-  SparkConf conf;
-  final Serializer serializer = new KryoSerializer(new SparkConf());
-  TaskMetrics taskMetrics;
-
-  @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager 
shuffleMemoryManager;
-  @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
-  @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver 
shuffleBlockResolver;
-  @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
-  @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
-  @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, 
Object> shuffleDep;
-
-  private final class CompressStream extends AbstractFunction1<OutputStream, 
OutputStream> {
-    @Override
-    public OutputStream apply(OutputStream stream) {
-      if (conf.getBoolean("spark.shuffle.compress", true)) {
-        return 
CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
-      } else {
-        return stream;
-      }
-    }
-  }
-
-  @After
-  public void tearDown() {
-    Utils.deleteRecursively(tempDir);
-    final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
-    if (leakedMemory != 0) {
-      fail("Test leaked " + leakedMemory + " bytes of managed memory");
-    }
-  }
-
-  @Before
-  @SuppressWarnings("unchecked")
-  public void setUp() throws IOException {
-    MockitoAnnotations.initMocks(this);
-    tempDir = Utils.createTempDir("test", "test");
-    mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
-    partitionSizesInMergedFile = null;
-    spillFilesCreated.clear();
-    conf = new SparkConf().set("spark.buffer.pageSize", "128m");
-    taskMetrics = new TaskMetrics();
-
-    when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
-    when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024);
-
-    when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
-    when(blockManager.getDiskWriter(
-      any(BlockId.class),
-      any(File.class),
-      any(SerializerInstance.class),
-      anyInt(),
-      any(ShuffleWriteMetrics.class))).thenAnswer(new 
Answer<DiskBlockObjectWriter>() {
-      @Override
-      public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) 
throws Throwable {
-        Object[] args = invocationOnMock.getArguments();
-
-        return new DiskBlockObjectWriter(
-          (File) args[1],
-          (SerializerInstance) args[2],
-          (Integer) args[3],
-          new CompressStream(),
-          false,
-          (ShuffleWriteMetrics) args[4]
-        );
-      }
-    });
-    when(blockManager.wrapForCompression(any(BlockId.class), 
any(InputStream.class))).thenAnswer(
-      new Answer<InputStream>() {
-        @Override
-        public InputStream answer(InvocationOnMock invocation) throws 
Throwable {
-          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
-          InputStream is = (InputStream) invocation.getArguments()[1];
-          if (conf.getBoolean("spark.shuffle.compress", true)) {
-            return 
CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
-          } else {
-            return is;
-          }
-        }
-      }
-    );
-
-    when(blockManager.wrapForCompression(any(BlockId.class), 
any(OutputStream.class))).thenAnswer(
-      new Answer<OutputStream>() {
-        @Override
-        public OutputStream answer(InvocationOnMock invocation) throws 
Throwable {
-          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
-          OutputStream os = (OutputStream) invocation.getArguments()[1];
-          if (conf.getBoolean("spark.shuffle.compress", true)) {
-            return 
CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
-          } else {
-            return os;
-          }
-        }
-      }
-    );
-
-    when(shuffleBlockResolver.getDataFile(anyInt(), 
anyInt())).thenReturn(mergedOutputFile);
-    doAnswer(new Answer<Void>() {
-      @Override
-      public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
-        partitionSizesInMergedFile = (long[]) 
invocationOnMock.getArguments()[2];
-        return null;
-      }
-    }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), 
any(long[].class));
-
-    when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
-      new Answer<Tuple2<TempShuffleBlockId, File>>() {
-        @Override
-        public Tuple2<TempShuffleBlockId, File> answer(
-          InvocationOnMock invocationOnMock) throws Throwable {
-          TempShuffleBlockId blockId = new 
TempShuffleBlockId(UUID.randomUUID());
-          File file = File.createTempFile("spillFile", ".spill", tempDir);
-          spillFilesCreated.add(file);
-          return Tuple2$.MODULE$.apply(blockId, file);
-        }
-      });
-
-    when(taskContext.taskMetrics()).thenReturn(taskMetrics);
-    when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
-
-    
when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
-    when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
-  }
-
-  private UnsafeShuffleWriter<Object, Object> createWriter(
-      boolean transferToEnabled) throws IOException {
-    conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
-    return new UnsafeShuffleWriter<Object, Object>(
-      blockManager,
-      shuffleBlockResolver,
-      taskMemoryManager,
-      shuffleMemoryManager,
-      new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
-      0, // map id
-      taskContext,
-      conf
-    );
-  }
-
-  private void assertSpillFilesWereCleanedUp() {
-    for (File spillFile : spillFilesCreated) {
-      assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
-        spillFile.exists());
-    }
-  }
-
-  private List<Tuple2<Object, Object>> readRecordsFromFile() throws 
IOException {
-    final ArrayList<Tuple2<Object, Object>> recordsList = new 
ArrayList<Tuple2<Object, Object>>();
-    long startOffset = 0;
-    for (int i = 0; i < NUM_PARTITITONS; i++) {
-      final long partitionSize = partitionSizesInMergedFile[i];
-      if (partitionSize > 0) {
-        InputStream in = new FileInputStream(mergedOutputFile);
-        ByteStreams.skipFully(in, startOffset);
-        in = new LimitedInputStream(in, partitionSize);
-        if (conf.getBoolean("spark.shuffle.compress", true)) {
-          in = 
CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
-        }
-        DeserializationStream recordsStream = 
serializer.newInstance().deserializeStream(in);
-        Iterator<Tuple2<Object, Object>> records = 
recordsStream.asKeyValueIterator();
-        while (records.hasNext()) {
-          Tuple2<Object, Object> record = records.next();
-          assertEquals(i, hashPartitioner.getPartition(record._1()));
-          recordsList.add(record);
-        }
-        recordsStream.close();
-        startOffset += partitionSize;
-      }
-    }
-    return recordsList;
-  }
-
-  @Test(expected=IllegalStateException.class)
-  public void mustCallWriteBeforeSuccessfulStop() throws IOException {
-    createWriter(false).stop(true);
-  }
-
-  @Test
-  public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
-    createWriter(false).stop(false);
-  }
-
-  class PandaException extends RuntimeException {
-  }
-
-  @Test(expected=PandaException.class)
-  public void writeFailurePropagates() throws Exception {
-    class BadRecords extends 
scala.collection.AbstractIterator<Product2<Object, Object>> {
-      @Override public boolean hasNext() {
-        throw new PandaException();
-      }
-      @Override public Product2<Object, Object> next() {
-        return null;
-      }
-    }
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
-    writer.write(new BadRecords());
-  }
-
-  @Test
-  public void writeEmptyIterator() throws Exception {
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
-    writer.write(Iterators.<Product2<Object, Object>>emptyIterator());
-    final Option<MapStatus> mapStatus = writer.stop(true);
-    assertTrue(mapStatus.isDefined());
-    assertTrue(mergedOutputFile.exists());
-    assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
-    assertEquals(0, 
taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten());
-    assertEquals(0, 
taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten());
-    assertEquals(0, taskMetrics.diskBytesSpilled());
-    assertEquals(0, taskMetrics.memoryBytesSpilled());
-  }
-
-  @Test
-  public void writeWithoutSpilling() throws Exception {
-    // In this example, each partition should have exactly one record:
-    final ArrayList<Product2<Object, Object>> dataToWrite =
-      new ArrayList<Product2<Object, Object>>();
-    for (int i = 0; i < NUM_PARTITITONS; i++) {
-      dataToWrite.add(new Tuple2<Object, Object>(i, i));
-    }
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
-    writer.write(dataToWrite.iterator());
-    final Option<MapStatus> mapStatus = writer.stop(true);
-    assertTrue(mapStatus.isDefined());
-    assertTrue(mergedOutputFile.exists());
-
-    long sumOfPartitionSizes = 0;
-    for (long size: partitionSizesInMergedFile) {
-      // All partitions should be the same size:
-      assertEquals(partitionSizesInMergedFile[0], size);
-      sumOfPartitionSizes += size;
-    }
-    assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
-    assertEquals(
-      HashMultiset.create(dataToWrite),
-      HashMultiset.create(readRecordsFromFile()));
-    assertSpillFilesWereCleanedUp();
-    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics().get();
-    assertEquals(dataToWrite.size(), 
shuffleWriteMetrics.shuffleRecordsWritten());
-    assertEquals(0, taskMetrics.diskBytesSpilled());
-    assertEquals(0, taskMetrics.memoryBytesSpilled());
-    assertEquals(mergedOutputFile.length(), 
shuffleWriteMetrics.shuffleBytesWritten());
-  }
-
-  private void testMergingSpills(
-      boolean transferToEnabled,
-      String compressionCodecName) throws IOException {
-    if (compressionCodecName != null) {
-      conf.set("spark.shuffle.compress", "true");
-      conf.set("spark.io.compression.codec", compressionCodecName);
-    } else {
-      conf.set("spark.shuffle.compress", "false");
-    }
-    final UnsafeShuffleWriter<Object, Object> writer = 
createWriter(transferToEnabled);
-    final ArrayList<Product2<Object, Object>> dataToWrite =
-      new ArrayList<Product2<Object, Object>>();
-    for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
-      dataToWrite.add(new Tuple2<Object, Object>(i, i));
-    }
-    writer.insertRecordIntoSorter(dataToWrite.get(0));
-    writer.insertRecordIntoSorter(dataToWrite.get(1));
-    writer.insertRecordIntoSorter(dataToWrite.get(2));
-    writer.insertRecordIntoSorter(dataToWrite.get(3));
-    writer.forceSorterToSpill();
-    writer.insertRecordIntoSorter(dataToWrite.get(4));
-    writer.insertRecordIntoSorter(dataToWrite.get(5));
-    writer.closeAndWriteOutput();
-    final Option<MapStatus> mapStatus = writer.stop(true);
-    assertTrue(mapStatus.isDefined());
-    assertTrue(mergedOutputFile.exists());
-    assertEquals(2, spillFilesCreated.size());
-
-    long sumOfPartitionSizes = 0;
-    for (long size: partitionSizesInMergedFile) {
-      sumOfPartitionSizes += size;
-    }
-    assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
-
-    assertEquals(
-      HashMultiset.create(dataToWrite),
-      HashMultiset.create(readRecordsFromFile()));
-    assertSpillFilesWereCleanedUp();
-    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics().get();
-    assertEquals(dataToWrite.size(), 
shuffleWriteMetrics.shuffleRecordsWritten());
-    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
-    assertThat(taskMetrics.diskBytesSpilled(), 
lessThan(mergedOutputFile.length()));
-    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
-    assertEquals(mergedOutputFile.length(), 
shuffleWriteMetrics.shuffleBytesWritten());
-  }
-
-  @Test
-  public void mergeSpillsWithTransferToAndLZF() throws Exception {
-    testMergingSpills(true, LZFCompressionCodec.class.getName());
-  }
-
-  @Test
-  public void mergeSpillsWithFileStreamAndLZF() throws Exception {
-    testMergingSpills(false, LZFCompressionCodec.class.getName());
-  }
-
-  @Test
-  public void mergeSpillsWithTransferToAndLZ4() throws Exception {
-    testMergingSpills(true, LZ4CompressionCodec.class.getName());
-  }
-
-  @Test
-  public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
-    testMergingSpills(false, LZ4CompressionCodec.class.getName());
-  }
-
-  @Test
-  public void mergeSpillsWithTransferToAndSnappy() throws Exception {
-    testMergingSpills(true, SnappyCompressionCodec.class.getName());
-  }
-
-  @Test
-  public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
-    testMergingSpills(false, SnappyCompressionCodec.class.getName());
-  }
-
-  @Test
-  public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
-    testMergingSpills(true, null);
-  }
-
-  @Test
-  public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
-    testMergingSpills(false, null);
-  }
-
-  @Test
-  public void writeEnoughDataToTriggerSpill() throws Exception {
-    when(shuffleMemoryManager.tryToAcquire(anyLong()))
-      .then(returnsFirstArg()) // Allocate initial sort buffer
-      .then(returnsFirstArg()) // Allocate initial data page
-      .thenReturn(0L) // Deny request to allocate new data page
-      .then(returnsFirstArg());  // Grant new sort buffer and data page.
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    final ArrayList<Product2<Object, Object>> dataToWrite = new 
ArrayList<Product2<Object, Object>>();
-    final byte[] bigByteArray = new 
byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
-    for (int i = 0; i < 128 + 1; i++) {
-      dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
-    }
-    writer.write(dataToWrite.iterator());
-    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
-    assertEquals(2, spillFilesCreated.size());
-    writer.stop(true);
-    readRecordsFromFile();
-    assertSpillFilesWereCleanedUp();
-    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics().get();
-    assertEquals(dataToWrite.size(), 
shuffleWriteMetrics.shuffleRecordsWritten());
-    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
-    assertThat(taskMetrics.diskBytesSpilled(), 
lessThan(mergedOutputFile.length()));
-    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
-    assertEquals(mergedOutputFile.length(), 
shuffleWriteMetrics.shuffleBytesWritten());
-  }
-
-  @Test
-  public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws 
Exception {
-    when(shuffleMemoryManager.tryToAcquire(anyLong()))
-      .then(returnsFirstArg()) // Allocate initial sort buffer
-      .then(returnsFirstArg()) // Allocate initial data page
-      .thenReturn(0L) // Deny request to grow sort buffer
-      .then(returnsFirstArg());  // Grant new sort buffer and data page.
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    final ArrayList<Product2<Object, Object>> dataToWrite = new 
ArrayList<Product2<Object, Object>>();
-    for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
-      dataToWrite.add(new Tuple2<Object, Object>(i, i));
-    }
-    writer.write(dataToWrite.iterator());
-    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
-    assertEquals(2, spillFilesCreated.size());
-    writer.stop(true);
-    readRecordsFromFile();
-    assertSpillFilesWereCleanedUp();
-    ShuffleWriteMetrics shuffleWriteMetrics = 
taskMetrics.shuffleWriteMetrics().get();
-    assertEquals(dataToWrite.size(), 
shuffleWriteMetrics.shuffleRecordsWritten());
-    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
-    assertThat(taskMetrics.diskBytesSpilled(), 
lessThan(mergedOutputFile.length()));
-    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
-    assertEquals(mergedOutputFile.length(), 
shuffleWriteMetrics.shuffleBytesWritten());
-  }
-
-  @Test
-  public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws 
Exception {
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    final ArrayList<Product2<Object, Object>> dataToWrite =
-      new ArrayList<Product2<Object, Object>>();
-    final byte[] bytes = new byte[(int) 
(UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
-    new Random(42).nextBytes(bytes);
-    dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
-    writer.write(dataToWrite.iterator());
-    writer.stop(true);
-    assertEquals(
-      HashMultiset.create(dataToWrite),
-      HashMultiset.create(readRecordsFromFile()));
-    assertSpillFilesWereCleanedUp();
-  }
-
-  @Test
-  public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    final ArrayList<Product2<Object, Object>> dataToWrite = new 
ArrayList<Product2<Object, Object>>();
-    dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(new 
byte[1])));
-    // We should be able to write a record that's right _at_ the max record 
size
-    final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()];
-    new Random(42).nextBytes(atMaxRecordSize);
-    dataToWrite.add(new Tuple2<Object, Object>(2, 
ByteBuffer.wrap(atMaxRecordSize)));
-    // Inserting a record that's larger than the max record size
-    final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 
1];
-    new Random(42).nextBytes(exceedsMaxRecordSize);
-    dataToWrite.add(new Tuple2<Object, Object>(3, 
ByteBuffer.wrap(exceedsMaxRecordSize)));
-    writer.write(dataToWrite.iterator());
-    writer.stop(true);
-    assertEquals(
-      HashMultiset.create(dataToWrite),
-      HashMultiset.create(readRecordsFromFile()));
-    assertSpillFilesWereCleanedUp();
-  }
-
-  @Test
-  public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
-    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
-    writer.forceSorterToSpill();
-    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
-    writer.stop(false);
-    assertSpillFilesWereCleanedUp();
-  }
-
-  @Test
-  public void testPeakMemoryUsed() throws Exception {
-    final long recordLengthBytes = 8;
-    final long pageSizeBytes = 256;
-    final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
-    when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
-    final UnsafeShuffleWriter<Object, Object> writer =
-      new UnsafeShuffleWriter<Object, Object>(
-        blockManager,
-        shuffleBlockResolver,
-        taskMemoryManager,
-        shuffleMemoryManager,
-        new UnsafeShuffleHandle<>(0, 1, shuffleDep),
-        0, // map id
-        taskContext,
-        conf);
-
-    // Peak memory should be monotonically increasing. More specifically, 
every time
-    // we allocate a new page it should increase by exactly the size of the 
page.
-    long previousPeakMemory = writer.getPeakMemoryUsedBytes();
-    long newPeakMemory;
-    try {
-      for (int i = 0; i < numRecordsPerPage * 10; i++) {
-        writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
-        newPeakMemory = writer.getPeakMemoryUsedBytes();
-        if (i % numRecordsPerPage == 0 && i != 0) {
-          // The first page is allocated in constructor, another page will be 
allocated after
-          // every numRecordsPerPage records (peak memory should change).
-          assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
-        } else {
-          assertEquals(previousPeakMemory, newPeakMemory);
-        }
-        previousPeakMemory = newPeakMemory;
-      }
-
-      // Spilling should not change peak memory
-      writer.forceSorterToSpill();
-      newPeakMemory = writer.getPeakMemoryUsedBytes();
-      assertEquals(previousPeakMemory, newPeakMemory);
-      for (int i = 0; i < numRecordsPerPage; i++) {
-        writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
-      }
-      newPeakMemory = writer.getPeakMemoryUsedBytes();
-      assertEquals(previousPeakMemory, newPeakMemory);
-
-      // Closing the writer should not change peak memory
-      writer.closeAndWriteOutput();
-      newPeakMemory = writer.getPeakMemoryUsedBytes();
-      assertEquals(previousPeakMemory, newPeakMemory);
-    } finally {
-      writer.stop(false);
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala 
b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
index 6335817..b8ab227 100644
--- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -17,13 +17,78 @@
 
 package org.apache.spark
 
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.apache.commons.io.filefilter.TrueFileFilter
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.util.Utils
+
 class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
 
   // This test suite should run all tests in ShuffleSuite with sort-based 
shuffle.
 
+  private var tempDir: File = _
+
   override def beforeAll() {
     conf.set("spark.shuffle.manager", "sort")
   }
+
+  override def beforeEach(): Unit = {
+    tempDir = Utils.createTempDir()
+    conf.set("spark.local.dir", tempDir.getAbsolutePath)
+  }
+
+  override def afterEach(): Unit = {
+    try {
+      Utils.deleteRecursively(tempDir)
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  test("SortShuffleManager properly cleans up files for shuffles that use the 
serialized path") {
+    sc = new SparkContext("local", "test", conf)
+    // Create a shuffled RDD and verify that it actually uses the new 
serialized map output path
+    val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+    val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new 
HashPartitioner(4))
+      .setSerializer(new KryoSerializer(conf))
+    val shuffleDep = 
shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+    assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep))
+    ensureFilesAreCleanedUp(shuffledRdd)
+  }
+
+  test("SortShuffleManager properly cleans up files for shuffles that use the 
deserialized path") {
+    sc = new SparkContext("local", "test", conf)
+    // Create a shuffled RDD and verify that it actually uses the old 
deserialized map output path
+    val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+    val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new 
HashPartitioner(4))
+      .setSerializer(new JavaSerializer(conf))
+    val shuffleDep = 
shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+    assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep))
+    ensureFilesAreCleanedUp(shuffledRdd)
+  }
+
+  private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit 
= {
+    def getAllFiles: Set[File] =
+      FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, 
TrueFileFilter.INSTANCE).asScala.toSet
+    val filesBeforeShuffle = getAllFiles
+    // Force the shuffle to be performed
+    shuffledRdd.count()
+    // Ensure that the shuffle actually created files that will need to be 
cleaned up
+    val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+    filesCreatedByShuffle.map(_.getName) should be
+    Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+    // Check that the cleanup actually removes the files
+    sc.env.blockManager.master.removeShuffle(0, blocking = true)
+    for (file <- filesCreatedByShuffle) {
+      assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 5b01ddb..3816b8c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -1062,10 +1062,10 @@ class DAGSchedulerSuite
    */
   test("don't submit stage until its dependencies map outputs are registered 
(SPARK-5259)") {
     val firstRDD = new MyRDD(sc, 3, Nil)
-    val firstShuffleDep = new ShuffleDependency(firstRDD, null)
+    val firstShuffleDep = new ShuffleDependency(firstRDD, new 
HashPartitioner(2))
     val firstShuffleId = firstShuffleDep.shuffleId
     val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep))
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(2))
     val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
     submit(reduceRdd, Array(0))
 
@@ -1175,7 +1175,7 @@ class DAGSchedulerSuite
    */
   test("register map outputs correctly after ExecutorLost and task 
Resubmitted") {
     val firstRDD = new MyRDD(sc, 3, Nil)
-    val firstShuffleDep = new ShuffleDependency(firstRDD, null)
+    val firstShuffleDep = new ShuffleDependency(firstRDD, new 
HashPartitioner(2))
     val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep))
     submit(reduceRdd, Array(0))
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to