Github user sethah commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21546#discussion_r199497456
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 ---
    @@ -183,34 +182,111 @@ private[sql] object ArrowConverters {
       }
     
       /**
    -   * Convert a byte array to an ArrowRecordBatch.
    +   * Load a serialized ArrowRecordBatch.
        */
    -  private[arrow] def byteArrayToBatch(
    +  private[arrow] def loadBatch(
           batchBytes: Array[Byte],
           allocator: BufferAllocator): ArrowRecordBatch = {
    -    val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
    -    val reader = new ArrowFileReader(in, allocator)
    -
    -    // Read a batch from a byte stream, ensure the reader is closed
    -    Utils.tryWithSafeFinally {
    -      val root = reader.getVectorSchemaRoot  // throws IOException
    -      val unloader = new VectorUnloader(root)
    -      reader.loadNextBatch()  // throws IOException
    -      unloader.getRecordBatch
    -    } {
    -      reader.close()
    -    }
    +    val in = new ByteArrayInputStream(batchBytes)
    +    MessageSerializer.deserializeMessageBatch(new 
ReadChannel(Channels.newChannel(in)), allocator)
    +      .asInstanceOf[ArrowRecordBatch]  // throws IOException
       }
     
    +  /**
    +   * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
    +   */
       private[sql] def toDataFrame(
    -      payloadRDD: JavaRDD[Array[Byte]],
    +      arrowBatchRDD: JavaRDD[Array[Byte]],
           schemaString: String,
           sqlContext: SQLContext): DataFrame = {
    -    val rdd = payloadRDD.rdd.mapPartitions { iter =>
    +    val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
    +    val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
    +    val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
           val context = TaskContext.get()
    -      ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), 
context)
    +      ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
         }
    -    val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
         sqlContext.internalCreateDataFrame(rdd, schema)
       }
    +
    +  /**
    +   * Read a file as an Arrow stream and parallelize as an RDD of 
serialized ArrowRecordBatches.
    +   */
    +  private[sql] def readArrowStreamFromFile(
    +      sqlContext: SQLContext,
    +      filename: String): JavaRDD[Array[Byte]] = {
    +    val fileStream = new FileInputStream(filename)
    +    try {
    +      // Create array so that we can safely close the file
    +      val batches = getBatchesFromStream(fileStream.getChannel).toArray
    +      // Parallelize the record batches to create an RDD
    +      JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, 
batches.length))
    +    } finally {
    +      fileStream.close()
    +    }
    +  }
    +
    +  /**
    +   * Read an Arrow stream input and return an iterator of serialized 
ArrowRecordBatches.
    +   */
    +  private[sql] def getBatchesFromStream(in: SeekableByteChannel): 
Iterator[Array[Byte]] = {
    +
    +    // TODO: this could be moved to Arrow
    +    def readMessageLength(in: ReadChannel): Int = {
    +      val buffer = ByteBuffer.allocate(4)
    +      if (in.readFully(buffer) != 4) {
    +        return 0
    +      }
    +      MessageSerializer.bytesToInt(buffer.array())
    +    }
    +
    +    // TODO: this could be moved to Arrow
    +    def loadMessage(in: ReadChannel, messageLength: Int, buffer: 
ByteBuffer): Message = {
    +      if (in.readFully(buffer) != messageLength) {
    +        throw new java.io.IOException(
    +          "Unexpected end of stream trying to read message.")
    +      }
    +      buffer.rewind()
    +      Message.getRootAsMessage(buffer)
    +    }
    +
    +
    +    // Create an iterator to get each serialized ArrowRecordBatch from a 
stream
    +    new Iterator[Array[Byte]] {
    +      val inputChannel = new ReadChannel(in)
    +      var batch: Array[Byte] = readNextBatch()
    +
    +      override def hasNext: Boolean = batch != null
    +
    +      override def next(): Array[Byte] = {
    +        val prevBatch = batch
    +        batch = readNextBatch()
    +        prevBatch
    +      }
    +
    +      def readNextBatch(): Array[Byte] = {
    +        val messageLength = readMessageLength(inputChannel)
    +        if (messageLength == 0) {
    +          return null
    +        }
    +
    +        val buffer = ByteBuffer.allocate(messageLength)
    +        val msg = loadMessage(inputChannel, messageLength, buffer)
    +        val bodyLength = msg.bodyLength().asInstanceOf[Int]
    +
    +        if (msg.headerType() == MessageHeader.RecordBatch) {
    +          val allbuf = ByteBuffer.allocate(4 + messageLength + bodyLength)
    +          allbuf.put(WriteChannel.intToBytes(messageLength))
    +          allbuf.put(buffer)
    +          inputChannel.readFully(allbuf)
    +          allbuf.array()
    +        } else {
    +          if (bodyLength > 0) {
    +            // Skip message body if not a record batch
    --- End diff --
    
    What are the conditions under which we'd expect this to happen?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to