viirya commented on code in PR #55552:
URL: https://github.com/apache/spark/pull/55552#discussion_r3158074021
##########
core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala:
##########
@@ -121,6 +121,15 @@ private[spark] object PythonEvalType {
private[spark] object BasePythonRunner extends Logging {
+ /**
+ * Shared thread pool for pipelined writer tasks. Using a cached thread pool
ensures that
+ * writer threads are reused across tasks, which keeps JIT-compiled code,
branch prediction
+ * history, and CPU caches warm. This avoids the 2-3x serialization slowdown
observed when
+ * using freshly created threads.
+ */
+ private[python] lazy val pipelinedWriterThreadPool =
+ ThreadUtils.newDaemonCachedThreadPool("python-udf-pipelined-writer")
Review Comment:
Good point. Changed from unbounded newDaemonCachedThreadPool to bounded
newDaemonCachedThreadPool("python-udf-pipelined-writer", maxThreads) where
maxThreads = SparkEnv.get.conf.get(EXECUTOR_CORES). Each task uses at most one
writer thread, and the number of concurrent tasks is bounded by executor cores,
so this is the natural upper bound.
##########
core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala:
##########
@@ -985,6 +1107,103 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
}
+ /**
+ * A dedicated thread that serializes input data and writes it directly to
the Python worker
+ * socket in blocking mode. The task main thread simultaneously reads output
from the same
+ * socket. TCP sockets are full-duplex, so concurrent read() and write()
from different
+ * threads is safe -- they operate on independent OS-level buffers.
+ *
+ * This design achieves true pipeline parallelism without any inter-thread
queues or locks:
+ * Writer Thread: serialize batch N -> channel.write(batch N)
[blocking]
+ * Reader Thread: channel.read(output N-1)
[blocking]
+ * Python: read batch N-1 -> compute -> write output -> read
batch N
+ *
+ * Deadlock safety: Python's UDF loop is "read input -> compute -> write
output -> repeat".
+ * As long as the reader thread is consuming Python's output (freeing
Python's send buffer),
+ * Python will eventually consume input from the socket (freeing the JVM's
send buffer for
+ * the writer thread). The reader thread is always actively reading because
the task's
+ * downstream operators pull output on demand.
+ *
+ * Unlike the old WriterThread (removed in SPARK-44705), this design uses a
blocking socket
+ * in full-duplex mode rather than two threads competing on the same
blocking socket with
+ * shared mutable state. The old design's deadlocks were caused by complex
interactions
+ * with vectorized readers and monitor threads, not by the fundamental
read/write split.
+ */
+ class PipelinedWriterRunnable(
+ worker: PythonWorker,
+ writer: Writer,
+ bufferSize: Int,
+ context: TaskContext)
+ extends Runnable {
+
+ // Capture InputFileBlockHolder from the task thread so we can propagate it
+ // to the writer pool thread. This is needed because upstream scan
operators
+ // set InputFileBlockHolder via InheritableThreadLocal, but pool threads
+ // don't inherit from the task thread.
+ private val parentInputFileBlockHolder =
InputFileBlockHolder.getThreadLocalValue()
+
+ override def run(): Unit = {
+ // Propagate TaskContext and InputFileBlockHolder to the pool thread so
that
+ // upstream operators work correctly.
+ TaskContext.setTaskContext(context)
+ InputFileBlockHolder.setThreadLocalValue(parentInputFileBlockHolder)
+ val bufferStream = new DirectByteBufferOutputStream(bufferSize)
+ val dataOut = new DataOutputStream(bufferStream)
+ try {
+ // Write command/metadata (partition index, task context, broadcasts,
UDF definition).
+ writer.open(dataOut)
+ flushToSocket(bufferStream)
+
+ // Write input data in a loop, batching into buffers of ~bufferSize.
+ var hasInput = true
+ while (hasInput && !Thread.currentThread().isInterrupted) {
+ hasInput = writer.writeNextInputToStream(dataOut)
+ if (bufferStream.size() >= bufferSize || !hasInput) {
+ if (!hasInput) {
+ writer.close(dataOut)
+ }
+ flushToSocket(bufferStream)
+ }
+ }
+ } catch {
+ case _: InterruptedException =>
+ Thread.currentThread().interrupt()
+ case t: Throwable if NonFatal(t) || t.isInstanceOf[Exception] =>
Review Comment:
The list(batch) materialize happens for grouped/aggregate UDF serializers
whose load_stream() yields lazy iterators. In sync mode, the same data is
already materialized in mapper(batch_iter) via list(batch_iter) — so per-group
peak memory is the same.
The difference is that with queueDepth=2 (default), up to 2 additional
groups can be buffered in the queue. In the worst case (skewed key with one
very large group), this could increase peak memory by ~2x the group size. Users
can set spark.python.udf.pipelined.queueDepth=1 to reduce this, or disable
pipelined mode entirely for memory-sensitive workloads.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]