Repository: spark
Updated Branches:
  refs/heads/master df733cbea -> 3d1535d48


[SPARK-9520] [SQL] Support in-place sort in UnsafeFixedWidthAggregationMap

This pull request adds a sortedIterator method to 
UnsafeFixedWidthAggregationMap that sorts its data in-place by the grouping key.

This is needed so we can fallback to external sorting for aggregation.

Author: Reynold Xin <r...@databricks.com>

Closes #7849 from rxin/bytes2bytes-sorting and squashes the following commits:

75018c6 [Reynold Xin] Updated documentation.
81a8694 [Reynold Xin] [SPARK-9520][SQL] Support in-place sort in 
UnsafeFixedWidthAggregationMap.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3d1535d4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3d1535d4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3d1535d4

Branch: refs/heads/master
Commit: 3d1535d48822281953de1e8447de86fad728412a
Parents: df733cb
Author: Reynold Xin <r...@databricks.com>
Authored: Sat Aug 1 13:20:26 2015 -0700
Committer: Josh Rosen <joshro...@databricks.com>
Committed: Sat Aug 1 13:20:26 2015 -0700

----------------------------------------------------------------------
 .../spark/unsafe/map/BytesToBytesMap.java       |  41 ++++++--
 .../sql/catalyst/expressions/Projection.scala   |   2 +
 .../expressions/codegen/GenerateOrdering.scala  |  12 ++-
 .../UnsafeFixedWidthAggregationMap.java         | 100 ++++++++++++++++++-
 .../spark/sql/execution/SortPrefixUtils.scala   |  18 +++-
 .../UnsafeFixedWidthAggregationMapSuite.scala   |  34 +++++++
 .../org/apache/spark/unsafe/KVIterator.java     |   4 +-
 7 files changed, 196 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3d1535d4/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java 
b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 481375f..cf222b7 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -23,6 +23,8 @@ import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 
+import javax.annotation.Nullable;
+
 import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -217,6 +219,7 @@ public final class BytesToBytesMap {
     private final Iterator<MemoryBlock> dataPagesIterator;
     private final Location loc;
 
+    private MemoryBlock currentPage;
     private int currentRecordNumber = 0;
     private Object pageBaseObject;
     private long offsetInPage;
@@ -232,7 +235,7 @@ public final class BytesToBytesMap {
     }
 
     private void advanceToNextPage() {
-      final MemoryBlock currentPage = dataPagesIterator.next();
+      currentPage = dataPagesIterator.next();
       pageBaseObject = currentPage.getBaseObject();
       offsetInPage = currentPage.getBaseOffset();
     }
@@ -249,7 +252,7 @@ public final class BytesToBytesMap {
         advanceToNextPage();
         totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, 
offsetInPage);
       }
-      loc.with(pageBaseObject, offsetInPage);
+      loc.with(currentPage, offsetInPage);
       offsetInPage += 8 + totalLength;
       currentRecordNumber++;
       return loc;
@@ -346,14 +349,19 @@ public final class BytesToBytesMap {
     private int keyLength;
     private int valueLength;
 
+    /**
+     * Memory page containing the record. Only set if created by {@link 
BytesToBytesMap#iterator()}.
+     */
+    @Nullable private MemoryBlock memoryPage;
+
     private void updateAddressesAndSizes(long fullKeyAddress) {
       updateAddressesAndSizes(
         taskMemoryManager.getPage(fullKeyAddress),
         taskMemoryManager.getOffsetInPage(fullKeyAddress));
     }
 
-    private void updateAddressesAndSizes(final Object page, final long 
keyOffsetInPage) {
-      long position = keyOffsetInPage;
+    private void updateAddressesAndSizes(final Object page, final long 
offsetInPage) {
+      long position = offsetInPage;
       final int totalLength = PlatformDependent.UNSAFE.getInt(page, position);
       position += 4;
       keyLength = PlatformDependent.UNSAFE.getInt(page, position);
@@ -366,7 +374,7 @@ public final class BytesToBytesMap {
       valueMemoryLocation.setObjAndOffset(page, position);
     }
 
-    Location with(int pos, int keyHashcode, boolean isDefined) {
+    private Location with(int pos, int keyHashcode, boolean isDefined) {
       this.pos = pos;
       this.isDefined = isDefined;
       this.keyHashcode = keyHashcode;
@@ -377,13 +385,22 @@ public final class BytesToBytesMap {
       return this;
     }
 
-    Location with(Object page, long keyOffsetInPage) {
+    private Location with(MemoryBlock page, long offsetInPage) {
       this.isDefined = true;
-      updateAddressesAndSizes(page, keyOffsetInPage);
+      this.memoryPage = page;
+      updateAddressesAndSizes(page.getBaseObject(), offsetInPage);
       return this;
     }
 
     /**
+     * Returns the memory page that contains the current record.
+     * This is only valid if this is returned by {@link 
BytesToBytesMap#iterator()}.
+     */
+    public MemoryBlock getMemoryPage() {
+      return this.memoryPage;
+    }
+
+    /**
      * Returns true if the key is defined at this position, and false 
otherwise.
      */
     public boolean isDefined() {
@@ -538,7 +555,7 @@ public final class BytesToBytesMap {
       long insertCursor = dataPageInsertOffset;
 
       // Compute all of our offsets up-front:
-      final long totalLengthOffset = insertCursor;
+      final long recordOffset = insertCursor;
       insertCursor += 4;
       final long keyLengthOffset = insertCursor;
       insertCursor += 4;
@@ -547,7 +564,7 @@ public final class BytesToBytesMap {
       final long valueDataOffsetInPage = insertCursor;
       insertCursor += valueLengthBytes; // word used to store the value size
 
-      PlatformDependent.UNSAFE.putInt(dataPageBaseObject, totalLengthOffset,
+      PlatformDependent.UNSAFE.putInt(dataPageBaseObject, recordOffset,
         keyLengthBytes + valueLengthBytes);
       PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, 
keyLengthBytes);
       // Copy the key
@@ -569,7 +586,7 @@ public final class BytesToBytesMap {
       numElements++;
       bitset.set(pos);
       final long storedKeyAddress = 
taskMemoryManager.encodePageNumberAndOffset(
-        dataPage, totalLengthOffset);
+        dataPage, recordOffset);
       longArray.set(pos * 2, storedKeyAddress);
       longArray.set(pos * 2 + 1, keyHashcode);
       updateAddressesAndSizes(storedKeyAddress);
@@ -618,6 +635,10 @@ public final class BytesToBytesMap {
     assert(dataPages.isEmpty());
   }
 
+  public TaskMemoryManager getTaskMemoryManager() {
+    return taskMemoryManager;
+  }
+
   /** Returns the total amount of memory, in bytes, consumed by this map's 
managed structures. */
   public long getTotalMemoryConsumption() {
     long totalDataPagesSize = 0L;

http://git-wip-us.apache.org/repos/asf/spark/blob/3d1535d4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index d79325a..000be70 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -125,6 +125,8 @@ object UnsafeProjection {
     GenerateUnsafeProjection.generate(exprs)
   }
 
+  def create(expr: Expression): UnsafeProjection = create(Seq(expr))
+
   /**
    * Returns an UnsafeProjection for given sequence of Expressions, which will 
be bound to
    * `inputSchema`.

http://git-wip-us.apache.org/repos/asf/spark/blob/3d1535d4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index dbd4616..cc848aa 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -21,6 +21,7 @@ import org.apache.spark.Logging
 import org.apache.spark.annotation.Private
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.StructType
 
 /**
  * Inherits some default implementation for Java from `Ordering[Row]`
@@ -43,7 +44,16 @@ object GenerateOrdering extends 
CodeGenerator[Seq[SortOrder], Ordering[InternalR
   protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): 
Seq[SortOrder] =
     in.map(BindReferences.bindReference(_, inputSchema))
 
-  protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = {
+  /**
+   * Creates a code gen ordering for sorting this schema, in ascending order.
+   */
+  def create(schema: StructType): BaseOrdering = {
+    create(schema.zipWithIndex.map { case (field, ordinal) =>
+      SortOrder(BoundReference(ordinal, field.dataType, nullable = true), 
Ascending)
+    })
+  }
+
+  protected def create(ordering: Seq[SortOrder]): BaseOrdering = {
     val ctx = newCodeGenContext()
 
     val comparisons = ordering.map { order =>

http://git-wip-us.apache.org/repos/asf/spark/blob/3d1535d4/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index c18b6de..a0a8dd5 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -17,19 +17,26 @@
 
 package org.apache.spark.sql.execution;
 
+import java.io.IOException;
+
 import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
-import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.spark.unsafe.KVIterator;
 import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.unsafe.memory.MemoryLocation;
 import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
+import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
 
 /**
  * Unsafe-based HashMap for performing aggregations where the aggregated 
values are fixed-width.
@@ -225,4 +232,93 @@ public final class UnsafeFixedWidthAggregationMap {
     System.out.println("Total memory consumption (bytes): " + 
map.getTotalMemoryConsumption());
   }
 
+  /**
+   * Sorts the key, value data in this map in place, and return them as an 
iterator.
+   *
+   * The only memory that is allocated is the address/prefix array, 16 bytes 
per record.
+   */
+  public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() {
+    int numElements = map.numElements();
+    final int numKeyFields = groupingKeySchema.size();
+    TaskMemoryManager memoryManager = map.getTaskMemoryManager();
+
+    UnsafeExternalRowSorter.PrefixComputer prefixComp =
+      SortPrefixUtils.createPrefixGenerator(groupingKeySchema);
+    PrefixComparator prefixComparator = 
SortPrefixUtils.getPrefixComparator(groupingKeySchema);
+
+    final BaseOrdering ordering = GenerateOrdering.create(groupingKeySchema);
+    RecordComparator recordComparator = new RecordComparator() {
+      private final UnsafeRow row1 = new UnsafeRow();
+      private final UnsafeRow row2 = new UnsafeRow();
+
+      @Override
+      public int compare(Object baseObj1, long baseOff1, Object baseObj2, long 
baseOff2) {
+        row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1);
+        row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1);
+        return ordering.compare(row1, row2);
+      }
+    };
+
+    // Insert the records into the in-memory sorter.
+    final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
+      memoryManager, recordComparator, prefixComparator, numElements);
+
+    BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
+    UnsafeRow row = new UnsafeRow();
+    while (iter.hasNext()) {
+      final BytesToBytesMap.Location loc = iter.next();
+      final Object baseObject = loc.getKeyAddress().getBaseObject();
+      final long baseOffset = loc.getKeyAddress().getBaseOffset();
+
+      // Get encoded memory address
+      MemoryBlock page = loc.getMemoryPage();
+      long address = memoryManager.encodePageNumberAndOffset(page, baseOffset 
- 8);
+
+      // Compute prefix
+      row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength());
+      final long prefix = prefixComp.computePrefix(row);
+
+      sorter.insertRecord(address, prefix);
+    }
+
+    // Return the sorted result as an iterator.
+    return new KVIterator<UnsafeRow, UnsafeRow>() {
+
+      private UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
+      private final UnsafeRow key = new UnsafeRow();
+      private final UnsafeRow value = new UnsafeRow();
+      private int numValueFields = aggregationBufferSchema.size();
+
+      @Override
+      public boolean next() throws IOException {
+        if (sortedIterator.hasNext()) {
+          sortedIterator.loadNext();
+          Object baseObj = sortedIterator.getBaseObject();
+          long recordOffset = sortedIterator.getBaseOffset();
+          int recordLen = sortedIterator.getRecordLength();
+          int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
+          key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
+          value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, 
recordLen - keyLen);
+          return true;
+        } else {
+          return false;
+        }
+      }
+
+      @Override
+      public UnsafeRow getKey() {
+        return key;
+      }
+
+      @Override
+      public UnsafeRow getValue() {
+        return value;
+      }
+
+      @Override
+      public void close() {
+        // Do nothing
+      }
+    };
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3d1535d4/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index a2145b1..17d4166 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -18,7 +18,8 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, 
PrefixComparator}
 
@@ -46,4 +47,19 @@ object SortPrefixUtils {
       case _ => NoOpPrefixComparator
     }
   }
+
+  def getPrefixComparator(schema: StructType): PrefixComparator = {
+    val field = schema.head
+    getPrefixComparator(SortOrder(BoundReference(0, field.dataType, 
field.nullable), Ascending))
+  }
+
+  def createPrefixGenerator(schema: StructType): 
UnsafeExternalRowSorter.PrefixComputer = {
+    val boundReference = BoundReference(0, schema.head.dataType, nullable = 
true)
+    val prefixProjection = 
UnsafeProjection.create(SortPrefix(SortOrder(boundReference, Ascending)))
+    new UnsafeExternalRowSorter.PrefixComputer {
+      override def computePrefix(row: InternalRow): Long = {
+        prefixProjection.apply(row).getLong(0)
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3d1535d4/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 6a2c51c..098bdd0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -140,4 +140,38 @@ class UnsafeFixedWidthAggregationMapSuite
     map.free()
   }
 
+  test("test sorting") {
+    val map = new UnsafeFixedWidthAggregationMap(
+      emptyAggregationBuffer,
+      aggBufferSchema,
+      groupKeySchema,
+      taskMemoryManager,
+      shuffleMemoryManager,
+      128, // initial capacity
+      PAGE_SIZE_BYTES,
+      false // disable perf metrics
+    )
+
+    val rand = new Random(42)
+    val groupKeys: Set[String] = Seq.fill(512) {
+      Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
+    }.toSet
+    groupKeys.foreach { keyString =>
+      val buf = 
map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
+      buf.setInt(0, keyString.length)
+      assert(buf != null)
+    }
+
+    val out = new scala.collection.mutable.ArrayBuffer[String]
+    val iter = map.sortedIterator()
+    while (iter.next()) {
+      assert(iter.getKey.getString(0).length === iter.getValue.getInt(0))
+      out += iter.getKey.getString(0)
+    }
+
+    assert(out === groupKeys.toSeq.sorted)
+
+    map.free()
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3d1535d4/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java 
b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
index fb16340..5c9d5d9 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
@@ -17,9 +17,11 @@
 
 package org.apache.spark.unsafe;
 
+import java.io.IOException;
+
 public abstract class KVIterator<K, V> {
 
-  public abstract boolean next();
+  public abstract boolean next() throws IOException;
 
   public abstract K getKey();
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to