[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user squito commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r213835409 --- Diff: python/pyspark/context.py --- @@ -494,10 +494,14 @@ def f(split, iterator): c = list(c)# Make it a list so we can compute its length batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) -jrdd = self._serialize_to_jvm(c, numSlices, serializer) + +def reader_func(temp_filename): +return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) + +jrdd = self._serialize_to_jvm(c, serializer, reader_func) return RDD(jrdd, self, serializer) -def _serialize_to_jvm(self, data, parallelism, serializer): +def _serialize_to_jvm(self, data, serializer, reader_func): --- End diff -- Thanks @BryanCutler , sorry I didn't know where to look for those, they look much better than what I would have added! --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r213826584 --- Diff: python/pyspark/context.py --- @@ -494,10 +494,14 @@ def f(split, iterator): c = list(c)# Make it a list so we can compute its length batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) -jrdd = self._serialize_to_jvm(c, numSlices, serializer) + +def reader_func(temp_filename): +return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) + +jrdd = self._serialize_to_jvm(c, serializer, reader_func) return RDD(jrdd, self, serializer) -def _serialize_to_jvm(self, data, parallelism, serializer): +def _serialize_to_jvm(self, data, serializer, reader_func): --- End diff -- Although most parts in PySpark should be guaranteed by Spark Core and SQL, PySpark starts to have more and more PySpark-only stuffs. I am not very sure how well they are tested. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r213825858 --- Diff: python/pyspark/context.py --- @@ -494,10 +494,14 @@ def f(split, iterator): c = list(c)# Make it a list so we can compute its length batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) -jrdd = self._serialize_to_jvm(c, numSlices, serializer) + +def reader_func(temp_filename): +return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) + +jrdd = self._serialize_to_jvm(c, serializer, reader_func) return RDD(jrdd, self, serializer) -def _serialize_to_jvm(self, data, parallelism, serializer): +def _serialize_to_jvm(self, data, serializer, reader_func): --- End diff -- To be honest, I worry about the test coverage of PySpark in general. Anybody in PySpark can lead the effort to propose a solution for improving the test coverage? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r213819766 --- Diff: python/pyspark/context.py --- @@ -494,10 +494,14 @@ def f(split, iterator): c = list(c)# Make it a list so we can compute its length batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) -jrdd = self._serialize_to_jvm(c, numSlices, serializer) + +def reader_func(temp_filename): +return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) + +jrdd = self._serialize_to_jvm(c, serializer, reader_func) return RDD(jrdd, self, serializer) -def _serialize_to_jvm(self, data, parallelism, serializer): +def _serialize_to_jvm(self, data, serializer, reader_func): --- End diff -- I made https://issues.apache.org/jira/browse/SPARK-25272 which will give a more clear output that the ArrowTests were run. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r213818193 --- Diff: python/pyspark/context.py --- @@ -494,10 +494,14 @@ def f(split, iterator): c = list(c)# Make it a list so we can compute its length batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) -jrdd = self._serialize_to_jvm(c, numSlices, serializer) + +def reader_func(temp_filename): +return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) + +jrdd = self._serialize_to_jvm(c, serializer, reader_func) return RDD(jrdd, self, serializer) -def _serialize_to_jvm(self, data, parallelism, serializer): +def _serialize_to_jvm(self, data, serializer, reader_func): --- End diff -- Hey @squito , yes that's correct this is in the path that `ArrowTests` with `createDataFrame` tests. These tests are skipped if pyarrow is not installed, but for our Jenkins tests it is installed under the Python 3.5 env so it gets tested there. It's a little subtle to see that they were run since the test output shows only when tests are skipped. You can see that for Python 2.7 `ArrowTests` show as skipped, but for 3.5 it does not. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user squito commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r213815841 --- Diff: python/pyspark/context.py --- @@ -494,10 +494,14 @@ def f(split, iterator): c = list(c)# Make it a list so we can compute its length batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) -jrdd = self._serialize_to_jvm(c, numSlices, serializer) + +def reader_func(temp_filename): +return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) + +jrdd = self._serialize_to_jvm(c, serializer, reader_func) return RDD(jrdd, self, serializer) -def _serialize_to_jvm(self, data, parallelism, serializer): +def _serialize_to_jvm(self, data, serializer, reader_func): --- End diff -- (if not, I can try to address this in some other work) --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user squito commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r213785551 --- Diff: python/pyspark/context.py --- @@ -494,10 +494,14 @@ def f(split, iterator): c = list(c)# Make it a list so we can compute its length batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) -jrdd = self._serialize_to_jvm(c, numSlices, serializer) + +def reader_func(temp_filename): +return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices) + +jrdd = self._serialize_to_jvm(c, serializer, reader_func) return RDD(jrdd, self, serializer) -def _serialize_to_jvm(self, data, parallelism, serializer): +def _serialize_to_jvm(self, data, serializer, reader_func): --- End diff -- hi, sorry for the late review here, and more just a question for myself -- is this aspect tested at all? IIUC, it would be used in `spark.createDataFrame`, but the tests in session.py don't have arrow enabled, right? not that I see a bug, mostly just wondering as I was looking at making my own changes here, and it would be nice if I knew there were some tests --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user asfgit closed the pull request at: https://github.com/apache/spark/pull/21546 --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r212178291 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +178,106 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) --- End diff -- Ah, sorry. You are right. I misread. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r212171980 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +178,106 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) --- End diff -- so this the length of the array of batches, not the number of records in the batch. The input is split according to the default parallelism config. So if that is 32, we will have an array of 32 batches and then parallelize those to 32 partitions. `parallelize` might usually have one big array of primitives as the first arg, that you then partition by the number in the second arg, but this is a little different since we are using batches. Does that answer your question? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r212170997 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3268,13 +3268,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = toArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Store collection results for worst case of 1 to N-1 partitions --- End diff -- It's not necessary to buffer the first partition because it can be sent to Python right away, so only need an array of size N-1 --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r212171051 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +178,106 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) --- End diff -- yup, thanks for catching that --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r212170606 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +178,106 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { +val prevBatch = batch +batch = readNextBatch() +prevBatch + } + + def readNextBatch(): Array[Byte] = { +val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in)) +if (msgMetadata == null) { + return null +} + +// Get the length of the body, which has not be read at this point +val bodyLength = msgMetadata.getMessageBodyLength.toInt + +// Only care about RecordBatch data, skip Schema and unsupported Dictionary messages +if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) { + + // Create output backed by buffer to hold msg length (int32), msg metadata, msg body + val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength) --- End diff -- I'll add some more details about what this is doing --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r212163122 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +178,106 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) --- End diff -- @BryanCutler, why did this parallelize with the length of batches size? I thought the data size is usually small and wondering if it necessarily speeds up in general. Did I misread? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r212162307 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +178,106 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) --- End diff -- nit: i would do `tryWithResource` --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r212161411 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3268,13 +3268,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = toArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Store collection results for worst case of 1 to N-1 partitions --- End diff -- Is it better `0 to N-1 partitions`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r212158131 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -111,65 +113,58 @@ private[sql] object ArrowConverters { rowCount += 1 } arrowWriter.finish() - writer.writeBatch() + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + batch.close() --- End diff -- Should we `tryWithResouce` here too? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r211964996 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +178,106 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { +val prevBatch = batch +batch = readNextBatch() +prevBatch + } + + def readNextBatch(): Array[Byte] = { +val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in)) +if (msgMetadata == null) { + return null +} + +// Get the length of the body, which has not be read at this point +val bodyLength = msgMetadata.getMessageBodyLength.toInt + +// Only care about RecordBatch data, skip Schema and unsupported Dictionary messages +if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) { + + // Create output backed by buffer to hold msg length (int32), msg metadata, msg body + val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength) --- End diff -- Add a comment that this is the deserialized form of an Arrow Record Batch? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r205642926 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,67 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): --- End diff -- Thank you @BryanCutler. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r205624568 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,67 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): --- End diff -- @HyukjinKwon those are fair questions, not sure I'm going to have the time to do anything about it right now, but I'll circle back to this later. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r204961977 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,67 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): --- End diff -- Since you verifies the performance difference is trivial, I don't think it's a hard requirement to merge this though. At least, I would just push this in. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r204960838 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,67 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): --- End diff -- Ah, okay. I think I understood the benefit. But my impression is that this is something we already were doing. Also, if this is something we could apply to other functionalities too, then it sounded to me a bit of orthogonal work to do separately. Another concern is, for example, how much we'd likely hit this OOM because I usually expect the data for createDataFrame from Pandas DataFrame or toPandas is likely be small. If the changes were small, then it would have been okay to me but kind of large changes and looks affecting many codes from Scala side to Python side. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r204868891 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,67 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): --- End diff -- Yeah, I could separate this but is there anything I can do to alleviate your concern? I'm not sure I'll have the time to try to make another PR before 2.4.0 code freeze and I think this is a really useful memory optimization to help prevent OOM in the driver JVM. Also, I might have to rerun the benchmarks here, just to be thorough, because the previous ones were from quite a while ago. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user felixcheung commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r204635055 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,67 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): """ -Serializes bytes as Arrow data with the Arrow file format. +Deserialize a stream of batches followed by batch order information. """ -def dumps(self, batch): +def __init__(self, serializer): +self.serializer = serializer +self.batch_order = None + +def dump_stream(self, iterator, stream): +return self.serializer.dump_stream(iterator, stream) + +def load_stream(self, stream): +for batch in self.serializer.load_stream(stream): +yield batch +num = read_int(stream) +self.batch_order = [] +for i in xrange(num): +index = read_int(stream) +self.batch_order.append(index) +raise StopIteration() --- End diff -- this seems important... --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r204606545 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,67 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): --- End diff -- Thanks for elaborating this, @BryanCutler. Would you mind if I ask to add this separately in a separate PR? I am actually not super sure on this .. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r204484728 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,67 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): --- End diff -- Yeah, the performance gain by sending out of order batches was small, but the reason this was done was to improve memory usage in the driver JVM. Before this it still had a worst case of buffering the entire dataset in the JVM, but now nothing is buffered and partitions are immediately sent to Python. I think that's a huge improvement that is worth the additional complexity. This method might even be applicable to a `collect()` in Python also. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r204479024 --- Diff: python/pyspark/sql/dataframe.py --- @@ -2095,9 +2095,11 @@ def toPandas(self): _check_dataframe_localize_timestamps import pyarrow -tables = self._collectAsArrow() -if tables: -table = pyarrow.concat_tables(tables) +# Collect un-ordered list of batches, and list of correct order indices +batches, batch_order = self._collectAsArrow() +if batches: --- End diff -- Sure, I was playing around with this being an iterator, but I will change it since it is a list now --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r204324646 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,67 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): --- End diff -- @BryanCutler, just read https://github.com/apache/spark/pull/21546#issuecomment-400824224. How much performance gain was made by this? Looks pretty complicated.. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r204321600 --- Diff: python/pyspark/sql/dataframe.py --- @@ -2146,14 +2148,15 @@ def toPandas(self): def _collectAsArrow(self): """ -Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed -and available. +Returns all records as a list of ArrowRecordBatches and batch order as a list of indices, +pyarrow must be installed and available on driver and worker Python environments. .. note:: Experimental. """ +ser = BatchOrderSerializer(ArrowStreamSerializer()) with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() -return list(_load_from_socket(sock_info, ArrowSerializer())) +return list(_load_from_socket(sock_info, ser)), ser.get_batch_order_and_reset() --- End diff -- Hmmm .. @BryanCutler, would you mind if I ask why this batch order is required? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r204321247 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala --- @@ -1318,18 +1318,52 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - test("roundtrip payloads") { + test("roundtrip arrow batches") { val inputRows = (0 until 9).map { i => InternalRow(i) } :+ InternalRow(null) val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() -val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) -val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) +val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) +val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) -assert(schema == outputRowIter.schema) +var count = 0 +outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { +assert(row.getInt(0) == i) + } else { +assert(row.isNullAt(0)) + } + count += 1 +} + +assert(count == inputRows.length) + } + + test("ArrowBatchStreamWriter roundtrip") { +val inputRows = (0 until 9).map { i => + InternalRow(i) +} :+ InternalRow(null) --- End diff -- tiny nit: looks we can make this inlined. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r204319644 --- Diff: python/pyspark/sql/dataframe.py --- @@ -2095,9 +2095,11 @@ def toPandas(self): _check_dataframe_localize_timestamps import pyarrow -tables = self._collectAsArrow() -if tables: -table = pyarrow.concat_tables(tables) +# Collect un-ordered list of batches, and list of correct order indices +batches, batch_order = self._collectAsArrow() +if batches: --- End diff -- Not a big deal at all and personal preference: I would do this like `len(batches) > 0`. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r203913178 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3349,20 +3385,20 @@ class Dataset[T] private[sql]( } } - /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { + /** Convert to an RDD of serialized ArrowRecordBatches. */ + private[sql] def getArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { --- End diff -- Yeah, I can't remember why I changed it.. but I think you're right it so I'll change it back. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r203790304 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,67 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): """ -Serializes bytes as Arrow data with the Arrow file format. +Deserialize a stream of batches followed by batch order information. """ -def dumps(self, batch): +def __init__(self, serializer): +self.serializer = serializer +self.batch_order = None + +def dump_stream(self, iterator, stream): +return self.serializer.dump_stream(iterator, stream) + +def load_stream(self, stream): +for batch in self.serializer.load_stream(stream): +yield batch +num = read_int(stream) +self.batch_order = [] +for i in xrange(num): +index = read_int(stream) +self.batch_order.append(index) +raise StopIteration() --- End diff -- @BryanCutler, I think this will be broken in Python 3.7 (see [PEP 479](https://www.python.org/dev/peps/pep-0479)). Shall we just remove this line or explicitly `return`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [SPARK-23030][SQL][PYTHON] Use Arrow stream forma...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r203760078 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3349,20 +3385,20 @@ class Dataset[T] private[sql]( } } - /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { + /** Convert to an RDD of serialized ArrowRecordBatches. */ + private[sql] def getArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { --- End diff -- @BryanCutler, not a big deal at all but how about `to` like the before? This reminds me of QueryExecution's toRdd and the previous one looks slightly better. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org