Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21546#discussion_r194948874
  
    --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---
    @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql](
       }
     
       /**
    -   * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
    +   * Collect a Dataset as Arrow batches and serve stream to PySpark.
        */
       private[sql] def collectAsArrowToPython(): Array[Any] = {
    +    val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
    +
         withAction("collectAsArrowToPython", queryExecution) { plan =>
    -      val iter: Iterator[Array[Byte]] =
    -        toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
    -      PythonRDD.serveIterator(iter, "serve-Arrow")
    +      PythonRDD.serveToStream("serve-Arrow") { out =>
    +        val batchWriter = new ArrowBatchStreamWriter(schema, out, 
timeZoneId)
    +        val arrowBatchRdd = getArrowBatchRdd(plan)
    +        val numPartitions = arrowBatchRdd.partitions.length
    +
    +        // Store collection results for worst case of 1 to N-1 partitions
    +        val results = new Array[Array[Array[Byte]]](numPartitions - 1)
    +        var lastIndex = -1  // index of last partition written
    +
    +        // Handler to eagerly write partitions to Python in order
    +        def handlePartitionBatches(index: Int, arrowBatches: 
Array[Array[Byte]]): Unit = {
    +          // If result is from next partition in order
    +          if (index - 1 == lastIndex) {
    +            batchWriter.writeBatches(arrowBatches.iterator)
    +            lastIndex += 1
    +            // Write stored partitions that come next in order
    +            while (lastIndex < results.length && results(lastIndex) != 
null) {
    +              batchWriter.writeBatches(results(lastIndex).iterator)
    +              results(lastIndex) = null
    +              lastIndex += 1
    +            }
    +            // After last batch, end the stream
    +            if (lastIndex == results.length) {
    +              batchWriter.end()
    +            }
    +          } else {
    +            // Store partitions received out of order
    +            results(index - 1) = arrowBatches
    +          }
    +        }
    +
    +        sparkSession.sparkContext.runJob(
    +          arrowBatchRdd,
    +          (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
    +          0 until numPartitions,
    +          handlePartitionBatches)
    --- End diff --
    
    Instead of collecting partitions back at once and holding out of order 
partitions in driver waiting for partitions in order, is it better to 
incrementally run job on partitions in order and send streams to python side? 
So we don't need to hold out of order partitions in driver.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to