This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c6aaa18e6cfd Revert "[SPARK-45302][PYTHON] Remove PID communication between Pythonworkers when no demon is used" c6aaa18e6cfd is described below commit c6aaa18e6cfd49b434f782171e42778012672b80 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Thu Apr 25 11:57:23 2024 +0900 Revert "[SPARK-45302][PYTHON] Remove PID communication between Pythonworkers when no demon is used" ### What changes were proposed in this pull request? This PR reverts https://github.com/apache/spark/pull/43087. ### Why are the changes needed? To clean up those workers. I will make a refactoring PR soon. I will bring them back again with a refactoring PR. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46195 from HyukjinKwon/SPARK-45302-revert. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 6 +++--- .../scala/org/apache/spark/api/python/PythonRunner.scala | 10 +++++----- .../org/apache/spark/api/python/PythonWorkerFactory.scala | 15 ++++++++------- python/pyspark/daemon.py | 4 ++-- .../sql/connect/streaming/worker/foreach_batch_worker.py | 2 ++ .../sql/connect/streaming/worker/listener_worker.py | 2 ++ .../sql/streaming/python_streaming_source_runner.py | 2 ++ python/pyspark/sql/worker/analyze_udtf.py | 3 +++ python/pyspark/sql/worker/commit_data_source_write.py | 2 ++ python/pyspark/sql/worker/create_data_source.py | 2 ++ python/pyspark/sql/worker/lookup_data_sources.py | 2 ++ python/pyspark/sql/worker/plan_data_source_read.py | 2 ++ python/pyspark/sql/worker/python_streaming_sink_runner.py | 2 ++ python/pyspark/sql/worker/write_into_data_source.py | 2 ++ python/pyspark/worker.py | 3 +++ .../spark/sql/execution/python/PythonArrowOutput.scala | 2 +- .../spark/sql/execution/python/PythonUDFRunner.scala | 2 +- 17 files changed, 44 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 50d0358004d4..e1c84d181a2f 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -142,7 +142,7 @@ class SparkEnv ( workerModule: String, daemonModule: String, envVars: Map[String, String], - useDaemon: Boolean): (PythonWorker, Option[Long]) = { + useDaemon: Boolean): (PythonWorker, Option[Int]) = { synchronized { val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars) val workerFactory = pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory( @@ -161,7 +161,7 @@ class SparkEnv ( pythonExec: String, workerModule: String, envVars: Map[String, String], - useDaemon: Boolean): (PythonWorker, Option[Long]) = { + useDaemon: Boolean): (PythonWorker, Option[Int]) = { createPythonWorker( pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, useDaemon) } @@ -170,7 +170,7 @@ class SparkEnv ( pythonExec: String, workerModule: String, daemonModule: String, - envVars: Map[String, String]): (PythonWorker, Option[Long]) = { + envVars: Map[String, String]): (PythonWorker, Option[Int]) = { val useDaemon = conf.get(Python.PYTHON_USE_DAEMON) createPythonWorker( pythonExec, workerModule, daemonModule, envVars, useDaemon) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 17cb0c5a55dd..7ff782db210d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -88,7 +88,7 @@ private object BasePythonRunner { private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler") - private def faultHandlerLogPath(pid: Long): Path = { + private def faultHandlerLogPath(pid: Int): Path = { new File(faultHandlerLogDir, pid.toString).toPath } } @@ -204,7 +204,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) - val (worker: PythonWorker, pid: Option[Long]) = env.createPythonWorker( + val (worker: PythonWorker, pid: Option[Int]) = env.createPythonWorker( pythonExec, workerModule, daemonModule, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make @@ -257,7 +257,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Long], + pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[OUT] @@ -465,7 +465,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Long], + pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext) extends Iterator[OUT] { @@ -842,7 +842,7 @@ private[spark] class PythonRunner( startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Long], + pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { new ReaderIterator( diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index eb740b72987c..f8260e177cc3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -92,7 +92,7 @@ private[spark] class PythonWorkerFactory( envVars.getOrElse("PYTHONPATH", ""), sys.env.getOrElse("PYTHONPATH", "")) - def create(): (PythonWorker, Option[Long]) = { + def create(): (PythonWorker, Option[Int]) = { if (useDaemon) { self.synchronized { // Pull from idle workers until we one that is alive, otherwise create a new one. @@ -102,7 +102,7 @@ private[spark] class PythonWorkerFactory( if (workerHandle.isAlive()) { try { worker.selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE) - return (worker, Some(workerHandle.pid())) + return (worker, Some(workerHandle.pid().toInt)) } catch { case c: CancelledKeyException => /* pass */ } @@ -122,9 +122,9 @@ private[spark] class PythonWorkerFactory( * processes itself to avoid the high cost of forking from Java. This currently only works * on UNIX-based systems. */ - private def createThroughDaemon(): (PythonWorker, Option[Long]) = { + private def createThroughDaemon(): (PythonWorker, Option[Int]) = { - def createWorker(): (PythonWorker, Option[Long]) = { + def createWorker(): (PythonWorker, Option[Int]) = { val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort)) // These calls are blocking. val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt() @@ -165,7 +165,7 @@ private[spark] class PythonWorkerFactory( /** * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ - private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Long]) = { + private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Int]) = { var serverSocketChannel: ServerSocketChannel = null try { serverSocketChannel = ServerSocketChannel.open() @@ -209,7 +209,8 @@ private[spark] class PythonWorkerFactory( "Timed out while waiting for the Python worker to connect back") } authHelper.authClient(socketChannel.socket()) - val pid = workerProcess.toHandle.pid() + // TODO: When we drop JDK 8, we can just use workerProcess.pid() + val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt() if (pid < 0) { throw new IllegalStateException("Python failed to launch worker with code " + pid) } @@ -405,7 +406,7 @@ private[spark] class PythonWorkerFactory( daemonWorkers.get(worker).foreach { processHandle => // tell daemon to kill worker by pid val output = new DataOutputStream(daemon.getOutputStream) - output.writeLong(processHandle.pid()) + output.writeInt(processHandle.pid().toInt) output.flush() daemon.getOutputStream.flush() } diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index b0e06d13beda..bbbc495d053e 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -28,7 +28,7 @@ from errno import EINTR, EAGAIN from socket import AF_INET, AF_INET6, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT -from pyspark.serializers import read_long, write_int, write_with_length, UTF8Deserializer +from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer if len(sys.argv) > 1: import importlib @@ -139,7 +139,7 @@ def manager(): if 0 in ready_fds: try: - worker_pid = read_long(stdin_bin) + worker_pid = read_int(stdin_bin) except EOFError: # Spark told us to exit by closing stdin shutdown(0) diff --git a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py index 80cc69126916..0c92de6372b6 100644 --- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py @@ -96,4 +96,6 @@ if __name__ == "__main__": (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) # There could be a long time between each micro batch. sock.settimeout(None) + write_int(os.getpid(), sock_file) + sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py index 3709e50ba026..a7a5066ca0d7 100644 --- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -110,4 +110,6 @@ if __name__ == "__main__": (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) # There could be a long time between each listener event. sock.settimeout(None) + write_int(os.getpid(), sock_file) + sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py index 76f9048e3edb..8109403b42dd 100644 --- a/python/pyspark/sql/streaming/python_streaming_source_runner.py +++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py @@ -163,4 +163,6 @@ if __name__ == "__main__": java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + write_int(os.getpid(), sock_file) + sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index d0a24363c0c1..7dafb87c4221 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -264,4 +264,7 @@ if __name__ == "__main__": java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + # TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8. + write_int(os.getpid(), sock_file) + sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/commit_data_source_write.py b/python/pyspark/sql/worker/commit_data_source_write.py index 530f18ef8288..cf22c19ab3eb 100644 --- a/python/pyspark/sql/worker/commit_data_source_write.py +++ b/python/pyspark/sql/worker/commit_data_source_write.py @@ -117,4 +117,6 @@ if __name__ == "__main__": java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + write_int(os.getpid(), sock_file) + sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py index 1f11b65f44c7..33394cdff876 100644 --- a/python/pyspark/sql/worker/create_data_source.py +++ b/python/pyspark/sql/worker/create_data_source.py @@ -187,4 +187,6 @@ if __name__ == "__main__": java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + write_int(os.getpid(), sock_file) + sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/lookup_data_sources.py b/python/pyspark/sql/worker/lookup_data_sources.py index 7f0127b71946..6da9d5925f63 100644 --- a/python/pyspark/sql/worker/lookup_data_sources.py +++ b/python/pyspark/sql/worker/lookup_data_sources.py @@ -95,4 +95,6 @@ if __name__ == "__main__": java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + write_int(os.getpid(), sock_file) + sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py index 6c0d48caefeb..8a8b2cab91d8 100644 --- a/python/pyspark/sql/worker/plan_data_source_read.py +++ b/python/pyspark/sql/worker/plan_data_source_read.py @@ -299,4 +299,6 @@ if __name__ == "__main__": java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + write_int(os.getpid(), sock_file) + sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py b/python/pyspark/sql/worker/python_streaming_sink_runner.py index ba0a8037de60..d14585eab51d 100644 --- a/python/pyspark/sql/worker/python_streaming_sink_runner.py +++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py @@ -137,4 +137,6 @@ if __name__ == "__main__": java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + write_int(os.getpid(), sock_file) + sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py index ad8717cb33b5..5714f35cbe71 100644 --- a/python/pyspark/sql/worker/write_into_data_source.py +++ b/python/pyspark/sql/worker/write_into_data_source.py @@ -229,4 +229,6 @@ if __name__ == "__main__": java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + write_int(os.getpid(), sock_file) + sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 41f6c35bc445..e9c259e68a27 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1868,4 +1868,7 @@ if __name__ == "__main__": java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + # TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8. + write_int(os.getpid(), sock_file) + sock_file.flush() main(sock_file, sock_file) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index 90922d89ad10..e7d4aa9f0460 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -49,7 +49,7 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Long], + pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[OUT] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index bbe9fbfc748d..87ff5a0ec433 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -80,7 +80,7 @@ abstract class BasePythonUDFRunner( startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Long], + pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { new ReaderIterator( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org