Repository: spark Updated Branches: refs/heads/master ecf437a64 -> 964b507c7
[SPARK-21583][SQL] Create a ColumnarBatch from ArrowColumnVectors ## What changes were proposed in this pull request? This PR allows the creation of a `ColumnarBatch` from `ReadOnlyColumnVectors` where previously a columnar batch could only allocate vectors internally. This is useful for using `ArrowColumnVectors` in a batch form to do row-based iteration. Also added `ArrowConverter.fromPayloadIterator` which converts `ArrowPayload` iterator to `InternalRow` iterator and uses a `ColumnarBatch` internally. ## How was this patch tested? Added a new unit test for creating a `ColumnarBatch` with `ReadOnlyColumnVectors` and a test to verify the roundtrip of rows -> ArrowPayload -> rows, using `toPayloadIterator` and `fromPayloadIterator`. Author: Bryan Cutler <cutl...@gmail.com> Closes #18787 from BryanCutler/arrow-ColumnarBatch-support-SPARK-21583. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/964b507c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/964b507c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/964b507c Branch: refs/heads/master Commit: 964b507c7511cf3f4383cb0fc4026a573034b8cc Parents: ecf437a Author: Bryan Cutler <cutl...@gmail.com> Authored: Thu Aug 31 13:08:52 2017 +0900 Committer: Takuya UESHIN <ues...@databricks.com> Committed: Thu Aug 31 13:08:52 2017 +0900 ---------------------------------------------------------------------- .../sql/execution/arrow/ArrowConverters.scala | 76 +++++++++++++++++++- .../execution/arrow/ArrowConvertersSuite.scala | 29 +++++++- .../vectorized/ColumnarBatchSuite.scala | 54 ++++++++++++++ 3 files changed, 157 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/964b507c/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index fa45822..561a067 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.arrow import java.io.ByteArrayOutputStream import java.nio.channels.Channels +import scala.collection.JavaConverters._ + import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ import org.apache.arrow.vector.file._ @@ -28,6 +30,7 @@ import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -35,7 +38,7 @@ import org.apache.spark.util.Utils /** * Store Arrow data in a form that can be serialized by Spark and served to a Python process. */ -private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable { +private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable { /** * Convert the ArrowPayload to an ArrowRecordBatch. @@ -50,6 +53,17 @@ private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Se def asPythonSerializable: Array[Byte] = payload } +/** + * Iterator interface to iterate over Arrow record batches and return rows + */ +private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { + + /** + * Return the schema loaded from the Arrow record batch being iterated over + */ + def schema: StructType +} + private[sql] object ArrowConverters { /** @@ -111,6 +125,66 @@ private[sql] object ArrowConverters { } /** + * Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing the row iterator + * and the schema from the first batch of Arrow data read. + */ + private[sql] def fromPayloadIterator( + payloadIter: Iterator[ArrowPayload], + context: TaskContext): ArrowRowIterator = { + val allocator = + ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue) + + new ArrowRowIterator { + private var reader: ArrowFileReader = null + private var schemaRead = StructType(Seq.empty) + private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty + + context.addTaskCompletionListener { _ => + closeReader() + allocator.close() + } + + override def schema: StructType = schemaRead + + override def hasNext: Boolean = rowIter.hasNext || { + closeReader() + if (payloadIter.hasNext) { + rowIter = nextBatch() + true + } else { + allocator.close() + false + } + } + + override def next(): InternalRow = rowIter.next() + + private def closeReader(): Unit = { + if (reader != null) { + reader.close() + reader = null + } + } + + private def nextBatch(): Iterator[InternalRow] = { + val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable) + reader = new ArrowFileReader(in, allocator) + reader.loadNextBatch() // throws IOException + val root = reader.getVectorSchemaRoot // throws IOException + schemaRead = ArrowUtils.fromArrowSchema(root.getSchema) + + val columns = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector).asInstanceOf[ColumnVector] + }.toArray + + val batch = new ColumnarBatch(schemaRead, columns, root.getRowCount) + batch.setNumRows(root.getRowCount) + batch.rowIterator().asScala + } + } + } + + /** * Convert a byte array to an ArrowRecordBatch. */ private[arrow] def byteArrayToBatch( http://git-wip-us.apache.org/repos/asf/spark/blob/964b507c/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 4893b52..30422b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -29,8 +29,9 @@ import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -1629,6 +1630,32 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } + test("roundtrip payloads") { + val inputRows = (0 until 9).map { i => + InternalRow(i) + } :+ InternalRow(null) + + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + + val ctx = TaskContext.empty() + val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, ctx) + val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) + + assert(schema.equals(outputRowIter.schema)) + + var count = 0 + outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { + assert(row.getInt(0) == i) + } else { + assert(row.isNullAt(0)) + } + count += 1 + } + + assert(count == inputRows.length) + } + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator http://git-wip-us.apache.org/repos/asf/spark/blob/964b507c/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 08ccbd6..1f21d3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -25,10 +25,13 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random +import org.apache.arrow.vector.NullableIntVector + import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval @@ -1261,4 +1264,55 @@ class ColumnarBatchSuite extends SparkFunSuite { s"vectorized reader")) } } + + test("create columnar batch from Arrow column vectors") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) + val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true) + .createVector(allocator).asInstanceOf[NullableIntVector] + vector1.allocateNew() + val mutator1 = vector1.getMutator() + val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true) + .createVector(allocator).asInstanceOf[NullableIntVector] + vector2.allocateNew() + val mutator2 = vector2.getMutator() + + (0 until 10).foreach { i => + mutator1.setSafe(i, i) + mutator2.setSafe(i + 1, i) + } + mutator1.setNull(10) + mutator1.setValueCount(11) + mutator2.setNull(0) + mutator2.setValueCount(11) + + val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2)) + + val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType))) + val batch = new ColumnarBatch(schema, columnVectors.toArray[ColumnVector], 11) + batch.setNumRows(11) + + assert(batch.numCols() == 2) + assert(batch.numRows() == 11) + + val rowIter = batch.rowIterator().asScala + rowIter.zipWithIndex.foreach { case (row, i) => + if (i == 10) { + assert(row.isNullAt(0)) + } else { + assert(row.getInt(0) == i) + } + if (i == 0) { + assert(row.isNullAt(1)) + } else { + assert(row.getInt(1) == i - 1) + } + } + + intercept[java.lang.AssertionError] { + batch.getRow(100) + } + + batch.close() + allocator.close() + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org