Github user icexelloss commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21546#discussion_r211964996
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 ---
    @@ -183,34 +178,106 @@ private[sql] object ArrowConverters {
       }
     
       /**
    -   * Convert a byte array to an ArrowRecordBatch.
    +   * Load a serialized ArrowRecordBatch.
        */
    -  private[arrow] def byteArrayToBatch(
    +  private[arrow] def loadBatch(
           batchBytes: Array[Byte],
           allocator: BufferAllocator): ArrowRecordBatch = {
    -    val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
    -    val reader = new ArrowFileReader(in, allocator)
    -
    -    // Read a batch from a byte stream, ensure the reader is closed
    -    Utils.tryWithSafeFinally {
    -      val root = reader.getVectorSchemaRoot  // throws IOException
    -      val unloader = new VectorUnloader(root)
    -      reader.loadNextBatch()  // throws IOException
    -      unloader.getRecordBatch
    -    } {
    -      reader.close()
    -    }
    +    val in = new ByteArrayInputStream(batchBytes)
    +    MessageSerializer.deserializeRecordBatch(
    +      new ReadChannel(Channels.newChannel(in)), allocator)  // throws 
IOException
       }
     
    +  /**
    +   * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
    +   */
       private[sql] def toDataFrame(
    -      payloadRDD: JavaRDD[Array[Byte]],
    +      arrowBatchRDD: JavaRDD[Array[Byte]],
           schemaString: String,
           sqlContext: SQLContext): DataFrame = {
    -    val rdd = payloadRDD.rdd.mapPartitions { iter =>
    +    val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
    +    val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
    +    val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
           val context = TaskContext.get()
    -      ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), 
context)
    +      ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
         }
    -    val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
         sqlContext.internalCreateDataFrame(rdd, schema)
       }
    +
    +  /**
    +   * Read a file as an Arrow stream and parallelize as an RDD of 
serialized ArrowRecordBatches.
    +   */
    +  private[sql] def readArrowStreamFromFile(
    +      sqlContext: SQLContext,
    +      filename: String): JavaRDD[Array[Byte]] = {
    +    val fileStream = new FileInputStream(filename)
    +    try {
    +      // Create array so that we can safely close the file
    +      val batches = getBatchesFromStream(fileStream.getChannel).toArray
    +      // Parallelize the record batches to create an RDD
    +      JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, 
batches.length))
    +    } finally {
    +      fileStream.close()
    +    }
    +  }
    +
    +  /**
    +   * Read an Arrow stream input and return an iterator of serialized 
ArrowRecordBatches.
    +   */
    +  private[sql] def getBatchesFromStream(in: SeekableByteChannel): 
Iterator[Array[Byte]] = {
    +
    +    // Create an iterator to get each serialized ArrowRecordBatch from a 
stream
    +    new Iterator[Array[Byte]] {
    +      var batch: Array[Byte] = readNextBatch()
    +
    +      override def hasNext: Boolean = batch != null
    +
    +      override def next(): Array[Byte] = {
    +        val prevBatch = batch
    +        batch = readNextBatch()
    +        prevBatch
    +      }
    +
    +      def readNextBatch(): Array[Byte] = {
    +        val msgMetadata = MessageSerializer.readMessage(new 
ReadChannel(in))
    +        if (msgMetadata == null) {
    +          return null
    +        }
    +
    +        // Get the length of the body, which has not be read at this point
    +        val bodyLength = msgMetadata.getMessageBodyLength.toInt
    +
    +        // Only care about RecordBatch data, skip Schema and unsupported 
Dictionary messages
    +        if (msgMetadata.getMessage.headerType() == 
MessageHeader.RecordBatch) {
    +
    +          // Create output backed by buffer to hold msg length (int32), 
msg metadata, msg body
    +          val bbout = new ByteBufferOutputStream(4 + 
msgMetadata.getMessageLength + bodyLength)
    --- End diff --
    
    Add a comment that this is the deserialized form of an Arrow Record Batch? 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to