Github user BryanCutler commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21546#discussion_r196247663
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 ---
    @@ -183,34 +182,131 @@ private[sql] object ArrowConverters {
       }
     
       /**
    -   * Convert a byte array to an ArrowRecordBatch.
    +   * Load a serialized ArrowRecordBatch.
        */
    -  private[arrow] def byteArrayToBatch(
    +  private[arrow] def loadBatch(
           batchBytes: Array[Byte],
           allocator: BufferAllocator): ArrowRecordBatch = {
    -    val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
    -    val reader = new ArrowFileReader(in, allocator)
    -
    -    // Read a batch from a byte stream, ensure the reader is closed
    -    Utils.tryWithSafeFinally {
    -      val root = reader.getVectorSchemaRoot  // throws IOException
    -      val unloader = new VectorUnloader(root)
    -      reader.loadNextBatch()  // throws IOException
    -      unloader.getRecordBatch
    -    } {
    -      reader.close()
    -    }
    +    val in = new ByteArrayInputStream(batchBytes)
    +    MessageSerializer.deserializeMessageBatch(new 
ReadChannel(Channels.newChannel(in)), allocator)
    +      .asInstanceOf[ArrowRecordBatch]  // throws IOException
       }
     
    +  /**
    +   * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
    +   */
       private[sql] def toDataFrame(
    -      payloadRDD: JavaRDD[Array[Byte]],
    +      arrowBatchRDD: JavaRDD[Array[Byte]],
           schemaString: String,
           sqlContext: SQLContext): DataFrame = {
    -    val rdd = payloadRDD.rdd.mapPartitions { iter =>
    +    val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
    +    val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
    +    val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
           val context = TaskContext.get()
    -      ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), 
context)
    +      ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
         }
    -    val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
         sqlContext.internalCreateDataFrame(rdd, schema)
       }
    +
    +  /**
    +   * Read a file as an Arrow stream and return an RDD of serialized 
ArrowRecordBatches.
    +   */
    +  private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, 
filename: String):
    +  JavaRDD[Array[Byte]] = {
    +    val fileStream = new FileInputStream(filename)
    +    try {
    +      // Create array so that we can safely close the file
    +      val batches = getBatchesFromStream(fileStream.getChannel).toArray
    +      // Parallelize the record batches to create an RDD
    +      JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, 
batches.length))
    +    } finally {
    +      fileStream.close()
    +    }
    +  }
    +
    +  /**
    +   * Read an Arrow stream input and return an iterator of serialized 
ArrowRecordBatches.
    +   */
    +  private[sql] def getBatchesFromStream(in: SeekableByteChannel): 
Iterator[Array[Byte]] = {
    +
    +    // TODO: simplify in super class
    +    class RecordBatchMessageReader(inputChannel: SeekableByteChannel) {
    --- End diff --
    
    I think we should definitely avoid extra copies of the data in the JVM if 
possible, since we are trying to be efficient here.  This process doesn't 
really seem complex to me, the specialized code here is about 10 lines and just 
reads bytes from an input stream to a byte buffer.  Can you clarify a bit more 
on why you think this should be in a separate PR? 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to