Repository: spark
Updated Branches:
  refs/heads/master 50d3242d6 -> 96aa01378


[SPARK-8492] [SQL] support binaryType in UnsafeRow

Support BinaryType in UnsafeRow, just like StringType.

Also change the layout of StringType and BinaryType in UnsafeRow, by combining 
offset and size together as Long, which will limit the size of Row to under 2G 
(given that fact that any single buffer can not be bigger than 2G in JVM).

Author: Davies Liu <dav...@databricks.com>

Closes #6911 from davies/unsafe_bin and squashes the following commits:

d68706f [Davies Liu] update comment
519f698 [Davies Liu] address comment
98a964b [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
unsafe_bin
180b49d [Davies Liu] fix zero-out
22e4c0a [Davies Liu] zero-out padding bytes
6abfe93 [Davies Liu] fix style
447dea0 [Davies Liu] support binaryType in UnsafeRow


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/96aa0137
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/96aa0137
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/96aa0137

Branch: refs/heads/master
Commit: 96aa01378e3b3dbb4601d31c7312a311cb65b22e
Parents: 50d3242
Author: Davies Liu <dav...@databricks.com>
Authored: Mon Jun 22 15:22:17 2015 -0700
Committer: Davies Liu <dav...@databricks.com>
Committed: Mon Jun 22 15:22:17 2015 -0700

----------------------------------------------------------------------
 .../UnsafeFixedWidthAggregationMap.java         |  8 ---
 .../sql/catalyst/expressions/UnsafeRow.java     | 34 ++++++-----
 .../expressions/UnsafeRowConverter.scala        | 60 +++++++++++++++-----
 .../expressions/UnsafeRowConverterSuite.scala   | 16 +++---
 4 files changed, 72 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/96aa0137/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index f7849eb..83f2a31 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions;
 
-import java.util.Arrays;
 import java.util.Iterator;
 
 import org.apache.spark.sql.catalyst.InternalRow;
@@ -142,14 +141,7 @@ public final class UnsafeFixedWidthAggregationMap {
     final int groupingKeySize = 
groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
     // Make sure that the buffer is large enough to hold the key. If it's not, 
grow it:
     if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
-      // This new array will be initially zero, so there's no need to zero it 
out here
       groupingKeyConversionScratchSpace = new byte[groupingKeySize];
-    } else {
-      // Zero out the buffer that's used to hold the current row. This is 
necessary in order
-      // to ensure that rows hash properly, since garbage data from the 
previous row could
-      // otherwise end up as padding in this row. As a performance 
optimization, we only zero out
-      // the portion of the buffer that we'll actually write to.
-      Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, 
(byte) 0);
     }
     final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
       groupingKey,

http://git-wip-us.apache.org/repos/asf/spark/blob/96aa0137/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index ed04d2e..bb2f207 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -47,7 +47,8 @@ import static org.apache.spark.sql.types.DataTypes.*;
  * In the `values` region, we store one 8-byte word per field. For fields that 
hold fixed-length
  * primitive types, such as long, double, or int, we store the value directly 
in the word. For
  * fields with non-primitive or variable-length values, we store a relative 
offset (w.r.t. the
- * base address of the row) that points to the beginning of the 
variable-length field.
+ * base address of the row) that points to the beginning of the 
variable-length field, and length
+ * (they are combined into a long).
  *
  * Instances of `UnsafeRow` act as pointers to row data stored in this format.
  */
@@ -92,6 +93,7 @@ public final class UnsafeRow extends BaseMutableRow {
    */
   public static final Set<DataType> readableFieldTypes;
 
+  // TODO: support DecimalType
   static {
     settableFieldTypes = Collections.unmodifiableSet(
       new HashSet<DataType>(
@@ -111,7 +113,8 @@ public final class UnsafeRow extends BaseMutableRow {
     // We support get() on a superset of the types for which we support set():
     final Set<DataType> _readableFieldTypes = new HashSet<DataType>(
       Arrays.asList(new DataType[]{
-        StringType
+        StringType,
+        BinaryType
       }));
     _readableFieldTypes.addAll(settableFieldTypes);
     readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
@@ -222,11 +225,6 @@ public final class UnsafeRow extends BaseMutableRow {
   }
 
   @Override
-  public void setString(int ordinal, String value) {
-    throw new UnsupportedOperationException();
-  }
-
-  @Override
   public int size() {
     return numFields;
   }
@@ -249,6 +247,8 @@ public final class UnsafeRow extends BaseMutableRow {
       return null;
     } else if (dataType == StringType) {
       return getUTF8String(i);
+    } else if (dataType == BinaryType) {
+      return getBinary(i);
     } else {
       throw new UnsupportedOperationException();
     }
@@ -311,19 +311,23 @@ public final class UnsafeRow extends BaseMutableRow {
   }
 
   public UTF8String getUTF8String(int i) {
+    return UTF8String.fromBytes(getBinary(i));
+  }
+
+  public byte[] getBinary(int i) {
     assertIndexIsValid(i);
-    final long offsetToStringSize = getLong(i);
-    final int stringSizeInBytes =
-      (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 
offsetToStringSize);
-    final byte[] strBytes = new byte[stringSizeInBytes];
+    final long offsetAndSize = getLong(i);
+    final int offset = (int)(offsetAndSize >> 32);
+    final int size = (int)(offsetAndSize & ((1L << 32) - 1));
+    final byte[] bytes = new byte[size];
     PlatformDependent.copyMemory(
       baseObject,
-      baseOffset + offsetToStringSize + 8,  // The `+ 8` is to skip past the 
size to get the data
-      strBytes,
+      baseOffset + offset,
+      bytes,
       PlatformDependent.BYTE_ARRAY_OFFSET,
-      stringSizeInBytes
+      size
     );
-    return UTF8String.fromBytes(strBytes);
+    return bytes;
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/spark/blob/96aa0137/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
index 72f740e..89adaf0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.catalyst.util.DateUtils
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.PlatformDependent
 import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -72,6 +70,19 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
    */
   def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long): Int = {
     unsafeRow.pointTo(baseObject, baseOffset, writers.length, null)
+
+    if (writers.length > 0) {
+      // zero-out the bitset
+      var n = writers.length / 64
+      while (n >= 0) {
+        PlatformDependent.UNSAFE.putLong(
+          unsafeRow.getBaseObject,
+          unsafeRow.getBaseOffset + n * 8,
+          0L)
+        n -= 1
+      }
+    }
+
     var fieldNumber = 0
     var appendCursor: Int = fixedLengthSize
     while (fieldNumber < writers.length) {
@@ -122,6 +133,7 @@ private object UnsafeColumnWriter {
       case FloatType => FloatUnsafeColumnWriter
       case DoubleType => DoubleUnsafeColumnWriter
       case StringType => StringUnsafeColumnWriter
+      case BinaryType => BinaryUnsafeColumnWriter
       case DateType => IntUnsafeColumnWriter
       case TimestampType => LongUnsafeColumnWriter
       case t =>
@@ -141,6 +153,7 @@ private object LongUnsafeColumnWriter extends 
LongUnsafeColumnWriter
 private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
 private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
 private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
+private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter
 
 private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
   // Primitives don't write to the variable-length region:
@@ -235,10 +248,13 @@ private class DoubleUnsafeColumnWriter private() extends 
PrimitiveUnsafeColumnWr
   }
 }
 
-private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter {
+private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter {
+
+  def getBytes(source: InternalRow, column: Int): Array[Byte]
+
   def getSize(source: InternalRow, column: Int): Int = {
-    val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length
-    8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
+    val numBytes = getBytes(source, column).length
+    ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
   }
 
   override def write(
@@ -246,19 +262,33 @@ private class StringUnsafeColumnWriter private() extends 
UnsafeColumnWriter {
       target: UnsafeRow,
       column: Int,
       appendCursor: Int): Int = {
-    val value = source.get(column).asInstanceOf[UTF8String]
-    val baseObject = target.getBaseObject
-    val baseOffset = target.getBaseOffset
-    val numBytes = value.getBytes.length
-    PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, 
numBytes)
+    val offset = target.getBaseOffset + appendCursor
+    val bytes = getBytes(source, column)
+    val numBytes = bytes.length
+    if ((numBytes & 0x07) > 0) {
+      // zero-out the padding bytes
+      PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + 
((numBytes >> 3) << 3), 0L)
+    }
     PlatformDependent.copyMemory(
-      value.getBytes,
+      bytes,
       PlatformDependent.BYTE_ARRAY_OFFSET,
-      baseObject,
-      baseOffset + appendCursor + 8,
+      target.getBaseObject,
+      offset,
       numBytes
     )
-    target.setLong(column, appendCursor)
-    8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
+    target.setLong(column, (appendCursor.toLong << 32L) | numBytes.toLong)
+    ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
+  }
+}
+
+private class StringUnsafeColumnWriter private() extends 
BytesUnsafeColumnWriter {
+  def getBytes(source: InternalRow, column: Int): Array[Byte] = {
+    source.getAs[UTF8String](column).getBytes
+  }
+}
+
+private class BinaryUnsafeColumnWriter private() extends 
BytesUnsafeColumnWriter {
+  def getBytes(source: InternalRow, column: Int): Array[Byte] = {
+    source.getAs[Array[Byte]](column)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/96aa0137/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 721ef8a..d8f3351 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -23,8 +23,8 @@ import java.util.Arrays
 import org.scalatest.Matchers
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types._
 import org.apache.spark.sql.catalyst.util.DateUtils
+import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.PlatformDependent
 import org.apache.spark.unsafe.array.ByteArrayMethods
 
@@ -52,19 +52,19 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     unsafeRow.getInt(2) should be (2)
   }
 
-  test("basic conversion with primitive and string types") {
-    val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType)
+  test("basic conversion with primitive, string and binary types") {
+    val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType)
     val converter = new UnsafeRowConverter(fieldTypes)
 
     val row = new SpecificMutableRow(fieldTypes)
     row.setLong(0, 0)
     row.setString(1, "Hello")
-    row.setString(2, "World")
+    row.update(2, "World".getBytes)
 
     val sizeRequired: Int = converter.getSizeRequirement(row)
     sizeRequired should be (8 + (8 * 3) +
-      ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length 
+ 8) +
-      ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length 
+ 8))
+      
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
+      
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
     val numBytesWritten = converter.writeRow(row, buffer, 
PlatformDependent.LONG_ARRAY_OFFSET)
     numBytesWritten should be (sizeRequired)
@@ -73,7 +73,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
     unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, 
fieldTypes.length, null)
     unsafeRow.getLong(0) should be (0)
     unsafeRow.getString(1) should be ("Hello")
-    unsafeRow.getString(2) should be ("World")
+    unsafeRow.getBinary(2) should be ("World".getBytes)
   }
 
   test("basic conversion with primitive, string, date and timestamp types") {
@@ -88,7 +88,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers {
 
     val sizeRequired: Int = converter.getSizeRequirement(row)
     sizeRequired should be (8 + (8 * 4) +
-      ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length 
+ 8))
+      
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
     val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
     val numBytesWritten = converter.writeRow(row, buffer, 
PlatformDependent.LONG_ARRAY_OFFSET)
     numBytesWritten should be (sizeRequired)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to