grundprinzip commented on code in PR #38659:
URL: https://github.com/apache/spark/pull/38659#discussion_r1029721265


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -213,58 +214,115 @@ private[sql] object ArrowConverters extends Logging {
     }.next()
   }
 
-  /**
-   * Maps iterator from serialized ArrowRecordBatches to InternalRows.
-   */
-  private[sql] def fromBatchIterator(
+  private[sql] abstract class InternalRowIterator(

Review Comment:
   Please add some comment on how this is supposed to be used. The name sounds 
very innocent :)



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -213,58 +214,115 @@ private[sql] object ArrowConverters extends Logging {
     }.next()
   }
 
-  /**
-   * Maps iterator from serialized ArrowRecordBatches to InternalRows.
-   */
-  private[sql] def fromBatchIterator(
+  private[sql] abstract class InternalRowIterator(
       arrowBatchIter: Iterator[Array[Byte]],
-      schema: StructType,
-      timeZoneId: String,
-      context: TaskContext): Iterator[InternalRow] = {
-    val allocator =
-      ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, 
Long.MaxValue)
-
-    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
-    val root = VectorSchemaRoot.create(arrowSchema, allocator)
-
-    new Iterator[InternalRow] {
-      private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else 
Iterator.empty
-
-      if (context != null) context.addTaskCompletionListener[Unit] { _ =>
-        root.close()
-        allocator.close()
-      }
+      context: TaskContext)
+      extends Iterator[InternalRow] {
+    // Keep all the resources we have opened in order, should be closed in 
reverse order finally.
+    val resources = new ArrayBuffer[AutoCloseable]()
+    protected val allocator: BufferAllocator = 
ArrowUtils.rootAllocator.newChildAllocator(
+      s"to${this.getClass.getSimpleName}",
+      0,
+      Long.MaxValue)
+    resources.append(allocator)
+
+    private var rowIterAndSchema =
+      if (arrowBatchIter.hasNext) nextBatch() else (Iterator.empty, null)
+    // We will ensure schemas parsed from every batch are the same
+    val schema: StructType = rowIterAndSchema._2
+
+    if (context != null) context.addTaskCompletionListener[Unit] { _ =>
+      closeAll(resources.reverse: _*)
+    }
 
-      override def hasNext: Boolean = rowIter.hasNext || {
-        if (arrowBatchIter.hasNext) {
-          rowIter = nextBatch()
-          true
-        } else {
-          root.close()
-          allocator.close()
-          false
+    override def hasNext: Boolean = rowIterAndSchema._1.hasNext || {
+      if (arrowBatchIter.hasNext) {
+        rowIterAndSchema = nextBatch()
+        if (schema != rowIterAndSchema._2) {
+          throw new IllegalArgumentException(
+            s"ArrowBatch iterator contain 2 batches with" +
+              s" different schema: $schema and ${rowIterAndSchema._2}")
         }
+        rowIterAndSchema._1.hasNext
+      } else {
+        closeAll(resources.reverse: _*)
+        false
       }
+    }
 
-      override def next(): InternalRow = rowIter.next()
+    override def next(): InternalRow = rowIterAndSchema._1.next()
 
-      private def nextBatch(): Iterator[InternalRow] = {
-        val arrowRecordBatch = 
ArrowConverters.loadBatch(arrowBatchIter.next(), allocator)
-        val vectorLoader = new VectorLoader(root)
-        vectorLoader.load(arrowRecordBatch)
-        arrowRecordBatch.close()
+    def nextBatch(): (Iterator[InternalRow], StructType)
+  }
 
-        val columns = root.getFieldVectors.asScala.map { vector =>
-          new ArrowColumnVector(vector).asInstanceOf[ColumnVector]
-        }.toArray
+  private[sql] class InternalRowIteratorWithoutSchema(
+      arrowBatchIter: Iterator[Array[Byte]],
+      schema: StructType,
+      timeZoneId: String,
+      context: TaskContext)
+      extends InternalRowIterator(arrowBatchIter, context) {
+
+    override def nextBatch(): (Iterator[InternalRow], StructType) = {
+      val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+      val root = VectorSchemaRoot.create(arrowSchema, allocator)
+      resources.append(root)
+      val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), 
allocator)
+      val vectorLoader = new VectorLoader(root)
+      vectorLoader.load(arrowRecordBatch)
+      arrowRecordBatch.close()
+      (vectorSchemaRootToIter(root), schema)
+    }
+  }
 
-        val batch = new ColumnarBatch(columns)
-        batch.setNumRows(root.getRowCount)
-        batch.rowIterator().asScala
+  private[sql] class InternalRowIteratorWithSchema(
+      arrowBatchIter: Iterator[Array[Byte]],
+      context: TaskContext)
+      extends InternalRowIterator(arrowBatchIter, context) {
+    override def nextBatch(): (Iterator[InternalRow], StructType) = {
+      val reader =
+        new ArrowStreamReader(new ByteArrayInputStream(arrowBatchIter.next()), 
allocator)
+      val root = if (reader.loadNextBatch()) reader.getVectorSchemaRoot else 
null
+      resources.append(reader, root)
+      if (root == null) {
+        (Iterator.empty, null)
+      } else {
+        (vectorSchemaRootToIter(root), 
ArrowUtils.fromArrowSchema(root.getSchema))
       }
     }
   }
 
+  /**
+   * Maps iterator from serialized ArrowRecordBatches to InternalRows.
+   */
+  private[sql] def fromBatchIterator(
+      arrowBatchIter: Iterator[Array[Byte]],
+      schema: StructType,
+      timeZoneId: String,
+      context: TaskContext): Iterator[InternalRow] = new 
InternalRowIteratorWithoutSchema(
+    arrowBatchIter, schema, timeZoneId, context
+  )
+
+  /**
+   * Maps iterator from serialized ArrowRecordBatches to InternalRows. 
Different from
+   * [[fromBatchIterator]], each input arrow batch starts with the schema.
+   */
+  private[sql] def fromBatchWithSchemaIterator(
+      arrowBatchIter: Iterator[Array[Byte]],
+      context: TaskContext): (Iterator[InternalRow], StructType) = {
+    val iterator = new InternalRowIteratorWithSchema(arrowBatchIter, context)
+    (iterator, iterator.schema)
+  }
+
+  private def vectorSchemaRootToIter(root: VectorSchemaRoot): 
Iterator[InternalRow] = {

Review Comment:
   Is this really such a trivial change? Can we maybe add a little bit of doc?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##########
@@ -213,58 +214,115 @@ private[sql] object ArrowConverters extends Logging {
     }.next()
   }
 
-  /**
-   * Maps iterator from serialized ArrowRecordBatches to InternalRows.
-   */
-  private[sql] def fromBatchIterator(
+  private[sql] abstract class InternalRowIterator(
       arrowBatchIter: Iterator[Array[Byte]],
-      schema: StructType,
-      timeZoneId: String,
-      context: TaskContext): Iterator[InternalRow] = {
-    val allocator =
-      ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, 
Long.MaxValue)
-
-    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
-    val root = VectorSchemaRoot.create(arrowSchema, allocator)
-
-    new Iterator[InternalRow] {
-      private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else 
Iterator.empty
-
-      if (context != null) context.addTaskCompletionListener[Unit] { _ =>
-        root.close()
-        allocator.close()
-      }
+      context: TaskContext)
+      extends Iterator[InternalRow] {
+    // Keep all the resources we have opened in order, should be closed in 
reverse order finally.
+    val resources = new ArrayBuffer[AutoCloseable]()
+    protected val allocator: BufferAllocator = 
ArrowUtils.rootAllocator.newChildAllocator(
+      s"to${this.getClass.getSimpleName}",
+      0,
+      Long.MaxValue)
+    resources.append(allocator)
+
+    private var rowIterAndSchema =
+      if (arrowBatchIter.hasNext) nextBatch() else (Iterator.empty, null)
+    // We will ensure schemas parsed from every batch are the same

Review Comment:
   ```suggestion
       // We will ensure schemas parsed from every batch are the same.
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to