Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r211964996 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +178,106 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { - val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - val reader = new ArrowFileReader(in, allocator) - - // Read a batch from a byte stream, ensure the reader is closed - Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch - } { - reader.close() - } + val in = new ByteArrayInputStream(batchBytes) + MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - val rdd = payloadRDD.rdd.mapPartitions { iter => + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone + val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { + val fileStream = new FileInputStream(filename) + try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) + } finally { + fileStream.close() + } + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + + // Create an iterator to get each serialized ArrowRecordBatch from a stream + new Iterator[Array[Byte]] { + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { + val prevBatch = batch + batch = readNextBatch() + prevBatch + } + + def readNextBatch(): Array[Byte] = { + val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in)) + if (msgMetadata == null) { + return null + } + + // Get the length of the body, which has not be read at this point + val bodyLength = msgMetadata.getMessageBodyLength.toInt + + // Only care about RecordBatch data, skip Schema and unsupported Dictionary messages + if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) { + + // Create output backed by buffer to hold msg length (int32), msg metadata, msg body + val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength) --- End diff -- Add a comment that this is the deserialized form of an Arrow Record Batch?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org