Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195562446 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { - val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - val reader = new ArrowFileReader(in, allocator) - - // Read a batch from a byte stream, ensure the reader is closed - Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch - } { - reader.close() - } + val in = new ByteArrayInputStream(batchBytes) + MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - val rdd = payloadRDD.rdd.mapPartitions { iter => + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone + val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { + val fileStream = new FileInputStream(filename) + try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) + } finally { + fileStream.close() + } + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + + // TODO: simplify in super class + class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- I wonder if this works? In a for-loop: (1) Read next batch into VectorSchemaRoot (copy into arrow memory) (2) Use VectorUnloader to unload the VectorSchemaRoot to an ArrowRecordBatch (no copy) (3) Use MessageSerializer.serialize to write ArrowRecordBatch to a ByteChannel (copy from arrow memory to java memory) Seems that we cannot directly read from socket into java memory anyway (have to go through Arrow memory allocator)..
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org