Github user BryanCutler commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21546#discussion_r199287248
  
    --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---
    @@ -3236,13 +3237,50 @@ class Dataset[T] private[sql](
       }
     
       /**
    -   * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
    +   * Collect a Dataset as Arrow batches and serve stream to PySpark.
        */
       private[sql] def collectAsArrowToPython(): Array[Any] = {
    +    val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
    +
         withAction("collectAsArrowToPython", queryExecution) { plan =>
    -      val iter: Iterator[Array[Byte]] =
    -        toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
    -      PythonRDD.serveIterator(iter, "serve-Arrow")
    +      PythonRDD.serveToStream("serve-Arrow") { outputStream =>
    +        val out = new DataOutputStream(outputStream)
    +        val batchWriter = new ArrowBatchStreamWriter(schema, out, 
timeZoneId)
    +        val arrowBatchRdd = getArrowBatchRdd(plan)
    +        val numPartitions = arrowBatchRdd.partitions.length
    +
    +        // Batches ordered by index of partition + batch number for that 
partition
    +        val batchOrder = new ArrayBuffer[Int]()
    +        var partitionCount = 0
    +
    +        // Handler to eagerly write batches to Python out of order
    +        def handlePartitionBatches(index: Int, arrowBatches: 
Array[Array[Byte]]): Unit = {
    +          if (arrowBatches.nonEmpty) {
    +            batchWriter.writeBatches(arrowBatches.iterator)
    +            (0 until arrowBatches.length).foreach { i =>
    +              batchOrder.append(index + i)
    +            }
    +          }
    +          partitionCount += 1
    +
    +          // After last batch, end the stream and write batch order
    +          if (partitionCount == numPartitions) {
    +            batchWriter.end()
    +            out.writeInt(batchOrder.length)
    +            // Batch order indices are from 0 to N-1 batches, sorted by 
order they arrived
    +            batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, i) =>
    --- End diff --
    
    Yeah, looks like something wasn't quite right with the batch indexing... I 
fixed it and added your test.  Thanks @sethah !


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to