http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala deleted file mode 100644 index 87a786b..0000000 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.InputStream -import java.nio.IntBuffer -import java.util.Comparator - -import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance} -import org.apache.spark.storage.DiskBlockObjectWriter -import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ - -/** - * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes - * its records upon insert and stores them as raw bytes. - * - * We use two data-structures to store the contents. The serialized records are stored in a - * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a - * metadata buffer that stores pointers into the data buffer as well as the partition ID of each - * record. Each entry in the metadata buffer takes up a fixed amount of space. - * - * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not - * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can - * happen without following any pointers, which should minimize cache misses. - * - * Currently, only sorting by partition is supported. - * - * Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across - * two integers: - * - * +-------------+------------+------------+-------------+ - * | keyStart | keyValLen | partitionId | - * +-------------+------------+------------+-------------+ - * - * The buffer can support up to `536870911 (2 ^ 29 - 1)` records. - * - * @param metaInitialRecords The initial number of entries in the metadata buffer. - * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records. - * @param serializerInstance the serializer used for serializing inserted records. - */ -private[spark] class PartitionedSerializedPairBuffer[K, V]( - metaInitialRecords: Int, - kvBlockSize: Int, - serializerInstance: SerializerInstance) - extends WritablePartitionedPairCollection[K, V] with SizeTracker { - - if (serializerInstance.isInstanceOf[JavaSerializerInstance]) { - throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" + - " Java-serialized objects.") - } - - require(metaInitialRecords <= MAXIMUM_RECORDS, - s"Can't make capacity bigger than ${MAXIMUM_RECORDS} records") - private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE) - - private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize) - private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer) - private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream) - - def insert(partition: Int, key: K, value: V): Unit = { - if (metaBuffer.position == metaBuffer.capacity) { - growMetaBuffer() - } - - val keyStart = kvBuffer.size - kvSerializationStream.writeKey[Any](key) - kvSerializationStream.writeValue[Any](value) - kvSerializationStream.flush() - val keyValLen = (kvBuffer.size - keyStart).toInt - - // keyStart, a long, gets split across two ints - metaBuffer.put(keyStart.toInt) - metaBuffer.put((keyStart >> 32).toInt) - metaBuffer.put(keyValLen) - metaBuffer.put(partition) - } - - /** Double the size of the array because we've reached capacity */ - private def growMetaBuffer(): Unit = { - if (metaBuffer.capacity >= MAXIMUM_META_BUFFER_CAPACITY) { - throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_RECORDS} records") - } - val newCapacity = - if (metaBuffer.capacity * 2 < 0 || metaBuffer.capacity * 2 > MAXIMUM_META_BUFFER_CAPACITY) { - // Overflow - MAXIMUM_META_BUFFER_CAPACITY - } else { - metaBuffer.capacity * 2 - } - val newMetaBuffer = IntBuffer.allocate(newCapacity) - newMetaBuffer.put(metaBuffer.array) - metaBuffer = newMetaBuffer - } - - /** Iterate through the data in a given order. For this class this is not really destructive. */ - override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) - : Iterator[((Int, K), V)] = { - sort(keyComparator) - val is = orderedInputStream - val deserStream = serializerInstance.deserializeStream(is) - new Iterator[((Int, K), V)] { - var metaBufferPos = 0 - def hasNext: Boolean = metaBufferPos < metaBuffer.position - def next(): ((Int, K), V) = { - val key = deserStream.readKey[Any]().asInstanceOf[K] - val value = deserStream.readValue[Any]().asInstanceOf[V] - val partition = metaBuffer.get(metaBufferPos + PARTITION) - metaBufferPos += RECORD_SIZE - ((partition, key), value) - } - } - } - - override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity - - override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) - : WritablePartitionedIterator = { - sort(keyComparator) - new WritablePartitionedIterator { - // current position in the meta buffer in ints - var pos = 0 - - def writeNext(writer: DiskBlockObjectWriter): Unit = { - val keyStart = getKeyStartPos(metaBuffer, pos) - val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) - pos += RECORD_SIZE - kvBuffer.read(keyStart, writer, keyValLen) - writer.recordWritten() - } - def nextPartition(): Int = metaBuffer.get(pos + PARTITION) - def hasNext(): Boolean = pos < metaBuffer.position - } - } - - // Visible for testing - def orderedInputStream: OrderedInputStream = { - new OrderedInputStream(metaBuffer, kvBuffer) - } - - private def sort(keyComparator: Option[Comparator[K]]): Unit = { - val comparator = if (keyComparator.isEmpty) { - new Comparator[Int]() { - def compare(partition1: Int, partition2: Int): Int = { - partition1 - partition2 - } - } - } else { - throw new UnsupportedOperationException() - } - - val sorter = new Sorter(new SerializedSortDataFormat) - sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator) - } -} - -private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer) - extends InputStream { - - import PartitionedSerializedPairBuffer._ - - private var metaBufferPos = 0 - private var kvBufferPos = - if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0 - - override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length) - - override def read(bytes: Array[Byte], offs: Int, len: Int): Int = { - if (metaBufferPos >= metaBuffer.position) { - return -1 - } - val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) - - (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt - val toRead = math.min(bytesRemainingInRecord, len) - kvBuffer.read(kvBufferPos, bytes, offs, toRead) - if (toRead == bytesRemainingInRecord) { - metaBufferPos += RECORD_SIZE - if (metaBufferPos < metaBuffer.position) { - kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos) - } - } else { - kvBufferPos += toRead - } - toRead - } - - override def read(): Int = { - throw new UnsupportedOperationException() - } -} - -private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] { - - private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE) - - /** Return the sort key for the element at the given index. */ - override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = { - metaBuffer.get(pos * RECORD_SIZE + PARTITION) - } - - /** Swap two elements. */ - override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = { - val iOff = pos0 * RECORD_SIZE - val jOff = pos1 * RECORD_SIZE - System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE) - System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE) - System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE) - } - - /** Copy a single element from src(srcPos) to dst(dstPos). */ - override def copyElement( - src: IntBuffer, - srcPos: Int, - dst: IntBuffer, - dstPos: Int): Unit = { - val srcOff = srcPos * RECORD_SIZE - val dstOff = dstPos * RECORD_SIZE - System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE) - } - - /** - * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos. - * Overlapping ranges are allowed. - */ - override def copyRange( - src: IntBuffer, - srcPos: Int, - dst: IntBuffer, - dstPos: Int, - length: Int): Unit = { - val srcOff = srcPos * RECORD_SIZE - val dstOff = dstPos * RECORD_SIZE - System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length) - } - - /** - * Allocates a Buffer that can hold up to 'length' elements. - * All elements of the buffer should be considered invalid until data is explicitly copied in. - */ - override def allocate(length: Int): IntBuffer = { - IntBuffer.allocate(length * RECORD_SIZE) - } -} - -private object PartitionedSerializedPairBuffer { - val KEY_START = 0 // keyStart, a long, gets split across two ints - val KEY_VAL_LEN = 2 - val PARTITION = 3 - val RECORD_SIZE = PARTITION + 1 // num ints of metadata - - val MAXIMUM_RECORDS = Int.MaxValue / RECORD_SIZE // (2 ^ 29) - 1 - val MAXIMUM_META_BUFFER_CAPACITY = MAXIMUM_RECORDS * RECORD_SIZE // (2 ^ 31) - 4 - - def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = { - val lower32 = metaBuffer.get(metaBufferPos + KEY_START) - val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1) - (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL) - } -}
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java new file mode 100644 index 0000000..232ae4d --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import org.apache.spark.shuffle.sort.PackedRecordPointer; +import org.junit.Test; +import static org.junit.Assert.*; + +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import static org.apache.spark.shuffle.sort.PackedRecordPointer.*; + +public class PackedRecordPointerSuite { + + @Test + public void heap() { + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock page0 = memoryManager.allocatePage(128); + final MemoryBlock page1 = memoryManager.allocatePage(128); + final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, + page1.getBaseOffset() + 42); + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); + assertEquals(360, packedPointer.getPartitionId()); + final long recordPointer = packedPointer.getRecordPointer(); + assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer)); + assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer)); + assertEquals(addressInPage1, recordPointer); + memoryManager.cleanUpAllAllocatedMemory(); + } + + @Test + public void offHeap() { + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + final MemoryBlock page0 = memoryManager.allocatePage(128); + final MemoryBlock page1 = memoryManager.allocatePage(128); + final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, + page1.getBaseOffset() + 42); + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); + assertEquals(360, packedPointer.getPartitionId()); + final long recordPointer = packedPointer.getRecordPointer(); + assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer)); + assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer)); + assertEquals(addressInPage1, recordPointer); + memoryManager.cleanUpAllAllocatedMemory(); + } + + @Test + public void maximumPartitionIdCanBeEncoded() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID)); + assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId()); + } + + @Test + public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + try { + // Pointers greater than the maximum partition ID will overflow or trigger an assertion error + packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1)); + assertFalse(MAXIMUM_PARTITION_ID + 1 == packedPointer.getPartitionId()); + } catch (AssertionError e ) { + // pass + } + } + + @Test + public void maximumOffsetInPageCanBeEncoded() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1); + packedPointer.set(PackedRecordPointer.packPointer(address, 0)); + assertEquals(address, packedPointer.getRecordPointer()); + } + + @Test + public void offsetsPastMaxOffsetInPageWillOverflow() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES); + packedPointer.set(PackedRecordPointer.packPointer(address, 0)); + assertEquals(0, packedPointer.getRecordPointer()); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java new file mode 100644 index 0000000..1ef3c5f --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.util.Arrays; +import java.util.Random; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public class ShuffleInMemorySorterSuite { + + private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { + final byte[] strBytes = new byte[strLength]; + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); + return new String(strBytes); + } + + @Test + public void testSortingEmptyInput() { + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100); + final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); + assert(!iter.hasNext()); + } + + @Test + public void testBasicSorting() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); + final HashPartitioner hashPartitioner = new HashPartitioner(4); + + // Write the records into the data page and store pointers into the sorter + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); + final byte[] strBytes = str.getBytes("utf-8"); + Platform.putInt(baseObject, position, strBytes.length); + position += 4; + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); + position += strBytes.length; + sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str)); + } + + // Sort the records + final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); + int prevPartitionId = -1; + Arrays.sort(dataToSort); + for (int i = 0; i < dataToSort.length; i++) { + Assert.assertTrue(iter.hasNext()); + iter.loadNext(); + final int partitionId = iter.packedRecordPointer.getPartitionId(); + Assert.assertTrue(partitionId >= 0 && partitionId <= 3); + Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId, + partitionId >= prevPartitionId); + final long recordAddress = iter.packedRecordPointer.getRecordPointer(); + final int recordLength = Platform.getInt( + memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress)); + final String str = getStringFromDataPage( + memoryManager.getPage(recordAddress), + memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length + recordLength); + Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1); + } + Assert.assertFalse(iter.hasNext()); + } + + @Test + public void testSortingManyNumbers() throws Exception { + ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); + int[] numbersToSort = new int[128000]; + Random random = new Random(16); + for (int i = 0; i < numbersToSort.length; i++) { + numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1); + sorter.insertRecord(0, numbersToSort[i]); + } + Arrays.sort(numbersToSort); + int[] sorterResult = new int[numbersToSort.length]; + ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); + int j = 0; + while (iter.hasNext()) { + iter.loadNext(); + sorterResult[j] = iter.packedRecordPointer.getPartitionId(); + j += 1; + } + Assert.assertArrayEquals(numbersToSort, sorterResult); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java new file mode 100644 index 0000000..29d9823 --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -0,0 +1,560 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.*; +import java.nio.ByteBuffer; +import java.util.*; + +import scala.*; +import scala.collection.Iterator; +import scala.runtime.AbstractFunction1; + +import com.google.common.collect.Iterators; +import com.google.common.collect.HashMultiset; +import com.google.common.io.ByteStreams; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.*; +import static org.mockito.AdditionalAnswers.returnsFirstArg; +import static org.mockito.Answers.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.*; + +import org.apache.spark.*; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZ4CompressionCodec; +import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.io.SnappyCompressionCodec; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.serializer.*; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.shuffle.sort.SerializedShuffleHandle; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +public class UnsafeShuffleWriterSuite { + + static final int NUM_PARTITITONS = 4; + final TaskMemoryManager taskMemoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS); + File mergedOutputFile; + File tempDir; + long[] partitionSizesInMergedFile; + final LinkedList<File> spillFilesCreated = new LinkedList<File>(); + SparkConf conf; + final Serializer serializer = new KryoSerializer(new SparkConf()); + TaskMetrics taskMetrics; + + @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; + @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; + @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver; + @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; + @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep; + + private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> { + @Override + public OutputStream apply(OutputStream stream) { + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream); + } else { + return stream; + } + } + } + + @After + public void tearDown() { + Utils.deleteRecursively(tempDir); + final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); + if (leakedMemory != 0) { + fail("Test leaked " + leakedMemory + " bytes of managed memory"); + } + } + + @Before + @SuppressWarnings("unchecked") + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + tempDir = Utils.createTempDir("test", "test"); + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); + partitionSizesInMergedFile = null; + spillFilesCreated.clear(); + conf = new SparkConf().set("spark.buffer.pageSize", "128m"); + taskMetrics = new TaskMetrics(); + + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024); + + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer( + new Answer<InputStream>() { + @Override + public InputStream answer(InvocationOnMock invocation) throws Throwable { + assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); + InputStream is = (InputStream) invocation.getArguments()[1]; + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is); + } else { + return is; + } + } + } + ); + + when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer( + new Answer<OutputStream>() { + @Override + public OutputStream answer(InvocationOnMock invocation) throws Throwable { + assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); + OutputStream os = (OutputStream) invocation.getArguments()[1]; + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os); + } else { + return os; + } + } + } + ); + + when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); + doAnswer(new Answer<Void>() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; + return null; + } + }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class)); + + when(diskBlockManager.createTempShuffleBlock()).thenAnswer( + new Answer<Tuple2<TempShuffleBlockId, File>>() { + @Override + public Tuple2<TempShuffleBlockId, File> answer( + InvocationOnMock invocationOnMock) throws Throwable { + TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + + when(taskContext.taskMetrics()).thenReturn(taskMetrics); + when(taskContext.internalMetricsToAccumulators()).thenReturn(null); + + when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer)); + when(shuffleDep.partitioner()).thenReturn(hashPartitioner); + } + + private UnsafeShuffleWriter<Object, Object> createWriter( + boolean transferToEnabled) throws IOException { + conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); + return new UnsafeShuffleWriter<Object, Object>( + blockManager, + shuffleBlockResolver, + taskMemoryManager, + shuffleMemoryManager, + new SerializedShuffleHandle<Object, Object>(0, 1, shuffleDep), + 0, // map id + taskContext, + conf + ); + } + + private void assertSpillFilesWereCleanedUp() { + for (File spillFile : spillFilesCreated) { + assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", + spillFile.exists()); + } + } + + private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException { + final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<Tuple2<Object, Object>>(); + long startOffset = 0; + for (int i = 0; i < NUM_PARTITITONS; i++) { + final long partitionSize = partitionSizesInMergedFile[i]; + if (partitionSize > 0) { + InputStream in = new FileInputStream(mergedOutputFile); + ByteStreams.skipFully(in, startOffset); + in = new LimitedInputStream(in, partitionSize); + if (conf.getBoolean("spark.shuffle.compress", true)) { + in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); + } + DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in); + Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator(); + while (records.hasNext()) { + Tuple2<Object, Object> record = records.next(); + assertEquals(i, hashPartitioner.getPartition(record._1())); + recordsList.add(record); + } + recordsStream.close(); + startOffset += partitionSize; + } + } + return recordsList; + } + + @Test(expected=IllegalStateException.class) + public void mustCallWriteBeforeSuccessfulStop() throws IOException { + createWriter(false).stop(true); + } + + @Test + public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException { + createWriter(false).stop(false); + } + + class PandaException extends RuntimeException { + } + + @Test(expected=PandaException.class) + public void writeFailurePropagates() throws Exception { + class BadRecords extends scala.collection.AbstractIterator<Product2<Object, Object>> { + @Override public boolean hasNext() { + throw new PandaException(); + } + @Override public Product2<Object, Object> next() { + return null; + } + } + final UnsafeShuffleWriter<Object, Object> writer = createWriter(true); + writer.write(new BadRecords()); + } + + @Test + public void writeEmptyIterator() throws Exception { + final UnsafeShuffleWriter<Object, Object> writer = createWriter(true); + writer.write(Iterators.<Product2<Object, Object>>emptyIterator()); + final Option<MapStatus> mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); + assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); + assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten()); + assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten()); + assertEquals(0, taskMetrics.diskBytesSpilled()); + assertEquals(0, taskMetrics.memoryBytesSpilled()); + } + + @Test + public void writeWithoutSpilling() throws Exception { + // In this example, each partition should have exactly one record: + final ArrayList<Product2<Object, Object>> dataToWrite = + new ArrayList<Product2<Object, Object>>(); + for (int i = 0; i < NUM_PARTITITONS; i++) { + dataToWrite.add(new Tuple2<Object, Object>(i, i)); + } + final UnsafeShuffleWriter<Object, Object> writer = createWriter(true); + writer.write(dataToWrite.iterator()); + final Option<MapStatus> mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); + + long sumOfPartitionSizes = 0; + for (long size: partitionSizesInMergedFile) { + // All partitions should be the same size: + assertEquals(partitionSizesInMergedFile[0], size); + sumOfPartitionSizes += size; + } + assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertEquals(0, taskMetrics.diskBytesSpilled()); + assertEquals(0, taskMetrics.memoryBytesSpilled()); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + } + + private void testMergingSpills( + boolean transferToEnabled, + String compressionCodecName) throws IOException { + if (compressionCodecName != null) { + conf.set("spark.shuffle.compress", "true"); + conf.set("spark.io.compression.codec", compressionCodecName); + } else { + conf.set("spark.shuffle.compress", "false"); + } + final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled); + final ArrayList<Product2<Object, Object>> dataToWrite = + new ArrayList<Product2<Object, Object>>(); + for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { + dataToWrite.add(new Tuple2<Object, Object>(i, i)); + } + writer.insertRecordIntoSorter(dataToWrite.get(0)); + writer.insertRecordIntoSorter(dataToWrite.get(1)); + writer.insertRecordIntoSorter(dataToWrite.get(2)); + writer.insertRecordIntoSorter(dataToWrite.get(3)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(dataToWrite.get(4)); + writer.insertRecordIntoSorter(dataToWrite.get(5)); + writer.closeAndWriteOutput(); + final Option<MapStatus> mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); + assertEquals(2, spillFilesCreated.size()); + + long sumOfPartitionSizes = 0; + for (long size: partitionSizesInMergedFile) { + sumOfPartitionSizes += size; + } + assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); + + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + } + + @Test + public void mergeSpillsWithTransferToAndLZF() throws Exception { + testMergingSpills(true, LZFCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndLZF() throws Exception { + testMergingSpills(false, LZFCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndLZ4() throws Exception { + testMergingSpills(true, LZ4CompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndLZ4() throws Exception { + testMergingSpills(false, LZ4CompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndSnappy() throws Exception { + testMergingSpills(true, SnappyCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndSnappy() throws Exception { + testMergingSpills(false, SnappyCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndNoCompression() throws Exception { + testMergingSpills(true, null); + } + + @Test + public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { + testMergingSpills(false, null); + } + + @Test + public void writeEnoughDataToTriggerSpill() throws Exception { + when(shuffleMemoryManager.tryToAcquire(anyLong())) + .then(returnsFirstArg()) // Allocate initial sort buffer + .then(returnsFirstArg()) // Allocate initial data page + .thenReturn(0L) // Deny request to allocate new data page + .then(returnsFirstArg()); // Grant new sort buffer and data page. + final UnsafeShuffleWriter<Object, Object> writer = createWriter(false); + final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>(); + final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128]; + for (int i = 0; i < 128 + 1; i++) { + dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray)); + } + writer.write(dataToWrite.iterator()); + verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); + assertEquals(2, spillFilesCreated.size()); + writer.stop(true); + readRecordsFromFile(); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + } + + @Test + public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { + when(shuffleMemoryManager.tryToAcquire(anyLong())) + .then(returnsFirstArg()) // Allocate initial sort buffer + .then(returnsFirstArg()) // Allocate initial data page + .thenReturn(0L) // Deny request to grow sort buffer + .then(returnsFirstArg()); // Grant new sort buffer and data page. + final UnsafeShuffleWriter<Object, Object> writer = createWriter(false); + final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>(); + for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) { + dataToWrite.add(new Tuple2<Object, Object>(i, i)); + } + writer.write(dataToWrite.iterator()); + verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); + assertEquals(2, spillFilesCreated.size()); + writer.stop(true); + readRecordsFromFile(); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + } + + @Test + public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception { + final UnsafeShuffleWriter<Object, Object> writer = createWriter(false); + final ArrayList<Product2<Object, Object>> dataToWrite = + new ArrayList<Product2<Object, Object>>(); + final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; + new Random(42).nextBytes(bytes); + dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes))); + writer.write(dataToWrite.iterator()); + writer.stop(true); + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { + final UnsafeShuffleWriter<Object, Object> writer = createWriter(false); + final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>(); + dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(new byte[1]))); + // We should be able to write a record that's right _at_ the max record size + final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()]; + new Random(42).nextBytes(atMaxRecordSize); + dataToWrite.add(new Tuple2<Object, Object>(2, ByteBuffer.wrap(atMaxRecordSize))); + // Inserting a record that's larger than the max record size + final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1]; + new Random(42).nextBytes(exceedsMaxRecordSize); + dataToWrite.add(new Tuple2<Object, Object>(3, ByteBuffer.wrap(exceedsMaxRecordSize))); + writer.write(dataToWrite.iterator()); + writer.stop(true); + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { + final UnsafeShuffleWriter<Object, Object> writer = createWriter(false); + writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1)); + writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2)); + writer.stop(false); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void testPeakMemoryUsed() throws Exception { + final long recordLengthBytes = 8; + final long pageSizeBytes = 256; + final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; + when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); + final UnsafeShuffleWriter<Object, Object> writer = + new UnsafeShuffleWriter<Object, Object>( + blockManager, + shuffleBlockResolver, + taskMemoryManager, + shuffleMemoryManager, + new SerializedShuffleHandle<>(0, 1, shuffleDep), + 0, // map id + taskContext, + conf); + + // Peak memory should be monotonically increasing. More specifically, every time + // we allocate a new page it should increase by exactly the size of the page. + long previousPeakMemory = writer.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (int i = 0; i < numRecordsPerPage * 10; i++) { + writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1)); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0 && i != 0) { + // The first page is allocated in constructor, another page will be allocated after + // every numRecordsPerPage records (peak memory should change). + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Spilling should not change peak memory + writer.forceSorterToSpill(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + for (int i = 0; i < numRecordsPerPage; i++) { + writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1)); + } + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + + // Closing the writer should not change peak memory + writer.closeAndWriteOutput(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { + writer.stop(false); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java deleted file mode 100644 index 934b7e0..0000000 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe; - -import org.junit.Test; -import static org.junit.Assert.*; - -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; -import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*; - -public class PackedRecordPointerSuite { - - @Test - public void heap() { - final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final MemoryBlock page0 = memoryManager.allocatePage(128); - final MemoryBlock page1 = memoryManager.allocatePage(128); - final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, - page1.getBaseOffset() + 42); - PackedRecordPointer packedPointer = new PackedRecordPointer(); - packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); - assertEquals(360, packedPointer.getPartitionId()); - final long recordPointer = packedPointer.getRecordPointer(); - assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer)); - assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer)); - assertEquals(addressInPage1, recordPointer); - memoryManager.cleanUpAllAllocatedMemory(); - } - - @Test - public void offHeap() { - final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); - final MemoryBlock page0 = memoryManager.allocatePage(128); - final MemoryBlock page1 = memoryManager.allocatePage(128); - final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, - page1.getBaseOffset() + 42); - PackedRecordPointer packedPointer = new PackedRecordPointer(); - packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); - assertEquals(360, packedPointer.getPartitionId()); - final long recordPointer = packedPointer.getRecordPointer(); - assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer)); - assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer)); - assertEquals(addressInPage1, recordPointer); - memoryManager.cleanUpAllAllocatedMemory(); - } - - @Test - public void maximumPartitionIdCanBeEncoded() { - PackedRecordPointer packedPointer = new PackedRecordPointer(); - packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID)); - assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId()); - } - - @Test - public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() { - PackedRecordPointer packedPointer = new PackedRecordPointer(); - try { - // Pointers greater than the maximum partition ID will overflow or trigger an assertion error - packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1)); - assertFalse(MAXIMUM_PARTITION_ID + 1 == packedPointer.getPartitionId()); - } catch (AssertionError e ) { - // pass - } - } - - @Test - public void maximumOffsetInPageCanBeEncoded() { - PackedRecordPointer packedPointer = new PackedRecordPointer(); - long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1); - packedPointer.set(PackedRecordPointer.packPointer(address, 0)); - assertEquals(address, packedPointer.getRecordPointer()); - } - - @Test - public void offsetsPastMaxOffsetInPageWillOverflow() { - PackedRecordPointer packedPointer = new PackedRecordPointer(); - long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES); - packedPointer.set(PackedRecordPointer.packPointer(address, 0)); - assertEquals(0, packedPointer.getRecordPointer()); - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java deleted file mode 100644 index 40fefe2..0000000 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe; - -import java.util.Arrays; -import java.util.Random; - -import org.junit.Assert; -import org.junit.Test; - -import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; - -public class UnsafeShuffleInMemorySorterSuite { - - private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { - final byte[] strBytes = new byte[strLength]; - Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); - return new String(strBytes); - } - - @Test - public void testSortingEmptyInput() { - final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100); - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); - assert(!iter.hasNext()); - } - - @Test - public void testBasicSorting() throws Exception { - final String[] dataToSort = new String[] { - "Boba", - "Pearls", - "Tapioca", - "Taho", - "Condensed Milk", - "Jasmine", - "Milk Tea", - "Lychee", - "Mango" - }; - final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final MemoryBlock dataPage = memoryManager.allocatePage(2048); - final Object baseObject = dataPage.getBaseObject(); - final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); - final HashPartitioner hashPartitioner = new HashPartitioner(4); - - // Write the records into the data page and store pointers into the sorter - long position = dataPage.getBaseOffset(); - for (String str : dataToSort) { - final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); - final byte[] strBytes = str.getBytes("utf-8"); - Platform.putInt(baseObject, position, strBytes.length); - position += 4; - Platform.copyMemory( - strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); - position += strBytes.length; - sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str)); - } - - // Sort the records - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); - int prevPartitionId = -1; - Arrays.sort(dataToSort); - for (int i = 0; i < dataToSort.length; i++) { - Assert.assertTrue(iter.hasNext()); - iter.loadNext(); - final int partitionId = iter.packedRecordPointer.getPartitionId(); - Assert.assertTrue(partitionId >= 0 && partitionId <= 3); - Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId, - partitionId >= prevPartitionId); - final long recordAddress = iter.packedRecordPointer.getRecordPointer(); - final int recordLength = Platform.getInt( - memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress)); - final String str = getStringFromDataPage( - memoryManager.getPage(recordAddress), - memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length - recordLength); - Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1); - } - Assert.assertFalse(iter.hasNext()); - } - - @Test - public void testSortingManyNumbers() throws Exception { - UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); - int[] numbersToSort = new int[128000]; - Random random = new Random(16); - for (int i = 0; i < numbersToSort.length; i++) { - numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1); - sorter.insertRecord(0, numbersToSort[i]); - } - Arrays.sort(numbersToSort); - int[] sorterResult = new int[numbersToSort.length]; - UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); - int j = 0; - while (iter.hasNext()) { - iter.loadNext(); - sorterResult[j] = iter.packedRecordPointer.getPartitionId(); - j += 1; - } - Assert.assertArrayEquals(numbersToSort, sorterResult); - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java deleted file mode 100644 index d218344..0000000 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ /dev/null @@ -1,560 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe; - -import java.io.*; -import java.nio.ByteBuffer; -import java.util.*; - -import scala.*; -import scala.collection.Iterator; -import scala.reflect.ClassTag; -import scala.runtime.AbstractFunction1; - -import com.google.common.collect.Iterators; -import com.google.common.collect.HashMultiset; -import com.google.common.io.ByteStreams; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.lessThan; -import static org.junit.Assert.*; -import static org.mockito.AdditionalAnswers.returnsFirstArg; -import static org.mockito.Answers.RETURNS_SMART_NULLS; -import static org.mockito.Mockito.*; - -import org.apache.spark.*; -import org.apache.spark.io.CompressionCodec$; -import org.apache.spark.io.LZ4CompressionCodec; -import org.apache.spark.io.LZFCompressionCodec; -import org.apache.spark.io.SnappyCompressionCodec; -import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.executor.TaskMetrics; -import org.apache.spark.network.util.LimitedInputStream; -import org.apache.spark.serializer.*; -import org.apache.spark.scheduler.MapStatus; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.storage.*; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.TaskMemoryManager; -import org.apache.spark.util.Utils; - -public class UnsafeShuffleWriterSuite { - - static final int NUM_PARTITITONS = 4; - final TaskMemoryManager taskMemoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS); - File mergedOutputFile; - File tempDir; - long[] partitionSizesInMergedFile; - final LinkedList<File> spillFilesCreated = new LinkedList<File>(); - SparkConf conf; - final Serializer serializer = new KryoSerializer(new SparkConf()); - TaskMetrics taskMetrics; - - @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; - @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; - @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver; - @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; - @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; - @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep; - - private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> { - @Override - public OutputStream apply(OutputStream stream) { - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream); - } else { - return stream; - } - } - } - - @After - public void tearDown() { - Utils.deleteRecursively(tempDir); - final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); - if (leakedMemory != 0) { - fail("Test leaked " + leakedMemory + " bytes of managed memory"); - } - } - - @Before - @SuppressWarnings("unchecked") - public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - tempDir = Utils.createTempDir("test", "test"); - mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); - partitionSizesInMergedFile = null; - spillFilesCreated.clear(); - conf = new SparkConf().set("spark.buffer.pageSize", "128m"); - taskMetrics = new TaskMetrics(); - - when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); - when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024); - - when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); - when(blockManager.getDiskWriter( - any(BlockId.class), - any(File.class), - any(SerializerInstance.class), - anyInt(), - any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() { - @Override - public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { - Object[] args = invocationOnMock.getArguments(); - - return new DiskBlockObjectWriter( - (File) args[1], - (SerializerInstance) args[2], - (Integer) args[3], - new CompressStream(), - false, - (ShuffleWriteMetrics) args[4] - ); - } - }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer( - new Answer<InputStream>() { - @Override - public InputStream answer(InvocationOnMock invocation) throws Throwable { - assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); - InputStream is = (InputStream) invocation.getArguments()[1]; - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is); - } else { - return is; - } - } - } - ); - - when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer( - new Answer<OutputStream>() { - @Override - public OutputStream answer(InvocationOnMock invocation) throws Throwable { - assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); - OutputStream os = (OutputStream) invocation.getArguments()[1]; - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os); - } else { - return os; - } - } - } - ); - - when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); - doAnswer(new Answer<Void>() { - @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; - return null; - } - }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class)); - - when(diskBlockManager.createTempShuffleBlock()).thenAnswer( - new Answer<Tuple2<TempShuffleBlockId, File>>() { - @Override - public Tuple2<TempShuffleBlockId, File> answer( - InvocationOnMock invocationOnMock) throws Throwable { - TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); - File file = File.createTempFile("spillFile", ".spill", tempDir); - spillFilesCreated.add(file); - return Tuple2$.MODULE$.apply(blockId, file); - } - }); - - when(taskContext.taskMetrics()).thenReturn(taskMetrics); - when(taskContext.internalMetricsToAccumulators()).thenReturn(null); - - when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer)); - when(shuffleDep.partitioner()).thenReturn(hashPartitioner); - } - - private UnsafeShuffleWriter<Object, Object> createWriter( - boolean transferToEnabled) throws IOException { - conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); - return new UnsafeShuffleWriter<Object, Object>( - blockManager, - shuffleBlockResolver, - taskMemoryManager, - shuffleMemoryManager, - new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep), - 0, // map id - taskContext, - conf - ); - } - - private void assertSpillFilesWereCleanedUp() { - for (File spillFile : spillFilesCreated) { - assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", - spillFile.exists()); - } - } - - private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException { - final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<Tuple2<Object, Object>>(); - long startOffset = 0; - for (int i = 0; i < NUM_PARTITITONS; i++) { - final long partitionSize = partitionSizesInMergedFile[i]; - if (partitionSize > 0) { - InputStream in = new FileInputStream(mergedOutputFile); - ByteStreams.skipFully(in, startOffset); - in = new LimitedInputStream(in, partitionSize); - if (conf.getBoolean("spark.shuffle.compress", true)) { - in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); - } - DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in); - Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator(); - while (records.hasNext()) { - Tuple2<Object, Object> record = records.next(); - assertEquals(i, hashPartitioner.getPartition(record._1())); - recordsList.add(record); - } - recordsStream.close(); - startOffset += partitionSize; - } - } - return recordsList; - } - - @Test(expected=IllegalStateException.class) - public void mustCallWriteBeforeSuccessfulStop() throws IOException { - createWriter(false).stop(true); - } - - @Test - public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException { - createWriter(false).stop(false); - } - - class PandaException extends RuntimeException { - } - - @Test(expected=PandaException.class) - public void writeFailurePropagates() throws Exception { - class BadRecords extends scala.collection.AbstractIterator<Product2<Object, Object>> { - @Override public boolean hasNext() { - throw new PandaException(); - } - @Override public Product2<Object, Object> next() { - return null; - } - } - final UnsafeShuffleWriter<Object, Object> writer = createWriter(true); - writer.write(new BadRecords()); - } - - @Test - public void writeEmptyIterator() throws Exception { - final UnsafeShuffleWriter<Object, Object> writer = createWriter(true); - writer.write(Iterators.<Product2<Object, Object>>emptyIterator()); - final Option<MapStatus> mapStatus = writer.stop(true); - assertTrue(mapStatus.isDefined()); - assertTrue(mergedOutputFile.exists()); - assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); - assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten()); - assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten()); - assertEquals(0, taskMetrics.diskBytesSpilled()); - assertEquals(0, taskMetrics.memoryBytesSpilled()); - } - - @Test - public void writeWithoutSpilling() throws Exception { - // In this example, each partition should have exactly one record: - final ArrayList<Product2<Object, Object>> dataToWrite = - new ArrayList<Product2<Object, Object>>(); - for (int i = 0; i < NUM_PARTITITONS; i++) { - dataToWrite.add(new Tuple2<Object, Object>(i, i)); - } - final UnsafeShuffleWriter<Object, Object> writer = createWriter(true); - writer.write(dataToWrite.iterator()); - final Option<MapStatus> mapStatus = writer.stop(true); - assertTrue(mapStatus.isDefined()); - assertTrue(mergedOutputFile.exists()); - - long sumOfPartitionSizes = 0; - for (long size: partitionSizesInMergedFile) { - // All partitions should be the same size: - assertEquals(partitionSizesInMergedFile[0], size); - sumOfPartitionSizes += size; - } - assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); - assertEquals( - HashMultiset.create(dataToWrite), - HashMultiset.create(readRecordsFromFile())); - assertSpillFilesWereCleanedUp(); - ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); - assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); - assertEquals(0, taskMetrics.diskBytesSpilled()); - assertEquals(0, taskMetrics.memoryBytesSpilled()); - assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); - } - - private void testMergingSpills( - boolean transferToEnabled, - String compressionCodecName) throws IOException { - if (compressionCodecName != null) { - conf.set("spark.shuffle.compress", "true"); - conf.set("spark.io.compression.codec", compressionCodecName); - } else { - conf.set("spark.shuffle.compress", "false"); - } - final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled); - final ArrayList<Product2<Object, Object>> dataToWrite = - new ArrayList<Product2<Object, Object>>(); - for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { - dataToWrite.add(new Tuple2<Object, Object>(i, i)); - } - writer.insertRecordIntoSorter(dataToWrite.get(0)); - writer.insertRecordIntoSorter(dataToWrite.get(1)); - writer.insertRecordIntoSorter(dataToWrite.get(2)); - writer.insertRecordIntoSorter(dataToWrite.get(3)); - writer.forceSorterToSpill(); - writer.insertRecordIntoSorter(dataToWrite.get(4)); - writer.insertRecordIntoSorter(dataToWrite.get(5)); - writer.closeAndWriteOutput(); - final Option<MapStatus> mapStatus = writer.stop(true); - assertTrue(mapStatus.isDefined()); - assertTrue(mergedOutputFile.exists()); - assertEquals(2, spillFilesCreated.size()); - - long sumOfPartitionSizes = 0; - for (long size: partitionSizesInMergedFile) { - sumOfPartitionSizes += size; - } - assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); - - assertEquals( - HashMultiset.create(dataToWrite), - HashMultiset.create(readRecordsFromFile())); - assertSpillFilesWereCleanedUp(); - ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); - assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); - assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); - assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); - assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); - assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); - } - - @Test - public void mergeSpillsWithTransferToAndLZF() throws Exception { - testMergingSpills(true, LZFCompressionCodec.class.getName()); - } - - @Test - public void mergeSpillsWithFileStreamAndLZF() throws Exception { - testMergingSpills(false, LZFCompressionCodec.class.getName()); - } - - @Test - public void mergeSpillsWithTransferToAndLZ4() throws Exception { - testMergingSpills(true, LZ4CompressionCodec.class.getName()); - } - - @Test - public void mergeSpillsWithFileStreamAndLZ4() throws Exception { - testMergingSpills(false, LZ4CompressionCodec.class.getName()); - } - - @Test - public void mergeSpillsWithTransferToAndSnappy() throws Exception { - testMergingSpills(true, SnappyCompressionCodec.class.getName()); - } - - @Test - public void mergeSpillsWithFileStreamAndSnappy() throws Exception { - testMergingSpills(false, SnappyCompressionCodec.class.getName()); - } - - @Test - public void mergeSpillsWithTransferToAndNoCompression() throws Exception { - testMergingSpills(true, null); - } - - @Test - public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { - testMergingSpills(false, null); - } - - @Test - public void writeEnoughDataToTriggerSpill() throws Exception { - when(shuffleMemoryManager.tryToAcquire(anyLong())) - .then(returnsFirstArg()) // Allocate initial sort buffer - .then(returnsFirstArg()) // Allocate initial data page - .thenReturn(0L) // Deny request to allocate new data page - .then(returnsFirstArg()); // Grant new sort buffer and data page. - final UnsafeShuffleWriter<Object, Object> writer = createWriter(false); - final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>(); - final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128]; - for (int i = 0; i < 128 + 1; i++) { - dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray)); - } - writer.write(dataToWrite.iterator()); - verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); - assertEquals(2, spillFilesCreated.size()); - writer.stop(true); - readRecordsFromFile(); - assertSpillFilesWereCleanedUp(); - ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); - assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); - assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); - assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); - assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); - assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); - } - - @Test - public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { - when(shuffleMemoryManager.tryToAcquire(anyLong())) - .then(returnsFirstArg()) // Allocate initial sort buffer - .then(returnsFirstArg()) // Allocate initial data page - .thenReturn(0L) // Deny request to grow sort buffer - .then(returnsFirstArg()); // Grant new sort buffer and data page. - final UnsafeShuffleWriter<Object, Object> writer = createWriter(false); - final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>(); - for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) { - dataToWrite.add(new Tuple2<Object, Object>(i, i)); - } - writer.write(dataToWrite.iterator()); - verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); - assertEquals(2, spillFilesCreated.size()); - writer.stop(true); - readRecordsFromFile(); - assertSpillFilesWereCleanedUp(); - ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); - assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); - assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); - assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); - assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); - assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); - } - - @Test - public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception { - final UnsafeShuffleWriter<Object, Object> writer = createWriter(false); - final ArrayList<Product2<Object, Object>> dataToWrite = - new ArrayList<Product2<Object, Object>>(); - final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; - new Random(42).nextBytes(bytes); - dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes))); - writer.write(dataToWrite.iterator()); - writer.stop(true); - assertEquals( - HashMultiset.create(dataToWrite), - HashMultiset.create(readRecordsFromFile())); - assertSpillFilesWereCleanedUp(); - } - - @Test - public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { - final UnsafeShuffleWriter<Object, Object> writer = createWriter(false); - final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>(); - dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(new byte[1]))); - // We should be able to write a record that's right _at_ the max record size - final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()]; - new Random(42).nextBytes(atMaxRecordSize); - dataToWrite.add(new Tuple2<Object, Object>(2, ByteBuffer.wrap(atMaxRecordSize))); - // Inserting a record that's larger than the max record size - final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1]; - new Random(42).nextBytes(exceedsMaxRecordSize); - dataToWrite.add(new Tuple2<Object, Object>(3, ByteBuffer.wrap(exceedsMaxRecordSize))); - writer.write(dataToWrite.iterator()); - writer.stop(true); - assertEquals( - HashMultiset.create(dataToWrite), - HashMultiset.create(readRecordsFromFile())); - assertSpillFilesWereCleanedUp(); - } - - @Test - public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { - final UnsafeShuffleWriter<Object, Object> writer = createWriter(false); - writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1)); - writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2)); - writer.forceSorterToSpill(); - writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2)); - writer.stop(false); - assertSpillFilesWereCleanedUp(); - } - - @Test - public void testPeakMemoryUsed() throws Exception { - final long recordLengthBytes = 8; - final long pageSizeBytes = 256; - final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; - when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); - final UnsafeShuffleWriter<Object, Object> writer = - new UnsafeShuffleWriter<Object, Object>( - blockManager, - shuffleBlockResolver, - taskMemoryManager, - shuffleMemoryManager, - new UnsafeShuffleHandle<>(0, 1, shuffleDep), - 0, // map id - taskContext, - conf); - - // Peak memory should be monotonically increasing. More specifically, every time - // we allocate a new page it should increase by exactly the size of the page. - long previousPeakMemory = writer.getPeakMemoryUsedBytes(); - long newPeakMemory; - try { - for (int i = 0; i < numRecordsPerPage * 10; i++) { - writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1)); - newPeakMemory = writer.getPeakMemoryUsedBytes(); - if (i % numRecordsPerPage == 0 && i != 0) { - // The first page is allocated in constructor, another page will be allocated after - // every numRecordsPerPage records (peak memory should change). - assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); - } else { - assertEquals(previousPeakMemory, newPeakMemory); - } - previousPeakMemory = newPeakMemory; - } - - // Spilling should not change peak memory - writer.forceSorterToSpill(); - newPeakMemory = writer.getPeakMemoryUsedBytes(); - assertEquals(previousPeakMemory, newPeakMemory); - for (int i = 0; i < numRecordsPerPage; i++) { - writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1)); - } - newPeakMemory = writer.getPeakMemoryUsedBytes(); - assertEquals(previousPeakMemory, newPeakMemory); - - // Closing the writer should not change peak memory - writer.closeAndWriteOutput(); - newPeakMemory = writer.getPeakMemoryUsedBytes(); - assertEquals(previousPeakMemory, newPeakMemory); - } finally { - writer.stop(false); - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 6335817..b8ab227 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -17,13 +17,78 @@ package org.apache.spark +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter import org.scalatest.BeforeAndAfterAll +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.util.Utils + class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with sort-based shuffle. + private var tempDir: File = _ + override def beforeAll() { conf.set("spark.shuffle.manager", "sort") } + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + conf.set("spark.local.dir", tempDir.getAbsolutePath) + } + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") { + sc = new SparkContext("local", "test", conf) + // Create a shuffled RDD and verify that it actually uses the new serialized map output path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new KryoSerializer(conf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep)) + ensureFilesAreCleanedUp(shuffledRdd) + } + + test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") { + sc = new SparkContext("local", "test", conf) + // Create a shuffled RDD and verify that it actually uses the old deserialized map output path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new JavaSerializer(conf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep)) + ensureFilesAreCleanedUp(shuffledRdd) + } + + private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit = { + def getAllFiles: Set[File] = + FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5b01ddb..3816b8c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1062,10 +1062,10 @@ class DAGSchedulerSuite */ test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") { val firstRDD = new MyRDD(sc, 3, Nil) - val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) val firstShuffleId = firstShuffleDep.shuffleId val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) submit(reduceRdd, Array(0)) @@ -1175,7 +1175,7 @@ class DAGSchedulerSuite */ test("register map outputs correctly after ExecutorLost and task Resubmitted") { val firstRDD = new MyRDD(sc, 3, Nil) - val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep)) submit(reduceRdd, Array(0)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org