Github user icexelloss commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21546#discussion_r195552695
  
    --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---
    @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql](
       }
     
       /**
    -   * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
    +   * Collect a Dataset as Arrow batches and serve stream to PySpark.
        */
       private[sql] def collectAsArrowToPython(): Array[Any] = {
    +    val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
    +
         withAction("collectAsArrowToPython", queryExecution) { plan =>
    -      val iter: Iterator[Array[Byte]] =
    -        toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
    -      PythonRDD.serveIterator(iter, "serve-Arrow")
    +      PythonRDD.serveToStream("serve-Arrow") { out =>
    +        val batchWriter = new ArrowBatchStreamWriter(schema, out, 
timeZoneId)
    +        val arrowBatchRdd = getArrowBatchRdd(plan)
    +        val numPartitions = arrowBatchRdd.partitions.length
    +
    +        // Store collection results for worst case of 1 to N-1 partitions
    +        val results = new Array[Array[Array[Byte]]](numPartitions - 1)
    +        var lastIndex = -1  // index of last partition written
    +
    +        // Handler to eagerly write partitions to Python in order
    +        def handlePartitionBatches(index: Int, arrowBatches: 
Array[Array[Byte]]): Unit = {
    +          // If result is from next partition in order
    +          if (index - 1 == lastIndex) {
    +            batchWriter.writeBatches(arrowBatches.iterator)
    +            lastIndex += 1
    +            // Write stored partitions that come next in order
    +            while (lastIndex < results.length && results(lastIndex) != 
null) {
    +              batchWriter.writeBatches(results(lastIndex).iterator)
    +              results(lastIndex) = null
    +              lastIndex += 1
    +            }
    +            // After last batch, end the stream
    +            if (lastIndex == results.length) {
    +              batchWriter.end()
    +            }
    +          } else {
    +            // Store partitions received out of order
    +            results(index - 1) = arrowBatches
    +          }
    +        }
    +
    +        sparkSession.sparkContext.runJob(
    +          arrowBatchRdd,
    +          (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
    +          0 until numPartitions,
    +          handlePartitionBatches)
    --- End diff --
    
    > I did have another idea though, we could stream all partitions to Python 
out of order, then follow with another small batch of data that contains maps 
of partitionIndex to orderReceived. Then the partitions could be put into order 
on the Python side before making the Pandas DataFrame.
    
    This sounds good!


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to