Github user icexelloss commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21546#discussion_r197552655
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 ---
    @@ -183,34 +182,131 @@ private[sql] object ArrowConverters {
       }
     
       /**
    -   * Convert a byte array to an ArrowRecordBatch.
    +   * Load a serialized ArrowRecordBatch.
        */
    -  private[arrow] def byteArrayToBatch(
    +  private[arrow] def loadBatch(
           batchBytes: Array[Byte],
           allocator: BufferAllocator): ArrowRecordBatch = {
    -    val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
    -    val reader = new ArrowFileReader(in, allocator)
    -
    -    // Read a batch from a byte stream, ensure the reader is closed
    -    Utils.tryWithSafeFinally {
    -      val root = reader.getVectorSchemaRoot  // throws IOException
    -      val unloader = new VectorUnloader(root)
    -      reader.loadNextBatch()  // throws IOException
    -      unloader.getRecordBatch
    -    } {
    -      reader.close()
    -    }
    +    val in = new ByteArrayInputStream(batchBytes)
    +    MessageSerializer.deserializeMessageBatch(new 
ReadChannel(Channels.newChannel(in)), allocator)
    +      .asInstanceOf[ArrowRecordBatch]  // throws IOException
       }
     
    +  /**
    +   * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
    +   */
       private[sql] def toDataFrame(
    -      payloadRDD: JavaRDD[Array[Byte]],
    +      arrowBatchRDD: JavaRDD[Array[Byte]],
           schemaString: String,
           sqlContext: SQLContext): DataFrame = {
    -    val rdd = payloadRDD.rdd.mapPartitions { iter =>
    +    val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
    +    val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
    +    val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
           val context = TaskContext.get()
    -      ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), 
context)
    +      ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
         }
    -    val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
         sqlContext.internalCreateDataFrame(rdd, schema)
       }
    +
    +  /**
    +   * Read a file as an Arrow stream and return an RDD of serialized 
ArrowRecordBatches.
    +   */
    +  private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, 
filename: String):
    +  JavaRDD[Array[Byte]] = {
    +    val fileStream = new FileInputStream(filename)
    +    try {
    +      // Create array so that we can safely close the file
    +      val batches = getBatchesFromStream(fileStream.getChannel).toArray
    +      // Parallelize the record batches to create an RDD
    +      JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, 
batches.length))
    +    } finally {
    +      fileStream.close()
    +    }
    +  }
    +
    +  /**
    +   * Read an Arrow stream input and return an iterator of serialized 
ArrowRecordBatches.
    +   */
    +  private[sql] def getBatchesFromStream(in: SeekableByteChannel): 
Iterator[Array[Byte]] = {
    +
    +    // TODO: simplify in super class
    +    class RecordBatchMessageReader(inputChannel: SeekableByteChannel) {
    --- End diff --
    
    Yeah.. one way to do it is to write a new MessageReader interface to read 
Arrow message from a Channel:
    
    ```
    public class OnHeapMessageChannelReader {
      /**
       * Read the next message in the sequence.
       *
       * @return The read message or null if reached the end of the message 
sequence
       * @throws IOException
       */
      Message readNextMessage() throws IOException;
    
      /**
       * When a message is followed by a body of data, read that data into an 
ArrowBuf. This should
       * only be called when a Message has a body length > 0.
       *
       * @param message Read message that is followed by a body of data
       * @param allocator BufferAllocator to allocate memory for body data
       * @return An ArrowBuf containing the body of the message that was read
       * @throws IOException
       */
       ByteBuffer readMessageBody(Message message) throws IOException;
    
       ...
    }
    ```
    
    We might need to duplicate some logic in 
https://github.com/apache/arrow/blob/master/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageChannelReader.java#L33
 
    
    For record batches, it's not too bad because the logic is pretty simple, 
but the down side is we will be using low level APIs of Arrow, which might not 
be guaranteed to be stable .
    
    @BryanCutler what kind of static function do you think we need to add on 
the Arrow side?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to