This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8aaff558394 [SPARK-44705][PYTHON] Make PythonRunner single-threaded
8aaff558394 is described below

commit 8aaff55839493e80e3ce376f928c04aa8f31d18c
Author: Utkarsh <utkarsh.agar...@databricks.com>
AuthorDate: Fri Aug 11 10:34:05 2023 +0900

    [SPARK-44705][PYTHON] Make PythonRunner single-threaded
    
    ### What changes were proposed in this pull request?
    PythonRunner, a utility that executes Python UDFs in Spark, uses two 
threads in a producer-consumer model today. This multi-threading model is 
problematic and confusing as Spark's execution model within a task is commonly 
understood to be single-threaded.
    More importantly, this departure of a double-threaded execution resulted in 
a series of customer issues involving [race 
conditions](https://issues.apache.org/jira/browse/SPARK-33277) and 
[deadlocks](https://issues.apache.org/jira/browse/SPARK-38677) between threads 
as the code was hard to reason about. There have been multiple attempts to 
reign in these issues, viz., [fix 
1](https://issues.apache.org/jira/browse/SPARK-22535), [fix 
2](https://github.com/apache/spark/pull/30177), [fix 3 [...]
    
    #### Current Execution Model in Spark for Python UDFs
    For queries containing Python UDFs, the main Java task thread spins up a 
new writer thread to pipe data from the child Spark plan into the Python worker 
evaluating the UDF. The writer thread runs in a tight loop: evaluates the child 
Spark plan, and feeds the resulting output to the Python worker. The main task 
thread simultaneously consumes the Python UDF’s output and evaluates the parent 
Spark plan to produce the final result.
    The I/O to/from the Python worker uses blocking Java Sockets necessitating 
the use of two threads, one responsible for input to the Python worker and the 
other for output. Without two threads, it is easy to run into a deadlock. For 
example, the task can block forever waiting for the output from the Python 
worker. The output will never arrive until the input is supplied to the Python 
worker, which is not possible as the task thread is blocked while waiting on 
output.
    
    #### Proposed Fix
    
    The proposed fix is to move to the standard single-threaded execution model 
within a task, i.e., to do away with the writer thread. In addition to 
mitigating the crashes, the fix reduces the complexity of the existing code by 
doing away with many safety checks in place to track deadlocks in the 
double-threaded execution model.
    
    In the new model, the main task thread alternates between consuming/feeding 
data to the Python worker using asynchronous I/O through Java’s 
[SocketChannel](https://docs.oracle.com/javase/7/docs/api/java/nio/channels/SocketChannel.html).
 See the `read()` method in the code below for approximately how this is 
achieved.
    
    ```
    case class PythonUDFRunner {
    
      private var nextRow: Row = _
      private var endOfStream = false
      private var childHasNext = true
      private var buffer: ByteBuffer = _
    
      def hasNext(): Boolean = nextRow != null || {
         if (!endOfStream) {
           read(buffer)
           nextRow = deserialize(buffer)
           hasNext
         } else {
           false
         }
      }
    
      def next(): Row = {
         if (hasNext) {
           val outputRow = nextRow
           nextRow = null
           outputRow
         } else {
           null
         }
      }
    
      def read(buf: Array[Byte]): Row = {
        var n = 0
        while (n == 0) {
        // Alternate between reading/writing to the Python worker using async 
I/O
        if (pythonWorker.isReadable) {
          n = pythonWorker.read(buf)
        }
        if (pythonWorker.isWritable) {
          consumeChildPlanAndWriteDataToPythonWorker()
        }
      }
    
      def consumeChildPlanAndWriteDataToPythonWorker(): Unit = {
          // Tracks whether the connection to the Python worker can be written 
to.
          var socketAcceptsInput = true
          while (socketAcceptsInput && (childHasNext || buffer.hasRemaining)) {
            if (!buffer.hasRemaining && childHasNext) {
              // Consume data from the child and buffer it.
              writeToBuffer(childPlan.next(), buffer)
              childHasNext = childPlan.hasNext()
              if (!childHasNext) {
                // Exhausted child plan’s output. Write a keyword to the Python 
worker signaling the end of data input.
                writeToBuffer(endOfStream)
              }
            }
            // Try to write as much buffered data as possible to the Python 
worker.
            while (buffer.hasRemaining && socketAcceptsInput) {
              val n = writeToPythonWorker(buffer)
              // `writeToPythonWorker()` returns 0 when the socket cannot 
accept more data right now.
              socketAcceptsInput = n > 0
            }
          }
        }
    }
    
    ```
    ### Why are the changes needed?
    This PR makes PythonRunner single-threaded making it easier to reason about 
and improving code health.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing tests.
    
    Closes #42385 from utkarsh39/SPARK-44705.
    
    Authored-by: Utkarsh <utkarsh.agar...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../org/apache/spark/ContextAwareIterator.scala    |   2 +
 .../src/main/scala/org/apache/spark/SparkEnv.scala |  15 +-
 .../org/apache/spark/api/python/PythonRDD.scala    |  22 +-
 .../org/apache/spark/api/python/PythonRunner.scala | 362 ++++++++++++++-------
 .../spark/api/python/PythonWorkerFactory.scala     | 105 +++---
 .../spark/api/python/PythonWorkerUtils.scala       |   6 +-
 .../spark/api/python/StreamingPythonRunner.scala   |  15 +-
 .../apache/spark/rdd/InputFileBlockHolder.scala    |  11 +
 .../spark/util/DirectByteBufferOutputStream.scala  |  85 +++++
 .../ApplyInPandasWithStatePythonRunner.scala       |  34 +-
 .../sql/execution/python/ArrowPythonRunner.scala   |   8 +-
 .../execution/python/BatchEvalPythonUDTFExec.scala |  11 +-
 .../python/CoGroupedArrowPythonRunner.scala        |  20 +-
 .../python/EvalPythonEvaluatorFactory.scala        |   5 +-
 .../sql/execution/python/EvaluatePython.scala      |   3 +-
 .../python/MapInBatchEvaluatorFactory.scala        |   5 +-
 .../sql/execution/python/PythonArrowInput.scala    |  81 +++--
 .../sql/execution/python/PythonArrowOutput.scala   |  13 +-
 .../sql/execution/python/PythonForeachWriter.scala |  69 +++-
 .../sql/execution/python/PythonUDFRunner.scala     |  37 ++-
 .../python/UserDefinedPythonFunction.scala         |  63 +++-
 21 files changed, 666 insertions(+), 306 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/ContextAwareIterator.scala 
b/core/src/main/scala/org/apache/spark/ContextAwareIterator.scala
index 84ae93f1788..facb03365e8 100644
--- a/core/src/main/scala/org/apache/spark/ContextAwareIterator.scala
+++ b/core/src/main/scala/org/apache/spark/ContextAwareIterator.scala
@@ -30,8 +30,10 @@ import org.apache.spark.annotation.DeveloperApi
  * Thus, we should use [[ContextAwareIterator]] to stop consuming after the 
task ends.
  *
  * @since 3.1.0
+ * @deprecated since 4.0.0 as its only usage for Python evaluation is now 
extinct
  */
 @DeveloperApi
+@deprecated("Only usage for Python evaluation is now extinct", "3.5.0")
 class ContextAwareIterator[+T](val context: TaskContext, val delegate: 
Iterator[T])
   extends Iterator[T] {
 
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala 
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index eef99c26e77..e404c9ee8b4 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -18,7 +18,6 @@
 package org.apache.spark
 
 import java.io.File
-import java.net.Socket
 import java.util.Locale
 
 import scala.collection.JavaConverters._
@@ -30,7 +29,7 @@ import com.google.common.cache.CacheBuilder
 import org.apache.hadoop.conf.Configuration
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.api.python.PythonWorkerFactory
+import org.apache.spark.api.python.{PythonWorker, PythonWorkerFactory}
 import org.apache.spark.broadcast.BroadcastManager
 import org.apache.spark.executor.ExecutorBackend
 import org.apache.spark.internal.{config, Logging}
@@ -129,7 +128,7 @@ class SparkEnv (
       pythonExec: String,
       workerModule: String,
       daemonModule: String,
-      envVars: Map[String, String]): (java.net.Socket, Option[Int]) = {
+      envVars: Map[String, String]): (PythonWorker, Option[Int]) = {
     synchronized {
       val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, 
envVars)
       pythonWorkers.getOrElseUpdate(key,
@@ -140,7 +139,7 @@ class SparkEnv (
   private[spark] def createPythonWorker(
       pythonExec: String,
       workerModule: String,
-      envVars: Map[String, String]): (java.net.Socket, Option[Int]) = {
+      envVars: Map[String, String]): (PythonWorker, Option[Int]) = {
     createPythonWorker(
       pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, 
envVars)
   }
@@ -150,7 +149,7 @@ class SparkEnv (
       workerModule: String,
       daemonModule: String,
       envVars: Map[String, String],
-      worker: Socket): Unit = {
+      worker: PythonWorker): Unit = {
     synchronized {
       val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, 
envVars)
       pythonWorkers.get(key).foreach(_.stopWorker(worker))
@@ -161,7 +160,7 @@ class SparkEnv (
       pythonExec: String,
       workerModule: String,
       envVars: Map[String, String],
-      worker: Socket): Unit = {
+      worker: PythonWorker): Unit = {
     destroyPythonWorker(
       pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, 
envVars, worker)
   }
@@ -171,7 +170,7 @@ class SparkEnv (
       workerModule: String,
       daemonModule: String,
       envVars: Map[String, String],
-      worker: Socket): Unit = {
+      worker: PythonWorker): Unit = {
     synchronized {
       val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, 
envVars)
       pythonWorkers.get(key).foreach(_.releaseWorker(worker))
@@ -182,7 +181,7 @@ class SparkEnv (
       pythonExec: String,
       workerModule: String,
       envVars: Map[String, String],
-      worker: Socket): Unit = {
+      worker: PythonWorker): Unit = {
     releasePythonWorker(
       pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, 
envVars, worker)
   }
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 91fd92d4422..a2f2d566db5 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -137,7 +137,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends 
RDD[(Long, Array[Byte]
 private[spark] object PythonRDD extends Logging {
 
   // remember the broadcasts sent to each worker
-  private val workerBroadcasts = new mutable.WeakHashMap[Socket, 
mutable.Set[Long]]()
+  private val workerBroadcasts = new mutable.WeakHashMap[PythonWorker, 
mutable.Set[Long]]()
 
   // Authentication helper used when serving iterator data.
   private lazy val authHelper = {
@@ -145,7 +145,7 @@ private[spark] object PythonRDD extends Logging {
     new SocketAuthHelper(conf)
   }
 
-  def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = {
+  def getWorkerBroadcasts(worker: PythonWorker): mutable.Set[Long] = {
     synchronized {
       workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
     }
@@ -300,7 +300,11 @@ private[spark] object PythonRDD extends Logging {
     new PythonBroadcast(path)
   }
 
-  def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream): 
Unit = {
+  /**
+   * Writes the next element of the iterator `iter` to `dataOut`. Returns true 
if any data was
+   * written to the stream. Returns false if no data was written as the 
iterator has been exhausted.
+   */
+  def writeNextElementToStream[T](iter: Iterator[T], dataOut: 
DataOutputStream): Boolean = {
 
     def write(obj: Any): Unit = obj match {
       case null =>
@@ -318,8 +322,18 @@ private[spark] object PythonRDD extends Logging {
       case other =>
         throw new SparkException("Unexpected element type " + other.getClass)
     }
+    if (iter.hasNext) {
+      write(iter.next())
+      true
+    } else {
+      false
+    }
+  }
 
-    iter.foreach(write)
+  def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream): 
Unit = {
+    while (writeNextElementToStream(iter, dataOut)) {
+      // Nothing.
+    }
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 0173de75ff2..d7801d2e83b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -19,6 +19,8 @@ package org.apache.spark.api.python
 
 import java.io._
 import java.net._
+import java.nio.ByteBuffer
+import java.nio.channels.SelectionKey
 import java.nio.charset.StandardCharsets
 import java.nio.charset.StandardCharsets.UTF_8
 import java.nio.file.{Files => JavaFiles, Path}
@@ -32,6 +34,7 @@ import org.apache.spark._
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
 import org.apache.spark.internal.config.Python._
+import org.apache.spark.rdd.InputFileBlockHolder
 import 
org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, 
PYSPARK_MEMORY_LOCAL_PROPERTY}
 import org.apache.spark.security.SocketAuthHelper
 import org.apache.spark.util._
@@ -103,6 +106,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
 
   private val conf = SparkEnv.get.conf
   protected val bufferSize: Int = conf.get(BUFFER_SIZE)
+  protected val timelyFlushEnabled: Boolean = false
+  protected val timelyFlushTimeoutNanos: Long = 0
   protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
   private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
   private val faultHandlerEnabled = 
conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
@@ -143,7 +148,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
   // Python accumulator is always set in production except in tests. See 
SPARK-27893
   private val maybeAccumulator: Option[PythonAccumulatorV2] = 
Option(accumulator)
 
-  // Expose a ServerSocket to support method calls via socket from Python side.
+  // Expose a ServerSocket to support method calls via socket from Python 
side. Only relevant for
+  // for tasks that are a part of barrier stage, refer [[BarrierTaskContext]] 
for details.
   private[spark] var serverSocket: Option[ServerSocket] = None
 
   // Authentication helper used when serving method calls via socket from 
Python side.
@@ -194,7 +200,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
 
     envVars.put("SPARK_JOB_ARTIFACT_UUID", 
jobArtifactUUID.getOrElse("default"))
 
-    val (worker: Socket, pid: Option[Int]) = env.createPythonWorker(
+    val (worker: PythonWorker, pid: Option[Int]) = env.createPythonWorker(
       pythonExec, workerModule, daemonModule, envVars.asScala.toMap)
     // Whether is the worker released into idle pool or closed. When any codes 
try to release or
     // close a worker, they should use `releasedOrClosed.compareAndSet` to 
flip the state to make
@@ -202,22 +208,19 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     val releasedOrClosed = new AtomicBoolean(false)
 
     // Start a thread to feed the process input from our parent's iterator
-    val writerThread = newWriterThread(env, worker, inputIterator, 
partitionIndex, context)
+    val writer = newWriter(env, worker, inputIterator, partitionIndex, context)
 
     context.addTaskCompletionListener[Unit] { _ =>
-      writerThread.shutdownOnTaskCompletion()
       if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) {
         try {
-          worker.close()
+          worker.stop()
         } catch {
           case e: Exception =>
-            logWarning("Failed to close worker socket", e)
+            logWarning("Failed to stop worker")
         }
       }
     }
 
-    writerThread.start()
-    new WriterMonitorThread(SparkEnv.get, worker, writerThread, 
context).start()
     if (reuseWorker) {
       val key = (worker, context.taskAttemptId)
       // SPARK-35009: avoid creating multiple monitor threads for the same 
python worker
@@ -230,68 +233,49 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     }
 
     // Return an iterator that read lines from the process's stdout
-    val stream = new DataInputStream(new 
BufferedInputStream(worker.getInputStream, bufferSize))
-
+    val dataIn = new DataInputStream(
+      new BufferedInputStream(new ReaderInputStream(worker, writer), 
bufferSize))
     val stdoutIterator = newReaderIterator(
-      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, 
context)
+      dataIn, writer, startTime, env, worker, pid, releasedOrClosed, context)
     new InterruptibleIterator(context, stdoutIterator)
   }
 
-  protected def newWriterThread(
+  protected def newWriter(
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       inputIterator: Iterator[IN],
       partitionIndex: Int,
-      context: TaskContext): WriterThread
+      context: TaskContext): Writer
 
   protected def newReaderIterator(
       stream: DataInputStream,
-      writerThread: WriterThread,
+      writer: Writer,
       startTime: Long,
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       pid: Option[Int],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[OUT]
 
   /**
-   * The thread responsible for writing the data from the PythonRDD's parent 
iterator to the
+   * Responsible for writing the data from the PythonRDD's parent iterator to 
the
    * Python process.
    */
-  abstract class WriterThread(
+  abstract class Writer(
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       inputIterator: Iterator[IN],
       partitionIndex: Int,
-      context: TaskContext)
-    extends Thread(s"stdout writer for $pythonExec") {
+      context: TaskContext) {
 
-    @volatile private var _exception: Throwable = null
+    @volatile private var _exception: Throwable = _
 
     private val pythonIncludes = 
funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
     private val broadcastVars = 
funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
 
-    setDaemon(true)
-
     /** Contains the throwable thrown while writing the parent iterator to the 
Python process. */
     def exception: Option[Throwable] = Option(_exception)
 
-    /**
-     * Terminates the writer thread and waits for it to exit, ignoring any 
exceptions that may occur
-     * due to cleanup.
-     */
-    def shutdownOnTaskCompletion(): Unit = {
-      assert(context.isCompleted)
-      this.interrupt()
-      // Task completion listeners that run after this method returns may 
invalidate
-      // `inputIterator`. For example, when `inputIterator` was generated by 
the off-heap vectorized
-      // reader, a task completion listener will free the underlying off-heap 
buffers. If the writer
-      // thread is still running when `inputIterator` is invalidated, it can 
cause a use-after-free
-      // bug that crashes the executor (SPARK-33277). Therefore this method 
must wait for the writer
-      // thread to exit before returning.
-      this.join()
-    }
-
     /**
      * Writes a command section to the stream connected to the Python worker.
      */
@@ -299,14 +283,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
 
     /**
      * Writes input data to the stream connected to the Python worker.
+     * Returns true if any data was written to the stream, false if the input 
is exhausted.
      */
-    protected def writeIteratorToStream(dataOut: DataOutputStream): Unit
+    def writeNextInputToStream(dataOut: DataOutputStream): Boolean
 
-    override def run(): Unit = Utils.logUncaughtExceptions {
+    def open(dataOut: DataOutputStream): Unit = Utils.logUncaughtExceptions {
       try {
-        TaskContext.setTaskContext(context)
-        val stream = new BufferedOutputStream(worker.getOutputStream, 
bufferSize)
-        val dataOut = new DataOutputStream(stream)
         // Partition index
         dataOut.writeInt(partitionIndex)
 
@@ -367,21 +349,25 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
         } else {
           ""
         }
-        // Close ServerSocket on task completion.
-        serverSocket.foreach { server =>
-          context.addTaskCompletionListener[Unit](_ => server.close())
-        }
-        val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
-        if (boundPort == -1) {
-          val message = "ServerSocket failed to bind to Java side."
-          logError(message)
-          throw new SparkException(message)
-        } else if (isBarrier) {
+        if (isBarrier) {
+          // Close ServerSocket on task completion.
+          serverSocket.foreach { server =>
+            context.addTaskCompletionListener[Unit](_ => server.close())
+          }
+          val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
+          if (boundPort == -1) {
+            val message = "ServerSocket failed to bind to Java side."
+            logError(message)
+            throw new SparkException(message)
+          }
           logDebug(s"Started ServerSocket on port $boundPort.")
+          dataOut.writeBoolean(/* isBarrier = */true)
+          dataOut.writeInt(boundPort)
+        } else {
+          dataOut.writeBoolean(/* isBarrier = */false)
+          dataOut.writeInt(0)
         }
         // Write out the TaskContextInfo
-        dataOut.writeBoolean(isBarrier)
-        dataOut.writeInt(boundPort)
         val secretBytes = secret.getBytes(UTF_8)
         dataOut.writeInt(secretBytes.length)
         dataOut.write(secretBytes, 0, secretBytes.length)
@@ -412,30 +398,33 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
 
         dataOut.writeInt(evalType)
         writeCommand(dataOut)
-        writeIteratorToStream(dataOut)
 
-        dataOut.writeInt(SpecialLengths.END_OF_STREAM)
         dataOut.flush()
       } catch {
-        case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) =>
+        case t: Throwable if NonFatal(t) || t.isInstanceOf[Exception] =>
           if (context.isCompleted || context.isInterrupted) {
             logDebug("Exception/NonFatal Error thrown after task completion 
(likely due to " +
               "cleanup)", t)
-            if (!worker.isClosed) {
-              Utils.tryLog(worker.shutdownOutput())
+            if (worker.channel.isConnected) {
+              Utils.tryLog(worker.channel.shutdownOutput())
             }
           } else {
             // We must avoid throwing exceptions/NonFatals here, because the 
thread uncaught
             // exception handler will kill the whole executor (see
             // org.apache.spark.executor.Executor).
             _exception = t
-            if (!worker.isClosed) {
-              Utils.tryLog(worker.shutdownOutput())
+            if (worker.channel.isConnected) {
+              Utils.tryLog(worker.channel.shutdownOutput())
             }
           }
       }
     }
 
+    def close(dataOut: DataOutputStream): Unit = {
+      dataOut.writeInt(SpecialLengths.END_OF_STREAM)
+      dataOut.flush()
+    }
+
     /**
      * Gateway to call BarrierTaskContext methods.
      */
@@ -470,10 +459,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
 
   abstract class ReaderIterator(
       stream: DataInputStream,
-      writerThread: WriterThread,
+      writer: Writer,
       startTime: Long,
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       pid: Option[Int],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext)
@@ -531,7 +520,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
       val obj = new Array[Byte](exLength)
       stream.readFully(obj)
       new PythonException(new String(obj, StandardCharsets.UTF_8),
-        writerThread.exception.orNull)
+        writer.exception.orNull)
     }
 
     protected def handleEndOfDataSection(): Unit = {
@@ -554,10 +543,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
         logDebug("Exception thrown after task interruption", e)
         throw new 
TaskKilledException(context.getKillReason().getOrElse("unknown reason"))
 
-      case e: Exception if writerThread.exception.isDefined =>
+      case e: Exception if writer.exception.isDefined =>
         logError("Python worker exited unexpectedly (crashed)", e)
-        logError("This may have been caused by a prior exception:", 
writerThread.exception.get)
-        throw writerThread.exception.get
+        logError("This may have been caused by a prior exception:", 
writer.exception.get)
+        throw writer.exception.get
 
       case eof: EOFException if faultHandlerEnabled && pid.isDefined &&
           JavaFiles.exists(BasePythonRunner.faultHandlerLogPath(pid.get)) =>
@@ -576,7 +565,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
    * interrupts disabled. In that case we will need to explicitly kill the 
worker, otherwise the
    * threads can block indefinitely.
    */
-  class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
+  class MonitorThread(env: SparkEnv, worker: PythonWorker, context: 
TaskContext)
     extends Thread(s"Worker Monitor for $pythonExec") {
 
     /** How long to wait before killing the python worker if a task cannot be 
interrupted. */
@@ -620,60 +609,185 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     }
   }
 
-  /**
-   * This thread monitors the WriterThread and kills it in case of deadlock.
-   *
-   * A deadlock can arise if the task completes while the writer thread is 
sending input to the
-   * Python process (e.g. due to the use of `take()`), and the Python process 
is still producing
-   * output. When the inputs are sufficiently large, this can result in a 
deadlock due to the use of
-   * blocking I/O (SPARK-38677). To resolve the deadlock, we need to close the 
socket.
-   */
-  class WriterMonitorThread(
-      env: SparkEnv, worker: Socket, writerThread: WriterThread, context: 
TaskContext)
-    extends Thread(s"Writer Monitor for $pythonExec (writer thread id 
${writerThread.getId})") {
-
+  class ReaderInputStream(worker: PythonWorker, writer: Writer) extends 
InputStream {
+    private[this] var writerIfbhThreadLocalValue: Object = null
+    private[this] val temp = new Array[Byte](1)
+    private[this] val bufferStream = new DirectByteBufferOutputStream()
     /**
-     * How long to wait before closing the socket if the writer thread has not 
exited after the task
-     * ends.
+     * Buffers data to be written to the Python worker until the socket is
+     * available for write.
+     * A best-effort attempt is made to not grow the buffer beyond 
"spark.buffer.size". See
+     * `writeAdditionalInputToPythonWorker()` for details.
      */
-    private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT)
+    private[this] var buffer: ByteBuffer = _
+    private[this] var hasInput = true
 
-    setDaemon(true)
+    writer.open(new DataOutputStream(bufferStream))
+    buffer = bufferStream.toByteBuffer
 
-    override def run(): Unit = {
-      // Wait until the task is completed (or the writer thread exits, in 
which case this thread has
-      // nothing to do).
-      while (!context.isCompleted && writerThread.isAlive) {
-        Thread.sleep(2000)
+    override def read(): Int = {
+      val n = read(temp)
+      if (n <= 0) {
+        -1
+      } else {
+        // Signed byte to unsigned integer
+        temp(0) & 0xff
       }
-      if (writerThread.isAlive) {
-        Thread.sleep(taskKillTimeout)
-        // If the writer thread continues running, this indicates a deadlock. 
Kill the worker to
-        // resolve the deadlock.
-        if (writerThread.isAlive) {
-          try {
-            // Mimic the task name used in `Executor` to help the user find 
out the task to blame.
-            val taskName = s"${context.partitionId}.${context.attemptNumber} " 
+
-              s"in stage ${context.stageId} (TID ${context.taskAttemptId})"
-            logWarning(
-              s"Detected deadlock while completing task $taskName: " +
-                "Attempting to kill Python Worker")
-            env.destroyPythonWorker(
-              pythonExec, workerModule, daemonModule, envVars.asScala.toMap, 
worker)
-          } catch {
-            case e: Exception =>
-              logError("Exception when trying to kill worker", e)
+    }
+
+    override def read(b: Array[Byte], off: Int, len: Int): Int = {
+      // The code below manipulates the InputFileBlockHolder thread local in 
order
+      // to prevent behavior changes in the input_file_name() expression due 
to the switch from
+      // multi-threaded to single-threaded Python execution (SPARK-44705).
+      //
+      // Prior to that change, scan operations feeding into PythonRunner would 
be evaluated in
+      // "writer" threads that were child threads of the main task thread. As 
a result, when
+      // a scan operation hit end-of-input and called 
InputFileBlockHolder.unset(), the effects
+      // of unset() would only occur in the writer thread and not the main 
task thread: this
+      // meant that code "downstream" of a PythonRunner would continue to 
observe the writer's
+      // last pre-unset() value (i.e. the last read filename).
+      //
+      // Switching to a single-threaded Python runner changed this behavior: 
now, unset() would
+      // impact operators both upstream and downstream of the PythonRunner and 
this would cause
+      // unset()'s effects to be immediately visible to downstream operators, 
in turn causing the
+      // input_file_name() expression to return empty filenames in situations 
where it previously
+      // would have returned the last non-empty filename.
+      //
+      // To avoid this behavior change, the code below simulates the behavior 
of the
+      // InputFileBlockHolder's inheritable thread local:
+      //
+      //  - Detect whether code that previously would have run in the writer 
thread has changed
+      //    the thread local value itself. Note that the thread local holds a 
mutable
+      //    AtomicReference, so the thread local's value only changes objects 
when unset() is
+      //    called.
+      //  - If an object change was detected, then henceforth we will swap 
between the "main"
+      //    and "writer" thread local values when context switching between 
upstream and
+      //    downstream operator execution.
+      //
+      // This issue is subtle and several other alternative approaches were 
considered
+      val buf = ByteBuffer.wrap(b, off, len)
+      var n = 0
+      while (n == 0) {
+        worker.selector.select()
+        if (worker.selectionKey.isReadable) {
+          n = worker.channel.read(buf)
+        }
+        if (worker.selectionKey.isWritable) {
+          val mainIfbhThreadLocalValue = 
InputFileBlockHolder.getThreadLocalValue()
+          // Check whether the writer's thread local value has diverged from 
its parent's value:
+          if (writerIfbhThreadLocalValue eq null) {
+            // Default case (which is why it appears first): the writer's 
thread local value
+            // is the same object as the main code, so no need to swap before 
executing the
+            // writer code.
+            try {
+              // Execute the writer code:
+              writeAdditionalInputToPythonWorker()
+            } finally {
+              // Check whether the writer code changed the thread local value:
+              val maybeNewIfbh = InputFileBlockHolder.getThreadLocalValue()
+              if (maybeNewIfbh ne mainIfbhThreadLocalValue) {
+                // The writer thread change the thread local, so henceforth we 
need to
+                // swap. Store the writer thread's value and restore the old 
main thread
+                // value:
+                writerIfbhThreadLocalValue = maybeNewIfbh
+                
InputFileBlockHolder.setThreadLocalValue(mainIfbhThreadLocalValue)
+              }
+            }
+          } else {
+            // The writer thread and parent thread have different values, so 
we must swap
+            // them when switching between writer and parent code:
+            try {
+              // Swap in the writer value:
+              
InputFileBlockHolder.setThreadLocalValue(writerIfbhThreadLocalValue)
+              try {
+                // Execute the writer code:
+                writeAdditionalInputToPythonWorker()
+              } finally {
+                // Store an updated writer thread value:
+                writerIfbhThreadLocalValue = 
InputFileBlockHolder.getThreadLocalValue()
+              }
+            } finally {
+              // Restore the main thread's value:
+              
InputFileBlockHolder.setThreadLocalValue(mainIfbhThreadLocalValue)
+            }
           }
         }
       }
+      n
+    }
+
+    private var lastFlushTime = System.nanoTime()
+
+    /**
+     * Returns false if `timelyFlushEnabled` is disabled.
+     *
+     * Otherwise, returns true if `buffer` should be flushed before any 
additional data is
+     * written to it.
+     * For small input rows the data might stay in the buffer for long before 
it is sent to the
+     * Python worker. We should flush the buffer periodically so that the 
downstream can make
+     * continued progress.
+     */
+    private def shouldFlush(): Boolean = {
+      if (!timelyFlushEnabled) {
+        false
+      } else {
+        val currentTime = System.nanoTime()
+        if (currentTime - lastFlushTime > timelyFlushTimeoutNanos) {
+          lastFlushTime = currentTime
+          bufferStream.size() > 0
+        } else {
+          false
+        }
+      }
+    }
+
+    /**
+     * Reads input data from `writer.inputIterator` into `buffer` and writes 
the buffer to the
+     * Python worker if the socket is available for writing.
+     */
+    private def writeAdditionalInputToPythonWorker(): Unit = {
+      var acceptsInput = true
+      while (acceptsInput && (hasInput || buffer.hasRemaining)) {
+        if (!buffer.hasRemaining && hasInput) {
+          // No buffered data is available. Try to read input into the buffer.
+          bufferStream.reset()
+          // Set the `buffer` to null to make it eligible for GC
+          buffer = null
+
+          val dataOut = new DataOutputStream(bufferStream)
+          // Try not to grow the buffer much beyond `bufferSize`. This is 
inevitable for large
+          // input rows.
+          while (bufferStream.size() < bufferSize && hasInput && 
!shouldFlush()) {
+            hasInput = writer.writeNextInputToStream(dataOut)
+          }
+          if (!hasInput) {
+            // Reached the end of the input.
+            writer.close(dataOut)
+          }
+          buffer = bufferStream.toByteBuffer
+        }
+
+        // Try to write as much buffered data as possible to the socket.
+        while (buffer.hasRemaining && acceptsInput) {
+          val n = worker.channel.write(buffer)
+          acceptsInput = n > 0
+        }
+      }
+
+      if (!hasInput && !buffer.hasRemaining) {
+        // We no longer have any data to write to the socket.
+        worker.selectionKey.interestOps(SelectionKey.OP_READ)
+        bufferStream.close()
+      }
     }
   }
+
 }
 
 private[spark] object PythonRunner {
 
   // already running worker monitor threads for worker and task attempts ID 
pairs
-  val runningMonitorThreads = ConcurrentHashMap.newKeySet[(Socket, Long)]()
+  val runningMonitorThreads = ConcurrentHashMap.newKeySet[(PythonWorker, 
Long)]()
 
   private var printPythonInfo: AtomicBoolean = new AtomicBoolean(true)
 
@@ -693,13 +807,13 @@ private[spark] class PythonRunner(
   extends BasePythonRunner[Array[Byte], Array[Byte]](
     funcs, PythonEvalType.NON_UDF, Array(Array(0)), jobArtifactUUID) {
 
-  protected override def newWriterThread(
+  protected override def newWriter(
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       inputIterator: Iterator[Array[Byte]],
       partitionIndex: Int,
-      context: TaskContext): WriterThread = {
-    new WriterThread(env, worker, inputIterator, partitionIndex, context) {
+      context: TaskContext): Writer = {
+    new Writer(env, worker, inputIterator, partitionIndex, context) {
 
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
         val command = funcs.head.funcs.head.command
@@ -707,28 +821,32 @@ private[spark] class PythonRunner(
         dataOut.write(command.toArray)
       }
 
-      protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
-        PythonRDD.writeIteratorToStream(inputIterator, dataOut)
-        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+      override def writeNextInputToStream(dataOut: DataOutputStream): Boolean 
= {
+        if (PythonRDD.writeNextElementToStream(inputIterator, dataOut)) {
+          true
+        } else {
+          dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+          false
+        }
       }
     }
   }
 
   protected override def newReaderIterator(
       stream: DataInputStream,
-      writerThread: WriterThread,
+      writer: Writer,
       startTime: Long,
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       pid: Option[Int],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[Array[Byte]] = {
     new ReaderIterator(
-      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, 
context) {
+      stream, writer, startTime, env, worker, pid, releasedOrClosed, context) {
 
       protected override def read(): Array[Byte] = {
-        if (writerThread.exception.isDefined) {
-          throw writerThread.exception.get
+        if (writer.exception.isDefined) {
+          throw writer.exception.get
         }
         try {
           stream.readInt() match {
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 4ba6dd949b1..1db8748c327 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -18,7 +18,8 @@
 package org.apache.spark.api.python
 
 import java.io.{DataInputStream, DataOutputStream, EOFException, File, 
InputStream}
-import java.net.{InetAddress, ServerSocket, Socket, SocketException}
+import java.net.{InetAddress, InetSocketAddress, SocketException}
+import java.nio.channels._
 import java.util.Arrays
 import java.util.concurrent.TimeUnit
 import javax.annotation.concurrent.GuardedBy
@@ -33,6 +34,14 @@ import org.apache.spark.internal.config.Python._
 import org.apache.spark.security.SocketAuthHelper
 import org.apache.spark.util.{RedirectThread, Utils}
 
+case class PythonWorker(channel: SocketChannel, selector: Selector, 
selectionKey: SelectionKey) {
+  def stop(): Unit = {
+    selectionKey.cancel()
+    selector.close()
+    channel.close()
+  }
+}
+
 private[spark] class PythonWorkerFactory(
     pythonExec: String,
     workerModule: String,
@@ -67,32 +76,33 @@ private[spark] class PythonWorkerFactory(
   @GuardedBy("self")
   private var daemonPort: Int = 0
   @GuardedBy("self")
-  private val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+  private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, Int]()
   @GuardedBy("self")
-  private val idleWorkers = new mutable.Queue[Socket]()
+  private val idleWorkers = new mutable.Queue[PythonWorker]()
   @GuardedBy("self")
   private var lastActivityNs = 0L
   new MonitorThread().start()
 
   @GuardedBy("self")
-  private val simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
+  private val simpleWorkers = new mutable.WeakHashMap[PythonWorker, Process]()
 
   private val pythonPath = PythonUtils.mergePythonPaths(
     PythonUtils.sparkPythonPath,
     envVars.getOrElse("PYTHONPATH", ""),
     sys.env.getOrElse("PYTHONPATH", ""))
 
-  def create(): (Socket, Option[Int]) = {
+  def create(): (PythonWorker, Option[Int]) = {
     if (useDaemon) {
       self.synchronized {
         if (idleWorkers.nonEmpty) {
           val worker = idleWorkers.dequeue()
+          worker.selectionKey.interestOps(SelectionKey.OP_READ | 
SelectionKey.OP_WRITE)
           return (worker, daemonWorkers.get(worker))
         }
       }
       createThroughDaemon()
     } else {
-      createSimpleWorker()
+      createSimpleWorker(blockingMode = false)
     }
   }
 
@@ -101,18 +111,25 @@ private[spark] class PythonWorkerFactory(
    * processes itself to avoid the high cost of forking from Java. This 
currently only works
    * on UNIX-based systems.
    */
-  private def createThroughDaemon(): (Socket, Option[Int]) = {
+  private def createThroughDaemon(): (PythonWorker, Option[Int]) = {
 
-    def createSocket(): (Socket, Option[Int]) = {
-      val socket = new Socket(daemonHost, daemonPort)
-      val pid = new DataInputStream(socket.getInputStream).readInt()
+    def createWorker(): (PythonWorker, Option[Int]) = {
+      val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, 
daemonPort))
+      // These calls are blocking.
+      val pid = new 
DataInputStream(Channels.newInputStream(socketChannel)).readInt()
       if (pid < 0) {
         throw new IllegalStateException("Python daemon failed to launch worker 
with code " + pid)
       }
 
-      authHelper.authToServer(socket)
-      daemonWorkers.put(socket, pid)
-      (socket, Some(pid))
+      authHelper.authToServer(socketChannel.socket())
+      socketChannel.configureBlocking(false)
+      val selector = Selector.open()
+      val selectionKey = socketChannel.register(selector,
+        SelectionKey.OP_READ | SelectionKey.OP_WRITE)
+      val worker = PythonWorker(socketChannel, selector, selectionKey)
+
+      daemonWorkers.put(worker, pid)
+      (worker, Some(pid))
     }
 
     self.synchronized {
@@ -121,14 +138,14 @@ private[spark] class PythonWorkerFactory(
 
       // Attempt to connect, restart and retry once if it fails
       try {
-        createSocket()
+        createWorker()
       } catch {
         case exc: SocketException =>
           logWarning("Failed to open socket to Python daemon:", exc)
           logWarning("Assuming that daemon unexpectedly quit, attempting to 
restart")
           stopDaemon()
           startDaemon()
-          createSocket()
+          createWorker()
       }
     }
   }
@@ -136,10 +153,11 @@ private[spark] class PythonWorkerFactory(
   /**
    * Launch a worker by executing worker.py (by default) directly and telling 
it to connect to us.
    */
-  private[spark] def createSimpleWorker(): (Socket, Option[Int]) = {
-    var serverSocket: ServerSocket = null
+  private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, 
Option[Int]) = {
+    var serverSocketChannel: ServerSocketChannel = null
     try {
-      serverSocket = new ServerSocket(0, 1, InetAddress.getLoopbackAddress())
+      serverSocketChannel = ServerSocketChannel.open()
+      serverSocketChannel.bind(new 
InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1)
 
       // Create and start the worker
       val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", 
workerModule))
@@ -154,38 +172,49 @@ private[spark] class PythonWorkerFactory(
       workerEnv.put("PYTHONPATH", pythonPath)
       // This is equivalent to setting the -u flag; we use it because ipython 
doesn't support -u:
       workerEnv.put("PYTHONUNBUFFERED", "YES")
-      workerEnv.put("PYTHON_WORKER_FACTORY_PORT", 
serverSocket.getLocalPort.toString)
+      workerEnv.put("PYTHON_WORKER_FACTORY_PORT", 
serverSocketChannel.socket().getLocalPort
+        .toString)
       workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
       if (Utils.preferIPv6) {
         workerEnv.put("SPARK_PREFER_IPV6", "True")
       }
-      val worker = pb.start()
+      val workerProcess = pb.start()
 
       // Redirect worker stdout and stderr
-      redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream)
+      redirectStreamsToStderr(workerProcess.getInputStream, 
workerProcess.getErrorStream)
 
       // Wait for it to connect to our socket, and validate the auth secret.
-      serverSocket.setSoTimeout(10000)
+      serverSocketChannel.socket().setSoTimeout(10000)
 
       try {
-        val socket = serverSocket.accept()
-        authHelper.authClient(socket)
-        // TODO: When we drop JDK 8, we can just use worker.pid()
-        val pid = new DataInputStream(socket.getInputStream).readInt()
+        val socketChannel = serverSocketChannel.accept()
+        authHelper.authClient(socketChannel.socket())
+        // TODO: When we drop JDK 8, we can just use workerProcess.pid()
+        val pid = new 
DataInputStream(Channels.newInputStream(socketChannel)).readInt()
         if (pid < 0) {
           throw new IllegalStateException("Python failed to launch worker with 
code " + pid)
         }
+        if (!blockingMode) {
+          socketChannel.configureBlocking(false)
+        }
+        val selector = Selector.open()
+        val selectionKey = if (blockingMode) {
+          null
+        } else {
+          socketChannel.register(selector, SelectionKey.OP_READ | 
SelectionKey.OP_WRITE)
+        }
+        val worker = PythonWorker(socketChannel, selector, selectionKey)
         self.synchronized {
-          simpleWorkers.put(socket, worker)
+          simpleWorkers.put(worker, workerProcess)
         }
-        return (socket, Some(pid))
+        return (worker, Some(pid))
       } catch {
         case e: Exception =>
           throw new SparkException("Python worker failed to connect back.", e)
       }
     } finally {
-      if (serverSocket != null) {
-        serverSocket.close()
+      if (serverSocketChannel != null) {
+        serverSocketChannel.close()
       }
     }
     null
@@ -320,11 +349,10 @@ private[spark] class PythonWorkerFactory(
     while (idleWorkers.nonEmpty) {
       val worker = idleWorkers.dequeue()
       try {
-        // the worker will exit after closing the socket
-        worker.close()
+        worker.stop()
       } catch {
         case e: Exception =>
-          logWarning("Failed to close worker socket", e)
+          logWarning("Failed to stop worker socket", e)
       }
     }
   }
@@ -351,7 +379,7 @@ private[spark] class PythonWorkerFactory(
     stopDaemon()
   }
 
-  def stopWorker(worker: Socket): Unit = {
+  def stopWorker(worker: PythonWorker): Unit = {
     self.synchronized {
       if (useDaemon) {
         if (daemon != null) {
@@ -367,22 +395,21 @@ private[spark] class PythonWorkerFactory(
         simpleWorkers.get(worker).foreach(_.destroy())
       }
     }
-    worker.close()
+    worker.stop()
   }
 
-  def releaseWorker(worker: Socket): Unit = {
+  def releaseWorker(worker: PythonWorker): Unit = {
     if (useDaemon) {
       self.synchronized {
         lastActivityNs = System.nanoTime()
         idleWorkers.enqueue(worker)
       }
     } else {
-      // Cleanup the worker socket. This will also cause the Python worker to 
exit.
       try {
-        worker.close()
+        worker.stop()
       } catch {
         case e: Exception =>
-          logWarning("Failed to close worker socket", e)
+          logWarning("Failed to close worker", e)
       }
     }
   }
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
index b6ab031d388..3f7b11a40ad 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.api.python
 
 import java.io.{DataInputStream, DataOutputStream, File}
-import java.net.Socket
 import java.nio.charset.StandardCharsets
 
 import org.apache.spark.{SparkEnv, SparkFiles}
@@ -76,7 +75,7 @@ private[spark] object PythonWorkerUtils extends Logging {
    */
   def writeBroadcasts(
       broadcastVars: Seq[Broadcast[PythonBroadcast]],
-      worker: Socket,
+      worker: PythonWorker,
       env: SparkEnv,
       dataOut: DataOutputStream): Unit = {
     // Broadcast variables
@@ -117,9 +116,6 @@ private[spark] object PythonWorkerUtils extends Logging {
         dataOut.writeLong(id)
       }
       dataOut.flush()
-      logTrace("waiting for python to read decrypted broadcast data from 
server")
-      server.waitTillBroadcastDataSent()
-      logTrace("done sending decrypted data to python")
     } else {
       sendBidsToRemove()
       for (broadcast <- broadcastVars) {
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index fdfe388db2d..e82052e41be 100644
--- 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++ 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.api.python
 
 import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
-import java.net.Socket
 
 import scala.collection.JavaConverters._
 
@@ -50,7 +49,8 @@ private[spark] class StreamingPythonRunner(
 
   private val envVars: java.util.Map[String, String] = func.envVars
   private val pythonExec: String = func.pythonExec
-  private var pythonWorker: Option[Socket] = None
+  private var pythonWorker: Option[PythonWorker] = None
+  private var pythonWorkerFactory: Option[PythonWorkerFactory] = None
   protected val pythonVer: String = func.pythonVer
 
   /**
@@ -71,14 +71,17 @@ private[spark] class StreamingPythonRunner(
     val prevConf = conf.get(PYTHON_USE_DAEMON)
     conf.set(PYTHON_USE_DAEMON, false)
     try {
-      val (worker, _) = env.createPythonWorker(
-        pythonExec, workerModule, envVars.asScala.toMap)
+      val workerFactory =
+        new PythonWorkerFactory(pythonExec, workerModule, 
envVars.asScala.toMap)
+      val (worker: PythonWorker, _) = 
workerFactory.createSimpleWorker(blockingMode = true)
       pythonWorker = Some(worker)
+      pythonWorkerFactory = Some(workerFactory)
     } finally {
       conf.set(PYTHON_USE_DAEMON, prevConf)
     }
 
-    val stream = new BufferedOutputStream(pythonWorker.get.getOutputStream, 
bufferSize)
+    val stream = new BufferedOutputStream(
+      pythonWorker.get.channel.socket().getOutputStream, bufferSize)
     val dataOut = new DataOutputStream(stream)
 
     PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
@@ -93,7 +96,7 @@ private[spark] class StreamingPythonRunner(
     dataOut.flush()
 
     val dataIn = new DataInputStream(
-      new BufferedInputStream(pythonWorker.get.getInputStream, bufferSize))
+      new 
BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, 
bufferSize))
 
     val resFromPython = dataIn.readInt()
     logInfo(s"Runner initialization returned $resFromPython")
diff --git 
a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala 
b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala
index 8230144025f..5f2a9dd2743 100644
--- a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala
@@ -55,6 +55,14 @@ private[spark] object InputFileBlockHolder {
         new AtomicReference(new FileBlock)
     }
 
+  private[spark] def setThreadLocalValue(ref: Object): Unit = {
+    inputBlock.set(ref.asInstanceOf[AtomicReference[FileBlock]])
+  }
+
+  private[spark] def getThreadLocalValue(): Object = {
+    inputBlock.get()
+  }
+
   /**
    * Returns the holding file name or empty string if it is unknown.
    */
@@ -72,6 +80,9 @@ private[spark] object InputFileBlockHolder {
 
   /**
    * Sets the thread-local input block.
+   *
+   * Callers of this method must ensure a task completion listener has been 
registered to unset()
+   * the thread local in the task thread.
    */
   def set(filePath: String, startOffset: Long, length: Long): Unit = {
     require(filePath != null, "filePath cannot be null")
diff --git 
a/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala 
b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala
new file mode 100644
index 00000000000..a4145bb36ac
--- /dev/null
+++ 
b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.OutputStream
+import java.nio.ByteBuffer
+
+import org.apache.spark.storage.StorageUtils
+import org.apache.spark.unsafe.Platform
+
+/**
+ * An output stream that dumps data into a direct byte buffer. The byte buffer 
grows in size
+ * as more data is written to the stream.
+ * @param capacity The initial capacity of the direct byte buffer
+ */
+private[spark] class DirectByteBufferOutputStream(capacity: Int) extends 
OutputStream {
+  private var buffer = Platform.allocateDirectBuffer(capacity)
+
+  def this() = this(32)
+
+  override def write(b: Int): Unit = {
+    ensureCapacity(buffer.position() + 1)
+    buffer.put(b.toByte)
+  }
+
+  override def write(b: Array[Byte], off: Int, len: Int): Unit = {
+    ensureCapacity(buffer.position() + len)
+    buffer.put(b, off, len)
+  }
+
+  private def ensureCapacity(minCapacity: Int): Unit = {
+    if (minCapacity > buffer.capacity()) grow(minCapacity)
+  }
+
+  /**
+   * Grows the current buffer to at least `minCapacity` capacity.
+   * As a side effect, all references to the old buffer will be invalidated.
+   */
+  private def grow(minCapacity: Int): Unit = {
+    val oldCapacity = buffer.capacity()
+    var newCapacity = oldCapacity << 1
+    if (newCapacity < minCapacity) newCapacity = minCapacity
+    val oldBuffer = buffer
+    oldBuffer.flip()
+    val newBuffer = ByteBuffer.allocateDirect(newCapacity)
+    newBuffer.put(oldBuffer)
+    StorageUtils.dispose(oldBuffer)
+    buffer = newBuffer
+  }
+
+  def reset(): Unit = buffer.clear()
+
+  def size(): Int = buffer.position()
+
+  /**
+   * Any subsequent call to [[close()]], [[write()]], [[reset()]] will 
invalidate the buffer
+   * returned by this method.
+   */
+  def toByteBuffer: ByteBuffer = {
+    val outputBuffer = buffer.duplicate()
+    outputBuffer.flip()
+    outputBuffer
+  }
+
+  override def close(): Unit = {
+    // Eagerly free the direct byte buffer without waiting for GC to reduce 
memory pressure.
+    StorageUtils.dispose(buffer)
+  }
+
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index d4c535fe76a..a60d0beeeed 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -55,7 +55,7 @@ class ApplyInPandasWithStatePythonRunner(
     evalType: Int,
     argOffsets: Array[Array[Int]],
     inputSchema: StructType,
-    override protected val timeZoneId: String,
+    _timeZoneId: String,
     initialWorkerConf: Map[String, String],
     stateEncoder: ExpressionEncoder[Row],
     keySchema: StructType,
@@ -73,8 +73,10 @@ class ApplyInPandasWithStatePythonRunner(
 
   private val sqlConf = SQLConf.get
 
-  override protected val schema: StructType = inputSchema.add("__state", 
STATE_METADATA_SCHEMA)
-
+  // Use lazy val to initialize the fields before these are accessed in 
[[PythonArrowInput]]'s
+  // constructor.
+  override protected lazy val schema: StructType = inputSchema.add("__state", 
STATE_METADATA_SCHEMA)
+  override protected lazy val timeZoneId: String = _timeZoneId
   override val errorOnDuplicatedFieldNames: Boolean = true
 
   override val simplifiedTraceback: Boolean = 
sqlConf.pysparkSimplifiedTraceback
@@ -113,37 +115,41 @@ class ApplyInPandasWithStatePythonRunner(
     // Also write the schema for state value
     PythonRDD.writeUTF(stateValueSchema.json, stream)
   }
-
+  private var pandasWriter: ApplyInPandasWithStateWriter = _
   /**
    * Read the (key, state, values) from input iterator and construct Arrow 
RecordBatches, and
    * write constructed RecordBatches to the writer.
    *
    * See [[ApplyInPandasWithStateWriter]] for more details.
    */
-  protected def writeIteratorToArrowStream(
+  protected def writeNextInputToArrowStream(
       root: VectorSchemaRoot,
       writer: ArrowStreamWriter,
       dataOut: DataOutputStream,
-      inputIterator: Iterator[InType]): Unit = {
-    val w = new ApplyInPandasWithStateWriter(root, writer, 
arrowMaxRecordsPerBatch)
-
-    while (inputIterator.hasNext) {
+      inputIterator: Iterator[InType]): Boolean = {
+    if (pandasWriter == null) {
+      pandasWriter = new ApplyInPandasWithStateWriter(root, writer, 
arrowMaxRecordsPerBatch)
+    }
+    if (inputIterator.hasNext) {
       val startData = dataOut.size()
       val (keyRow, groupState, dataIter) = inputIterator.next()
       assert(dataIter.hasNext, "should have at least one data row!")
-      w.startNewGroup(keyRow, groupState)
+      pandasWriter.startNewGroup(keyRow, groupState)
 
       while (dataIter.hasNext) {
         val dataRow = dataIter.next()
-        w.writeRow(dataRow)
+        pandasWriter.writeRow(dataRow)
       }
 
-      w.finalizeGroup()
+      pandasWriter.finalizeGroup()
       val deltaData = dataOut.size() - startData
       pythonMetrics("pythonDataSent") += deltaData
+      true
+    } else {
+      pandasWriter.finalizeData()
+      super[PythonArrowInput].close()
+      false
     }
-
-    w.finalizeData()
   }
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index d9bce96c477..0f26d8f21f8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -31,8 +31,8 @@ class ArrowPythonRunner(
     funcs: Seq[ChainedPythonFunctions],
     evalType: Int,
     argOffsets: Array[Array[Int]],
-    protected override val schema: StructType,
-    protected override val timeZoneId: String,
+    _schema: StructType,
+    _timeZoneId: String,
     protected override val largeVarTypes: Boolean,
     protected override val workerConf: Map[String, String],
     val pythonMetrics: Map[String, SQLMetric],
@@ -50,6 +50,10 @@ class ArrowPythonRunner(
 
   override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
 
+  // Use lazy val to initialize the fields before these are accessed in 
[[PythonArrowInput]]'s
+  // constructor.
+  override protected lazy val timeZoneId: String = _timeZoneId
+  override protected lazy val schema: StructType = _schema
   override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
   require(
     bufferSize >= 4,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index 9dae874e3ed..6c8412f8b37 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -18,14 +18,13 @@
 package org.apache.spark.sql.execution.python
 
 import java.io.DataOutputStream
-import java.net.Socket
 
 import scala.collection.JavaConverters._
 
 import net.razorvine.pickle.Unpickler
 
 import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext}
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, 
PythonWorkerUtils}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, 
PythonWorker, PythonWorkerUtils}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.GenericArrayData
@@ -101,13 +100,13 @@ class PythonUDTFRunner(
     Seq(ChainedPythonFunctions(Seq(udtf.func))),
     PythonEvalType.SQL_TABLE_UDF, Array(argOffsets), pythonMetrics, 
jobArtifactUUID) {
 
-  protected override def newWriterThread(
+  protected override def newWriter(
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       inputIterator: Iterator[Array[Byte]],
       partitionIndex: Int,
-      context: TaskContext): WriterThread = {
-    new PythonUDFWriterThread(env, worker, inputIterator, partitionIndex, 
context) {
+      context: TaskContext): Writer = {
+    new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) {
 
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
         PythonUDTFRunner.writeUDTF(dataOut, udtf, argOffsets)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index eef8be7c940..bd901545bb0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -18,13 +18,12 @@
 package org.apache.spark.sql.execution.python
 
 import java.io.DataOutputStream
-import java.net.Socket
 
 import org.apache.arrow.vector.VectorSchemaRoot
 import org.apache.arrow.vector.ipc.ArrowStreamWriter
 
 import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonRDD}
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonRDD, PythonWorker}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.arrow.ArrowWriter
 import org.apache.spark.sql.execution.metric.SQLMetric
@@ -60,14 +59,14 @@ class CoGroupedArrowPythonRunner(
 
   override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
 
-  protected def newWriterThread(
+  protected def newWriter(
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       inputIterator: Iterator[(Iterator[InternalRow], Iterator[InternalRow])],
       partitionIndex: Int,
-      context: TaskContext): WriterThread = {
+      context: TaskContext): Writer = {
 
-    new WriterThread(env, worker, inputIterator, partitionIndex, context) {
+    new Writer(env, worker, inputIterator, partitionIndex, context) {
 
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
 
@@ -81,10 +80,10 @@ class CoGroupedArrowPythonRunner(
         PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
       }
 
-      protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
+      override def writeNextInputToStream(dataOut: DataOutputStream): Boolean 
= {
         // For each we first send the number of dataframes in each group then 
send
         // first df, then send second df.  End of data is marked by sending 0.
-        while (inputIterator.hasNext) {
+        if (inputIterator.hasNext) {
           val startData = dataOut.size()
           dataOut.writeInt(2)
           val (nextLeft, nextRight) = inputIterator.next()
@@ -93,8 +92,11 @@ class CoGroupedArrowPythonRunner(
 
           val deltaData = dataOut.size() - startData
           pythonMetrics("pythonDataSent") += deltaData
+          true
+        } else {
+          dataOut.writeInt(0)
+          false
         }
-        dataOut.writeInt(0)
       }
 
       private def writeGroup(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
index 10bb3a45be9..373e17c0aa3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
@@ -21,7 +21,7 @@ import java.io.File
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{ContextAwareIterator, PartitionEvaluator, 
PartitionEvaluatorFactory, SparkEnv, TaskContext}
+import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory, 
SparkEnv, TaskContext}
 import org.apache.spark.api.python.ChainedPythonFunctions
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -62,7 +62,6 @@ abstract class EvalPythonEvaluatorFactory(
         iters: Iterator[InternalRow]*): Iterator[InternalRow] = {
       val iter = iters.head
       val context = TaskContext.get()
-      val contextAwareIterator = new ContextAwareIterator(context, iter)
 
       // The queue used to buffer input rows so we can drain it to
       // combine input with output from Python.
@@ -97,7 +96,7 @@ abstract class EvalPythonEvaluatorFactory(
       }.toArray)
 
       // Add rows to queue to join later with the result.
-      val projectedRowIter = contextAwareIterator.map { inputRow =>
+      val projectedRowIter = iter.map { inputRow =>
         queue.add(inputRow.asInstanceOf[UnsafeRow])
         projection(inputRow)
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index 8d2f788e05c..6664acf9572 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -24,7 +24,6 @@ import scala.collection.JavaConverters._
 
 import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler}
 
-import org.apache.spark.{ContextAwareIterator, TaskContext}
 import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -302,7 +301,7 @@ object EvaluatePython {
   def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = {
     rdd.mapPartitions { iter =>
       registerPicklers()  // let it called in executor
-      new SerDeUtil.AutoBatchedPickler(new 
ContextAwareIterator(TaskContext.get, iter))
+      new SerDeUtil.AutoBatchedPickler(iter)
     }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
index 1e15aa7f777..6f501e1411a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python
 
 import scala.collection.JavaConverters._
 
-import org.apache.spark.{ContextAwareIterator, PartitionEvaluator, 
PartitionEvaluatorFactory, TaskContext}
+import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory, 
TaskContext}
 import org.apache.spark.api.python.ChainedPythonFunctions
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -52,11 +52,10 @@ class MapInBatchEvaluatorFactory(
       // Single function with one struct.
       val argOffsets = Array(Array(0))
       val context = TaskContext.get()
-      val contextAwareIterator = new ContextAwareIterator(context, inputIter)
 
       // Here we wrap it via another row so that Python sides understand it
       // as a DataFrame.
-      val wrappedIter = contextAwareIterator.map(InternalRow(_))
+      val wrappedIter = inputIter.map(InternalRow(_))
 
       // DO NOT use iter.grouped(). See BatchIterator.
       val batchIter =
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 5c99a3f9808..00ee3a17563 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -17,14 +17,14 @@
 package org.apache.spark.sql.execution.python
 
 import java.io.DataOutputStream
-import java.net.Socket
 
 import org.apache.arrow.vector.VectorSchemaRoot
 import org.apache.arrow.vector.ipc.ArrowStreamWriter
 
 import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonRDD}
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonRDD, PythonWorker}
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.arrow
 import org.apache.spark.sql.execution.arrow.ArrowWriter
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.StructType
@@ -48,11 +48,11 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
 
   protected def pythonMetrics: Map[String, SQLMetric]
 
-  protected def writeIteratorToArrowStream(
+  protected def writeNextInputToArrowStream(
       root: VectorSchemaRoot,
       writer: ArrowStreamWriter,
       dataOut: DataOutputStream,
-      inputIterator: Iterator[IN]): Unit
+      inputIterator: Iterator[IN]): Boolean
 
   protected def writeUDF(
       dataOut: DataOutputStream,
@@ -68,51 +68,46 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
       PythonRDD.writeUTF(v, stream)
     }
   }
+  private val arrowSchema = ArrowUtils.toArrowSchema(
+    schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
+  private val allocator =
+    ArrowUtils.rootAllocator.newChildAllocator(s"stdout writer for 
$pythonExec", 0, Long.MaxValue)
+  protected val root = VectorSchemaRoot.create(arrowSchema, allocator)
+  protected var writer: ArrowStreamWriter = _
+
+protected def close(): Unit = {
+  Utils.tryWithSafeFinally {
+    // end writes footer to the output stream and doesn't clean any resources.
+    // It could throw exception if the output stream is closed, so it should be
+    // in the try block.
+    writer.end()
+  } {
+    root.close()
+    allocator.close()
+  }
+}
 
-  protected override def newWriterThread(
+  protected override def newWriter(
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       inputIterator: Iterator[IN],
       partitionIndex: Int,
-      context: TaskContext): WriterThread = {
-    new WriterThread(env, worker, inputIterator, partitionIndex, context) {
-
+      context: TaskContext): Writer = {
+    new Writer(env, worker, inputIterator, partitionIndex, context) {
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
         handleMetadataBeforeExec(dataOut)
         writeUDF(dataOut, funcs, argOffsets)
       }
 
-      protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
-        val arrowSchema = ArrowUtils.toArrowSchema(
-          schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
-        val allocator = ArrowUtils.rootAllocator.newChildAllocator(
-          s"stdout writer for $pythonExec", 0, Long.MaxValue)
-        val root = VectorSchemaRoot.create(arrowSchema, allocator)
+      override def writeNextInputToStream(dataOut: DataOutputStream): Boolean 
= {
 
-        Utils.tryWithSafeFinally {
-          val writer = new ArrowStreamWriter(root, null, dataOut)
+        if (writer == null) {
+          writer = new ArrowStreamWriter(root, null, dataOut)
           writer.start()
-
-          writeIteratorToArrowStream(root, writer, dataOut, inputIterator)
-
-          // end writes footer to the output stream and doesn't clean any 
resources.
-          // It could throw exception if the output stream is closed, so it 
should be
-          // in the try block.
-          writer.end()
-        } {
-          // If we close root and allocator in TaskCompletionListener, there 
could be a race
-          // condition where the writer thread keeps writing to the 
VectorSchemaRoot while
-          // it's being closed by the TaskCompletion listener.
-          // Closing root and allocator here is cleaner because root and 
allocator is owned
-          // by the writer thread and is only visible to the writer thread.
-          //
-          // If the writer thread is interrupted by TaskCompletionListener, it 
should either
-          // (1) in the try block, in which case it will get an 
InterruptedException when
-          // performing io, and goes into the finally block or (2) in the 
finally block,
-          // in which case it will ignore the interruption and close the 
resources.
-          root.close()
-          allocator.close()
         }
+
+        assert(writer != null)
+        writeNextInputToArrowStream(root, writer, dataOut, inputIterator)
       }
     }
   }
@@ -120,15 +115,15 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
 
 private[python] trait BasicPythonArrowInput extends 
PythonArrowInput[Iterator[InternalRow]] {
   self: BasePythonRunner[Iterator[InternalRow], _] =>
+  private val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root)
 
-  protected def writeIteratorToArrowStream(
+  protected def writeNextInputToArrowStream(
       root: VectorSchemaRoot,
       writer: ArrowStreamWriter,
       dataOut: DataOutputStream,
-      inputIterator: Iterator[Iterator[InternalRow]]): Unit = {
-    val arrowWriter = ArrowWriter.create(root)
+      inputIterator: Iterator[Iterator[InternalRow]]): Boolean = {
 
-    while (inputIterator.hasNext) {
+    if (inputIterator.hasNext) {
       val startData = dataOut.size()
       val nextBatch = inputIterator.next()
 
@@ -141,6 +136,10 @@ private[python] trait BasicPythonArrowInput extends 
PythonArrowInput[Iterator[In
       arrowWriter.reset()
       val deltaData = dataOut.size() - startData
       pythonMetrics("pythonDataSent") += deltaData
+      true
+    } else {
+      super[PythonArrowInput].close()
+      false
     }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index c12c690f776..8f99325e4e0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -17,7 +17,6 @@
 package org.apache.spark.sql.execution.python
 
 import java.io.DataInputStream
-import java.net.Socket
 import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.collection.JavaConverters._
@@ -26,7 +25,7 @@ import org.apache.arrow.vector.VectorSchemaRoot
 import org.apache.arrow.vector.ipc.ArrowStreamReader
 
 import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths}
+import org.apache.spark.api.python.{BasePythonRunner, PythonWorker, 
SpecialLengths}
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.ArrowUtils
@@ -46,16 +45,16 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { 
self: BasePythonRunner[
 
   protected def newReaderIterator(
       stream: DataInputStream,
-      writerThread: WriterThread,
+      writer: Writer,
       startTime: Long,
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       pid: Option[Int],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[OUT] = {
 
     new ReaderIterator(
-      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, 
context) {
+      stream, writer, startTime, env, worker, pid, releasedOrClosed, context) {
 
       private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
         s"stdin reader for $pythonExec", 0, Long.MaxValue)
@@ -80,8 +79,8 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { 
self: BasePythonRunner[
       }
 
       protected override def read(): OUT = {
-        if (writerThread.exception.isDefined) {
-          throw writerThread.exception.get
+        if (writer.exception.isDefined) {
+          throw writer.exception.get
         }
         try {
           if (reader != null && batchLoaded) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
index 3857f084bcb..a229931cec8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
@@ -31,6 +31,44 @@ import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.{NextIterator, Utils}
 
+/**
+ * Writes the rows buffered in [[UnsafeRowBuffer]] to the Python worker.
+ * Any exceptions encountered will be cached to be read later by the parent 
thread.
+ */
+class WriterThread(outputIterator: Iterator[Array[Byte]])
+  extends Thread(s"Thread streaming data to the Python worker") {
+
+  @volatile var _exception: Throwable = _
+
+  override def run(): Unit = {
+    try {
+      // [[PythonForEachWriter]] is a sink and thus the Python worker does not 
generate any output.
+      // The `hasNext()` and `next()` call are an indirect way to ship the 
input data to the
+      // Python worker. Consuming the Python worker's output iterator, as a 
side-effect, drives the
+      // write of the input data to the Python worker through 
[[org.apache.spark.api.python.
+      // BasePythonRunner.ReaderInputStream 
.writeAdditionalInputToPythonWorker]].
+      if (outputIterator.hasNext) {
+        outputIterator.next()
+      }
+    } catch {
+      // Cache exceptions seen while evaluating the Python function on the 
streamed input. The
+      // parent thread will throw this crashed exception eventually.
+      case t: Throwable =>
+        _exception = t
+    }
+  }
+}
+
+/**
+ * The class proceeds as follows:
+ *  - Rows streamed through a `process()` call on the
+ * [[org.apache.spark.sql.execution.streaming.QueryExecutionThread]] are 
buffered in the
+ * `UnsafeRowBuffer`.
+ * - The [[WriterThread]] streams the buffered data to the Python worker.
+ * - Once the streaming query ends, [[close()]] is called which signals the 
buffer to mark the
+ * end of streaming input. The streaming query execution thread waits for the 
[[WriterThread]] to
+ * complete and throws any exceptions seen by the [[WriterThread]].
+ */
 class PythonForeachWriter(func: PythonFunction, schema: StructType)
   extends ForeachWriter[UnsafeRow] {
 
@@ -58,8 +96,11 @@ class PythonForeachWriter(func: PythonFunction, schema: 
StructType)
   private lazy val outputIterator =
     pythonRunner.compute(inputByteIterator, context.partitionId(), context)
 
+  private lazy val writerThread = new WriterThread(outputIterator)
+
   override def open(partitionId: Long, version: Long): Boolean = {
     outputIterator  // initialize everything
+    writerThread.start()
     TaskContext.get.addTaskCompletionListener[Unit] { _ => buffer.close() }
     true
   }
@@ -68,9 +109,15 @@ class PythonForeachWriter(func: PythonFunction, schema: 
StructType)
     buffer.add(value)
   }
 
+  /**
+   * Waits for the writer thread to finish evaluating the Python function. 
Throws any exceptions
+   * seen by the writer thread.
+   */
   override def close(errorOrNull: Throwable): Unit = {
     buffer.allRowsAdded()
-    if (outputIterator.hasNext) outputIterator.next() // to throw python 
exception if there was one
+    writerThread.join()
+    // Throw Python exception if there was one.
+    if (writerThread._exception != null) throw writerThread._exception
   }
 }
 
@@ -78,18 +125,20 @@ object PythonForeachWriter {
 
   /**
    * A buffer that is designed for the sole purpose of buffering UnsafeRows in 
PythonForeachWriter.
-   * It is designed to be used with only 1 writer thread (i.e. JVM task 
thread) and only 1 reader
-   * thread (i.e. PythonRunner writing thread that reads from the buffer and 
writes to the Python
-   * worker stdin). Adds to the buffer are non-blocking, and reads through the 
buffer's iterator
-   * are blocking, that is, it blocks until new data is available or all data 
has been added.
+   * It is designed to be used with only two threads: the QueryExecutionThread 
which writes data
+   * to the buffer and [[WriterThread]] thread that reads from the buffer and 
writes to the
+   * Python worker stdin. Adds to the buffer are non-blocking, and reads 
through the buffer's
+   * iterator are blocking, that is, it blocks until new data is available or 
all data has been
+   * added.
    *
    * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a 
practically unlimited queue
    * across memory and local disk. However, HybridRowQueue is designed to be 
used only with
-   * EvalPythonExec where the reader is always behind the writer, that is, the 
reader does not
-   * try to read n+1 rows if the writer has only written n rows at any point 
of time. This
-   * assumption is not true for PythonForeachWriter where rows may be added at 
a different rate as
-   * they are consumed by the python worker. Hence, to maintain the invariant 
of the reader being
-   * behind the writer while using HybridRowQueue, the buffer does the 
following
+   * EvalPythonExec where the buffer's consumer is always behind the buffer's 
populator, that is,
+   * the [[WriterThread]] does not try to read n + 1 rows if the streaming 
thread has only
+   * written n rows at any point of time. This assumption is not true for 
PythonForeachWriter
+   * where rows may be added at a different rate as they are consumed by the 
Python worker.
+   * Hence, to maintain the invariant of the reader being behind the writer 
while using
+   * HybridRowQueue, the buffer does the following:
    * - Keeps a count of the rows in the HybridRowQueue
    * - Blocks the buffer's consuming iterator when the count is 0 so that the 
reader does not
    *   try to read more rows than what has been written.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index 22083e0473b..bc27ee6919d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.execution.python
 
 import java.io._
-import java.net._
 import java.util.concurrent.atomic.AtomicBoolean
 
 import org.apache.spark._
@@ -44,40 +43,42 @@ abstract class BasePythonUDFRunner(
 
   override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
 
-  abstract class PythonUDFWriterThread(
+  abstract class PythonUDFWriter(
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       inputIterator: Iterator[Array[Byte]],
       partitionIndex: Int,
       context: TaskContext)
-    extends WriterThread(env, worker, inputIterator, partitionIndex, context) {
+    extends Writer(env, worker, inputIterator, partitionIndex, context) {
 
-    protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
+    override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = {
       val startData = dataOut.size()
-
-      PythonRDD.writeIteratorToStream(inputIterator, dataOut)
-      dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
-
+      val wroteData = PythonRDD.writeNextElementToStream(inputIterator, 
dataOut)
+      if (!wroteData) {
+        // Reached the end of input.
+        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+      }
       val deltaData = dataOut.size() - startData
       pythonMetrics("pythonDataSent") += deltaData
+      wroteData
     }
   }
 
   protected override def newReaderIterator(
       stream: DataInputStream,
-      writerThread: WriterThread,
+      writer: Writer,
       startTime: Long,
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       pid: Option[Int],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[Array[Byte]] = {
     new ReaderIterator(
-      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, 
context) {
+      stream, writer, startTime, env, worker, pid, releasedOrClosed, context) {
 
       protected override def read(): Array[Byte] = {
-        if (writerThread.exception.isDefined) {
-          throw writerThread.exception.get
+        if (writer.exception.isDefined) {
+          throw writer.exception.get
         }
         try {
           stream.readInt() match {
@@ -110,13 +111,13 @@ class PythonUDFRunner(
     jobArtifactUUID: Option[String])
   extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics, 
jobArtifactUUID) {
 
-  protected override def newWriterThread(
+  protected override def newWriter(
       env: SparkEnv,
-      worker: Socket,
+      worker: PythonWorker,
       inputIterator: Iterator[Array[Byte]],
       partitionIndex: Int,
-      context: TaskContext): WriterThread = {
-    new PythonUDFWriterThread(env, worker, inputIterator, partitionIndex, 
context) {
+      context: TaskContext): Writer = {
+    new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) {
 
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
         PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 36cb2e17835..5fa9c89b3d1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -17,8 +17,9 @@
 
 package org.apache.spark.sql.execution.python
 
-import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream, EOFException}
-import java.net.Socket
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream, EOFException, InputStream}
+import java.nio.ByteBuffer
+import java.nio.channels.SelectionKey
 import java.nio.charset.StandardCharsets
 import java.util.HashMap
 
@@ -27,7 +28,7 @@ import scala.collection.JavaConverters._
 import net.razorvine.pickle.Pickler
 
 import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException}
-import org.apache.spark.api.python.{PythonEvalType, PythonFunction, 
PythonWorkerUtils, SpecialLengths}
+import org.apache.spark.api.python.{PythonEvalType, PythonFunction, 
PythonWorker, PythonWorkerUtils, SpecialLengths}
 import org.apache.spark.internal.config.BUFFER_SIZE
 import org.apache.spark.internal.config.Python._
 import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession}
@@ -36,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Generate, 
LogicalPlan, OneRo
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.util.DirectByteBufferOutputStream
 
 /**
  * A user-defined Python function. This is used by the Python API.
@@ -205,13 +207,14 @@ object UserDefinedPythonTableFunction {
     val pickler = new Pickler(/* useMemo = */ true,
       /* valueCompare = */ false)
 
-    val (worker: Socket, _) =
+    val (worker: PythonWorker, _) =
       env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap)
     var releasedOrClosed = false
+    val bufferStream = new DirectByteBufferOutputStream()
     try {
-      val dataOut =
-        new DataOutputStream(new BufferedOutputStream(worker.getOutputStream, 
bufferSize))
-      val dataIn = new DataInputStream(new 
BufferedInputStream(worker.getInputStream, bufferSize))
+      val dataOut = new DataOutputStream(new 
BufferedOutputStream(bufferStream, bufferSize))
+      val dataIn = new DataInputStream(new BufferedInputStream(
+        new WorkerInputStream(worker, bufferStream), bufferSize))
 
       PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
       PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, 
dataOut)
@@ -276,4 +279,50 @@ object UserDefinedPythonTableFunction {
       }
     }
   }
+
+  /**
+   * A wrapper of the non-blocking IO to write to/read from the worker.
+   *
+   * Since we use non-blocking IO to communicate with workers; see SPARK-44705,
+   * a wrapper is needed to do IO with the worker.
+   * This is a port and simplified version of `PythonRunner.ReaderInputStream`,
+   * and only supports to write all at once and then read all.
+   */
+  private class WorkerInputStream(
+      worker: PythonWorker, bufferStream: DirectByteBufferOutputStream) 
extends InputStream {
+
+    private[this] val temp = new Array[Byte](1)
+
+    override def read(): Int = {
+      val n = read(temp)
+      if (n <= 0) {
+        -1
+      } else {
+        // Signed byte to unsigned integer
+        temp(0) & 0xff
+      }
+    }
+
+    override def read(b: Array[Byte], off: Int, len: Int): Int = {
+      val buf = ByteBuffer.wrap(b, off, len)
+      var n = 0
+      while (n == 0) {
+        worker.selector.select()
+        if (worker.selectionKey.isReadable) {
+          n = worker.channel.read(buf)
+        }
+        if (worker.selectionKey.isWritable) {
+          val buffer = bufferStream.toByteBuffer
+          var acceptsInput = true
+          while (acceptsInput && buffer.hasRemaining) {
+            val n = worker.channel.write(buffer)
+            acceptsInput = n > 0
+          }
+          // We no longer have any data to write to the socket.
+          worker.selectionKey.interestOps(SelectionKey.OP_READ)
+        }
+      }
+      n
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to