Repository: spark
Updated Branches:
  refs/heads/master ecf437a64 -> 964b507c7


[SPARK-21583][SQL] Create a ColumnarBatch from ArrowColumnVectors

## What changes were proposed in this pull request?

This PR allows the creation of a `ColumnarBatch` from `ReadOnlyColumnVectors` 
where previously a columnar batch could only allocate vectors internally.  This 
is useful for using `ArrowColumnVectors` in a batch form to do row-based 
iteration.  Also added `ArrowConverter.fromPayloadIterator` which converts 
`ArrowPayload` iterator to `InternalRow` iterator and uses a `ColumnarBatch` 
internally.

## How was this patch tested?

Added a new unit test for creating a `ColumnarBatch` with 
`ReadOnlyColumnVectors` and a test to verify the roundtrip of rows -> 
ArrowPayload -> rows, using `toPayloadIterator` and `fromPayloadIterator`.

Author: Bryan Cutler <cutl...@gmail.com>

Closes #18787 from BryanCutler/arrow-ColumnarBatch-support-SPARK-21583.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/964b507c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/964b507c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/964b507c

Branch: refs/heads/master
Commit: 964b507c7511cf3f4383cb0fc4026a573034b8cc
Parents: ecf437a
Author: Bryan Cutler <cutl...@gmail.com>
Authored: Thu Aug 31 13:08:52 2017 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Thu Aug 31 13:08:52 2017 +0900

----------------------------------------------------------------------
 .../sql/execution/arrow/ArrowConverters.scala   | 76 +++++++++++++++++++-
 .../execution/arrow/ArrowConvertersSuite.scala  | 29 +++++++-
 .../vectorized/ColumnarBatchSuite.scala         | 54 ++++++++++++++
 3 files changed, 157 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/964b507c/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index fa45822..561a067 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.arrow
 import java.io.ByteArrayOutputStream
 import java.nio.channels.Channels
 
+import scala.collection.JavaConverters._
+
 import org.apache.arrow.memory.BufferAllocator
 import org.apache.arrow.vector._
 import org.apache.arrow.vector.file._
@@ -28,6 +30,7 @@ import 
org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
 
 import org.apache.spark.TaskContext
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, 
ColumnarBatch, ColumnVector}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
@@ -35,7 +38,7 @@ import org.apache.spark.util.Utils
 /**
  * Store Arrow data in a form that can be serialized by Spark and served to a 
Python process.
  */
-private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends 
Serializable {
+private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends 
Serializable {
 
   /**
    * Convert the ArrowPayload to an ArrowRecordBatch.
@@ -50,6 +53,17 @@ private[sql] class ArrowPayload private[arrow] (payload: 
Array[Byte]) extends Se
   def asPythonSerializable: Array[Byte] = payload
 }
 
+/**
+ * Iterator interface to iterate over Arrow record batches and return rows
+ */
+private[sql] trait ArrowRowIterator extends Iterator[InternalRow] {
+
+  /**
+   * Return the schema loaded from the Arrow record batch being iterated over
+   */
+  def schema: StructType
+}
+
 private[sql] object ArrowConverters {
 
   /**
@@ -111,6 +125,66 @@ private[sql] object ArrowConverters {
   }
 
   /**
+   * Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing 
the row iterator
+   * and the schema from the first batch of Arrow data read.
+   */
+  private[sql] def fromPayloadIterator(
+      payloadIter: Iterator[ArrowPayload],
+      context: TaskContext): ArrowRowIterator = {
+    val allocator =
+      ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, 
Long.MaxValue)
+
+    new ArrowRowIterator {
+      private var reader: ArrowFileReader = null
+      private var schemaRead = StructType(Seq.empty)
+      private var rowIter = if (payloadIter.hasNext) nextBatch() else 
Iterator.empty
+
+      context.addTaskCompletionListener { _ =>
+        closeReader()
+        allocator.close()
+      }
+
+      override def schema: StructType = schemaRead
+
+      override def hasNext: Boolean = rowIter.hasNext || {
+        closeReader()
+        if (payloadIter.hasNext) {
+          rowIter = nextBatch()
+          true
+        } else {
+          allocator.close()
+          false
+        }
+      }
+
+      override def next(): InternalRow = rowIter.next()
+
+      private def closeReader(): Unit = {
+        if (reader != null) {
+          reader.close()
+          reader = null
+        }
+      }
+
+      private def nextBatch(): Iterator[InternalRow] = {
+        val in = new 
ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable)
+        reader = new ArrowFileReader(in, allocator)
+        reader.loadNextBatch()  // throws IOException
+        val root = reader.getVectorSchemaRoot  // throws IOException
+        schemaRead = ArrowUtils.fromArrowSchema(root.getSchema)
+
+        val columns = root.getFieldVectors.asScala.map { vector =>
+          new ArrowColumnVector(vector).asInstanceOf[ColumnVector]
+        }.toArray
+
+        val batch = new ColumnarBatch(schemaRead, columns, root.getRowCount)
+        batch.setNumRows(root.getRowCount)
+        batch.rowIterator().asScala
+      }
+    }
+  }
+
+  /**
    * Convert a byte array to an ArrowRecordBatch.
    */
   private[arrow] def byteArrayToBatch(

http://git-wip-us.apache.org/repos/asf/spark/blob/964b507c/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index 4893b52..30422b6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -29,8 +29,9 @@ import org.apache.arrow.vector.file.json.JsonFileReader
 import org.apache.arrow.vector.util.Validator
 import org.scalatest.BeforeAndAfterAll
 
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, TaskContext}
 import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, 
StructType}
 import org.apache.spark.util.Utils
@@ -1629,6 +1630,32 @@ class ArrowConvertersSuite extends SharedSQLContext with 
BeforeAndAfterAll {
     }
   }
 
+  test("roundtrip payloads") {
+    val inputRows = (0 until 9).map { i =>
+      InternalRow(i)
+    } :+ InternalRow(null)
+
+    val schema = StructType(Seq(StructField("int", IntegerType, nullable = 
true)))
+
+    val ctx = TaskContext.empty()
+    val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, 
schema, 0, ctx)
+    val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx)
+
+    assert(schema.equals(outputRowIter.schema))
+
+    var count = 0
+    outputRowIter.zipWithIndex.foreach { case (row, i) =>
+      if (i != 9) {
+        assert(row.getInt(0) == i)
+      } else {
+        assert(row.isNullAt(0))
+      }
+      count += 1
+    }
+
+    assert(count == inputRows.length)
+  }
+
   /** Test that a converted DataFrame to Arrow record batch equals batch read 
from JSON file */
   private def collectAndValidate(df: DataFrame, json: String, file: String): 
Unit = {
     // NOTE: coalesce to single partition because can only load 1 batch in 
validator

http://git-wip-us.apache.org/repos/asf/spark/blob/964b507c/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 08ccbd6..1f21d3c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -25,10 +25,13 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.util.Random
 
+import org.apache.arrow.vector.NullableIntVector
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.memory.MemoryMode
 import org.apache.spark.sql.{RandomDataGenerator, Row}
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.arrow.ArrowUtils
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.types.CalendarInterval
@@ -1261,4 +1264,55 @@ class ColumnarBatchSuite extends SparkFunSuite {
         s"vectorized reader"))
     }
   }
+
+  test("create columnar batch from Arrow column vectors") {
+    val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, 
Long.MaxValue)
+    val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true)
+      .createVector(allocator).asInstanceOf[NullableIntVector]
+    vector1.allocateNew()
+    val mutator1 = vector1.getMutator()
+    val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true)
+      .createVector(allocator).asInstanceOf[NullableIntVector]
+    vector2.allocateNew()
+    val mutator2 = vector2.getMutator()
+
+    (0 until 10).foreach { i =>
+      mutator1.setSafe(i, i)
+      mutator2.setSafe(i + 1, i)
+    }
+    mutator1.setNull(10)
+    mutator1.setValueCount(11)
+    mutator2.setNull(0)
+    mutator2.setValueCount(11)
+
+    val columnVectors = Seq(new ArrowColumnVector(vector1), new 
ArrowColumnVector(vector2))
+
+    val schema = StructType(Seq(StructField("int1", IntegerType), 
StructField("int2", IntegerType)))
+    val batch = new ColumnarBatch(schema, columnVectors.toArray[ColumnVector], 
11)
+    batch.setNumRows(11)
+
+    assert(batch.numCols() == 2)
+    assert(batch.numRows() == 11)
+
+    val rowIter = batch.rowIterator().asScala
+    rowIter.zipWithIndex.foreach { case (row, i) =>
+      if (i == 10) {
+        assert(row.isNullAt(0))
+      } else {
+        assert(row.getInt(0) == i)
+      }
+      if (i == 0) {
+        assert(row.isNullAt(1))
+      } else {
+        assert(row.getInt(1) == i - 1)
+      }
+    }
+
+    intercept[java.lang.AssertionError] {
+      batch.getRow(100)
+    }
+
+    batch.close()
+    allocator.close()
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to