Repository: spark
Updated Branches:
  refs/heads/master 95034af69 -> 3323b156f


[SPARK-23864][SQL] Add unsafe object writing to UnsafeWriter

## What changes were proposed in this pull request?
This PR moves writing of `UnsafeRow`, `UnsafeArrayData` & `UnsafeMapData` out 
of the `GenerateUnsafeProjection`/`InterpretedUnsafeProjection` classes into 
the `UnsafeWriter` interface. This cleans up the code a little bit, and it 
should also result in less byte code for the code generated path.

## How was this patch tested?
Existing tests

Author: Herman van Hovell <hvanhov...@databricks.com>

Closes #20986 from hvanhovell/SPARK-23864.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3323b156
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3323b156
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3323b156

Branch: refs/heads/master
Commit: 3323b156f9c0beb0b3c2b724a6faddc6ffdfe99a
Parents: 95034af
Author: Herman van Hovell <hvanhov...@databricks.com>
Authored: Tue Apr 10 17:32:00 2018 +0200
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Tue Apr 10 17:32:00 2018 +0200

----------------------------------------------------------------------
 .../expressions/codegen/UnsafeWriter.java       |  72 +++--
 .../InterpretedUnsafeProjection.scala           |  46 +--
 .../codegen/GenerateUnsafeProjection.scala      | 322 ++++++++-----------
 .../spark/sql/types/UserDefinedType.scala       |  10 +
 4 files changed, 204 insertions(+), 246 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3323b156/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index de0eb6d..2781655 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -16,6 +16,9 @@
  */
 package org.apache.spark.sql.catalyst.expressions.codegen;
 
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.catalyst.expressions.UnsafeMapData;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
 import org.apache.spark.sql.types.Decimal;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -103,21 +106,7 @@ public abstract class UnsafeWriter {
   public abstract void write(int ordinal, Decimal input, int precision, int 
scale);
 
   public final void write(int ordinal, UTF8String input) {
-    final int numBytes = input.numBytes();
-    final int roundedSize = 
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
-
-    // grow the global buffer before writing data.
-    grow(roundedSize);
-
-    zeroOutPaddingBytes(numBytes);
-
-    // Write the bytes to the variable length portion.
-    input.writeToMemory(getBuffer(), cursor());
-
-    setOffsetAndSize(ordinal, numBytes);
-
-    // move the cursor forward.
-    increaseCursor(roundedSize);
+    writeUnalignedBytes(ordinal, input.getBaseObject(), input.getBaseOffset(), 
input.numBytes());
   }
 
   public final void write(int ordinal, byte[] input) {
@@ -125,20 +114,19 @@ public abstract class UnsafeWriter {
   }
 
   public final void write(int ordinal, byte[] input, int offset, int numBytes) 
{
-    final int roundedSize = 
ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
+    writeUnalignedBytes(ordinal, input, Platform.BYTE_ARRAY_OFFSET + offset, 
numBytes);
+  }
 
-    // grow the global buffer before writing data.
+  private void writeUnalignedBytes(
+      int ordinal,
+      Object baseObject,
+      long baseOffset,
+      int numBytes) {
+    final int roundedSize = 
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
     grow(roundedSize);
-
     zeroOutPaddingBytes(numBytes);
-
-    // Write the bytes to the variable length portion.
-    Platform.copyMemory(
-      input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), 
numBytes);
-
+    Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), 
numBytes);
     setOffsetAndSize(ordinal, numBytes);
-
-    // move the cursor forward.
     increaseCursor(roundedSize);
   }
 
@@ -156,6 +144,40 @@ public abstract class UnsafeWriter {
     increaseCursor(16);
   }
 
+  public final void write(int ordinal, UnsafeRow row) {
+    writeAlignedBytes(ordinal, row.getBaseObject(), row.getBaseOffset(), 
row.getSizeInBytes());
+  }
+
+  public final void write(int ordinal, UnsafeMapData map) {
+    writeAlignedBytes(ordinal, map.getBaseObject(), map.getBaseOffset(), 
map.getSizeInBytes());
+  }
+
+  public final void write(UnsafeArrayData array) {
+    // Unsafe arrays both can be written as a regular array field or as part 
of a map. This makes
+    // updating the offset and size dependent on the code path, this is why we 
currently do not
+    // provide an method for writing unsafe arrays that also updates the size 
and offset.
+    int numBytes = array.getSizeInBytes();
+    grow(numBytes);
+    Platform.copyMemory(
+            array.getBaseObject(),
+            array.getBaseOffset(),
+            getBuffer(),
+            cursor(),
+            numBytes);
+    increaseCursor(numBytes);
+  }
+
+  private void writeAlignedBytes(
+      int ordinal,
+      Object baseObject,
+      long baseOffset,
+      int numBytes) {
+    grow(numBytes);
+    Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), 
numBytes);
+    setOffsetAndSize(ordinal, numBytes);
+    increaseCursor(numBytes);
+  }
+
   protected final void writeBoolean(long offset, boolean value) {
     Platform.putBoolean(getBuffer(), offset, value);
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/3323b156/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
index b31466f..6d69d69 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -173,21 +173,17 @@ object InterpretedUnsafeProjection extends 
UnsafeProjectionCreator {
         val rowWriter = new UnsafeRowWriter(writer, numFields)
         val structWriter = generateStructWriter(rowWriter, fields)
         (v, i) => {
-          val previousCursor = writer.cursor()
           v.getStruct(i, fields.length) match {
             case row: UnsafeRow =>
-              writeUnsafeData(
-                rowWriter,
-                row.getBaseObject,
-                row.getBaseOffset,
-                row.getSizeInBytes)
+              writer.write(i, row)
             case row =>
+              val previousCursor = writer.cursor()
               // Nested struct. We don't know where this will start because a 
row can be
               // variable length, so we need to update the offsets and zero 
out the bit mask.
               rowWriter.resetRowWriter()
               structWriter.apply(row)
+              writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
           }
-          writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
         }
 
       case ArrayType(elementType, containsNull) =>
@@ -214,15 +210,12 @@ object InterpretedUnsafeProjection extends 
UnsafeProjectionCreator {
           valueType,
           valueContainsNull)
         (v, i) => {
-          val previousCursor = writer.cursor()
           v.getMap(i) match {
             case map: UnsafeMapData =>
-              writeUnsafeData(
-                valueArrayWriter,
-                map.getBaseObject,
-                map.getBaseOffset,
-                map.getSizeInBytes)
+              writer.write(i, map)
             case map =>
+              val previousCursor = writer.cursor()
+
               // preserve 8 bytes to write the key array numBytes later.
               valueArrayWriter.grow(8)
               valueArrayWriter.increaseCursor(8)
@@ -237,8 +230,8 @@ object InterpretedUnsafeProjection extends 
UnsafeProjectionCreator {
 
               // Write the values.
               writeArray(valueArrayWriter, valueWriter, map.valueArray())
+              writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
           }
-          writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
         }
 
       case udt: UserDefinedType[_] =>
@@ -318,11 +311,7 @@ object InterpretedUnsafeProjection extends 
UnsafeProjectionCreator {
       elementWriter: (SpecializedGetters, Int) => Unit,
       array: ArrayData): Unit = array match {
     case unsafe: UnsafeArrayData =>
-      writeUnsafeData(
-        arrayWriter,
-        unsafe.getBaseObject,
-        unsafe.getBaseOffset,
-        unsafe.getSizeInBytes)
+      arrayWriter.write(unsafe)
     case _ =>
       val numElements = array.numElements()
       arrayWriter.initialize(numElements)
@@ -332,23 +321,4 @@ object InterpretedUnsafeProjection extends 
UnsafeProjectionCreator {
         i += 1
       }
   }
-
-  /**
-   * Write an opaque block of data to the buffer. This is used to copy
-   * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects.
-   */
-  private def writeUnsafeData(
-      writer: UnsafeWriter,
-      baseObject: AnyRef,
-      baseOffset: Long,
-      sizeInBytes: Int) : Unit = {
-    writer.grow(sizeInBytes)
-    Platform.copyMemory(
-      baseObject,
-      baseOffset,
-      writer.getBuffer,
-      writer.cursor,
-      sizeInBytes)
-    writer.increaseCursor(sizeInBytes)
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3323b156/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 4a4d763..2fb441a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -32,14 +32,13 @@ import org.apache.spark.sql.types._
 object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], 
UnsafeProjection] {
 
   /** Returns true iff we support this data type. */
-  def canSupport(dataType: DataType): Boolean = dataType match {
+  def canSupport(dataType: DataType): Boolean = 
UserDefinedType.sqlType(dataType) match {
     case NullType => true
-    case t: AtomicType => true
+    case _: AtomicType => true
     case _: CalendarIntervalType => true
     case t: StructType => t.forall(field => canSupport(field.dataType))
     case t: ArrayType if canSupport(t.elementType) => true
     case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
-    case udt: UserDefinedType[_] => canSupport(udt.sqlType)
     case _ => false
   }
 
@@ -47,6 +46,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
   private def writeStructToBuffer(
       ctx: CodegenContext,
       input: String,
+      index: String,
       fieldTypes: Seq[DataType],
       rowWriter: String): String = {
     // Puts `input` in a local variable to avoid to re-evaluate it if it's a 
statement.
@@ -60,15 +60,19 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     val rowWriterClass = classOf[UnsafeRowWriter].getName
     val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
       v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});")
-
+    val previousCursor = ctx.freshName("previousCursor")
     s"""
-      final InternalRow $tmpInput = $input;
-      if ($tmpInput instanceof UnsafeRow) {
-        ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)}
-      } else {
-        ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, 
structRowWriter)}
-      }
-    """
+       |final InternalRow $tmpInput = $input;
+       |if ($tmpInput instanceof UnsafeRow) {
+       |  $rowWriter.write($index, (UnsafeRow) $tmpInput);
+       |} else {
+       |  // Remember the current cursor so that we can calculate how many 
bytes are
+       |  // written later.
+       |  final int $previousCursor = $rowWriter.cursor();
+       |  ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, 
structRowWriter)}
+       |  $rowWriter.setOffsetAndSizeFromPreviousCursor($index, 
$previousCursor);
+       |}
+     """.stripMargin
   }
 
   private def writeExpressionsToBuffer(
@@ -95,10 +99,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
 
     val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
       case ((input, dataType), index) =>
-        val dt = dataType match {
-          case udt: UserDefinedType[_] => udt.sqlType
-          case other => other
-        }
+        val dt = UserDefinedType.sqlType(dataType)
 
         val setNull = dt match {
           case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
@@ -106,58 +107,22 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
             s"$rowWriter.write($index, (Decimal) null, ${t.precision}, 
${t.scale});"
           case _ => s"$rowWriter.setNullAt($index);"
         }
-        val previousCursor = ctx.freshName("previousCursor")
-
-        val writeField = dt match {
-          case t: StructType =>
-            s"""
-              // Remember the current cursor so that we can calculate how many 
bytes are
-              // written later.
-              final int $previousCursor = $rowWriter.cursor();
-              ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), 
rowWriter)}
-              $rowWriter.setOffsetAndSizeFromPreviousCursor($index, 
$previousCursor);
-            """
-
-          case a @ ArrayType(et, _) =>
-            s"""
-              // Remember the current cursor so that we can calculate how many 
bytes are
-              // written later.
-              final int $previousCursor = $rowWriter.cursor();
-              ${writeArrayToBuffer(ctx, input.value, et, rowWriter)}
-              $rowWriter.setOffsetAndSizeFromPreviousCursor($index, 
$previousCursor);
-            """
-
-          case m @ MapType(kt, vt, _) =>
-            s"""
-              // Remember the current cursor so that we can calculate how many 
bytes are
-              // written later.
-              final int $previousCursor = $rowWriter.cursor();
-              ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)}
-              $rowWriter.setOffsetAndSizeFromPreviousCursor($index, 
$previousCursor);
-            """
-
-          case t: DecimalType =>
-            s"$rowWriter.write($index, ${input.value}, ${t.precision}, 
${t.scale});"
-
-          case NullType => ""
-
-          case _ => s"$rowWriter.write($index, ${input.value});"
-        }
 
+        val writeField = writeElement(ctx, input.value, index.toString, dt, 
rowWriter)
         if (input.isNull == "false") {
           s"""
-            ${input.code}
-            ${writeField.trim}
-          """
+             |${input.code}
+             |${writeField.trim}
+           """.stripMargin
         } else {
           s"""
-            ${input.code}
-            if (${input.isNull}) {
-              ${setNull.trim}
-            } else {
-              ${writeField.trim}
-            }
-          """
+             |${input.code}
+             |if (${input.isNull}) {
+             |  ${setNull.trim}
+             |} else {
+             |  ${writeField.trim}
+             |}
+           """.stripMargin
         }
     }
 
@@ -171,11 +136,10 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
         funcName = "writeFields",
         arguments = Seq("InternalRow" -> row))
     }
-
     s"""
-      $resetWriter
-      $writeFieldsCode
-    """.trim
+       |$resetWriter
+       |$writeFieldsCode
+     """.stripMargin
   }
 
   // TODO: if the nullability of array element is correct, we can use it to 
save null check.
@@ -189,10 +153,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     val numElements = ctx.freshName("numElements")
     val index = ctx.freshName("index")
 
-    val et = elementType match {
-      case udt: UserDefinedType[_] => udt.sqlType
-      case other => other
-    }
+    val et = UserDefinedType.sqlType(elementType)
 
     val jt = CodeGenerator.javaType(et)
 
@@ -205,106 +166,100 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     val arrayWriterClass = classOf[UnsafeArrayWriter].getName
     val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter",
       v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);")
-    val previousCursor = ctx.freshName("previousCursor")
 
     val element = CodeGenerator.getValue(tmpInput, et, index)
-    val writeElement = et match {
-      case t: StructType =>
-        s"""
-          final int $previousCursor = $arrayWriter.cursor();
-          ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)}
-          $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, 
$previousCursor);
-        """
-
-      case a @ ArrayType(et, _) =>
-        s"""
-          final int $previousCursor = $arrayWriter.cursor();
-          ${writeArrayToBuffer(ctx, element, et, arrayWriter)}
-          $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, 
$previousCursor);
-        """
-
-      case m @ MapType(kt, vt, _) =>
-        s"""
-          final int $previousCursor = $arrayWriter.cursor();
-          ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)}
-          $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, 
$previousCursor);
-        """
-
-      case t: DecimalType =>
-        s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});"
-
-      case NullType => ""
-
-      case _ => s"$arrayWriter.write($index, $element);"
-    }
 
-    val primitiveTypeName =
-      if (CodeGenerator.isPrimitiveType(jt)) 
CodeGenerator.primitiveTypeName(et) else ""
     s"""
-      final ArrayData $tmpInput = $input;
-      if ($tmpInput instanceof UnsafeArrayData) {
-        ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)}
-      } else {
-        final int $numElements = $tmpInput.numElements();
-        $arrayWriter.initialize($numElements);
-
-        for (int $index = 0; $index < $numElements; $index++) {
-          if ($tmpInput.isNullAt($index)) {
-            $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
-          } else {
-            $writeElement
-          }
-        }
-      }
-    """
+       |final ArrayData $tmpInput = $input;
+       |if ($tmpInput instanceof UnsafeArrayData) {
+       |  $rowWriter.write((UnsafeArrayData) $tmpInput);
+       |} else {
+       |  final int $numElements = $tmpInput.numElements();
+       |  $arrayWriter.initialize($numElements);
+       |
+       |  for (int $index = 0; $index < $numElements; $index++) {
+       |    if ($tmpInput.isNullAt($index)) {
+       |      $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
+       |    } else {
+       |      ${writeElement(ctx, element, index, et, arrayWriter)}
+       |    }
+       |  }
+       |}
+     """.stripMargin
   }
 
   // TODO: if the nullability of value element is correct, we can use it to 
save null check.
   private def writeMapToBuffer(
       ctx: CodegenContext,
       input: String,
+      index: String,
       keyType: DataType,
       valueType: DataType,
       rowWriter: String): String = {
     // Puts `input` in a local variable to avoid to re-evaluate it if it's a 
statement.
     val tmpInput = ctx.freshName("tmpInput")
     val tmpCursor = ctx.freshName("tmpCursor")
+    val previousCursor = ctx.freshName("previousCursor")
 
     // Writes out unsafe map according to the format described in 
`UnsafeMapData`.
     s"""
-      final MapData $tmpInput = $input;
-      if ($tmpInput instanceof UnsafeMapData) {
-        ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)}
-      } else {
-        // preserve 8 bytes to write the key array numBytes later.
-        $rowWriter.grow(8);
-        $rowWriter.increaseCursor(8);
+       |final MapData $tmpInput = $input;
+       |if ($tmpInput instanceof UnsafeMapData) {
+       |  $rowWriter.write($index, (UnsafeMapData) $tmpInput);
+       |} else {
+       |  // Remember the current cursor so that we can calculate how many 
bytes are
+       |  // written later.
+       |  final int $previousCursor = $rowWriter.cursor();
+       |
+       |  // preserve 8 bytes to write the key array numBytes later.
+       |  $rowWriter.grow(8);
+       |  $rowWriter.increaseCursor(8);
+       |
+       |  // Remember the current cursor so that we can write numBytes of key 
array later.
+       |  final int $tmpCursor = $rowWriter.cursor();
+       |
+       |  ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, 
rowWriter)}
+       |
+       |  // Write the numBytes of key array into the first 8 bytes.
+       |  Platform.putLong(
+       |    $rowWriter.getBuffer(),
+       |    $tmpCursor - 8,
+       |    $rowWriter.cursor() - $tmpCursor);
+       |
+       |  ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, 
rowWriter)}
+       |  $rowWriter.setOffsetAndSizeFromPreviousCursor($index, 
$previousCursor);
+       |}
+     """.stripMargin
+  }
 
-        // Remember the current cursor so that we can write numBytes of key 
array later.
-        final int $tmpCursor = $rowWriter.cursor();
+  private def writeElement(
+      ctx: CodegenContext,
+      input: String,
+      index: String,
+      dt: DataType,
+      writer: String): String = dt match {
+    case t: StructType =>
+      writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer)
+
+    case ArrayType(et, _) =>
+      val previousCursor = ctx.freshName("previousCursor")
+      s"""
+         |// Remember the current cursor so that we can calculate how many 
bytes are
+         |// written later.
+         |final int $previousCursor = $writer.cursor();
+         |${writeArrayToBuffer(ctx, input, et, writer)}
+         |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
+       """.stripMargin
 
-        ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)}
-        // Write the numBytes of key array into the first 8 bytes.
-        Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, 
$rowWriter.cursor() - $tmpCursor);
+    case MapType(kt, vt, _) =>
+      writeMapToBuffer(ctx, input, index, kt, vt, writer)
 
-        ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, 
rowWriter)}
-      }
-    """
-  }
+    case DecimalType.Fixed(precision, scale) =>
+      s"$writer.write($index, $input, $precision, $scale);"
 
-  /**
-   * If the input is already in unsafe format, we don't need to go through all 
elements/fields,
-   * we can directly write it.
-   */
-  private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: 
String) = {
-    val sizeInBytes = ctx.freshName("sizeInBytes")
-    s"""
-      final int $sizeInBytes = $input.getSizeInBytes();
-      // grow the global buffer before writing data.
-      $rowWriter.grow($sizeInBytes);
-      $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor());
-      $rowWriter.increaseCursor($sizeInBytes);
-    """
+    case NullType => ""
+
+    case _ => s"$writer.write($index, $input);"
   }
 
   def createCode(
@@ -332,10 +287,10 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
 
     val code =
       s"""
-        $rowWriter.reset();
-        $evalSubexpr
-        $writeExpressions
-      """
+         |$rowWriter.reset();
+         |$evalSubexpr
+         |$writeExpressions
+       """.stripMargin
     // `rowWriter` is declared as a class field, so we can access it directly 
in methods.
     ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", 
"UnsafeRow",
       canDirectAccess = true))
@@ -363,38 +318,39 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     val ctx = newCodeGenContext()
     val eval = createCode(ctx, expressions, subexpressionEliminationEnabled)
 
-    val codeBody = s"""
-      public java.lang.Object generate(Object[] references) {
-        return new SpecificUnsafeProjection(references);
-      }
-
-      class SpecificUnsafeProjection extends 
${classOf[UnsafeProjection].getName} {
-
-        private Object[] references;
-        ${ctx.declareMutableStates()}
-
-        public SpecificUnsafeProjection(Object[] references) {
-          this.references = references;
-          ${ctx.initMutableStates()}
-        }
-
-        public void initialize(int partitionIndex) {
-          ${ctx.initPartition()}
-        }
-
-        // Scala.Function1 need this
-        public java.lang.Object apply(java.lang.Object row) {
-          return apply((InternalRow) row);
-        }
-
-        public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) {
-          ${eval.code.trim}
-          return ${eval.value};
-        }
-
-        ${ctx.declareAddedFunctions()}
-      }
-      """
+    val codeBody =
+      s"""
+         |public java.lang.Object generate(Object[] references) {
+         |  return new SpecificUnsafeProjection(references);
+         |}
+         |
+         |class SpecificUnsafeProjection extends 
${classOf[UnsafeProjection].getName} {
+         |
+         |  private Object[] references;
+         |  ${ctx.declareMutableStates()}
+         |
+         |  public SpecificUnsafeProjection(Object[] references) {
+         |    this.references = references;
+         |    ${ctx.initMutableStates()}
+         |  }
+         |
+         |  public void initialize(int partitionIndex) {
+         |    ${ctx.initPartition()}
+         |  }
+         |
+         |  // Scala.Function1 need this
+         |  public java.lang.Object apply(java.lang.Object row) {
+         |    return apply((InternalRow) row);
+         |  }
+         |
+         |  public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) {
+         |    ${eval.code.trim}
+         |    return ${eval.value};
+         |  }
+         |
+         |  ${ctx.declareAddedFunctions()}
+         |}
+       """.stripMargin
 
     val code = CodeFormatter.stripOverlappingComments(
       new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))

http://git-wip-us.apache.org/repos/asf/spark/blob/3323b156/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 5a944e7..6af16e2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -97,6 +97,16 @@ abstract class UserDefinedType[UserType >: Null] extends 
DataType with Serializa
   override def catalogString: String = sqlType.simpleString
 }
 
+private[spark] object UserDefinedType {
+  /**
+   * Get the sqlType of a (potential) [[UserDefinedType]].
+   */
+  def sqlType(dt: DataType): DataType = dt match {
+    case udt: UserDefinedType[_] => udt.sqlType
+    case _ => dt
+  }
+}
+
 /**
  * The user defined type in Python.
  *


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to