[SPARK-9285][SQL] Remove InternalRow's inheritance from Row. I also changed InternalRow's size/length function to numFields, to make it more obvious that it is not about bytes, but the number of fields.
Author: Reynold Xin <r...@databricks.com> Closes #7626 from rxin/internalRow and squashes the following commits: e124daf [Reynold Xin] Fixed test case. 805ceb7 [Reynold Xin] Commented out the failed test suite. f8a9ca5 [Reynold Xin] Fixed more bugs. Still at least one more remaining. 76d9081 [Reynold Xin] Fixed data sources. 7807f70 [Reynold Xin] Fixed DataFrameSuite. cb60cd2 [Reynold Xin] Code review & small bug fixes. 0a2948b [Reynold Xin] Fixed style. 3280d03 [Reynold Xin] [SPARK-9285][SQL] Remove InternalRow's inheritance from Row. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/431ca39b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/431ca39b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/431ca39b Branch: refs/heads/master Commit: 431ca39be51352dfcdacc87de7e64c2af313558d Parents: 3aec9f4 Author: Reynold Xin <r...@databricks.com> Authored: Fri Jul 24 09:37:36 2015 -0700 Committer: Reynold Xin <r...@databricks.com> Committed: Fri Jul 24 09:37:36 2015 -0700 ---------------------------------------------------------------------- .../apache/spark/mllib/linalg/Matrices.scala | 4 +- .../org/apache/spark/mllib/linalg/Vectors.scala | 4 +- .../sql/catalyst/expressions/UnsafeRow.java | 9 +- .../sql/catalyst/CatalystTypeConverters.scala | 14 +- .../apache/spark/sql/catalyst/InternalRow.scala | 153 +++++++++++++---- .../spark/sql/catalyst/expressions/Cast.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 168 +++++++++---------- .../expressions/SpecificMutableRow.scala | 4 +- .../sql/catalyst/expressions/aggregates.scala | 4 +- .../codegen/GenerateProjection.scala | 2 +- .../expressions/complexTypeExtractors.scala | 4 +- .../spark/sql/catalyst/expressions/rows.scala | 57 ++++--- .../scala/org/apache/spark/sql/RowTest.scala | 10 -- .../sql/catalyst/expressions/CastSuite.scala | 24 ++- .../catalyst/expressions/ComplexTypeSuite.scala | 7 +- .../apache/spark/sql/columnar/ColumnType.scala | 2 +- .../columnar/InMemoryColumnarTableScan.scala | 12 +- .../sql/execution/SparkSqlSerializer2.scala | 10 +- .../datasources/DataSourceStrategy.scala | 4 +- .../sql/execution/datasources/commands.scala | 53 +++--- .../spark/sql/execution/datasources/ddl.scala | 16 +- .../apache/spark/sql/execution/pythonUDFs.scala | 4 +- .../spark/sql/expressions/aggregate/udaf.scala | 3 +- .../apache/spark/sql/jdbc/JDBCRelation.scala | 3 +- .../apache/spark/sql/json/JSONRelation.scala | 6 +- .../sql/parquet/CatalystRowConverter.scala | 10 +- .../sql/parquet/ParquetTableOperations.scala | 2 +- .../spark/sql/parquet/ParquetTableSupport.scala | 12 +- .../apache/spark/sql/parquet/newParquet.scala | 6 +- .../apache/spark/sql/sources/interfaces.scala | 22 ++- .../scala/org/apache/spark/sql/RowSuite.scala | 4 +- .../apache/spark/sql/sources/DDLTestSuite.scala | 5 +- .../spark/sql/sources/PrunedScanSuite.scala | 2 +- .../spark/sql/sources/TableScanSuite.scala | 2 +- .../hive/execution/InsertIntoHiveTable.scala | 9 +- .../spark/sql/hive/hiveWriterContainers.scala | 8 +- .../apache/spark/sql/hive/orc/OrcRelation.scala | 8 +- .../CommitFailureTestRelationSuite.scala | 47 ++++++ .../sources/ParquetHadoopFsRelationSuite.scala | 139 +++++++++++++++ .../SimpleTextHadoopFsRelationSuite.scala | 57 +++++++ .../sql/sources/hadoopFsRelationSuites.scala | 166 ------------------ 41 files changed, 647 insertions(+), 433 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 55da0e0..b6e2c30 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -174,8 +174,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { override def deserialize(datum: Any): Matrix = { datum match { case row: InternalRow => - require(row.length == 7, - s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7") + require(row.numFields == 7, + s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7") val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 9067b3b..c884aad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -203,8 +203,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def deserialize(datum: Any): Vector = { datum match { case row: InternalRow => - require(row.length == 4, - s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") + require(row.numFields == 4, + s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4") val tpe = row.getByte(0) tpe match { case 0 => http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index fa1216b..a898660 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -64,7 +64,8 @@ public final class UnsafeRow extends MutableRow { /** The size of this row's backing data, in bytes) */ private int sizeInBytes; - public int length() { return numFields; } + @Override + public int numFields() { return numFields; } /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; @@ -218,12 +219,12 @@ public final class UnsafeRow extends MutableRow { } @Override - public int size() { - return numFields; + public Object get(int i) { + throw new UnsupportedOperationException(); } @Override - public Object get(int i) { + public <T> T getAs(int i) { throw new UnsupportedOperationException(); } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index bfaee04..5c3072a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -140,14 +140,14 @@ object CatalystTypeConverters { private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = scalaValue override def toScala(catalystValue: Any): Any = catalystValue - override def toScalaImpl(row: InternalRow, column: Int): Any = row(column) + override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column) } private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) - override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row(column)) + override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row.get(column)) } /** Converter for arrays, sequences, and Java iterables. */ @@ -184,7 +184,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = - toScala(row(column).asInstanceOf[Seq[Any]]) + toScala(row.get(column).asInstanceOf[Seq[Any]]) } private case class MapConverter( @@ -227,7 +227,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] = - toScala(row(column).asInstanceOf[Map[Any, Any]]) + toScala(row.get(column).asInstanceOf[Map[Any, Any]]) } private case class StructConverter( @@ -260,9 +260,9 @@ object CatalystTypeConverters { if (row == null) { null } else { - val ar = new Array[Any](row.size) + val ar = new Array[Any](row.numFields) var idx = 0 - while (idx < row.size) { + while (idx < row.numFields) { ar(idx) = converters(idx).toScala(row, idx) idx += 1 } @@ -271,7 +271,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Row = - toScala(row(column).asInstanceOf[InternalRow]) + toScala(row.get(column).asInstanceOf[InternalRow]) } private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] { http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index c7ec49b..efc4fae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -25,48 +25,139 @@ import org.apache.spark.unsafe.types.UTF8String * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. */ -abstract class InternalRow extends Row { +abstract class InternalRow extends Serializable { - def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) + def numFields: Int - def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) + def get(i: Int): Any - // This is only use for test - override def getString(i: Int): String = getAs[UTF8String](i).toString - - // These expensive API should not be used internally. - final override def getDecimal(i: Int): java.math.BigDecimal = - throw new UnsupportedOperationException - final override def getDate(i: Int): java.sql.Date = - throw new UnsupportedOperationException - final override def getTimestamp(i: Int): java.sql.Timestamp = - throw new UnsupportedOperationException - final override def getSeq[T](i: Int): Seq[T] = throw new UnsupportedOperationException - final override def getList[T](i: Int): java.util.List[T] = throw new UnsupportedOperationException - final override def getMap[K, V](i: Int): scala.collection.Map[K, V] = - throw new UnsupportedOperationException - final override def getJavaMap[K, V](i: Int): java.util.Map[K, V] = - throw new UnsupportedOperationException - final override def getStruct(i: Int): Row = throw new UnsupportedOperationException - final override def getAs[T](fieldName: String): T = throw new UnsupportedOperationException - final override def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = - throw new UnsupportedOperationException - - // A default implementation to change the return type - override def copy(): InternalRow = this + // TODO: Remove this. + def apply(i: Int): Any = get(i) + + def getAs[T](i: Int): T = get(i).asInstanceOf[T] + + def isNullAt(i: Int): Boolean = get(i) == null + + def getBoolean(i: Int): Boolean = getAs[Boolean](i) + + def getByte(i: Int): Byte = getAs[Byte](i) + + def getShort(i: Int): Short = getAs[Short](i) + + def getInt(i: Int): Int = getAs[Int](i) + + def getLong(i: Int): Long = getAs[Long](i) + + def getFloat(i: Int): Float = getAs[Float](i) + + def getDouble(i: Int): Double = getAs[Double](i) + + override def toString: String = s"[${this.mkString(",")}]" + + /** + * Make a copy of the current [[InternalRow]] object. + */ + def copy(): InternalRow = this + + /** Returns true if there are any NULL values in this row. */ + def anyNull: Boolean = { + val len = numFields + var i = 0 + while (i < len) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[InternalRow]) { + return false + } + + val other = o.asInstanceOf[InternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = get(i) + val o2 = other.get(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + /* ---------------------- utility methods for Scala ---------------------- */ /** - * Returns true if we can check equality for these 2 rows. - * Equality check between external row and internal row is not allowed. - * Here we do this check to prevent call `equals` on internal row with external row. + * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - protected override def canEqual(other: Row) = other.isInstanceOf[InternalRow] + def toSeq: Seq[Any] = { + val n = numFields + val values = new Array[Any](n) + var i = 0 + while (i < n) { + values.update(i, get(i)) + i += 1 + } + values.toSeq + } + + /** Displays all elements of this sequence in a string (without a separator). */ + def mkString: String = toSeq.mkString + + /** Displays all elements of this sequence in a string using a separator string. */ + def mkString(sep: String): String = toSeq.mkString(sep) + + /** + * Displays all elements of this traversable or iterator in a string using + * start, end, and separator strings. + */ + def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + + def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) + + def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) + + // This is only use for test + def getString(i: Int): String = getAs[UTF8String](i).toString // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 var i = 0 - while (i < length) { + val len = numFields + while (i < len) { val update: Int = if (isNullAt(i)) { 0 http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c66854d..47ad3e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -382,8 +382,8 @@ case class Cast(child: Expression, dataType: DataType) val newRow = new GenericMutableRow(from.fields.length) buildCast[InternalRow](_, row => { var i = 0 - while (i < row.length) { - newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row(i))) + while (i < row.numFields) { + newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row.get(i))) i += 1 } newRow.copy() http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 04872fb..dbda05a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -176,49 +176,49 @@ class JoinedRow extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -278,49 +278,49 @@ class JoinedRow2 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -374,50 +374,50 @@ class JoinedRow3 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -471,50 +471,50 @@ class JoinedRow4 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -568,50 +568,50 @@ class JoinedRow5 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -665,50 +665,50 @@ class JoinedRow6 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 6f291d2..4b4833b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -211,7 +211,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR def this() = this(Seq.empty) - override def length: Int = values.length + override def numFields: Int = values.length override def toSeq: Seq[Any] = values.map(_.boxed).toSeq @@ -245,7 +245,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String.fromString(value)) - override def getString(ordinal: Int): String = apply(ordinal).toString + override def getString(ordinal: Int): String = get(ordinal).toString override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 73fde4e..62b6cc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -675,7 +675,7 @@ case class CombineSetsAndSumFunction( val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] val inputIterator = inputSetEval.iterator while (inputIterator.hasNext) { - seen.add(inputIterator.next) + seen.add(inputIterator.next()) } } @@ -685,7 +685,7 @@ case class CombineSetsAndSumFunction( null } else { Cast(Literal( - casted.iterator.map(f => f.apply(0)).reduceLeft( + casted.iterator.map(f => f.get(0)).reduceLeft( base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), base.dataType).eval(null) } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 405d6b0..f0efc4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -178,7 +178,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { $initColumns } - public int length() { return ${expressions.length};} + public int numFields() { return ${expressions.length};} protected boolean[] nullBits = new boolean[${expressions.length}]; public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 5504781..c91122c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -110,7 +110,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[InternalRow](ordinal) + input.asInstanceOf[InternalRow].get(ordinal) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { @@ -142,7 +142,7 @@ case class GetArrayStructFields( protected override def nullSafeEval(input: Any): Any = { input.asInstanceOf[Seq[InternalRow]].map { row => - if (row == null) null else row(ordinal) + if (row == null) null else row.get(ordinal) } } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index d78be5a..53779dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -44,9 +44,10 @@ abstract class MutableRow extends InternalRow { } override def copy(): InternalRow = { - val arr = new Array[Any](length) + val n = numFields + val arr = new Array[Any](n) var i = 0 - while (i < length) { + while (i < n) { arr(i) = get(i) i += 1 } @@ -55,35 +56,22 @@ abstract class MutableRow extends InternalRow { } /** - * A row implementation that uses an array of objects as the underlying storage. - */ -trait ArrayBackedRow { - self: Row => - - protected val values: Array[Any] - - override def toSeq: Seq[Any] = values.toSeq - - def length: Int = values.length - - override def get(i: Int): Any = values(i) - - def setNullAt(i: Int): Unit = { values(i) = null} - - def update(i: Int, value: Any): Unit = { values(i) = value } -} - -/** * A row implementation that uses an array of objects as the underlying storage. Note that, while * the array is not copied, and thus could technically be mutated after creation, this is not * allowed. */ -class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBackedRow { +class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) + override def length: Int = values.length + + override def get(i: Int): Any = values(i) + + override def toSeq: Seq[Any] = values.toSeq + override def copy(): Row = this } @@ -101,34 +89,49 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(protected[sql] val values: Array[Any]) - extends InternalRow with ArrayBackedRow { +class GenericInternalRow(protected[sql] val values: Array[Any]) extends InternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) + override def toSeq: Seq[Any] = values.toSeq + + override def numFields: Int = values.length + + override def get(i: Int): Any = values(i) + override def copy(): InternalRow = this } /** * This is used for serialization of Python DataFrame */ -class GenericInternalRowWithSchema(values: Array[Any], override val schema: StructType) +class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) extends GenericInternalRow(values) { /** No-arg constructor for serialization. */ protected def this() = this(null, null) - override def fieldIndex(name: String): Int = schema.fieldIndex(name) + def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow { +class GenericMutableRow(val values: Array[Any]) extends MutableRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) + override def toSeq: Seq[Any] = values.toSeq + + override def numFields: Int = values.length + + override def get(i: Int): Any = values(i) + + override def setNullAt(i: Int): Unit = { values(i) = null} + + override def update(i: Int, value: Any): Unit = { values(i) = value } + override def copy(): InternalRow = new GenericInternalRow(values.clone()) } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 878a1bb..01ff84c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -83,15 +83,5 @@ class RowTest extends FunSpec with Matchers { it("equality check for internal rows") { internalRow shouldEqual internalRow2 } - - it("throws an exception when check equality between external and internal rows") { - def assertError(f: => Unit): Unit = { - val e = intercept[UnsupportedOperationException](f) - e.getMessage.contains("cannot check equality between external and internal rows") - } - - assertError(internalRow.equals(externalRow)) - assertError(externalRow.equals(internalRow)) - } } } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index facf65c..408353c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Test suite for data type casting expression [[Cast]]. @@ -580,14 +581,21 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from struct") { val struct = Literal.create( - InternalRow("123", "abc", "", null), + InternalRow( + UTF8String.fromString("123"), + UTF8String.fromString("abc"), + UTF8String.fromString(""), + null), StructType(Seq( StructField("a", StringType, nullable = true), StructField("b", StringType, nullable = true), StructField("c", StringType, nullable = true), StructField("d", StringType, nullable = true)))) val struct_notNull = Literal.create( - InternalRow("123", "abc", ""), + InternalRow( + UTF8String.fromString("123"), + UTF8String.fromString("abc"), + UTF8String.fromString("")), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -676,8 +684,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( InternalRow( - Seq("123", "abc", ""), - Map("a" -> "123", "b" -> "abc", "c" -> ""), + Seq(UTF8String.fromString("123"), UTF8String.fromString("abc"), UTF8String.fromString("")), + Map( + UTF8String.fromString("a") -> UTF8String.fromString("123"), + UTF8String.fromString("b") -> UTF8String.fromString("abc"), + UTF8String.fromString("c") -> UTF8String.fromString("")), InternalRow(0)), StructType(Seq( StructField("a", @@ -700,7 +711,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === true) checkEvaluation(ret, InternalRow( Seq(123, null, null), - Map("a" -> true, "b" -> true, "c" -> false), + Map( + UTF8String.fromString("a") -> true, + UTF8String.fromString("b") -> true, + UTF8String.fromString("c") -> false), InternalRow(0L))) } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index a8aee8f..fc84277 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -150,12 +151,14 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("CreateNamedStruct with literal field") { val row = InternalRow(1, 2, 3) val c1 = 'a.int.at(0) - checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), + InternalRow(1, UTF8String.fromString("y")), row) } test("CreateNamedStruct from all literal fields") { checkEvaluation( - CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) + CreateNamedStruct(Seq("a", "x", "b", 2.0)), + InternalRow(UTF8String.fromString("x"), 2.0), InternalRow.empty) } test("test dsl for complex type") { http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 9d8415f..ac42bde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -309,7 +309,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { override def actualSize(row: InternalRow, ordinal: Int): Int = { - row.getString(ordinal).getBytes("utf-8").length + 4 + row.getUTF8String(ordinal).numBytes() + 4 } override def append(v: UTF8String, buffer: ByteBuffer): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 3872096..5d5b069 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -134,13 +134,13 @@ private[sql] case class InMemoryRelation( // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat // hard to decipher. assert( - row.size == columnBuilders.size, - s"""Row column number mismatch, expected ${output.size} columns, but got ${row.size}. - |Row content: $row - """.stripMargin) + row.numFields == columnBuilders.size, + s"Row column number mismatch, expected ${output.size} columns, " + + s"but got ${row.numFields}." + + s"\nRow content: $row") var i = 0 - while (i < row.length) { + while (i < row.numFields) { columnBuilders(i).appendFrom(row, i) i += 1 } @@ -304,7 +304,7 @@ private[sql] case class InMemoryColumnarTableScan( // Extract rows via column accessors new Iterator[InternalRow] { - private[this] val rowLen = nextRow.length + private[this] val rowLen = nextRow.numFields override def next(): InternalRow = { var i = 0 while (i < rowLen) { http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index c87e206..83c4e87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -25,7 +25,6 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.serializer._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.types._ @@ -53,7 +52,7 @@ private[sql] class Serializer2SerializationStream( private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut) override def writeObject[T: ClassTag](t: T): SerializationStream = { - val kv = t.asInstanceOf[Product2[Row, Row]] + val kv = t.asInstanceOf[Product2[InternalRow, InternalRow]] writeKey(kv._1) writeValue(kv._2) @@ -66,7 +65,7 @@ private[sql] class Serializer2SerializationStream( } override def writeValue[T: ClassTag](t: T): SerializationStream = { - writeRowFunc(t.asInstanceOf[Row]) + writeRowFunc(t.asInstanceOf[InternalRow]) this } @@ -205,8 +204,9 @@ private[sql] object SparkSqlSerializer2 { /** * The util function to create the serialization function based on the given schema. */ - def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = { - (row: Row) => + def createSerializationFunction(schema: Array[DataType], out: DataOutputStream) + : InternalRow => Unit = { + (row: InternalRow) => // If the schema is null, the returned function does nothing when it get called. if (schema != null) { var i = 0 http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2b40092..7f452da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -206,7 +206,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val mutableRow = new SpecificMutableRow(dataTypes) iterator.map { dataRow => var i = 0 - while (i < mutableRow.length) { + while (i < mutableRow.numFields) { mergers(i)(mutableRow, dataRow, i) i += 1 } @@ -315,7 +315,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { if (relation.relation.needConversion) { execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { - rdd.map(_.asInstanceOf[InternalRow]) + rdd.asInstanceOf[RDD[InternalRow]] } } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala index cd2aa7f..d551f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala @@ -174,14 +174,19 @@ private[sql] case class InsertIntoHadoopFsRelation( try { writerContainer.executorSideSetup(taskContext) - val converter: InternalRow => Row = if (needsConversion) { - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + if (needsConversion) { + val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) + .asInstanceOf[InternalRow => Row] + while (iterator.hasNext) { + val internalRow = iterator.next() + writerContainer.outputWriterForRow(internalRow).write(converter(internalRow)) + } } else { - r: InternalRow => r.asInstanceOf[Row] - } - while (iterator.hasNext) { - val internalRow = iterator.next() - writerContainer.outputWriterForRow(internalRow).write(converter(internalRow)) + while (iterator.hasNext) { + val internalRow = iterator.next() + writerContainer.outputWriterForRow(internalRow) + .asInstanceOf[OutputWriterInternal].writeInternal(internalRow) + } } writerContainer.commitTask() @@ -248,17 +253,23 @@ private[sql] case class InsertIntoHadoopFsRelation( val partitionProj = newProjection(codegenEnabled, partitionCasts, output) val dataProj = newProjection(codegenEnabled, dataOutput, output) - val dataConverter: InternalRow => Row = if (needsConversion) { - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + if (needsConversion) { + val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) + .asInstanceOf[InternalRow => Row] + while (iterator.hasNext) { + val internalRow = iterator.next() + val partitionPart = partitionProj(internalRow) + val dataPart = converter(dataProj(internalRow)) + writerContainer.outputWriterForRow(partitionPart).write(dataPart) + } } else { - r: InternalRow => r.asInstanceOf[Row] - } - - while (iterator.hasNext) { - val internalRow = iterator.next() - val partitionPart = partitionProj(internalRow) - val dataPart = dataConverter(dataProj(internalRow)) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) + while (iterator.hasNext) { + val internalRow = iterator.next() + val partitionPart = partitionProj(internalRow) + val dataPart = dataProj(internalRow) + writerContainer.outputWriterForRow(partitionPart) + .asInstanceOf[OutputWriterInternal].writeInternal(dataPart) + } } writerContainer.commitTask() @@ -530,8 +541,12 @@ private[sql] class DynamicPartitionWriterContainer( while (i < partitionColumns.length) { val col = partitionColumns(i) val partitionValueString = { - val string = row.getString(i) - if (string.eq(null)) defaultPartitionName else PartitioningUtils.escapePathName(string) + val string = row.getUTF8String(i) + if (string.eq(null)) { + defaultPartitionName + } else { + PartitioningUtils.escapePathName(string.toString) + } } if (i > 0) { http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index c8033d3..1f2797e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -23,11 +23,11 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} +import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, InternalRow} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -415,12 +415,12 @@ private[sql] case class CreateTempTableUsing( provider: String, options: Map[String, String]) extends RunnableCommand { - def run(sqlContext: SQLContext): Seq[InternalRow] = { + def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) - Seq.empty + Seq.empty[Row] } } @@ -432,20 +432,20 @@ private[sql] case class CreateTempTableUsingAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) - Seq.empty + Seq.empty[Row] } } private[sql] case class RefreshTable(databaseName: String, tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Refresh the given table's metadata first. sqlContext.catalog.refreshTable(databaseName, tableName) @@ -464,7 +464,7 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String) sqlContext.cacheManager.cacheQuery(df, Some(tableName)) } - Seq.empty[InternalRow] + Seq.empty[Row] } } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index e6e27a8..40bf03a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -126,9 +126,9 @@ object EvaluatePython { case (null, _) => null case (row: InternalRow, struct: StructType) => - val values = new Array[Any](row.size) + val values = new Array[Any](row.numFields) var i = 0 - while (i < row.size) { + while (i < row.numFields) { values(i) = toJava(row(i), struct.fields(i).dataType) i += 1 } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala index 6c49a90..46f0fac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala @@ -148,7 +148,7 @@ class InputAggregationBuffer private[sql] ( toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], bufferOffset: Int, - var underlyingInputBuffer: Row) + var underlyingInputBuffer: InternalRow) extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { override def get(i: Int): Any = { @@ -156,6 +156,7 @@ class InputAggregationBuffer private[sql] ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } + // TODO: Use buffer schema to avoid using generic getter. toScalaConverters(i)(underlyingInputBuffer(offsets(i))) } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 4d3aac4..41d0ecb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -128,6 +128,7 @@ private[sql] case class JDBCRelation( override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val driver: String = DriverRegistry.getDriverClassName(url) + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sqlContext.sparkContext, schema, @@ -137,7 +138,7 @@ private[sql] case class JDBCRelation( table, requiredColumns, filters, - parts).map(_.asInstanceOf[Row]) + parts).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 922794a..562b058 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -154,17 +154,19 @@ private[sql] class JSONRelation( } override def buildScan(): RDD[Row] = { + // Rely on type erasure hack to pass RDD[InternalRow] back as RDD[Row] JacksonParser( baseRDD(), schema, - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) + sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]] } override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JacksonParser( baseRDD(), StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) + sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index 0c3d8fd..b5e4263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -28,7 +28,7 @@ import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveCo import org.apache.parquet.schema.Type.Repetition import org.apache.parquet.schema.{GroupType, PrimitiveType, Type} -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -55,8 +55,8 @@ private[parquet] trait ParentContainerUpdater { private[parquet] object NoopUpdater extends ParentContainerUpdater /** - * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[Row]]s. Since - * any Parquet record is also a struct, this converter can also be used as root converter. + * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[InternalRow]]s. + * Since any Parquet record is also a struct, this converter can also be used as root converter. * * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have * any "parent" container. @@ -108,7 +108,7 @@ private[parquet] class CatalystRowConverter( override def start(): Unit = { var i = 0 - while (i < currentRow.length) { + while (i < currentRow.numFields) { currentRow.setNullAt(i) i += 1 } @@ -178,7 +178,7 @@ private[parquet] class CatalystRowConverter( case t: StructType => new CatalystRowConverter(parquetType.asGroupType(), t, new ParentContainerUpdater { - override def set(value: Any): Unit = updater.set(value.asInstanceOf[Row].copy()) + override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy()) }) case t: UserDefinedType[_] => http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 28cba5e..8cab27d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -178,7 +178,7 @@ private[sql] case class ParquetTableScan( val row = iter.next()._2.asInstanceOf[InternalRow] var i = 0 - while (i < row.size) { + while (i < row.numFields) { mutableRow(i) = row(i) i += 1 } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index d1040bf..c7c58e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -208,9 +208,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo override def write(record: InternalRow): Unit = { val attributesSize = attributes.size - if (attributesSize > record.size) { - throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") + if (attributesSize > record.numFields) { + throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + + s"($attributesSize > ${record.numFields})") } var index = 0 @@ -378,9 +378,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo private[parquet] class MutableRowWriteSupport extends RowWriteSupport { override def write(record: InternalRow): Unit = { val attributesSize = attributes.size - if (attributesSize > record.size) { - throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") + if (attributesSize > record.numFields) { + throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + + s"($attributesSize > ${record.numFields})") } var index = 0 http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index c384697..8ec228c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -61,7 +61,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) - extends OutputWriter { + extends OutputWriterInternal { private val recordWriter: RecordWriter[Void, InternalRow] = { val outputFormat = { @@ -86,7 +86,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext outputFormat.getRecordWriter(context) } - override def write(row: Row): Unit = recordWriter.write(null, row.asInstanceOf[InternalRow]) + override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) } @@ -324,7 +324,7 @@ private[sql] class ParquetRelation2( new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) } } - }.values.map(_.asInstanceOf[Row]) + }.values.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] } } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7cd005b..119bac7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -345,6 +345,18 @@ abstract class OutputWriter { } /** + * This is an internal, private version of [[OutputWriter]] with an writeInternal method that + * accepts an [[InternalRow]] rather than an [[Row]]. Data sources that return this must have + * the conversion flag set to false. + */ +private[sql] abstract class OutputWriterInternal extends OutputWriter { + + override def write(row: Row): Unit = throw new UnsupportedOperationException + + def writeInternal(row: InternalRow): Unit +} + +/** * ::Experimental:: * A [[BaseRelation]] that provides much of the common code required for formats that store their * data to an HDFS compatible filesystem. @@ -592,12 +604,12 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable) }.toSeq - val rdd = buildScan(inputFiles) - val converted = + val rdd: RDD[Row] = buildScan(inputFiles) + val converted: RDD[InternalRow] = if (needConversion) { RDDConversions.rowToRowRdd(rdd, dataSchema.fields.map(_.dataType)) } else { - rdd.map(_.asInstanceOf[InternalRow]) + rdd.asInstanceOf[RDD[InternalRow]] } converted.mapPartitions { rows => val buildProjection = if (codegenEnabled) { @@ -606,8 +618,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) } val mutableProjection = buildProjection() - rows.map(r => mutableProjection(r).asInstanceOf[Row]) - } + rows.map(r => mutableProjection(r)) + }.asInstanceOf[RDD[Row]] } /** http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 7cc6ffd..0e5c5ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -35,14 +35,14 @@ class RowSuite extends SparkFunSuite { expected.update(2, false) expected.update(3, null) val actual1 = Row(2147483647, "this is a string", false, null) - assert(expected.size === actual1.size) + assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) assert(expected(3) === actual1(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) - assert(expected.size === actual2.size) + assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index da53ec1..84855ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -61,9 +61,10 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo override def needConversion: Boolean = false override def buildScan(): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] sqlContext.sparkContext.parallelize(from to to).map { e => - InternalRow(UTF8String.fromString(s"people$e"), e * 2): Row - } + InternalRow(UTF8String.fromString(s"people$e"), e * 2) + }.asInstanceOf[RDD[Row]] } } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 257526f..0d51834 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -131,7 +131,7 @@ class PrunedScanSuite extends DataSourceTest { queryExecution) } - if (rawOutput.size != expectedColumns.size) { + if (rawOutput.numFields != expectedColumns.size) { fail(s"Wrong output row. Got $rawOutput\n$queryExecution") } } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 143aadc..5e189c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -93,7 +93,7 @@ case class AllDataTypesScan( InternalRow(i, UTF8String.fromString(i.toString)), InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) - } + }.asInstanceOf[RDD[Row]] } } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 8202e55..34b6294 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -122,7 +122,7 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { + protected[sql] lazy val sideEffectResult: Seq[Row] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc @@ -252,13 +252,12 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - Seq.empty[InternalRow] + Seq.empty[Row] } - override def executeCollect(): Array[Row] = - sideEffectResult.toArray + override def executeCollect(): Array[Row] = sideEffectResult.toArray protected override def doExecute(): RDD[InternalRow] = { - sqlContext.sparkContext.parallelize(sideEffectResult, 1) + sqlContext.sparkContext.parallelize(sideEffectResult.asInstanceOf[Seq[InternalRow]], 1) } } http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ecc78a5..8850e06 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ @@ -94,7 +95,9 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = writer + def getLocalFileWriter(row: InternalRow, schema: StructType): FileSinkOperator.RecordWriter = { + writer + } def close() { // Seems the boolean value passed into close does not matter. @@ -197,7 +200,8 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = { + override def getLocalFileWriter(row: InternalRow, schema: StructType) + : FileSinkOperator.RecordWriter = { def convertToHiveRawString(col: String, value: Any): String = { val raw = String.valueOf(value) schema(col).dataType match { http://git-wip-us.apache.org/repos/asf/spark/blob/431ca39b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index de63ee5..10623dc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -66,7 +66,7 @@ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { + extends OutputWriterInternal with SparkHadoopMapRedUtil with HiveInspectors { private val serializer = { val table = new Properties() @@ -119,9 +119,9 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def write(row: Row): Unit = { + override def writeInternal(row: InternalRow): Unit = { var i = 0 - while (i < row.length) { + while (i < row.numFields) { reusableOutputBuffer(i) = wrappers(i)(row(i)) i += 1 } @@ -192,7 +192,7 @@ private[sql] class OrcRelation( filters: Array[Filter], inputPaths: Array[FileStatus]): RDD[Row] = { val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(output, this, filters, inputPaths).execute().map(_.asInstanceOf[Row]) + OrcTableScan(output, this, filters, inputPaths).execute().asInstanceOf[RDD[Row]] } override def prepareJobForWrite(job: Job): OutputWriterFactory = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org