Github user felixcheung commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21546#discussion_r194954051
  
    --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---
    @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql](
       }
     
       /**
    -   * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
    +   * Collect a Dataset as Arrow batches and serve stream to PySpark.
        */
       private[sql] def collectAsArrowToPython(): Array[Any] = {
    +    val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
    +
         withAction("collectAsArrowToPython", queryExecution) { plan =>
    -      val iter: Iterator[Array[Byte]] =
    -        toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
    -      PythonRDD.serveIterator(iter, "serve-Arrow")
    +      PythonRDD.serveToStream("serve-Arrow") { out =>
    +        val batchWriter = new ArrowBatchStreamWriter(schema, out, 
timeZoneId)
    +        val arrowBatchRdd = getArrowBatchRdd(plan)
    +        val numPartitions = arrowBatchRdd.partitions.length
    +
    +        // Store collection results for worst case of 1 to N-1 partitions
    +        val results = new Array[Array[Array[Byte]]](numPartitions - 1)
    +        var lastIndex = -1  // index of last partition written
    +
    +        // Handler to eagerly write partitions to Python in order
    +        def handlePartitionBatches(index: Int, arrowBatches: 
Array[Array[Byte]]): Unit = {
    +          // If result is from next partition in order
    +          if (index - 1 == lastIndex) {
    +            batchWriter.writeBatches(arrowBatches.iterator)
    +            lastIndex += 1
    +            // Write stored partitions that come next in order
    +            while (lastIndex < results.length && results(lastIndex) != 
null) {
    +              batchWriter.writeBatches(results(lastIndex).iterator)
    +              results(lastIndex) = null
    +              lastIndex += 1
    +            }
    +            // After last batch, end the stream
    +            if (lastIndex == results.length) {
    +              batchWriter.end()
    +            }
    +          } else {
    +            // Store partitions received out of order
    +            results(index - 1) = arrowBatches
    +          }
    +        }
    +
    +        sparkSession.sparkContext.runJob(
    +          arrowBatchRdd,
    +          (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
    +          0 until numPartitions,
    +          handlePartitionBatches)
    --- End diff --
    
    +1 chunking if we could. I recall Bryan said for grouped UDF we need the 
entire set.
    
    Also not sure if python side we have any assumption on how much of the 
partition is in each chunk (there shouldn't be?)



---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to