This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 17430fe4702 [SPARK-45302][PYTHON] Remove PID communication between Python workers when no demon is used 17430fe4702 is described below commit 17430fe47029f1d27c7913468b95abfd856fddcc Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Wed Sep 27 10:48:17 2023 +0900 [SPARK-45302][PYTHON] Remove PID communication between Python workers when no demon is used ### What changes were proposed in this pull request? This PR removes the legacy workaround for JDK 8 in `PythonWorkerFactory`. ### Why are the changes needed? No need to manually send the PID around through the socket. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? There are existing unittests for the daemon disabled. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43087 from HyukjinKwon/SPARK-45302. Lead-authored-by: Hyukjin Kwon <gurwls...@apache.org> Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 4 ++-- .../scala/org/apache/spark/api/python/PythonRunner.scala | 10 +++++----- .../org/apache/spark/api/python/PythonWorkerFactory.scala | 15 +++++++-------- python/pyspark/daemon.py | 4 ++-- .../sql/connect/streaming/worker/foreach_batch_worker.py | 2 -- .../sql/connect/streaming/worker/listener_worker.py | 2 -- python/pyspark/sql/worker/analyze_udtf.py | 3 --- python/pyspark/worker.py | 3 --- .../spark/sql/execution/python/PythonArrowOutput.scala | 2 +- .../spark/sql/execution/python/PythonUDFRunner.scala | 2 +- 10 files changed, 18 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index e404c9ee8b4..937170b5ee8 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -128,7 +128,7 @@ class SparkEnv ( pythonExec: String, workerModule: String, daemonModule: String, - envVars: Map[String, String]): (PythonWorker, Option[Int]) = { + envVars: Map[String, String]): (PythonWorker, Option[Long]) = { synchronized { val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars) pythonWorkers.getOrElseUpdate(key, @@ -139,7 +139,7 @@ class SparkEnv ( private[spark] def createPythonWorker( pythonExec: String, workerModule: String, - envVars: Map[String, String]): (PythonWorker, Option[Int]) = { + envVars: Map[String, String]): (PythonWorker, Option[Long]) = { createPythonWorker( pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index db95e6c2bd6..2a63298d0a1 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -84,7 +84,7 @@ private object BasePythonRunner { private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler") - private def faultHandlerLogPath(pid: Int): Path = { + private def faultHandlerLogPath(pid: Long): Path = { new File(faultHandlerLogDir, pid.toString).toPath } } @@ -200,7 +200,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) - val (worker: PythonWorker, pid: Option[Int]) = env.createPythonWorker( + val (worker: PythonWorker, pid: Option[Long]) = env.createPythonWorker( pythonExec, workerModule, daemonModule, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make @@ -253,7 +253,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Int], + pid: Option[Long], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[OUT] @@ -463,7 +463,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Int], + pid: Option[Long], releasedOrClosed: AtomicBoolean, context: TaskContext) extends Iterator[OUT] { @@ -838,7 +838,7 @@ private[spark] class PythonRunner( startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Int], + pid: Option[Long], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { new ReaderIterator( diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index d0776eb2cc7..cef815b22ac 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -76,7 +76,7 @@ private[spark] class PythonWorkerFactory( @GuardedBy("self") private var daemonPort: Int = 0 @GuardedBy("self") - private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, Int]() + private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, Long]() @GuardedBy("self") private val idleWorkers = new mutable.Queue[PythonWorker]() @GuardedBy("self") @@ -91,7 +91,7 @@ private[spark] class PythonWorkerFactory( envVars.getOrElse("PYTHONPATH", ""), sys.env.getOrElse("PYTHONPATH", "")) - def create(): (PythonWorker, Option[Int]) = { + def create(): (PythonWorker, Option[Long]) = { if (useDaemon) { self.synchronized { if (idleWorkers.nonEmpty) { @@ -111,9 +111,9 @@ private[spark] class PythonWorkerFactory( * processes itself to avoid the high cost of forking from Java. This currently only works * on UNIX-based systems. */ - private def createThroughDaemon(): (PythonWorker, Option[Int]) = { + private def createThroughDaemon(): (PythonWorker, Option[Long]) = { - def createWorker(): (PythonWorker, Option[Int]) = { + def createWorker(): (PythonWorker, Option[Long]) = { val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort)) // These calls are blocking. val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt() @@ -153,7 +153,7 @@ private[spark] class PythonWorkerFactory( /** * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ - private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Int]) = { + private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Long]) = { var serverSocketChannel: ServerSocketChannel = null try { serverSocketChannel = ServerSocketChannel.open() @@ -189,8 +189,7 @@ private[spark] class PythonWorkerFactory( try { val socketChannel = serverSocketChannel.accept() authHelper.authClient(socketChannel.socket()) - // TODO: When we drop JDK 8, we can just use workerProcess.pid() - val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt() + val pid = workerProcess.toHandle.pid() if (pid < 0) { throw new IllegalStateException("Python failed to launch worker with code " + pid) } @@ -386,7 +385,7 @@ private[spark] class PythonWorkerFactory( daemonWorkers.get(worker).foreach { pid => // tell daemon to kill worker by pid val output = new DataOutputStream(daemon.getOutputStream) - output.writeInt(pid) + output.writeLong(pid) output.flush() daemon.getOutputStream.flush() } diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index bbbc495d053..b0e06d13bed 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -28,7 +28,7 @@ from errno import EINTR, EAGAIN from socket import AF_INET, AF_INET6, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT -from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer +from pyspark.serializers import read_long, write_int, write_with_length, UTF8Deserializer if len(sys.argv) > 1: import importlib @@ -139,7 +139,7 @@ def manager(): if 0 in ready_fds: try: - worker_pid = read_int(stdin_bin) + worker_pid = read_long(stdin_bin) except EOFError: # Spark told us to exit by closing stdin shutdown(0) diff --git a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py index 06534e355de..022e768c43b 100644 --- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py @@ -87,6 +87,4 @@ if __name__ == "__main__": (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) # There could be a long time between each micro batch. sock.settimeout(None) - write_int(os.getpid(), sock_file) - sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py index ed38a788435..bb6bcd5d965 100644 --- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -101,6 +101,4 @@ if __name__ == "__main__": (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) # There could be a long time between each listener event. sock.settimeout(None) - write_int(os.getpid(), sock_file) - sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index 194cd3db765..6fb3ca995e5 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -164,7 +164,4 @@ if __name__ == "__main__": java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) - # TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8. - write_int(os.getpid(), sock_file) - sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a3c7bbb59dd..77481704979 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1352,7 +1352,4 @@ if __name__ == "__main__": java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) - # TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8. - write_int(os.getpid(), sock_file) - sock_file.flush() main(sock_file, sock_file) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index 8f99325e4e0..2e410eae61e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -49,7 +49,7 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Int], + pid: Option[Long], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[OUT] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index b99517f544d..12c51506b13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -78,7 +78,7 @@ abstract class BasePythonUDFRunner( startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Int], + pid: Option[Long], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { new ReaderIterator( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org