This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 17430fe4702 [SPARK-45302][PYTHON] Remove PID communication between 
Python workers when no demon is used
17430fe4702 is described below

commit 17430fe47029f1d27c7913468b95abfd856fddcc
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Wed Sep 27 10:48:17 2023 +0900

    [SPARK-45302][PYTHON] Remove PID communication between Python workers when 
no demon is used
    
    ### What changes were proposed in this pull request?
    
    This PR removes the legacy workaround for JDK 8 in `PythonWorkerFactory`.
    
    ### Why are the changes needed?
    
    No need to manually send the PID around through the socket.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    There are existing unittests for the daemon disabled.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43087 from HyukjinKwon/SPARK-45302.
    
    Lead-authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 core/src/main/scala/org/apache/spark/SparkEnv.scala       |  4 ++--
 .../scala/org/apache/spark/api/python/PythonRunner.scala  | 10 +++++-----
 .../org/apache/spark/api/python/PythonWorkerFactory.scala | 15 +++++++--------
 python/pyspark/daemon.py                                  |  4 ++--
 .../sql/connect/streaming/worker/foreach_batch_worker.py  |  2 --
 .../sql/connect/streaming/worker/listener_worker.py       |  2 --
 python/pyspark/sql/worker/analyze_udtf.py                 |  3 ---
 python/pyspark/worker.py                                  |  3 ---
 .../spark/sql/execution/python/PythonArrowOutput.scala    |  2 +-
 .../spark/sql/execution/python/PythonUDFRunner.scala      |  2 +-
 10 files changed, 18 insertions(+), 29 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala 
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index e404c9ee8b4..937170b5ee8 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -128,7 +128,7 @@ class SparkEnv (
       pythonExec: String,
       workerModule: String,
       daemonModule: String,
-      envVars: Map[String, String]): (PythonWorker, Option[Int]) = {
+      envVars: Map[String, String]): (PythonWorker, Option[Long]) = {
     synchronized {
       val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, 
envVars)
       pythonWorkers.getOrElseUpdate(key,
@@ -139,7 +139,7 @@ class SparkEnv (
   private[spark] def createPythonWorker(
       pythonExec: String,
       workerModule: String,
-      envVars: Map[String, String]): (PythonWorker, Option[Int]) = {
+      envVars: Map[String, String]): (PythonWorker, Option[Long]) = {
     createPythonWorker(
       pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, 
envVars)
   }
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index db95e6c2bd6..2a63298d0a1 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -84,7 +84,7 @@ private object BasePythonRunner {
 
   private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = 
"faulthandler")
 
-  private def faultHandlerLogPath(pid: Int): Path = {
+  private def faultHandlerLogPath(pid: Long): Path = {
     new File(faultHandlerLogDir, pid.toString).toPath
   }
 }
@@ -200,7 +200,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
 
     envVars.put("SPARK_JOB_ARTIFACT_UUID", 
jobArtifactUUID.getOrElse("default"))
 
-    val (worker: PythonWorker, pid: Option[Int]) = env.createPythonWorker(
+    val (worker: PythonWorker, pid: Option[Long]) = env.createPythonWorker(
       pythonExec, workerModule, daemonModule, envVars.asScala.toMap)
     // Whether is the worker released into idle pool or closed. When any codes 
try to release or
     // close a worker, they should use `releasedOrClosed.compareAndSet` to 
flip the state to make
@@ -253,7 +253,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
       startTime: Long,
       env: SparkEnv,
       worker: PythonWorker,
-      pid: Option[Int],
+      pid: Option[Long],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[OUT]
 
@@ -463,7 +463,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
       startTime: Long,
       env: SparkEnv,
       worker: PythonWorker,
-      pid: Option[Int],
+      pid: Option[Long],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext)
     extends Iterator[OUT] {
@@ -838,7 +838,7 @@ private[spark] class PythonRunner(
       startTime: Long,
       env: SparkEnv,
       worker: PythonWorker,
-      pid: Option[Int],
+      pid: Option[Long],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[Array[Byte]] = {
     new ReaderIterator(
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index d0776eb2cc7..cef815b22ac 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -76,7 +76,7 @@ private[spark] class PythonWorkerFactory(
   @GuardedBy("self")
   private var daemonPort: Int = 0
   @GuardedBy("self")
-  private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, Int]()
+  private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, Long]()
   @GuardedBy("self")
   private val idleWorkers = new mutable.Queue[PythonWorker]()
   @GuardedBy("self")
@@ -91,7 +91,7 @@ private[spark] class PythonWorkerFactory(
     envVars.getOrElse("PYTHONPATH", ""),
     sys.env.getOrElse("PYTHONPATH", ""))
 
-  def create(): (PythonWorker, Option[Int]) = {
+  def create(): (PythonWorker, Option[Long]) = {
     if (useDaemon) {
       self.synchronized {
         if (idleWorkers.nonEmpty) {
@@ -111,9 +111,9 @@ private[spark] class PythonWorkerFactory(
    * processes itself to avoid the high cost of forking from Java. This 
currently only works
    * on UNIX-based systems.
    */
-  private def createThroughDaemon(): (PythonWorker, Option[Int]) = {
+  private def createThroughDaemon(): (PythonWorker, Option[Long]) = {
 
-    def createWorker(): (PythonWorker, Option[Int]) = {
+    def createWorker(): (PythonWorker, Option[Long]) = {
       val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, 
daemonPort))
       // These calls are blocking.
       val pid = new 
DataInputStream(Channels.newInputStream(socketChannel)).readInt()
@@ -153,7 +153,7 @@ private[spark] class PythonWorkerFactory(
   /**
    * Launch a worker by executing worker.py (by default) directly and telling 
it to connect to us.
    */
-  private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, 
Option[Int]) = {
+  private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, 
Option[Long]) = {
     var serverSocketChannel: ServerSocketChannel = null
     try {
       serverSocketChannel = ServerSocketChannel.open()
@@ -189,8 +189,7 @@ private[spark] class PythonWorkerFactory(
       try {
         val socketChannel = serverSocketChannel.accept()
         authHelper.authClient(socketChannel.socket())
-        // TODO: When we drop JDK 8, we can just use workerProcess.pid()
-        val pid = new 
DataInputStream(Channels.newInputStream(socketChannel)).readInt()
+        val pid = workerProcess.toHandle.pid()
         if (pid < 0) {
           throw new IllegalStateException("Python failed to launch worker with 
code " + pid)
         }
@@ -386,7 +385,7 @@ private[spark] class PythonWorkerFactory(
           daemonWorkers.get(worker).foreach { pid =>
             // tell daemon to kill worker by pid
             val output = new DataOutputStream(daemon.getOutputStream)
-            output.writeInt(pid)
+            output.writeLong(pid)
             output.flush()
             daemon.getOutputStream.flush()
           }
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index bbbc495d053..b0e06d13bed 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -28,7 +28,7 @@ from errno import EINTR, EAGAIN
 from socket import AF_INET, AF_INET6, SOCK_STREAM, SOMAXCONN
 from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
 
-from pyspark.serializers import read_int, write_int, write_with_length, 
UTF8Deserializer
+from pyspark.serializers import read_long, write_int, write_with_length, 
UTF8Deserializer
 
 if len(sys.argv) > 1:
     import importlib
@@ -139,7 +139,7 @@ def manager():
 
             if 0 in ready_fds:
                 try:
-                    worker_pid = read_int(stdin_bin)
+                    worker_pid = read_long(stdin_bin)
                 except EOFError:
                     # Spark told us to exit by closing stdin
                     shutdown(0)
diff --git 
a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py 
b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
index 06534e355de..022e768c43b 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
@@ -87,6 +87,4 @@ if __name__ == "__main__":
     (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
     # There could be a long time between each micro batch.
     sock.settimeout(None)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
     main(sock_file, sock_file)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py 
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index ed38a788435..bb6bcd5d965 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -101,6 +101,4 @@ if __name__ == "__main__":
     (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
     # There could be a long time between each listener event.
     sock.settimeout(None)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
     main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/analyze_udtf.py 
b/python/pyspark/sql/worker/analyze_udtf.py
index 194cd3db765..6fb3ca995e5 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -164,7 +164,4 @@ if __name__ == "__main__":
     java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
     auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
     (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
-    # TODO: Remove the following two lines and use `Process.pid()` when we 
drop JDK 8.
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
     main(sock_file, sock_file)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index a3c7bbb59dd..77481704979 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1352,7 +1352,4 @@ if __name__ == "__main__":
     java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
     auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
     (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
-    # TODO: Remove the following two lines and use `Process.pid()` when we 
drop JDK 8.
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
     main(sock_file, sock_file)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index 8f99325e4e0..2e410eae61e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -49,7 +49,7 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { 
self: BasePythonRunner[
       startTime: Long,
       env: SparkEnv,
       worker: PythonWorker,
-      pid: Option[Int],
+      pid: Option[Long],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[OUT] = {
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index b99517f544d..12c51506b13 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -78,7 +78,7 @@ abstract class BasePythonUDFRunner(
       startTime: Long,
       env: SparkEnv,
       worker: PythonWorker,
-      pid: Option[Int],
+      pid: Option[Long],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[Array[Byte]] = {
     new ReaderIterator(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to