[SPARK-4550] In sort-based shuffle, store map outputs in serialized form Refer to the JIRA for the design doc and some perf results.
I wanted to call out some of the more possibly controversial changes up front: * Map outputs are only stored in serialized form when Kryo is in use. I'm still unsure whether Java-serialized objects can be relocated. At the very least, Java serialization writes out a stream header which causes problems with the current approach, so I decided to leave investigating this to future work. * The shuffle now explicitly operates on key-value pairs instead of any object. Data is written to shuffle files in alternating keys and values instead of key-value tuples. `BlockObjectWriter.write` now accepts a key argument and a value argument instead of any object. * The map output buffer can hold a max of Integer.MAX_VALUE bytes. Though this wouldn't be terribly difficult to change. * When spilling occurs, the objects that still in memory at merge time end up serialized and deserialized an extra time. Author: Sandy Ryza <sa...@cloudera.com> Closes #4450 from sryza/sandy-spark-4550 and squashes the following commits: 8c70dd9 [Sandy Ryza] Fix serialization 9c16fe6 [Sandy Ryza] Fix a couple tests and move getAutoReset to KryoSerializerInstance 6c54e06 [Sandy Ryza] Fix scalastyle d8462d8 [Sandy Ryza] SPARK-4550 Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0a2b15ce Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0a2b15ce Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0a2b15ce Branch: refs/heads/master Commit: 0a2b15ce43cf6096e1a7ae060b7c8a4010ce3b92 Parents: a9fc505 Author: Sandy Ryza <sa...@cloudera.com> Authored: Thu Apr 30 23:14:14 2015 -0700 Committer: Patrick Wendell <patr...@databricks.com> Committed: Thu Apr 30 23:14:14 2015 -0700 ---------------------------------------------------------------------- .../spark/serializer/KryoSerializer.scala | 10 + .../apache/spark/serializer/Serializer.scala | 31 +++ .../spark/shuffle/hash/HashShuffleWriter.scala | 2 +- .../spark/storage/BlockObjectWriter.scala | 37 ++- .../storage/ShuffleBlockFetcherIterator.scala | 6 +- .../spark/util/collection/ChainedBuffer.scala | 144 +++++++++++ .../util/collection/ExternalAppendOnlyMap.scala | 6 +- .../spark/util/collection/ExternalSorter.scala | 144 ++++++----- .../spark/util/collection/PairIterator.scala | 24 ++ .../collection/PartitionedAppendOnlyMap.scala | 44 ++++ .../util/collection/PartitionedPairBuffer.scala | 92 +++++++ .../PartitionedSerializedPairBuffer.scala | 254 +++++++++++++++++++ .../collection/SizeTrackingAppendOnlyMap.scala | 2 +- .../collection/SizeTrackingPairBuffer.scala | 86 ------- .../collection/SizeTrackingPairCollection.scala | 34 --- .../WritablePartitionedPairCollection.scala | 113 +++++++++ .../spark/serializer/KryoSerializerSuite.scala | 15 ++ .../spark/serializer/TestSerializer.scala | 4 +- .../shuffle/hash/HashShuffleManagerSuite.scala | 12 +- .../spark/storage/BlockObjectWriterSuite.scala | 8 +- .../util/collection/ChainedBufferSuite.scala | 143 +++++++++++ .../util/collection/ExternalSorterSuite.scala | 189 ++++++++++---- .../PartitionedSerializedPairBufferSuite.scala | 149 +++++++++++ .../sql/execution/SparkSqlSerializer2.scala | 38 ++- .../apache/spark/tools/StoragePerfTester.scala | 5 +- 25 files changed, 1321 insertions(+), 271 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 754832b..b7bc087 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -200,6 +200,16 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ override def deserializeStream(s: InputStream): DeserializationStream = { new KryoDeserializationStream(kryo, s) } + + /** + * Returns true if auto-reset is on. The only reason this would be false is if the user-supplied + * registrator explicitly turns auto-reset off. + */ + def getAutoReset(): Boolean = { + val field = classOf[Kryo].getDeclaredField("autoReset") + field.setAccessible(true) + field.get(kryo).asInstanceOf[Boolean] + } } /** http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/serializer/Serializer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index ca6e971..c381672 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -101,7 +101,12 @@ abstract class SerializerInstance { */ @DeveloperApi abstract class SerializationStream { + /** The most general-purpose method to write an object. */ def writeObject[T: ClassTag](t: T): SerializationStream + /** Writes the object representing the key of a key-value pair. */ + def writeKey[T: ClassTag](key: T): SerializationStream = writeObject(key) + /** Writes the object representing the value of a key-value pair. */ + def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value) def flush(): Unit def close(): Unit @@ -120,7 +125,12 @@ abstract class SerializationStream { */ @DeveloperApi abstract class DeserializationStream { + /** The most general-purpose method to read an object. */ def readObject[T: ClassTag](): T + /** Reads the object representing the key of a key-value pair. */ + def readKey[T: ClassTag](): T = readObject[T]() + /** Reads the object representing the value of a key-value pair. */ + def readValue[T: ClassTag](): T = readObject[T]() def close(): Unit /** @@ -141,4 +151,25 @@ abstract class DeserializationStream { DeserializationStream.this.close() } } + + /** + * Read the elements of this stream through an iterator over key-value pairs. This can only be + * called once, as reading each element will consume data from the input source. + */ + def asKeyValueIterator: Iterator[(Any, Any)] = new NextIterator[(Any, Any)] { + override protected def getNext() = { + try { + (readKey[Any](), readValue[Any]()) + } catch { + case eof: EOFException => { + finished = true + null + } + } + } + + override protected def close() { + DeserializationStream.this.close() + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 755f17d..cd27c9e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -63,7 +63,7 @@ private[spark] class HashShuffleWriter[K, V]( for (elem <- iter) { val bucketId = dep.partitioner.getPartition(elem._1) - shuffle.writers(bucketId).write(elem) + shuffle.writers(bucketId).write(elem._1, elem._2) } } http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 1483379..499dd97 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -33,7 +33,7 @@ import org.apache.spark.util.Utils * This interface does not support concurrent writes. Also, once the writer has * been opened, it cannot be reopened again. */ -private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { +private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream { def open(): BlockObjectWriter @@ -54,9 +54,14 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { def revertPartialWritesAndClose() /** - * Writes an object. + * Writes a key-value pair. */ - def write(value: Any) + def write(key: Any, value: Any) + + /** + * Notify the writer that a record worth of bytes has been written with writeBytes. + */ + def recordWritten() /** * Returns the file segment of committed data that this Writer has written. @@ -203,12 +208,32 @@ private[spark] class DiskBlockObjectWriter( } } - override def write(value: Any) { + override def write(key: Any, value: Any) { + if (!initialized) { + open() + } + + objOut.writeKey(key) + objOut.writeValue(value) + numRecordsWritten += 1 + writeMetrics.incShuffleRecordsWritten(1) + + if (numRecordsWritten % 32 == 0) { + updateBytesWritten() + } + } + + override def write(b: Int): Unit = throw new UnsupportedOperationException() + + override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = { if (!initialized) { open() } - objOut.writeObject(value) + bs.write(kvBytes, offs, len) + } + + override def recordWritten(): Unit = { numRecordsWritten += 1 writeMetrics.incShuffleRecordsWritten(1) @@ -238,7 +263,7 @@ private[spark] class DiskBlockObjectWriter( } // For testing - private[spark] def flush() { + private[spark] override def flush() { objOut.flush() bs.flush() } http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f337952..d0faab6 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,14 +17,12 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} -import scala.util.{Failure, Success, Try} +import scala.util.{Failure, Try} import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.BlockTransferService import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.serializer.{SerializerInstance, Serializer} @@ -301,7 +299,7 @@ final class ShuffleBlockFetcherIterator( // the scheduler gets a FetchFailedException. Try(buf.createInputStream()).map { is0 => val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializerInstance.deserializeStream(is).asIterator + val iter = serializerInstance.deserializeStream(is).asKeyValueIterator CompletionIterator[Any, Iterator[Any]](iter, { // Once the iterator is exhausted, release the buffer and set currentResult to null // so we don't release it again in cleanup. http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala new file mode 100644 index 0000000..a60bffe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io.OutputStream + +import scala.collection.mutable.ArrayBuffer + +/** + * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The + * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts + * of memory and needing to copy the full contents. The disadvantage is that the contents don't + * occupy a contiguous segment of memory. + */ +private[spark] class ChainedBuffer(chunkSize: Int) { + private val chunkSizeLog2 = (math.log(chunkSize) / math.log(2)).toInt + assert(math.pow(2, chunkSizeLog2).toInt == chunkSize, + s"ChainedBuffer chunk size $chunkSize must be a power of two") + private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]() + private var _size: Int = _ + + /** + * Feed bytes from this buffer into a BlockObjectWriter. + * + * @param pos Offset in the buffer to read from. + * @param os OutputStream to read into. + * @param len Number of bytes to read. + */ + def read(pos: Int, os: OutputStream, len: Int): Unit = { + if (pos + len > _size) { + throw new IndexOutOfBoundsException( + s"Read of $len bytes at position $pos would go past size ${_size} of buffer") + } + var chunkIndex = pos >> chunkSizeLog2 + var posInChunk = pos - (chunkIndex << chunkSizeLog2) + var written = 0 + while (written < len) { + val toRead = math.min(len - written, chunkSize - posInChunk) + os.write(chunks(chunkIndex), posInChunk, toRead) + written += toRead + chunkIndex += 1 + posInChunk = 0 + } + } + + /** + * Read bytes from this buffer into a byte array. + * + * @param pos Offset in the buffer to read from. + * @param bytes Byte array to read into. + * @param offs Offset in the byte array to read to. + * @param len Number of bytes to read. + */ + def read(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = { + if (pos + len > _size) { + throw new IndexOutOfBoundsException( + s"Read of $len bytes at position $pos would go past size of buffer") + } + var chunkIndex = pos >> chunkSizeLog2 + var posInChunk = pos - (chunkIndex << chunkSizeLog2) + var written = 0 + while (written < len) { + val toRead = math.min(len - written, chunkSize - posInChunk) + System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead) + written += toRead + chunkIndex += 1 + posInChunk = 0 + } + } + + /** + * Write bytes from a byte array into this buffer. + * + * @param pos Offset in the buffer to write to. + * @param bytes Byte array to write from. + * @param offs Offset in the byte array to write from. + * @param len Number of bytes to write. + */ + def write(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = { + if (pos > _size) { + throw new IndexOutOfBoundsException( + s"Write at position $pos starts after end of buffer ${_size}") + } + // Grow if needed + val endChunkIndex = (pos + len - 1) >> chunkSizeLog2 + while (endChunkIndex >= chunks.length) { + chunks += new Array[Byte](chunkSize) + } + + var chunkIndex = pos >> chunkSizeLog2 + var posInChunk = pos - (chunkIndex << chunkSizeLog2) + var written = 0 + while (written < len) { + val toWrite = math.min(len - written, chunkSize - posInChunk) + System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite) + written += toWrite + chunkIndex += 1 + posInChunk = 0 + } + + _size = math.max(_size, pos + len) + } + + /** + * Total size of buffer that can be written to without allocating additional memory. + */ + def capacity: Int = chunks.size * chunkSize + + /** + * Size of the logical buffer. + */ + def size: Int = _size +} + +/** + * Output stream that writes to a ChainedBuffer. + */ +private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream { + private var pos = 0 + + override def write(b: Int): Unit = { + throw new UnsupportedOperationException() + } + + override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = { + chainedBuffer.write(pos, bytes, offs, len) + pos += len + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index f912049..b850973 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -174,7 +174,7 @@ class ExternalAppendOnlyMap[K, V, C]( val it = currentMap.destructiveSortedIterator(keyComparator) while (it.hasNext) { val kv = it.next() - writer.write(kv) + writer.write(kv._1, kv._2) objectsWritten += 1 if (objectsWritten == serializerBatchSize) { @@ -435,7 +435,9 @@ class ExternalAppendOnlyMap[K, V, C]( */ private def readNextItem(): (K, C) = { try { - val item = deserializeStream.readObject().asInstanceOf[(K, C)] + val k = deserializeStream.readKey().asInstanceOf[K] + val c = deserializeStream.readValue().asInstanceOf[C] + val item = (k, c) objectsRead += 1 if (objectsRead == serializerBatchSize) { objectsRead = 0 http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 4ed8a74..b7306cd 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -26,7 +26,7 @@ import scala.collection.mutable import com.google.common.io.ByteStreams import org.apache.spark._ -import org.apache.spark.serializer.{DeserializationStream, Serializer} +import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.storage.{BlockObjectWriter, BlockId} @@ -66,10 +66,11 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId} * * At a high level, this class works internally as follows: * - * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if - * we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers, - * we sort elements of type ((Int, K), C) where the Int is the partition ID. This is done to - * avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner). + * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if + * we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we + * don't. Inside these buffers, we sort elements by partition ID and then possibly also by key. + * To avoid calling the partitioner multiple times with each key, we store the partition ID + * alongside each record. * * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first * by partition ID and possibly second by key or by hash code of the key, if we want to do @@ -96,7 +97,7 @@ private[spark] class ExternalSorter[K, V, C]( partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, serializer: Option[Serializer] = None) - extends Logging with Spillable[SizeTrackingPairCollection[(Int, K), C]] { + extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] { private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) private val shouldPartition = numPartitions > 1 @@ -126,11 +127,22 @@ private[spark] class ExternalSorter[K, V, C]( if (shouldPartition) partitioner.get.getPartition(key) else 0 } + private val metaInitialRecords = 256 + private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB + private val useSerializedPairBuffer = + !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && + ser.isInstanceOf[KryoSerializer] && + serInstance.asInstanceOf[KryoSerializerInstance].getAutoReset + // Data structures to store in-memory objects before we spill. Depending on whether we have an // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we // store them in an array buffer. - private var map = new SizeTrackingAppendOnlyMap[(Int, K), C] - private var buffer = new SizeTrackingPairBuffer[(Int, K), C] + private var map = new PartitionedAppendOnlyMap[K, C] + private var buffer = if (useSerializedPairBuffer) { + new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance) + } else { + new PartitionedPairBuffer[K, C] + } // Total spilling statistics private var _diskBytesSpilled = 0L @@ -163,33 +175,6 @@ private[spark] class ExternalSorter[K, V, C]( } }) - // A comparator for (Int, K) pairs that orders them by only their partition ID - private val partitionComparator: Comparator[(Int, K)] = new Comparator[(Int, K)] { - override def compare(a: (Int, K), b: (Int, K)): Int = { - a._1 - b._1 - } - } - - // A comparator that orders (Int, K) pairs by partition ID and then possibly by key - private val partitionKeyComparator: Comparator[(Int, K)] = { - if (ordering.isDefined || aggregator.isDefined) { - // Sort by partition ID then key comparator - new Comparator[(Int, K)] { - override def compare(a: (Int, K), b: (Int, K)): Int = { - val partitionDiff = a._1 - b._1 - if (partitionDiff != 0) { - partitionDiff - } else { - keyComparator.compare(a._2, b._2) - } - } - } - } else { - // Just sort it by partition ID - partitionComparator - } - } - // Information about a spilled file. Includes sizes in bytes of "batches" written by the // serializer as we periodically reset its stream, as well as number of elements in each // partition, used to efficiently keep track of partitions when merging. @@ -221,16 +206,18 @@ private[spark] class ExternalSorter[K, V, C]( } else if (bypassMergeSort) { // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies if (records.hasNext) { - spillToPartitionFiles(records.map { kv => - ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) - }) + spillToPartitionFiles( + WritablePartitionedIterator.fromIterator(records.map { kv => + ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) + }) + ) } } else { // Stick values into our buffer while (records.hasNext) { addElementsRead() val kv = records.next() - buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) + buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C]) maybeSpillCollection(usingMap = false) } } @@ -248,11 +235,15 @@ private[spark] class ExternalSorter[K, V, C]( if (usingMap) { if (maybeSpill(map, map.estimateSize())) { - map = new SizeTrackingAppendOnlyMap[(Int, K), C] + map = new PartitionedAppendOnlyMap[K, C] } } else { if (maybeSpill(buffer, buffer.estimateSize())) { - buffer = new SizeTrackingPairBuffer[(Int, K), C] + buffer = if (useSerializedPairBuffer) { + new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance) + } else { + new PartitionedPairBuffer[K, C] + } } } } @@ -260,7 +251,7 @@ private[spark] class ExternalSorter[K, V, C]( /** * Spill the current in-memory collection to disk, adding a new file to spills, and clear it. */ - override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = { if (bypassMergeSort) { spillToPartitionFiles(collection) } else { @@ -277,7 +268,7 @@ private[spark] class ExternalSorter[K, V, C]( * * @param collection whichever collection we're using (map or buffer) */ - private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + private def spillToMergeableFile(collection: WritablePartitionedPairCollection[K, C]): Unit = { assert(!bypassMergeSort) // Because these files may be read during shuffle, their compression must be controlled by @@ -308,14 +299,10 @@ private[spark] class ExternalSorter[K, V, C]( var success = false try { - val it = collection.destructiveSortedIterator(partitionKeyComparator) + val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext) { - val elem = it.next() - val partitionId = elem._1._1 - val key = elem._1._2 - val value = elem._2 - writer.write(key) - writer.write(value) + val partitionId = it.nextPartition() + it.writeNext(writer) elementsPerPartition(partitionId) += 1 objectsWritten += 1 @@ -357,11 +344,11 @@ private[spark] class ExternalSorter[K, V, C]( * * @param collection whichever collection we're using (map or buffer) */ - private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { - spillToPartitionFiles(collection.iterator) + private def spillToPartitionFiles(collection: WritablePartitionedPairCollection[K, C]): Unit = { + spillToPartitionFiles(collection.writablePartitionedIterator()) } - private def spillToPartitionFiles(iterator: Iterator[((Int, K), C)]): Unit = { + private def spillToPartitionFiles(iterator: WritablePartitionedIterator): Unit = { assert(bypassMergeSort) // Create our file writers if we haven't done so yet @@ -385,11 +372,8 @@ private[spark] class ExternalSorter[K, V, C]( // No need to sort stuff, just write each element out while (iterator.hasNext) { - val elem = iterator.next() - val partitionId = elem._1._1 - val key = elem._1._2 - val value = elem._2 - partitionWriters(partitionId).write((key, value)) + val partitionId = iterator.nextPartition() + iterator.writeNext(partitionWriters(partitionId)) } } @@ -618,8 +602,8 @@ private[spark] class ExternalSorter[K, V, C]( if (finished || deserializeStream == null) { return null } - val k = deserializeStream.readObject().asInstanceOf[K] - val c = deserializeStream.readObject().asInstanceOf[C] + val k = deserializeStream.readKey().asInstanceOf[K] + val c = deserializeStream.readValue().asInstanceOf[C] lastPartitionId = partitionId // Start reading the next batch if we're done with this one indexInBatch += 1 @@ -695,27 +679,27 @@ private[spark] class ExternalSorter[K, V, C]( */ def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined - val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer + val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer if (spills.isEmpty && partitionWriters == null) { // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps // we don't even need to sort by anything other than partition ID if (!ordering.isDefined) { // The user hasn't requested sorted keys, so only sort by partition ID, not key - groupByPartition(collection.destructiveSortedIterator(partitionComparator)) + groupByPartition(collection.partitionedDestructiveSortedIterator(None)) } else { // We do need to sort by both partition ID and key - groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator)) + groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator))) } } else if (bypassMergeSort) { // Read data from each partition file and merge it together with the data in memory; // note that there's no ordering or aggregator in this case -- we just partition objects - val collIter = groupByPartition(collection.destructiveSortedIterator(partitionComparator)) + val collIter = groupByPartition(collection.partitionedDestructiveSortedIterator(None)) collIter.map { case (partitionId, values) => (partitionId, values ++ readPartitionFile(partitionWriters(partitionId))) } } else { // Merge spilled and in-memory data - merge(spills, collection.destructiveSortedIterator(partitionKeyComparator)) + merge(spills, collection.partitionedDestructiveSortedIterator(comparator)) } } @@ -762,15 +746,29 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics.shuffleWriteMetrics.foreach( _.incShuffleWriteTime(System.nanoTime - writeStartTime)) } + } else if (spills.isEmpty && partitionWriters == null) { + // Case where we only have in-memory data + val collection = if (aggregator.isDefined) map else buffer + val it = collection.destructiveSortedWritablePartitionedIterator(comparator) + while (it.hasNext) { + val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, + context.taskMetrics.shuffleWriteMetrics.get) + val partitionId = it.nextPartition() + while (it.hasNext && it.nextPartition() == partitionId) { + it.writeNext(writer) + } + writer.commitAndClose() + val segment = writer.fileSegment() + lengths(partitionId) = segment.length + } } else { - // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by - // partition and just write everything directly. + // Not bypassing merge-sort; get an iterator by partition and just write everything directly. for ((id, elements) <- this.partitionedIterator) { if (elements.hasNext) { val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get) for (elem <- elements) { - writer.write(elem) + writer.write(elem._1, elem._2) } writer.commitAndClose() val segment = writer.fileSegment() @@ -799,7 +797,7 @@ private[spark] class ExternalSorter[K, V, C]( if (writer.isOpen) { writer.commitAndClose() } - blockManager.diskStore.getValues(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]] + new PairIterator[K, C](blockManager.diskStore.getValues(writer.blockId, ser).get) } def stop(): Unit = { @@ -829,6 +827,14 @@ private[spark] class ExternalSorter[K, V, C]( (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered))) } + private def comparator: Option[Comparator[K]] = { + if (ordering.isDefined || aggregator.isDefined) { + Some(keyComparator) + } else { + None + } + } + /** * An iterator that reads only the elements for a given partition ID from an underlying buffered * stream, assuming this partition is the next one to be read. Used to make it easier to return http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala new file mode 100644 index 0000000..d75959f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] { + def hasNext: Boolean = iter.hasNext + + def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V]) +} http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala new file mode 100644 index 0000000..e2e2f1f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +import org.apache.spark.util.collection.WritablePartitionedPairCollection._ + +/** + * Implementation of WritablePartitionedPairCollection that wraps a map in which the keys are tuples + * of (partition ID, K) + */ +private[spark] class PartitionedAppendOnlyMap[K, V] + extends SizeTrackingAppendOnlyMap[(Int, K), V] with WritablePartitionedPairCollection[K, V] { + + def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) + : Iterator[((Int, K), V)] = { + val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator) + destructiveSortedIterator(comparator) + } + + def writablePartitionedIterator(): WritablePartitionedIterator = { + WritablePartitionedIterator.fromIterator(super.iterator) + } + + def insert(partition: Int, key: K, value: V): Unit = { + update((partition, key), value) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala new file mode 100644 index 0000000..e8332e1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.util.collection.WritablePartitionedPairCollection._ + +/** + * Append-only buffer of key-value pairs, each with a corresponding partition ID, that keeps track + * of its estimated size in bytes. + */ +private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) + extends WritablePartitionedPairCollection[K, V] with SizeTracker +{ + require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity >= 1, "Invalid initial capacity") + + // Basic growable array data structure. We use a single array of AnyRef to hold both the keys + // and the values, so that we can sort them efficiently with KVArraySortDataFormat. + private var capacity = initialCapacity + private var curSize = 0 + private var data = new Array[AnyRef](2 * initialCapacity) + + /** Add an element into the buffer */ + def insert(partition: Int, key: K, value: V): Unit = { + if (curSize == capacity) { + growArray() + } + data(2 * curSize) = (partition, key.asInstanceOf[AnyRef]) + data(2 * curSize + 1) = value.asInstanceOf[AnyRef] + curSize += 1 + afterUpdate() + } + + /** Double the size of the array because we've reached capacity */ + private def growArray(): Unit = { + if (capacity == (1 << 29)) { + // Doubling the capacity would create an array bigger than Int.MaxValue, so don't + throw new Exception("Can't grow buffer beyond 2^29 elements") + } + val newCapacity = capacity * 2 + val newArray = new Array[AnyRef](2 * newCapacity) + System.arraycopy(data, 0, newArray, 0, 2 * capacity) + data = newArray + capacity = newCapacity + resetSamples() + } + + /** Iterate through the data in a given order. For this class this is not really destructive. */ + override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) + : Iterator[((Int, K), V)] = { + val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator) + new Sorter(new KVArraySortDataFormat[(Int, K), AnyRef]).sort(data, 0, curSize, comparator) + iterator + } + + override def writablePartitionedIterator(): WritablePartitionedIterator = { + WritablePartitionedIterator.fromIterator(iterator) + } + + private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] { + var pos = 0 + + override def hasNext: Boolean = pos < curSize + + override def next(): ((Int, K), V) = { + if (!hasNext) { + throw new NoSuchElementException + } + val pair = (data(2 * pos).asInstanceOf[(Int, K)], data(2 * pos + 1).asInstanceOf[V]) + pos += 1 + pair + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala new file mode 100644 index 0000000..b5ca0c6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io.InputStream +import java.nio.IntBuffer +import java.util.Comparator + +import org.apache.spark.SparkEnv +import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance} +import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ + +/** + * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes + * its records upon insert and stores them as raw bytes. + * + * We use two data-structures to store the contents. The serialized records are stored in a + * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a + * metadata buffer that stores pointers into the data buffer as well as the partition ID of each + * record. Each entry in the metadata buffer takes up a fixed amount of space. + * + * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not + * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can + * happen without following any pointers, which should minimize cache misses. + * + * Currently, only sorting by partition is supported. + * + * @param metaInitialRecords The initial number of entries in the metadata buffer. + * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records. + * @param serializerInstance the serializer used for serializing inserted records. + */ +private[spark] class PartitionedSerializedPairBuffer[K, V]( + metaInitialRecords: Int, + kvBlockSize: Int, + serializerInstance: SerializerInstance) + extends WritablePartitionedPairCollection[K, V] with SizeTracker { + + if (serializerInstance.isInstanceOf[JavaSerializerInstance]) { + throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" + + " Java-serialized objects.") + } + + private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE) + + private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize) + private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer) + private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream) + + def insert(partition: Int, key: K, value: V): Unit = { + if (metaBuffer.position == metaBuffer.capacity) { + growMetaBuffer() + } + + val keyStart = kvBuffer.size + if (keyStart < 0) { + throw new Exception(s"Can't grow buffer beyond ${1 << 31} bytes") + } + kvSerializationStream.writeObject[Any](key) + kvSerializationStream.flush() + val valueStart = kvBuffer.size + kvSerializationStream.writeObject[Any](value) + kvSerializationStream.flush() + val valueEnd = kvBuffer.size + + metaBuffer.put(keyStart) + metaBuffer.put(valueStart) + metaBuffer.put(valueEnd) + metaBuffer.put(partition) + } + + /** Double the size of the array because we've reached capacity */ + private def growMetaBuffer(): Unit = { + if (metaBuffer.capacity.toLong * 2 > Int.MaxValue) { + // Doubling the capacity would create an array bigger than Int.MaxValue, so don't + throw new Exception(s"Can't grow buffer beyond ${Int.MaxValue} bytes") + } + val newMetaBuffer = IntBuffer.allocate(metaBuffer.capacity * 2) + newMetaBuffer.put(metaBuffer.array) + metaBuffer = newMetaBuffer + } + + /** Iterate through the data in a given order. For this class this is not really destructive. */ + override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) + : Iterator[((Int, K), V)] = { + sort(keyComparator) + val is = orderedInputStream + val deserStream = serializerInstance.deserializeStream(is) + new Iterator[((Int, K), V)] { + var metaBufferPos = 0 + def hasNext: Boolean = metaBufferPos < metaBuffer.position + def next(): ((Int, K), V) = { + val key = deserStream.readKey[Any]().asInstanceOf[K] + val value = deserStream.readValue[Any]().asInstanceOf[V] + val partition = metaBuffer.get(metaBufferPos + PARTITION) + metaBufferPos += RECORD_SIZE + ((partition, key), value) + } + } + } + + override def estimateSize: Long = metaBuffer.capacity * 4 + kvBuffer.capacity + + override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) + : WritablePartitionedIterator = { + sort(keyComparator) + writablePartitionedIterator + } + + override def writablePartitionedIterator(): WritablePartitionedIterator = { + new WritablePartitionedIterator { + // current position in the meta buffer in ints + var pos = 0 + + def writeNext(writer: BlockObjectWriter): Unit = { + val keyStart = metaBuffer.get(pos + KEY_START) + val valueEnd = metaBuffer.get(pos + VAL_END) + pos += RECORD_SIZE + kvBuffer.read(keyStart, writer, valueEnd - keyStart) + writer.recordWritten() + } + def nextPartition(): Int = metaBuffer.get(pos + PARTITION) + def hasNext(): Boolean = pos < metaBuffer.position + } + } + + // Visible for testing + def orderedInputStream: OrderedInputStream = { + new OrderedInputStream(metaBuffer, kvBuffer) + } + + private def sort(keyComparator: Option[Comparator[K]]): Unit = { + val comparator = if (keyComparator.isEmpty) { + new Comparator[Int]() { + def compare(partition1: Int, partition2: Int): Int = { + partition1 - partition2 + } + } + } else { + throw new UnsupportedOperationException() + } + + val sorter = new Sorter(new SerializedSortDataFormat) + sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator) + } +} + +private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer) + extends InputStream { + + private var metaBufferPos = 0 + private var kvBufferPos = + if (metaBuffer.position > 0) metaBuffer.get(metaBufferPos + KEY_START) else 0 + + override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length) + + override def read(bytes: Array[Byte], offs: Int, len: Int): Int = { + if (metaBufferPos >= metaBuffer.position) { + return -1 + } + val bytesRemainingInRecord = metaBuffer.get(metaBufferPos + VAL_END) - kvBufferPos + val toRead = math.min(bytesRemainingInRecord, len) + kvBuffer.read(kvBufferPos, bytes, offs, toRead) + if (toRead == bytesRemainingInRecord) { + metaBufferPos += RECORD_SIZE + if (metaBufferPos < metaBuffer.position) { + kvBufferPos = metaBuffer.get(metaBufferPos + KEY_START) + } + } else { + kvBufferPos += toRead + } + toRead + } + + override def read(): Int = { + throw new UnsupportedOperationException() + } +} + +private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] { + + private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE) + + /** Return the sort key for the element at the given index. */ + override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = { + metaBuffer.get(pos * RECORD_SIZE + PARTITION) + } + + /** Swap two elements. */ + override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = { + val iOff = pos0 * RECORD_SIZE + val jOff = pos1 * RECORD_SIZE + System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE) + System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE) + System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE) + } + + /** Copy a single element from src(srcPos) to dst(dstPos). */ + override def copyElement( + src: IntBuffer, + srcPos: Int, + dst: IntBuffer, + dstPos: Int): Unit = { + val srcOff = srcPos * RECORD_SIZE + val dstOff = dstPos * RECORD_SIZE + System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE) + } + + /** + * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos. + * Overlapping ranges are allowed. + */ + override def copyRange( + src: IntBuffer, + srcPos: Int, + dst: IntBuffer, + dstPos: Int, + length: Int): Unit = { + val srcOff = srcPos * RECORD_SIZE + val dstOff = dstPos * RECORD_SIZE + System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length) + } + + /** + * Allocates a Buffer that can hold up to 'length' elements. + * All elements of the buffer should be considered invalid until data is explicitly copied in. + */ + override def allocate(length: Int): IntBuffer = { + IntBuffer.allocate(length * RECORD_SIZE) + } +} + +private[spark] object PartitionedSerializedPairBuffer { + val KEY_START = 0 + val VAL_START = 1 + val VAL_END = 2 + val PARTITION = 3 + val RECORD_SIZE = Seq(KEY_START, VAL_START, VAL_END, PARTITION).size // num ints of metadata +} http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala index eb4de41..722f78b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala @@ -21,7 +21,7 @@ package org.apache.spark.util.collection * An append-only map that keeps track of its estimated size in bytes. */ private[spark] class SizeTrackingAppendOnlyMap[K, V] - extends AppendOnlyMap[K, V] with SizeTracker with SizeTrackingPairCollection[K, V] + extends AppendOnlyMap[K, V] with SizeTracker { override def update(key: K, value: V): Unit = { super.update(key, value) http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala deleted file mode 100644 index 9e9c16c..0000000 --- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.util.Comparator - -/** - * Append-only buffer of key-value pairs that keeps track of its estimated size in bytes. - */ -private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64) - extends SizeTracker with SizeTrackingPairCollection[K, V] -{ - require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") - require(initialCapacity >= 1, "Invalid initial capacity") - - // Basic growable array data structure. We use a single array of AnyRef to hold both the keys - // and the values, so that we can sort them efficiently with KVArraySortDataFormat. - private var capacity = initialCapacity - private var curSize = 0 - private var data = new Array[AnyRef](2 * initialCapacity) - - /** Add an element into the buffer */ - def insert(key: K, value: V): Unit = { - if (curSize == capacity) { - growArray() - } - data(2 * curSize) = key.asInstanceOf[AnyRef] - data(2 * curSize + 1) = value.asInstanceOf[AnyRef] - curSize += 1 - afterUpdate() - } - - /** Total number of elements in buffer */ - override def size: Int = curSize - - /** Iterate over the elements of the buffer */ - override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { - var pos = 0 - - override def hasNext: Boolean = pos < curSize - - override def next(): (K, V) = { - if (!hasNext) { - throw new NoSuchElementException - } - val pair = (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) - pos += 1 - pair - } - } - - /** Double the size of the array because we've reached capacity */ - private def growArray(): Unit = { - if (capacity == (1 << 29)) { - // Doubling the capacity would create an array bigger than Int.MaxValue, so don't - throw new Exception("Can't grow buffer beyond 2^29 elements") - } - val newCapacity = capacity * 2 - val newArray = new Array[AnyRef](2 * newCapacity) - System.arraycopy(data, 0, newArray, 0, 2 * capacity) - data = newArray - capacity = newCapacity - resetSamples() - } - - /** Iterate through the data in a given order. For this class this is not really destructive. */ - override def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = { - new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, curSize, keyComparator) - iterator - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala deleted file mode 100644 index faa4e2b..0000000 --- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.util.Comparator - -/** - * A common interface for our size-tracking collections of key-value pairs, which are used in - * external operations. These all support estimating the size and obtaining a memory-efficient - * sorted iterator. - */ -// TODO: should extend Iterable[Product2[K, V]] instead of (K, V) -private[spark] trait SizeTrackingPairCollection[K, V] extends Iterable[(K, V)] { - /** Estimate the collection's current memory usage in bytes. */ - def estimateSize(): Long - - /** Iterate through the data in a given key order. This may destroy the underlying collection. */ - def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] -} http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala new file mode 100644 index 0000000..f26d161 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +import org.apache.spark.storage.BlockObjectWriter + +/** + * A common interface for size-tracking collections of key-value pairs that + * - Have an associated partition for each key-value pair. + * - Support a memory-efficient sorted iterator + * - Support a WritablePartitionedIterator for writing the contents directly as bytes. + */ +private[spark] trait WritablePartitionedPairCollection[K, V] { + /** + * Insert a key-value pair with a partition into the collection + */ + def insert(partition: Int, key: K, value: V): Unit + + /** + * Iterate through the data in order of partition ID and then the given comparator. This may + * destroy the underlying collection. + */ + def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) + : Iterator[((Int, K), V)] + + /** + * Iterate through the data and write out the elements instead of returning them. Records are + * returned in order of their partition ID and then the given comparator. + * This may destroy the underlying collection. + */ + def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) + : WritablePartitionedIterator = { + WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator)) + } + + /** + * Iterate through the data and write out the elements instead of returning them. + */ + def writablePartitionedIterator(): WritablePartitionedIterator +} + +private[spark] object WritablePartitionedPairCollection { + /** + * A comparator for (Int, K) pairs that orders them by only their partition ID. + */ + def partitionComparator[K]: Comparator[(Int, K)] = new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + a._1 - b._1 + } + } + + /** + * A comparator for (Int, K) pairs that orders them both by their partition ID and a key ordering. + */ + def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = { + new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + val partitionDiff = a._1 - b._1 + if (partitionDiff != 0) { + partitionDiff + } else { + keyComparator.compare(a._2, b._2) + } + } + } + } +} + +/** + * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element + * has an associated partition. + */ +private[spark] trait WritablePartitionedIterator { + def writeNext(writer: BlockObjectWriter): Unit + + def hasNext(): Boolean + + def nextPartition(): Int +} + +private[spark] object WritablePartitionedIterator { + def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = { + new WritablePartitionedIterator { + var cur = if (it.hasNext) it.next() else null + + def writeNext(writer: BlockObjectWriter): Unit = { + writer.write(cur._1._2, cur._2) + cur = if (it.hasNext) it.next() else null + } + + def hasNext(): Boolean = cur != null + + def nextPartition(): Int = cur._1._1 + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 1b13559..778a7ee 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -280,6 +280,15 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { val thrown = intercept[SparkException](ser.serialize(largeObject)) assert(thrown.getMessage.contains(kryoBufferMaxProperty)) } + + test("getAutoReset") { + val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance] + assert(ser.getAutoReset) + val conf = new SparkConf().set("spark.kryo.registrator", + classOf[RegistratorWithoutAutoReset].getName) + val ser2 = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance] + assert(!ser2.getAutoReset) + } } @@ -313,4 +322,10 @@ object KryoTest { k.register(classOf[java.util.HashMap[_, _]]) } } + + class RegistratorWithoutAutoReset extends KryoRegistrator { + override def registerClasses(k: Kryo) { + k.setAutoReset(false) + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala index 963264c..86fcf44 100644 --- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag /** - * A serializer implementation that always return a single element in a deserialization stream. + * A serializer implementation that always returns two elements in a deserialization stream. */ class TestSerializer extends Serializer { override def newInstance(): TestSerializerInstance = new TestSerializerInstance @@ -51,7 +51,7 @@ class TestDeserializationStream extends DeserializationStream { override def readObject[T: ClassTag](): T = { count += 1 - if (count == 2) { + if (count == 3) { throw new EOFException } new Object().asInstanceOf[T] http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 7d76435..84384bb 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -59,8 +59,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf), new ShuffleWriteMetrics) for (writer <- shuffle1.writers) { - writer.write("test1") - writer.write("test2") + writer.write("test1", "value") + writer.write("test2", "value") } for (writer <- shuffle1.writers) { writer.commitAndClose() @@ -73,8 +73,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { new ShuffleWriteMetrics) for (writer <- shuffle2.writers) { - writer.write("test3") - writer.write("test4") + writer.write("test3", "value") + writer.write("test4", "vlue") } for (writer <- shuffle2.writers) { writer.commitAndClose() @@ -91,8 +91,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { val shuffle3 = shuffleBlockManager.forMapTask(1, 3, 1, new JavaSerializer(testConf), new ShuffleWriteMetrics) for (writer <- shuffle3.writers) { - writer.write("test3") - writer.write("test4") + writer.write("test3", "value") + writer.write("test4", "value") } for (writer <- shuffle3.writers) { writer.commitAndClose() http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala index 003a728..43ef469 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -32,7 +32,7 @@ class BlockObjectWriterSuite extends FunSuite { val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - writer.write(Long.box(20)) + writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write assert(writeMetrics.shuffleRecordsWritten === 1) // Metrics don't update on every write @@ -40,7 +40,7 @@ class BlockObjectWriterSuite extends FunSuite { // After 32 writes, metrics should update for (i <- 0 until 32) { writer.flush() - writer.write(Long.box(i)) + writer.write(Long.box(i), Long.box(i)) } assert(writeMetrics.shuffleBytesWritten > 0) assert(writeMetrics.shuffleRecordsWritten === 33) @@ -54,7 +54,7 @@ class BlockObjectWriterSuite extends FunSuite { val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - writer.write(Long.box(20)) + writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write assert(writeMetrics.shuffleRecordsWritten === 1) // Metrics don't update on every write @@ -62,7 +62,7 @@ class BlockObjectWriterSuite extends FunSuite { // After 32 writes, metrics should update for (i <- 0 until 32) { writer.flush() - writer.write(Long.box(i)) + writer.write(Long.box(i), Long.box(i)) } assert(writeMetrics.shuffleBytesWritten > 0) assert(writeMetrics.shuffleRecordsWritten === 33) http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala new file mode 100644 index 0000000..c0c38cd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.nio.ByteBuffer + +import org.scalatest.FunSuite +import org.scalatest.Matchers._ + +class ChainedBufferSuite extends FunSuite { + test("write and read at start") { + // write from start of source array + val buffer = new ChainedBuffer(8) + buffer.capacity should be (0) + verifyWriteAndRead(buffer, 0, 0, 0, 4) + buffer.capacity should be (8) + + // write from middle of source array + verifyWriteAndRead(buffer, 0, 5, 0, 4) + buffer.capacity should be (8) + + // read to middle of target array + verifyWriteAndRead(buffer, 0, 0, 5, 4) + buffer.capacity should be (8) + + // write up to border + verifyWriteAndRead(buffer, 0, 0, 0, 8) + buffer.capacity should be (8) + + // expand into second buffer + verifyWriteAndRead(buffer, 0, 0, 0, 12) + buffer.capacity should be (16) + + // expand into multiple buffers + verifyWriteAndRead(buffer, 0, 0, 0, 28) + buffer.capacity should be (32) + } + + test("write and read at middle") { + val buffer = new ChainedBuffer(8) + + // fill to a middle point + verifyWriteAndRead(buffer, 0, 0, 0, 3) + + // write from start of source array + verifyWriteAndRead(buffer, 3, 0, 0, 4) + buffer.capacity should be (8) + + // write from middle of source array + verifyWriteAndRead(buffer, 3, 5, 0, 4) + buffer.capacity should be (8) + + // read to middle of target array + verifyWriteAndRead(buffer, 3, 0, 5, 4) + buffer.capacity should be (8) + + // write up to border + verifyWriteAndRead(buffer, 3, 0, 0, 5) + buffer.capacity should be (8) + + // expand into second buffer + verifyWriteAndRead(buffer, 3, 0, 0, 12) + buffer.capacity should be (16) + + // expand into multiple buffers + verifyWriteAndRead(buffer, 3, 0, 0, 28) + buffer.capacity should be (32) + } + + test("write and read at later buffer") { + val buffer = new ChainedBuffer(8) + + // fill to a middle point + verifyWriteAndRead(buffer, 0, 0, 0, 11) + + // write from start of source array + verifyWriteAndRead(buffer, 11, 0, 0, 4) + buffer.capacity should be (16) + + // write from middle of source array + verifyWriteAndRead(buffer, 11, 5, 0, 4) + buffer.capacity should be (16) + + // read to middle of target array + verifyWriteAndRead(buffer, 11, 0, 5, 4) + buffer.capacity should be (16) + + // write up to border + verifyWriteAndRead(buffer, 11, 0, 0, 5) + buffer.capacity should be (16) + + // expand into second buffer + verifyWriteAndRead(buffer, 11, 0, 0, 12) + buffer.capacity should be (24) + + // expand into multiple buffers + verifyWriteAndRead(buffer, 11, 0, 0, 28) + buffer.capacity should be (40) + } + + + // Used to make sure we're writing different bytes each time + var rangeStart = 0 + + /** + * @param buffer The buffer to write to and read from. + * @param offsetInBuffer The offset to write to in the buffer. + * @param offsetInSource The offset in the array that the bytes are written from. + * @param offsetInTarget The offset in the array to read the bytes into. + * @param length The number of bytes to read and write + */ + def verifyWriteAndRead( + buffer: ChainedBuffer, + offsetInBuffer: Int, + offsetInSource: Int, + offsetInTarget: Int, + length: Int): Unit = { + val source = new Array[Byte](offsetInSource + length) + (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource) + buffer.write(offsetInBuffer, source, offsetInSource, length) + val target = new Array[Byte](offsetInTarget + length) + buffer.read(offsetInBuffer, target, offsetInTarget, length) + ByteBuffer.wrap(source, offsetInSource, length) should be + (ByteBuffer.wrap(target, offsetInTarget, length)) + + rangeStart += 100 + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org