Repository: spark Updated Branches: refs/heads/master 95034af69 -> 3323b156f
[SPARK-23864][SQL] Add unsafe object writing to UnsafeWriter ## What changes were proposed in this pull request? This PR moves writing of `UnsafeRow`, `UnsafeArrayData` & `UnsafeMapData` out of the `GenerateUnsafeProjection`/`InterpretedUnsafeProjection` classes into the `UnsafeWriter` interface. This cleans up the code a little bit, and it should also result in less byte code for the code generated path. ## How was this patch tested? Existing tests Author: Herman van Hovell <hvanhov...@databricks.com> Closes #20986 from hvanhovell/SPARK-23864. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3323b156 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3323b156 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3323b156 Branch: refs/heads/master Commit: 3323b156f9c0beb0b3c2b724a6faddc6ffdfe99a Parents: 95034af Author: Herman van Hovell <hvanhov...@databricks.com> Authored: Tue Apr 10 17:32:00 2018 +0200 Committer: Herman van Hovell <hvanhov...@databricks.com> Committed: Tue Apr 10 17:32:00 2018 +0200 ---------------------------------------------------------------------- .../expressions/codegen/UnsafeWriter.java | 72 +++-- .../InterpretedUnsafeProjection.scala | 46 +-- .../codegen/GenerateUnsafeProjection.scala | 322 ++++++++----------- .../spark/sql/types/UserDefinedType.scala | 10 + 4 files changed, 204 insertions(+), 246 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3323b156/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index de0eb6d..2781655 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -16,6 +16,9 @@ */ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeMapData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -103,21 +106,7 @@ public abstract class UnsafeWriter { public abstract void write(int ordinal, Decimal input, int precision, int scale); public final void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(getBuffer(), cursor()); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - increaseCursor(roundedSize); + writeUnalignedBytes(ordinal, input.getBaseObject(), input.getBaseOffset(), input.numBytes()); } public final void write(int ordinal, byte[] input) { @@ -125,20 +114,19 @@ public abstract class UnsafeWriter { } public final void write(int ordinal, byte[] input, int offset, int numBytes) { - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + writeUnalignedBytes(ordinal, input, Platform.BYTE_ARRAY_OFFSET + offset, numBytes); + } - // grow the global buffer before writing data. + private void writeUnalignedBytes( + int ordinal, + Object baseObject, + long baseOffset, + int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); grow(roundedSize); - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes); - + Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. increaseCursor(roundedSize); } @@ -156,6 +144,40 @@ public abstract class UnsafeWriter { increaseCursor(16); } + public final void write(int ordinal, UnsafeRow row) { + writeAlignedBytes(ordinal, row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes()); + } + + public final void write(int ordinal, UnsafeMapData map) { + writeAlignedBytes(ordinal, map.getBaseObject(), map.getBaseOffset(), map.getSizeInBytes()); + } + + public final void write(UnsafeArrayData array) { + // Unsafe arrays both can be written as a regular array field or as part of a map. This makes + // updating the offset and size dependent on the code path, this is why we currently do not + // provide an method for writing unsafe arrays that also updates the size and offset. + int numBytes = array.getSizeInBytes(); + grow(numBytes); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + getBuffer(), + cursor(), + numBytes); + increaseCursor(numBytes); + } + + private void writeAlignedBytes( + int ordinal, + Object baseObject, + long baseOffset, + int numBytes) { + grow(numBytes); + Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); + increaseCursor(numBytes); + } + protected final void writeBoolean(long offset, boolean value) { Platform.putBoolean(getBuffer(), offset, value); } http://git-wip-us.apache.org/repos/asf/spark/blob/3323b156/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index b31466f..6d69d69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -173,21 +173,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { val rowWriter = new UnsafeRowWriter(writer, numFields) val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val previousCursor = writer.cursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => - writeUnsafeData( - rowWriter, - row.getBaseObject, - row.getBaseOffset, - row.getSizeInBytes) + writer.write(i, row) case row => + val previousCursor = writer.cursor() // Nested struct. We don't know where this will start because a row can be // variable length, so we need to update the offsets and zero out the bit mask. rowWriter.resetRowWriter() structWriter.apply(row) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case ArrayType(elementType, containsNull) => @@ -214,15 +210,12 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { valueType, valueContainsNull) (v, i) => { - val previousCursor = writer.cursor() v.getMap(i) match { case map: UnsafeMapData => - writeUnsafeData( - valueArrayWriter, - map.getBaseObject, - map.getBaseOffset, - map.getSizeInBytes) + writer.write(i, map) case map => + val previousCursor = writer.cursor() + // preserve 8 bytes to write the key array numBytes later. valueArrayWriter.grow(8) valueArrayWriter.increaseCursor(8) @@ -237,8 +230,8 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { // Write the values. writeArray(valueArrayWriter, valueWriter, map.valueArray()) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case udt: UserDefinedType[_] => @@ -318,11 +311,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { elementWriter: (SpecializedGetters, Int) => Unit, array: ArrayData): Unit = array match { case unsafe: UnsafeArrayData => - writeUnsafeData( - arrayWriter, - unsafe.getBaseObject, - unsafe.getBaseOffset, - unsafe.getSizeInBytes) + arrayWriter.write(unsafe) case _ => val numElements = array.numElements() arrayWriter.initialize(numElements) @@ -332,23 +321,4 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { i += 1 } } - - /** - * Write an opaque block of data to the buffer. This is used to copy - * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. - */ - private def writeUnsafeData( - writer: UnsafeWriter, - baseObject: AnyRef, - baseOffset: Long, - sizeInBytes: Int) : Unit = { - writer.grow(sizeInBytes) - Platform.copyMemory( - baseObject, - baseOffset, - writer.getBuffer, - writer.cursor, - sizeInBytes) - writer.increaseCursor(sizeInBytes) - } } http://git-wip-us.apache.org/repos/asf/spark/blob/3323b156/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4a4d763..2fb441a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -32,14 +32,13 @@ import org.apache.spark.sql.types._ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { /** Returns true iff we support this data type. */ - def canSupport(dataType: DataType): Boolean = dataType match { + def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match { case NullType => true - case t: AtomicType => true + case _: AtomicType => true case _: CalendarIntervalType => true case t: StructType => t.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true - case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -47,6 +46,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private def writeStructToBuffer( ctx: CodegenContext, input: String, + index: String, fieldTypes: Seq[DataType], rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. @@ -60,15 +60,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriterClass = classOf[UnsafeRowWriter].getName val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") - + val previousCursor = ctx.freshName("previousCursor") s""" - final InternalRow $tmpInput = $input; - if ($tmpInput instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)} - } else { - ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} - } - """ + |final InternalRow $tmpInput = $input; + |if ($tmpInput instanceof UnsafeRow) { + | $rowWriter.write($index, (UnsafeRow) $tmpInput); + |} else { + | // Remember the current cursor so that we can calculate how many bytes are + | // written later. + | final int $previousCursor = $rowWriter.cursor(); + | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} + | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + |} + """.stripMargin } private def writeExpressionsToBuffer( @@ -95,10 +99,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val writeFields = inputs.zip(inputTypes).zipWithIndex.map { case ((input, dataType), index) => - val dt = dataType match { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - } + val dt = UserDefinedType.sqlType(dataType) val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => @@ -106,58 +107,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } - val previousCursor = ctx.freshName("previousCursor") - - val writeField = dt match { - case t: StructType => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case a @ ArrayType(et, _) => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case m @ MapType(kt, vt, _) => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case t: DecimalType => - s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" - - case NullType => "" - - case _ => s"$rowWriter.write($index, ${input.value});" - } + val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) if (input.isNull == "false") { s""" - ${input.code} - ${writeField.trim} - """ + |${input.code} + |${writeField.trim} + """.stripMargin } else { s""" - ${input.code} - if (${input.isNull}) { - ${setNull.trim} - } else { - ${writeField.trim} - } - """ + |${input.code} + |if (${input.isNull}) { + | ${setNull.trim} + |} else { + | ${writeField.trim} + |} + """.stripMargin } } @@ -171,11 +136,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro funcName = "writeFields", arguments = Seq("InternalRow" -> row)) } - s""" - $resetWriter - $writeFieldsCode - """.trim + |$resetWriter + |$writeFieldsCode + """.stripMargin } // TODO: if the nullability of array element is correct, we can use it to save null check. @@ -189,10 +153,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") - val et = elementType match { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - } + val et = UserDefinedType.sqlType(elementType) val jt = CodeGenerator.javaType(et) @@ -205,106 +166,100 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") - val previousCursor = ctx.freshName("previousCursor") val element = CodeGenerator.getValue(tmpInput, et, index) - val writeElement = et match { - case t: StructType => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case a @ ArrayType(et, _) => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeArrayToBuffer(ctx, element, et, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case m @ MapType(kt, vt, _) => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case t: DecimalType => - s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});" - - case NullType => "" - - case _ => s"$arrayWriter.write($index, $element);" - } - val primitiveTypeName = - if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else "" s""" - final ArrayData $tmpInput = $input; - if ($tmpInput instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)} - } else { - final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($numElements); - - for (int $index = 0; $index < $numElements; $index++) { - if ($tmpInput.isNullAt($index)) { - $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); - } else { - $writeElement - } - } - } - """ + |final ArrayData $tmpInput = $input; + |if ($tmpInput instanceof UnsafeArrayData) { + | $rowWriter.write((UnsafeArrayData) $tmpInput); + |} else { + | final int $numElements = $tmpInput.numElements(); + | $arrayWriter.initialize($numElements); + | + | for (int $index = 0; $index < $numElements; $index++) { + | if ($tmpInput.isNullAt($index)) { + | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); + | } else { + | ${writeElement(ctx, element, index, et, arrayWriter)} + | } + | } + |} + """.stripMargin } // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( ctx: CodegenContext, input: String, + index: String, keyType: DataType, valueType: DataType, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") + val previousCursor = ctx.freshName("previousCursor") // Writes out unsafe map according to the format described in `UnsafeMapData`. s""" - final MapData $tmpInput = $input; - if ($tmpInput instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)} - } else { - // preserve 8 bytes to write the key array numBytes later. - $rowWriter.grow(8); - $rowWriter.increaseCursor(8); + |final MapData $tmpInput = $input; + |if ($tmpInput instanceof UnsafeMapData) { + | $rowWriter.write($index, (UnsafeMapData) $tmpInput); + |} else { + | // Remember the current cursor so that we can calculate how many bytes are + | // written later. + | final int $previousCursor = $rowWriter.cursor(); + | + | // preserve 8 bytes to write the key array numBytes later. + | $rowWriter.grow(8); + | $rowWriter.increaseCursor(8); + | + | // Remember the current cursor so that we can write numBytes of key array later. + | final int $tmpCursor = $rowWriter.cursor(); + | + | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} + | + | // Write the numBytes of key array into the first 8 bytes. + | Platform.putLong( + | $rowWriter.getBuffer(), + | $tmpCursor - 8, + | $rowWriter.cursor() - $tmpCursor); + | + | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} + | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + |} + """.stripMargin + } - // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $rowWriter.cursor(); + private def writeElement( + ctx: CodegenContext, + input: String, + index: String, + dt: DataType, + writer: String): String = dt match { + case t: StructType => + writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer) + + case ArrayType(et, _) => + val previousCursor = ctx.freshName("previousCursor") + s""" + |// Remember the current cursor so that we can calculate how many bytes are + |// written later. + |final int $previousCursor = $writer.cursor(); + |${writeArrayToBuffer(ctx, input, et, writer)} + |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + """.stripMargin - ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} - // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); + case MapType(kt, vt, _) => + writeMapToBuffer(ctx, input, index, kt, vt, writer) - ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} - } - """ - } + case DecimalType.Fixed(precision, scale) => + s"$writer.write($index, $input, $precision, $scale);" - /** - * If the input is already in unsafe format, we don't need to go through all elements/fields, - * we can directly write it. - */ - private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = { - val sizeInBytes = ctx.freshName("sizeInBytes") - s""" - final int $sizeInBytes = $input.getSizeInBytes(); - // grow the global buffer before writing data. - $rowWriter.grow($sizeInBytes); - $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor()); - $rowWriter.increaseCursor($sizeInBytes); - """ + case NullType => "" + + case _ => s"$writer.write($index, $input);" } def createCode( @@ -332,10 +287,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = s""" - $rowWriter.reset(); - $evalSubexpr - $writeExpressions - """ + |$rowWriter.reset(); + |$evalSubexpr + |$writeExpressions + """.stripMargin // `rowWriter` is declared as a class field, so we can access it directly in methods. ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", canDirectAccess = true)) @@ -363,38 +318,39 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val ctx = newCodeGenContext() val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) - val codeBody = s""" - public java.lang.Object generate(Object[] references) { - return new SpecificUnsafeProjection(references); - } - - class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { - - private Object[] references; - ${ctx.declareMutableStates()} - - public SpecificUnsafeProjection(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} - } - - public void initialize(int partitionIndex) { - ${ctx.initPartition()} - } - - // Scala.Function1 need this - public java.lang.Object apply(java.lang.Object row) { - return apply((InternalRow) row); - } - - public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - ${eval.code.trim} - return ${eval.value}; - } - - ${ctx.declareAddedFunctions()} - } - """ + val codeBody = + s""" + |public java.lang.Object generate(Object[] references) { + | return new SpecificUnsafeProjection(references); + |} + | + |class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { + | + | private Object[] references; + | ${ctx.declareMutableStates()} + | + | public SpecificUnsafeProjection(Object[] references) { + | this.references = references; + | ${ctx.initMutableStates()} + | } + | + | public void initialize(int partitionIndex) { + | ${ctx.initPartition()} + | } + | + | // Scala.Function1 need this + | public java.lang.Object apply(java.lang.Object row) { + | return apply((InternalRow) row); + | } + | + | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { + | ${eval.code.trim} + | return ${eval.value}; + | } + | + | ${ctx.declareAddedFunctions()} + |} + """.stripMargin val code = CodeFormatter.stripOverlappingComments( new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) http://git-wip-us.apache.org/repos/asf/spark/blob/3323b156/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 5a944e7..6af16e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -97,6 +97,16 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa override def catalogString: String = sqlType.simpleString } +private[spark] object UserDefinedType { + /** + * Get the sqlType of a (potential) [[UserDefinedType]]. + */ + def sqlType(dt: DataType): DataType = dt match { + case udt: UserDefinedType[_] => udt.sqlType + case _ => dt + } +} + /** * The user defined type in Python. * --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org