viirya commented on code in PR #55552:
URL: https://github.com/apache/spark/pull/55552#discussion_r3158072585


##########
core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala:
##########
@@ -985,6 +1107,103 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     }
   }
 
+  /**
+   * A dedicated thread that serializes input data and writes it directly to 
the Python worker
+   * socket in blocking mode. The task main thread simultaneously reads output 
from the same
+   * socket. TCP sockets are full-duplex, so concurrent read() and write() 
from different
+   * threads is safe -- they operate on independent OS-level buffers.
+   *
+   * This design achieves true pipeline parallelism without any inter-thread 
queues or locks:
+   *   Writer Thread:  serialize batch N  ->  channel.write(batch N)    
[blocking]
+   *   Reader Thread:  channel.read(output N-1)                        
[blocking]
+   *   Python:         read batch N-1  ->  compute  ->  write output  ->  read 
batch N
+   *
+   * Deadlock safety: Python's UDF loop is "read input -> compute -> write 
output -> repeat".
+   * As long as the reader thread is consuming Python's output (freeing 
Python's send buffer),
+   * Python will eventually consume input from the socket (freeing the JVM's 
send buffer for
+   * the writer thread). The reader thread is always actively reading because 
the task's
+   * downstream operators pull output on demand.
+   *
+   * Unlike the old WriterThread (removed in SPARK-44705), this design uses a 
blocking socket
+   * in full-duplex mode rather than two threads competing on the same 
blocking socket with
+   * shared mutable state. The old design's deadlocks were caused by complex 
interactions
+   * with vectorized readers and monitor threads, not by the fundamental 
read/write split.
+   */
+  class PipelinedWriterRunnable(
+      worker: PythonWorker,
+      writer: Writer,
+      bufferSize: Int,
+      context: TaskContext)
+    extends Runnable {
+
+    // Capture InputFileBlockHolder from the task thread so we can propagate it
+    // to the writer pool thread. This is needed because upstream scan 
operators
+    // set InputFileBlockHolder via InheritableThreadLocal, but pool threads
+    // don't inherit from the task thread.
+    private val parentInputFileBlockHolder = 
InputFileBlockHolder.getThreadLocalValue()
+
+    override def run(): Unit = {
+      // Propagate TaskContext and InputFileBlockHolder to the pool thread so 
that
+      // upstream operators work correctly.
+      TaskContext.setTaskContext(context)
+      InputFileBlockHolder.setThreadLocalValue(parentInputFileBlockHolder)
+      val bufferStream = new DirectByteBufferOutputStream(bufferSize)
+      val dataOut = new DataOutputStream(bufferStream)
+      try {
+        // Write command/metadata (partition index, task context, broadcasts, 
UDF definition).
+        writer.open(dataOut)
+        flushToSocket(bufferStream)
+
+        // Write input data in a loop, batching into buffers of ~bufferSize.
+        var hasInput = true
+        while (hasInput && !Thread.currentThread().isInterrupted) {
+          hasInput = writer.writeNextInputToStream(dataOut)
+          if (bufferStream.size() >= bufferSize || !hasInput) {
+            if (!hasInput) {
+              writer.close(dataOut)
+            }
+            flushToSocket(bufferStream)
+          }
+        }
+      } catch {
+        case _: InterruptedException =>
+          Thread.currentThread().interrupt()
+        case t: Throwable if NonFatal(t) || t.isInstanceOf[Exception] =>

Review Comment:
   The list(batch) materialize happens for grouped/aggregate UDF serializers 
whose load_stream() yields lazy iterators. In sync mode, the same data is 
already materialized in mapper(batch_iter) via list(batch_iter) — so per-group 
peak memory is the same.
                                                                                
                                                                       
   The difference is that with queueDepth=2 (default), up to 2 additional 
groups can be buffered in the queue. In the worst case (skewed key with one 
very large group), this could increase peak memory by ~2x the group size. Users 
can set spark.python.udf.pipelined.queueDepth=1 to reduce this, or disable 
pipelined mode entirely for memory-sensitive workloads.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to