Repository: spark Updated Branches: refs/heads/master a400ab516 -> f99cb5615
[SPARK-9330][SQL] Create specialized getStruct getter in InternalRow. Also took the chance to rearrange some of the methods in UnsafeRow to group static/private/public things together. Author: Reynold Xin <r...@databricks.com> Closes #7654 from rxin/getStruct and squashes the following commits: b491a09 [Reynold Xin] Fixed typo. 48d77e5 [Reynold Xin] [SPARK-9330][SQL] Create specialized getStruct getter in InternalRow. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f99cb561 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f99cb561 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f99cb561 Branch: refs/heads/master Commit: f99cb5615cbc0b469d52af6bd08f8bf888af58f3 Parents: a400ab5 Author: Reynold Xin <r...@databricks.com> Authored: Fri Jul 24 19:29:01 2015 -0700 Committer: Reynold Xin <r...@databricks.com> Committed: Fri Jul 24 19:29:01 2015 -0700 ---------------------------------------------------------------------- .../sql/catalyst/expressions/UnsafeRow.java | 87 +++++++++++++------- .../sql/catalyst/CatalystTypeConverters.scala | 2 +- .../apache/spark/sql/catalyst/InternalRow.scala | 22 +++-- .../catalyst/expressions/BoundAttribute.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 5 +- 5 files changed, 77 insertions(+), 40 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f99cb561/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index a898660..225f6e6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -51,28 +51,9 @@ import static org.apache.spark.sql.types.DataTypes.*; */ public final class UnsafeRow extends MutableRow { - private Object baseObject; - private long baseOffset; - - public Object getBaseObject() { return baseObject; } - public long getBaseOffset() { return baseOffset; } - public int getSizeInBytes() { return sizeInBytes; } - - /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ - private int numFields; - - /** The size of this row's backing data, in bytes) */ - private int sizeInBytes; - - @Override - public int numFields() { return numFields; } - - /** The width of the null tracking bit set, in bytes */ - private int bitSetWidthInBytes; - - private long getFieldOffset(int ordinal) { - return baseOffset + bitSetWidthInBytes + ordinal * 8L; - } + ////////////////////////////////////////////////////////////////////////////// + // Static methods + ////////////////////////////////////////////////////////////////////////////// public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; @@ -103,7 +84,7 @@ public final class UnsafeRow extends MutableRow { DoubleType, DateType, TimestampType - }))); + }))); // We support get() on a superset of the types for which we support set(): final Set<DataType> _readableFieldTypes = new HashSet<>( @@ -115,12 +96,48 @@ public final class UnsafeRow extends MutableRow { readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); } + ////////////////////////////////////////////////////////////////////////////// + // Private fields and methods + ////////////////////////////////////////////////////////////////////////////// + + private Object baseObject; + private long baseOffset; + + /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ + private int numFields; + + /** The size of this row's backing data, in bytes) */ + private int sizeInBytes; + + private void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); + } + + /** The width of the null tracking bit set, in bytes */ + private int bitSetWidthInBytes; + + private long getFieldOffset(int ordinal) { + return baseOffset + bitSetWidthInBytes + ordinal * 8L; + } + + ////////////////////////////////////////////////////////////////////////////// + // Public methods + ////////////////////////////////////////////////////////////////////////////// + /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, * since the value returned by this constructor is equivalent to a null pointer. */ public UnsafeRow() { } + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } + + @Override + public int numFields() { return numFields; } + /** * Update this UnsafeRow to point to different backing data. * @@ -130,7 +147,7 @@ public final class UnsafeRow extends MutableRow { * @param sizeInBytes the size of this row's backing data, in bytes */ public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeInBytes) { - assert numFields >= 0 : "numFields should >= 0"; + assert numFields >= 0 : "numFields (" + numFields + ") should >= 0"; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; @@ -153,11 +170,6 @@ public final class UnsafeRow extends MutableRow { PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); } - private void setNotNullAt(int i) { - assertIndexIsValid(i); - BitSetMethods.unset(baseObject, baseOffset, i); - } - @Override public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); @@ -316,6 +328,21 @@ public final class UnsafeRow extends MutableRow { return getUTF8String(i).toString(); } + @Override + public UnsafeRow getStruct(int i, int numFields) { + if (isNullAt(i)) { + return null; + } else { + assertIndexIsValid(i); + final long offsetAndSize = getLong(i); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final UnsafeRow row = new UnsafeRow(); + row.pointTo(baseObject, baseOffset + offset, numFields, size); + return row; + } + } + /** * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal * byte array rather than referencing data stored in a data page. @@ -388,7 +415,7 @@ public final class UnsafeRow extends MutableRow { */ public byte[] getBytes() { if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET - && (((byte[]) baseObject).length == sizeInBytes)) { + && (((byte[]) baseObject).length == sizeInBytes)) { return (byte[]) baseObject; } else { byte[] bytes = new byte[sizeInBytes]; http://git-wip-us.apache.org/repos/asf/spark/blob/f99cb561/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 5c3072a..7416ddb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -271,7 +271,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Row = - toScala(row.get(column).asInstanceOf[InternalRow]) + toScala(row.getStruct(column, structType.size)) } private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] { http://git-wip-us.apache.org/repos/asf/spark/blob/f99cb561/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index efc4fae..f248b1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -52,6 +52,21 @@ abstract class InternalRow extends Serializable { def getDouble(i: Int): Double = getAs[Double](i) + def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) + + def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) + + // This is only use for test + def getString(i: Int): String = getAs[UTF8String](i).toString + + /** + * Returns a struct from ordinal position. + * + * @param ordinal position to get the struct from. + * @param numFields number of fields the struct type has + */ + def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal) + override def toString: String = s"[${this.mkString(",")}]" /** @@ -145,13 +160,6 @@ abstract class InternalRow extends Serializable { */ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) - def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) - - def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) - - // This is only use for test - def getString(i: Int): String = getAs[UTF8String](i).toString - // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 http://git-wip-us.apache.org/repos/asf/spark/blob/f99cb561/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 6aa4930..1f7adcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -48,6 +48,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case DoubleType => input.getDouble(ordinal) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) + case t: StructType => input.getStruct(ordinal, t.size) case _ => input.get(ordinal) } } http://git-wip-us.apache.org/repos/asf/spark/blob/f99cb561/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 48225e1..4a90f1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -109,6 +109,7 @@ class CodeGenContext { case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" case StringType => s"$row.getUTF8String($ordinal)" case BinaryType => s"$row.getBinary($ordinal)" + case t: StructType => s"$row.getStruct($ordinal, ${t.size})" case _ => s"($jt)$row.apply($ordinal)" } } @@ -249,13 +250,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected val mutableRowType: String = classOf[MutableRow].getName protected val genericMutableRowType: String = classOf[GenericMutableRow].getName - protected def declareMutableStates(ctx: CodeGenContext) = { + protected def declareMutableStates(ctx: CodeGenContext): String = { ctx.mutableStates.map { case (javaType, variableName, _) => s"private $javaType $variableName;" }.mkString("\n ") } - protected def initMutableStates(ctx: CodeGenContext) = { + protected def initMutableStates(ctx: CodeGenContext): String = { ctx.mutableStates.map(_._3).mkString("\n ") } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org