Repository: spark
Updated Branches:
  refs/heads/master a29081487 -> b55499a44


[SPARK-8932] Support copy() for UnsafeRows that do not use ObjectPools

We call Row.copy() in many places throughout SQL but UnsafeRow currently throws 
UnsupportedOperationException when copy() is called.

Supporting copying when ObjectPool is used may be difficult, since we may need 
to handle deep-copying of objects in the pool. In addition, this copy() method 
needs to produce a self-contained row object which may be passed around / 
buffered by downstream code which does not understand the UnsafeRow format.

In the long run, we'll need to figure out how to handle the ObjectPool corner 
cases, but this may be unnecessary if other changes are made. Therefore, in 
order to unblock my sort patch (#6444) I propose that we support copy() for the 
cases where UnsafeRow does not use an ObjectPool and continue to throw 
UnsupportedOperationException when an ObjectPool is used.

This patch accomplishes this by modifying UnsafeRow so that it knows the size 
of the row's backing data in order to be able to copy it into a byte array.

Author: Josh Rosen <joshro...@databricks.com>

Closes #7306 from JoshRosen/SPARK-8932 and squashes the following commits:

338e6bf [Josh Rosen] Support copy for UnsafeRows that do not use ObjectPools.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b55499a4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b55499a4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b55499a4

Branch: refs/heads/master
Commit: b55499a44ab74e33378211fb0d6940905d7c6318
Parents: a290814
Author: Josh Rosen <joshro...@databricks.com>
Authored: Wed Jul 8 20:28:05 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Wed Jul 8 20:28:05 2015 -0700

----------------------------------------------------------------------
 .../UnsafeFixedWidthAggregationMap.java         | 12 +++--
 .../sql/catalyst/expressions/UnsafeRow.java     | 32 +++++++++++-
 .../expressions/UnsafeRowConverter.scala        | 10 +++-
 .../expressions/UnsafeRowConverterSuite.scala   | 52 +++++++++++++++-----
 4 files changed, 87 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b55499a4/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index 1e79f4b..79d55b3 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -120,9 +120,11 @@ public final class UnsafeFixedWidthAggregationMap {
     this.bufferPool = new ObjectPool(initialCapacity);
 
     InternalRow initRow = initProjection.apply(emptyRow);
-    this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)];
+    int emptyBufferSize = bufferConverter.getSizeRequirement(initRow);
+    this.emptyBuffer = new byte[emptyBufferSize];
     int writtenLength = bufferConverter.writeRow(
-      initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool);
+      initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, 
emptyBufferSize,
+      bufferPool);
     assert (writtenLength == emptyBuffer.length): "Size requirement 
calculation was wrong!";
     // re-use the empty buffer only when there is no object saved in pool.
     reuseEmptyBuffer = bufferPool.size() == 0;
@@ -142,6 +144,7 @@ public final class UnsafeFixedWidthAggregationMap {
       groupingKey,
       groupingKeyConversionScratchSpace,
       PlatformDependent.BYTE_ARRAY_OFFSET,
+      groupingKeySize,
       keyPool);
     assert (groupingKeySize == actualGroupingKeySize) : "Size requirement 
calculation was wrong!";
 
@@ -157,7 +160,7 @@ public final class UnsafeFixedWidthAggregationMap {
         // There is some objects referenced by emptyBuffer, so generate a new 
one
         InternalRow initRow = initProjection.apply(emptyRow);
         bufferConverter.writeRow(initRow, emptyBuffer, 
PlatformDependent.BYTE_ARRAY_OFFSET,
-          bufferPool);
+          groupingKeySize, bufferPool);
       }
       loc.putNewKey(
         groupingKeyConversionScratchSpace,
@@ -175,6 +178,7 @@ public final class UnsafeFixedWidthAggregationMap {
       address.getBaseObject(),
       address.getBaseOffset(),
       bufferConverter.numFields(),
+      loc.getValueLength(),
       bufferPool
     );
     return currentBuffer;
@@ -214,12 +218,14 @@ public final class UnsafeFixedWidthAggregationMap {
           keyAddress.getBaseObject(),
           keyAddress.getBaseOffset(),
           keyConverter.numFields(),
+          loc.getKeyLength(),
           keyPool
         );
         entry.value.pointTo(
           valueAddress.getBaseObject(),
           valueAddress.getBaseOffset(),
           bufferConverter.numFields(),
+          loc.getValueLength(),
           bufferPool
         );
         return entry;

http://git-wip-us.apache.org/repos/asf/spark/blob/b55499a4/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index aeb64b0..edb7202 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -68,6 +68,9 @@ public final class UnsafeRow extends MutableRow {
   /** The number of fields in this row, used for calculating the bitset width 
(and in assertions) */
   private int numFields;
 
+  /** The size of this row's backing data, in bytes) */
+  private int sizeInBytes;
+
   public int length() { return numFields; }
 
   /** The width of the null tracking bit set, in bytes */
@@ -95,14 +98,17 @@ public final class UnsafeRow extends MutableRow {
    * @param baseObject the base object
    * @param baseOffset the offset within the base object
    * @param numFields the number of fields in this row
+   * @param sizeInBytes the size of this row's backing data, in bytes
    * @param pool the object pool to hold arbitrary objects
    */
-  public void pointTo(Object baseObject, long baseOffset, int numFields, 
ObjectPool pool) {
+  public void pointTo(
+      Object baseObject, long baseOffset, int numFields, int sizeInBytes, 
ObjectPool pool) {
     assert numFields >= 0 : "numFields should >= 0";
     this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
     this.baseObject = baseObject;
     this.baseOffset = baseOffset;
     this.numFields = numFields;
+    this.sizeInBytes = sizeInBytes;
     this.pool = pool;
   }
 
@@ -336,9 +342,31 @@ public final class UnsafeRow extends MutableRow {
     }
   }
 
+  /**
+   * Copies this row, returning a self-contained UnsafeRow that stores its 
data in an internal
+   * byte array rather than referencing data stored in a data page.
+   * <p>
+   * This method is only supported on UnsafeRows that do not use ObjectPools.
+   */
   @Override
   public InternalRow copy() {
-    throw new UnsupportedOperationException();
+    if (pool != null) {
+      throw new UnsupportedOperationException(
+        "Copy is not supported for UnsafeRows that use object pools");
+    } else {
+      UnsafeRow rowCopy = new UnsafeRow();
+      final byte[] rowDataCopy = new byte[sizeInBytes];
+      PlatformDependent.copyMemory(
+        baseObject,
+        baseOffset,
+        rowDataCopy,
+        PlatformDependent.BYTE_ARRAY_OFFSET,
+        sizeInBytes
+      );
+      rowCopy.pointTo(
+        rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, 
sizeInBytes, null);
+      return rowCopy;
+    }
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/spark/blob/b55499a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
index 1f39549..6af5e62 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
@@ -70,10 +70,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
    * @param row the row to convert
    * @param baseObject the base object of the destination address
    * @param baseOffset the base offset of the destination address
+   * @param rowLengthInBytes the length calculated by `getSizeRequirement(row)`
    * @return the number of bytes written. This should be equal to 
`getSizeRequirement(row)`.
    */
-  def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: 
ObjectPool): Int = {
-    unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool)
+  def writeRow(
+      row: InternalRow,
+      baseObject: Object,
+      baseOffset: Long,
+      rowLengthInBytes: Int,
+      pool: ObjectPool): Int = {
+    unsafeRow.pointTo(baseObject, baseOffset, writers.length, 
rowLengthInBytes, pool)
 
     if (writers.length > 0) {
       // zero-out the bitset

http://git-wip-us.apache.org/repos/asf/spark/blob/b55499a4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 96d4e64..d00aeb4d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -44,19 +44,32 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     val sizeRequired: Int = converter.getSizeRequirement(row)
     assert(sizeRequired === 8 + (3 * 8))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
-    val numBytesWritten = converter.writeRow(row, buffer, 
PlatformDependent.LONG_ARRAY_OFFSET, null)
+    val numBytesWritten =
+      converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, 
sizeRequired, null)
     assert(numBytesWritten === sizeRequired)
 
     val unsafeRow = new UnsafeRow()
-    unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, 
fieldTypes.length, null)
+    unsafeRow.pointTo(
+      buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, 
sizeRequired, null)
     assert(unsafeRow.getLong(0) === 0)
     assert(unsafeRow.getLong(1) === 1)
     assert(unsafeRow.getInt(2) === 2)
 
+    // We can copy UnsafeRows as long as they don't reference ObjectPools
+    val unsafeRowCopy = unsafeRow.copy()
+    assert(unsafeRowCopy.getLong(0) === 0)
+    assert(unsafeRowCopy.getLong(1) === 1)
+    assert(unsafeRowCopy.getInt(2) === 2)
+
     unsafeRow.setLong(1, 3)
     assert(unsafeRow.getLong(1) === 3)
     unsafeRow.setInt(2, 4)
     assert(unsafeRow.getInt(2) === 4)
+
+    // Mutating the original row should not have changed the copy
+    assert(unsafeRowCopy.getLong(0) === 0)
+    assert(unsafeRowCopy.getLong(1) === 1)
+    assert(unsafeRowCopy.getInt(2) === 2)
   }
 
   test("basic conversion with primitive, string and binary types") {
@@ -73,12 +86,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
       
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
       
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
-    val numBytesWritten = converter.writeRow(row, buffer, 
PlatformDependent.LONG_ARRAY_OFFSET, null)
+    val numBytesWritten = converter.writeRow(
+      row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
     assert(numBytesWritten === sizeRequired)
 
     val unsafeRow = new UnsafeRow()
     val pool = new ObjectPool(10)
-    unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, 
fieldTypes.length, pool)
+    unsafeRow.pointTo(
+      buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, 
sizeRequired, pool)
     assert(unsafeRow.getLong(0) === 0)
     assert(unsafeRow.getString(1) === "Hello")
     assert(unsafeRow.get(2) === "World".getBytes)
@@ -96,6 +111,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     unsafeRow.update(2, "Hello World".getBytes)
     assert(unsafeRow.get(2) === "Hello World".getBytes)
     assert(pool.size === 2)
+
+    // We do not support copy() for UnsafeRows that reference ObjectPools
+    intercept[UnsupportedOperationException] {
+      unsafeRow.copy()
+    }
   }
 
   test("basic conversion with primitive, decimal and array") {
@@ -111,12 +131,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     val sizeRequired: Int = converter.getSizeRequirement(row)
     assert(sizeRequired === 8 + (8 * 3))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
-    val numBytesWritten = converter.writeRow(row, buffer, 
PlatformDependent.LONG_ARRAY_OFFSET, pool)
+    val numBytesWritten =
+      converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, 
sizeRequired, pool)
     assert(numBytesWritten === sizeRequired)
     assert(pool.size === 2)
 
     val unsafeRow = new UnsafeRow()
-    unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, 
fieldTypes.length, pool)
+    unsafeRow.pointTo(
+      buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, 
sizeRequired, pool)
     assert(unsafeRow.getLong(0) === 0)
     assert(unsafeRow.get(1) === Decimal(1))
     assert(unsafeRow.get(2) === Array(2))
@@ -142,11 +164,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     assert(sizeRequired === 8 + (8 * 4) +
       
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
-    val numBytesWritten = converter.writeRow(row, buffer, 
PlatformDependent.LONG_ARRAY_OFFSET, null)
+    val numBytesWritten =
+      converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, 
sizeRequired, null)
     assert(numBytesWritten === sizeRequired)
 
     val unsafeRow = new UnsafeRow()
-    unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, 
fieldTypes.length, null)
+    unsafeRow.pointTo(
+      buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, 
sizeRequired, null)
     assert(unsafeRow.getLong(0) === 0)
     assert(unsafeRow.getString(1) === "Hello")
     // Date is represented as Int in unsafeRow
@@ -190,12 +214,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
     val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
     val numBytesWritten = converter.writeRow(
-      rowWithAllNullColumns, createdFromNullBuffer, 
PlatformDependent.LONG_ARRAY_OFFSET, null)
+      rowWithAllNullColumns, createdFromNullBuffer, 
PlatformDependent.LONG_ARRAY_OFFSET,
+      sizeRequired, null)
     assert(numBytesWritten === sizeRequired)
 
     val createdFromNull = new UnsafeRow()
     createdFromNull.pointTo(
-      createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, 
fieldTypes.length, null)
+      createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, 
fieldTypes.length,
+      sizeRequired, null)
     for (i <- 0 to fieldTypes.length - 1) {
       assert(createdFromNull.isNullAt(i))
     }
@@ -233,10 +259,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     val pool = new ObjectPool(1)
     val setToNullAfterCreationBuffer: Array[Long] = new 
Array[Long](sizeRequired / 8 + 2)
     converter.writeRow(
-      rowWithNoNullColumns, setToNullAfterCreationBuffer, 
PlatformDependent.LONG_ARRAY_OFFSET, pool)
+      rowWithNoNullColumns, setToNullAfterCreationBuffer, 
PlatformDependent.LONG_ARRAY_OFFSET,
+      sizeRequired, pool)
     val setToNullAfterCreation = new UnsafeRow()
     setToNullAfterCreation.pointTo(
-      setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, 
fieldTypes.length, pool)
+      setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, 
fieldTypes.length,
+      sizeRequired, pool)
 
     assert(setToNullAfterCreation.isNullAt(0) === 
rowWithNoNullColumns.isNullAt(0))
     assert(setToNullAfterCreation.getBoolean(1) === 
rowWithNoNullColumns.getBoolean(1))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to