Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199496002 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3237,50 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = - toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { outputStream => + val out = new DataOutputStream(outputStream) + val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) + val arrowBatchRdd = getArrowBatchRdd(plan) + val numPartitions = arrowBatchRdd.partitions.length + + // Batches ordered by index of partition + fractional value of batch # in partition + val batchOrder = new ArrayBuffer[Float]() + var partitionCount = 0 + + // Handler to eagerly write batches to Python out of order + def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + if (arrowBatches.nonEmpty) { + batchWriter.writeBatches(arrowBatches.iterator) + (0 until arrowBatches.length).foreach { i => + batchOrder.append(index + i / arrowBatches.length) --- End diff -- This code: `(0 until array.length).map(i => i / array.length)` is guaranteed to produce only zero values isn't it? The code works, since `sortBy` evidently preserves the ordering of equal elements, but you may as well do `batchOrder.append(index)` since it's the same.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org