Repository: spark
Updated Branches:
  refs/heads/master 67cc89ff0 -> 3074f575a


[SPARK-15391] [SQL] manage the temporary memory of timsort

## What changes were proposed in this pull request?

Currently, the memory for temporary buffer used by TimSort is always allocated 
as on-heap without bookkeeping, it could cause OOM both in on-heap and off-heap 
mode.

This PR will try to manage that by preallocate it together with the pointer 
array, same with RadixSort. It both works for on-heap and off-heap mode.

This PR also change the loadFactor of BytesToBytesMap to 0.5 (it was 0.70), it 
enables use to radix sort also makes sure that we have enough memory for 
timsort.

## How was this patch tested?

Existing tests.

Author: Davies Liu <dav...@databricks.com>

Closes #13318 from davies/fix_timsort.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3074f575
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3074f575
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3074f575

Branch: refs/heads/master
Commit: 3074f575a3c84108fddab3f5f56eb1929a4b2cff
Parents: 67cc89f
Author: Davies Liu <dav...@databricks.com>
Authored: Fri Jun 3 16:45:09 2016 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Fri Jun 3 16:45:09 2016 -0700

----------------------------------------------------------------------
 .../apache/spark/unsafe/memory/MemoryBlock.java |  2 +-
 .../shuffle/sort/ShuffleInMemorySorter.java     | 38 ++++++++++++-----
 .../shuffle/sort/ShuffleSortDataFormat.java     | 13 +++---
 .../spark/unsafe/map/BytesToBytesMap.java       |  3 +-
 .../unsafe/sort/UnsafeInMemorySorter.java       | 43 +++++++++++++-------
 .../unsafe/sort/UnsafeSortDataFormat.java       | 13 +++---
 .../shuffle/sort/UnsafeShuffleWriterSuite.java  | 35 +++++++++-------
 .../map/AbstractBytesToBytesMapSuite.java       |  2 +-
 .../util/collection/ExternalSorterSuite.scala   |  4 +-
 .../collection/unsafe/sort/RadixSortSuite.scala |  3 +-
 .../sql/execution/UnsafeKVExternalSorter.java   |  8 +++-
 .../sql/execution/joins/HashedRelation.scala    | 12 +++---
 .../sql/execution/benchmark/SortBenchmark.scala |  3 +-
 13 files changed, 115 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
----------------------------------------------------------------------
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
index e3e7947..1bc924d 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -51,6 +51,6 @@ public class MemoryBlock extends MemoryLocation {
    * Creates a memory block pointing to the memory used by the long array.
    */
   public static MemoryBlock fromLongArray(final long[] array) {
-    return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 
8);
+    return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 
8L);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java 
b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index 75a0e80..dc36809 100644
--- 
a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ 
b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -22,12 +22,12 @@ import java.util.Comparator;
 import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.collection.Sorter;
 import org.apache.spark.util.collection.unsafe.sort.RadixSort;
 
 final class ShuffleInMemorySorter {
 
-  private final Sorter<PackedRecordPointer, LongArray> sorter;
   private static final class SortComparator implements 
Comparator<PackedRecordPointer> {
     @Override
     public int compare(PackedRecordPointer left, PackedRecordPointer right) {
@@ -44,6 +44,9 @@ final class ShuffleInMemorySorter {
    * An array of record pointers and partition ids that have been encoded by
    * {@link PackedRecordPointer}. The sort operates on this array instead of 
directly manipulating
    * records.
+   *
+   * Only part of the array will be used to store the pointers, the rest part 
is preserved as
+   * temporary buffer for sorting.
    */
   private LongArray array;
 
@@ -54,14 +57,14 @@ final class ShuffleInMemorySorter {
   private final boolean useRadixSort;
 
   /**
-   * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 
1x.
+   * The position in the pointer array where new records can be inserted.
    */
-  private final int memoryAllocationFactor;
+  private int pos = 0;
 
   /**
-   * The position in the pointer array where new records can be inserted.
+   * How many records could be inserted, because part of the array should be 
left for sorting.
    */
-  private int pos = 0;
+  private int usableCapacity = 0;
 
   private int initialSize;
 
@@ -70,9 +73,14 @@ final class ShuffleInMemorySorter {
     assert (initialSize > 0);
     this.initialSize = initialSize;
     this.useRadixSort = useRadixSort;
-    this.memoryAllocationFactor = useRadixSort ? 2 : 1;
     this.array = consumer.allocateArray(initialSize);
-    this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
+    this.usableCapacity = getUsableCapacity();
+  }
+
+  private int getUsableCapacity() {
+    // Radix sort requires same amount of used memory as buffer, Tim sort 
requires
+    // half of the used memory as buffer.
+    return (int) (array.size() / (useRadixSort ? 2 : 1.5));
   }
 
   public void free() {
@@ -89,7 +97,8 @@ final class ShuffleInMemorySorter {
   public void reset() {
     if (consumer != null) {
       consumer.freeArray(array);
-      this.array = consumer.allocateArray(initialSize);
+      array = consumer.allocateArray(initialSize);
+      usableCapacity = getUsableCapacity();
     }
     pos = 0;
   }
@@ -101,14 +110,15 @@ final class ShuffleInMemorySorter {
       array.getBaseOffset(),
       newArray.getBaseObject(),
       newArray.getBaseOffset(),
-      array.size() * (8 / memoryAllocationFactor)
+      pos * 8L
     );
     consumer.freeArray(array);
     array = newArray;
+    usableCapacity = getUsableCapacity();
   }
 
   public boolean hasSpaceForAnotherRecord() {
-    return pos < array.size() / memoryAllocationFactor;
+    return pos < usableCapacity;
   }
 
   public long getMemoryUsage() {
@@ -170,6 +180,14 @@ final class ShuffleInMemorySorter {
         PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX,
         PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false);
     } else {
+      MemoryBlock unused = new MemoryBlock(
+        array.getBaseObject(),
+        array.getBaseOffset() + pos * 8L,
+        (array.size() - pos) * 8L);
+      LongArray buffer = new LongArray(unused);
+      Sorter<PackedRecordPointer, LongArray> sorter =
+        new Sorter<>(new ShuffleSortDataFormat(buffer));
+
       sorter.sort(array, 0, pos, SORT_COMPARATOR);
     }
     return new ShuffleSorterIterator(pos, array, offset);

http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java 
b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
index 1e924d2..717bdd7 100644
--- 
a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
+++ 
b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
@@ -19,14 +19,15 @@ package org.apache.spark.shuffle.sort;
 
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.LongArray;
-import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.collection.SortDataFormat;
 
 final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, 
LongArray> {
 
-  public static final ShuffleSortDataFormat INSTANCE = new 
ShuffleSortDataFormat();
+  private final LongArray buffer;
 
-  private ShuffleSortDataFormat() { }
+  ShuffleSortDataFormat(LongArray buffer) {
+    this.buffer = buffer;
+  }
 
   @Override
   public PackedRecordPointer getKey(LongArray data, int pos) {
@@ -70,8 +71,8 @@ final class ShuffleSortDataFormat extends 
SortDataFormat<PackedRecordPointer, Lo
 
   @Override
   public LongArray allocate(int length) {
-    // This buffer is used temporary (usually small), so it's fine to 
allocated from JVM heap.
-    return new LongArray(MemoryBlock.fromLongArray(new long[length]));
+    assert (length <= buffer.size()) :
+      "the buffer is smaller than required: " + buffer.size() + " < " + length;
+    return buffer;
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java 
b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 6c00608..dc04025 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -221,7 +221,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
       SparkEnv.get() != null ? SparkEnv.get().blockManager() :  null,
       SparkEnv.get() != null ? SparkEnv.get().serializerManager() :  null,
       initialCapacity,
-      0.70,
+      // In order to re-use the longArray for sorting, the load factor cannot 
be larger than 0.5.
+      0.5,
       pageSizeBytes,
       enablePerfMetrics);
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 0cce792..c7b070f 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -25,6 +25,7 @@ import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.collection.Sorter;
 
 /**
@@ -69,8 +70,6 @@ public final class UnsafeInMemorySorter {
   private final MemoryConsumer consumer;
   private final TaskMemoryManager memoryManager;
   @Nullable
-  private final Sorter<RecordPointerAndKeyPrefix, LongArray> sorter;
-  @Nullable
   private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
 
   /**
@@ -80,13 +79,11 @@ public final class UnsafeInMemorySorter {
   private final PrefixComparators.RadixSortSupport radixSortSupport;
 
   /**
-   * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 
1x.
-   */
-  private final int memoryAllocationFactor;
-
-  /**
    * Within this buffer, position {@code 2 * i} holds a pointer pointer to the 
record at
    * index {@code i}, while position {@code 2 * i + 1} in the array holds an 
8-byte key prefix.
+   *
+   * Only part of the array will be used to store the pointers, the rest part 
is preserved as
+   * temporary buffer for sorting.
    */
   private LongArray array;
 
@@ -95,6 +92,11 @@ public final class UnsafeInMemorySorter {
    */
   private int pos = 0;
 
+  /**
+   * How many records could be inserted, because part of the array should be 
left for sorting.
+   */
+  private int usableCapacity = 0;
+
   private long initialSize;
 
   private long totalSortTimeNanos = 0L;
@@ -121,7 +123,6 @@ public final class UnsafeInMemorySorter {
     this.memoryManager = memoryManager;
     this.initialSize = array.size();
     if (recordComparator != null) {
-      this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
       this.sortComparator = new SortComparator(recordComparator, 
prefixComparator, memoryManager);
       if (canUseRadixSort && prefixComparator instanceof 
PrefixComparators.RadixSortSupport) {
         this.radixSortSupport = 
(PrefixComparators.RadixSortSupport)prefixComparator;
@@ -129,12 +130,17 @@ public final class UnsafeInMemorySorter {
         this.radixSortSupport = null;
       }
     } else {
-      this.sorter = null;
       this.sortComparator = null;
       this.radixSortSupport = null;
     }
-    this.memoryAllocationFactor = this.radixSortSupport != null ? 2 : 1;
     this.array = array;
+    this.usableCapacity = getUsableCapacity();
+  }
+
+  private int getUsableCapacity() {
+    // Radix sort requires same amount of used memory as buffer, Tim sort 
requires
+    // half of the used memory as buffer.
+    return (int) (array.size() / (radixSortSupport != null ? 2 : 1.5));
   }
 
   /**
@@ -150,7 +156,8 @@ public final class UnsafeInMemorySorter {
   public void reset() {
     if (consumer != null) {
       consumer.freeArray(array);
-      this.array = consumer.allocateArray(initialSize);
+      array = consumer.allocateArray(initialSize);
+      usableCapacity = getUsableCapacity();
     }
     pos = 0;
   }
@@ -174,7 +181,7 @@ public final class UnsafeInMemorySorter {
   }
 
   public boolean hasSpaceForAnotherRecord() {
-    return pos + 1 < (array.size() / memoryAllocationFactor);
+    return pos + 1 < usableCapacity;
   }
 
   public void expandPointerArray(LongArray newArray) {
@@ -186,9 +193,10 @@ public final class UnsafeInMemorySorter {
       array.getBaseOffset(),
       newArray.getBaseObject(),
       newArray.getBaseOffset(),
-      array.size() * (8 / memoryAllocationFactor));
+      pos * 8L);
     consumer.freeArray(array);
     array = newArray;
+    usableCapacity = getUsableCapacity();
   }
 
   /**
@@ -275,13 +283,20 @@ public final class UnsafeInMemorySorter {
   public SortedIterator getSortedIterator() {
     int offset = 0;
     long start = System.nanoTime();
-    if (sorter != null) {
+    if (sortComparator != null) {
       if (this.radixSortSupport != null) {
         // TODO(ekl) we should handle NULL values before radix sort for 
efficiency, since they
         // force a full-width sort (and we cannot radix-sort nullable long 
fields at all).
         offset = RadixSort.sortKeyPrefixArray(
           array, pos / 2, 0, 7, radixSortSupport.sortDescending(), 
radixSortSupport.sortSigned());
       } else {
+        MemoryBlock unused = new MemoryBlock(
+          array.getBaseObject(),
+          array.getBaseOffset() + pos * 8L,
+          (array.size() - pos) * 8L);
+        LongArray buffer = new LongArray(unused);
+        Sorter<RecordPointerAndKeyPrefix, LongArray> sorter =
+          new Sorter<>(new UnsafeSortDataFormat(buffer));
         sorter.sort(array, 0, pos / 2, sortComparator);
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
index 7bda769..430bf67 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -19,7 +19,6 @@ package org.apache.spark.util.collection.unsafe.sort;
 
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.LongArray;
-import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.collection.SortDataFormat;
 
 /**
@@ -32,9 +31,11 @@ import org.apache.spark.util.collection.SortDataFormat;
 public final class UnsafeSortDataFormat
   extends SortDataFormat<RecordPointerAndKeyPrefix, LongArray> {
 
-  public static final UnsafeSortDataFormat INSTANCE = new 
UnsafeSortDataFormat();
+  private final LongArray buffer;
 
-  private UnsafeSortDataFormat() { }
+  public UnsafeSortDataFormat(LongArray buffer) {
+    this.buffer = buffer;
+  }
 
   @Override
   public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) {
@@ -83,9 +84,9 @@ public final class UnsafeSortDataFormat
 
   @Override
   public LongArray allocate(int length) {
-    assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too 
large";
-    // This is used as temporary buffer, it's fine to allocate from JVM heap.
-    return new LongArray(MemoryBlock.fromLongArray(new long[length * 2]));
+    assert (length * 2 <= buffer.size()) :
+      "the buffer is smaller than required: " + buffer.size() + " < " + 
(length * 2);
+    return buffer;
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
 
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index f9dc20d..7dd61f8 100644
--- 
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -21,12 +21,15 @@ import java.io.*;
 import java.nio.ByteBuffer;
 import java.util.*;
 
-import scala.*;
+import scala.Option;
+import scala.Product2;
+import scala.Tuple2;
+import scala.Tuple2$;
 import scala.collection.Iterator;
 import scala.runtime.AbstractFunction1;
 
-import com.google.common.collect.Iterators;
 import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Iterators;
 import com.google.common.io.ByteStreams;
 import org.junit.After;
 import org.junit.Before;
@@ -35,29 +38,33 @@ import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.greaterThan;
-import static org.hamcrest.Matchers.lessThan;
-import static org.junit.Assert.*;
-import static org.mockito.Answers.RETURNS_SMART_NULLS;
-import static org.mockito.Mockito.*;
 
-import org.apache.spark.*;
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.ShuffleDependency;
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
 import org.apache.spark.io.CompressionCodec$;
 import org.apache.spark.io.LZ4CompressionCodec;
 import org.apache.spark.io.LZFCompressionCodec;
 import org.apache.spark.io.SnappyCompressionCodec;
-import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.network.util.LimitedInputStream;
-import org.apache.spark.serializer.*;
 import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.serializer.*;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
 import org.apache.spark.storage.*;
-import org.apache.spark.memory.TestMemoryManager;
-import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+import static org.junit.Assert.*;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
 public class UnsafeShuffleWriterSuite {
 
   static final int NUM_PARTITITONS = 4;

http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
 
b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 84b82f5..fc127f0 100644
--- 
a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ 
b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -589,7 +589,7 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Test
   public void multipleValuesForSameKey() {
     BytesToBytesMap map =
-      new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 
1, 0.75, 1024, false);
+      new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 
1, 0.5, 1024, false);
     try {
       int i;
       for (i = 0; i < 1024; i++) {

http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
 
b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 699f7fa..6bcc601 100644
--- 
a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -106,8 +106,10 @@ class ExternalSorterSuite extends SparkFunSuite with 
LocalSparkContext {
     // that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi()
     val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i 
else i }
     val buf = new LongArray(MemoryBlock.fromLongArray(ref))
+    val tmp = new Array[Long](size/2)
+    val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp))
 
-    new Sorter(UnsafeSortDataFormat.INSTANCE).sort(
+    new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort(
       buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] {
         override def compare(
             r1: RecordPointerAndKeyPrefix,

http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
 
b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
index def0752..1d26d4a 100644
--- 
a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
@@ -93,7 +93,8 @@ class RadixSortSuite extends SparkFunSuite with Logging {
   }
 
   private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: 
PrefixComparator) {
-    new Sorter(UnsafeSortDataFormat.INSTANCE).sort(
+    val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new 
Array[Long](buf.size().toInt)))
+    new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
       buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
         override def compare(
             r1: RecordPointerAndKeyPrefix,

http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 38dbfef..bb823cd 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -73,6 +73,8 @@ public final class UnsafeKVExternalSorter {
     PrefixComparator prefixComparator = 
SortPrefixUtils.getPrefixComparator(keySchema);
     BaseOrdering ordering = GenerateOrdering.create(keySchema);
     KVComparator recordComparator = new KVComparator(ordering, 
keySchema.length());
+    boolean canUseRadixSort = keySchema.length() == 1 &&
+      SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0));
 
     TaskMemoryManager taskMemoryManager = taskContext.taskMemoryManager();
 
@@ -86,14 +88,16 @@ public final class UnsafeKVExternalSorter {
         prefixComparator,
         /* initialSize */ 4096,
         pageSizeBytes,
-        keySchema.length() == 1 && 
SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0)));
+        canUseRadixSort);
     } else {
+      // The array will be used to do in-place sort, which require half of the 
space to be empty.
+      assert(map.numKeys() <= map.getArray().size() / 2);
       // During spilling, the array in map will not be used, so we can borrow 
that and use it
       // as the underline array for in-memory sorter (it's always large 
enough).
       // Since we will not grow the array, it's fine to pass `null` as 
consumer.
       final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
         null, taskMemoryManager, recordComparator, prefixComparator, 
map.getArray(),
-        false /* TODO(ekl) we can only radix sort if the BytesToBytes load 
factor is <= 0.5 */);
+        canUseRadixSort);
 
       // We cannot use the destructive iterator here because we are reusing 
the existing memory
       // pages in BytesToBytesMap to hold records during sorting.

http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index cd6b97a..412e8c5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -540,7 +540,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
       Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, 
Platform.LONG_ARRAY_OFFSET,
         cursor - Platform.LONG_ARRAY_OFFSET)
       page = newPage
-      freeMemory(used * 8)
+      freeMemory(used * 8L)
     }
 
     // copy the bytes of UnsafeRow
@@ -599,7 +599,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
       i += 2
     }
     old_array = null  // release the reference to old array
-    freeMemory(n * 8)
+    freeMemory(n * 8L)
   }
 
   /**
@@ -610,7 +610,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
     // Convert to dense mode if it does not require more memory or could fit 
within L1 cache
     if (range < array.length || range < 1024) {
       try {
-        ensureAcquireMemory((range + 1) * 8)
+        ensureAcquireMemory((range + 1) * 8L)
       } catch {
         case e: SparkException =>
           // there is no enough memory to convert
@@ -628,7 +628,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
       val old_length = array.length
       array = denseArray
       isDense = true
-      freeMemory(old_length * 8)
+      freeMemory(old_length * 8L)
     }
   }
 
@@ -637,11 +637,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: 
TaskMemoryManager, cap
    */
   def free(): Unit = {
     if (page != null) {
-      freeMemory(page.length * 8)
+      freeMemory(page.length * 8L)
       page = null
     }
     if (array != null) {
-      freeMemory(array.length * 8)
+      freeMemory(array.length * 8L)
       array = null
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/3074f575/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
index 0e1868d..9964b73 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
@@ -36,7 +36,8 @@ import org.apache.spark.util.random.XORShiftRandom
 class SortBenchmark extends BenchmarkBase {
 
   private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: 
PrefixComparator) {
-    new Sorter(UnsafeSortDataFormat.INSTANCE).sort(
+    val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new 
Array[Long](buf.size().toInt)))
+    new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
       buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
         override def compare(
           r1: RecordPointerAndKeyPrefix,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to