This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 203942c58318 [SPARK-51688][PYTHON][FOLLOW-UP] Implement UDS in Accumulators 203942c58318 is described below commit 203942c583187b6d0a012ff2b2d6aab5c664bd39 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Tue Apr 15 16:52:01 2025 +0900 [SPARK-51688][PYTHON][FOLLOW-UP] Implement UDS in Accumulators ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/50466 that enables Unix Domain Sockets in Accumulators as well. This will be the last PR to complete UDS in PySpark. ### Why are the changes needed? This was not handled in the original PR as it was a bit complicated. Was separated out. See the original PR. ### Does this PR introduce _any_ user-facing change? See https://github.com/apache/spark/pull/50466 ### How was this patch tested? CI in this PR with enabling it by default. Also will set up a scheduled build through https://github.com/apache/spark/pull/50585 ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50587 from HyukjinKwon/SPARK-51688-followup. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../org/apache/spark/api/python/PythonRDD.scala | 47 +++++++++++------ python/pyspark/accumulators.py | 60 ++++++++++++++++------ python/pyspark/core/context.py | 21 ++++++-- 3 files changed, 94 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 2152724c4c13..e1f16fe32ebe 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -39,8 +39,9 @@ import org.apache.spark.api.python.PythonFunction.PythonAccumulator import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.{Logging, MDC} -import org.apache.spark.internal.LogKeys.{HOST, PORT} +import org.apache.spark.internal.LogKeys.{HOST, PORT, SOCKET_ADDRESS} import org.apache.spark.internal.config.BUFFER_SIZE +import org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, SocketFuncServer} @@ -717,35 +718,50 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By * collects a list of pickled strings that we pass to Python through a socket. */ private[spark] class PythonAccumulatorV2( - @transient private val serverHost: String, - private val serverPort: Int, - private val secretToken: String) + @transient private val serverHost: Option[String], + private val serverPort: Option[Int], + private val secretToken: Option[String], + @transient private val socketPath: Option[String]) extends CollectionAccumulator[Array[Byte]] with Logging { - Utils.checkHost(serverHost) + // Unix domain socket + def this(socketPath: String) = this(None, None, None, Some(socketPath)) + // TPC socket + def this(serverHost: String, serverPort: Int, secretToken: String) = this( + Some(serverHost), Some(serverPort), Some(secretToken), None) + + serverHost.foreach(Utils.checkHost) val bufferSize = SparkEnv.get.conf.get(BUFFER_SIZE) + val isUnixDomainSock = SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED) /** * We try to reuse a single Socket to transfer accumulator updates, as they are all added * by the DAGScheduler's single-threaded RpcEndpoint anyway. */ - @transient private var socket: Socket = _ + @transient private var socket: SocketChannel = _ - private def openSocket(): Socket = synchronized { - if (socket == null || socket.isClosed) { - socket = new Socket(serverHost, serverPort) - logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, serverHost)}" + - log" port: ${MDC(PORT, serverPort)}") + private def openSocket(): SocketChannel = synchronized { + if (socket == null || !socket.isOpen) { + if (isUnixDomainSock) { + socket = SocketChannel.open(UnixDomainSocketAddress.of(socketPath.get)) + logInfo(log"Connected to AccumulatorServer at socket: ${MDC(SOCKET_ADDRESS, serverHost)}") + } else { + socket = SocketChannel.open(new InetSocketAddress(serverHost.get, serverPort.get)) + logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, serverHost)}" + + log" port: ${MDC(PORT, serverPort)}") + } // send the secret just for the initial authentication when opening a new connection - socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8)) + secretToken.foreach { token => + Channels.newOutputStream(socket).write(token.getBytes(StandardCharsets.UTF_8)) + } } socket } // Need to override so the types match with PythonFunction override def copyAndReset(): PythonAccumulatorV2 = { - new PythonAccumulatorV2(serverHost, serverPort, secretToken) + new PythonAccumulatorV2(serverHost, serverPort, secretToken, socketPath) } override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized { @@ -758,8 +774,9 @@ private[spark] class PythonAccumulatorV2( } else { // This happens on the master, where we pass the updates to Python through a socket val socket = openSocket() - val in = socket.getInputStream - val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) + val in = Channels.newInputStream(socket) + val out = new DataOutputStream( + new BufferedOutputStream(Channels.newOutputStream(socket), bufferSize)) val values = other.value out.writeInt(values.size) for (array <- values.asScala) { diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 93a64d8eef10..59f7856688ee 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -15,12 +15,13 @@ # limitations under the License. # +import os import sys import select import struct -import socketserver as SocketServer +import socketserver import threading -from typing import Callable, Dict, Generic, Tuple, Type, TYPE_CHECKING, TypeVar, Union +from typing import Callable, Dict, Generic, Tuple, Type, TYPE_CHECKING, TypeVar, Union, Optional from pyspark.serializers import read_int, CPickleSerializer from pyspark.errors import PySparkRuntimeError @@ -252,7 +253,7 @@ FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) # type: ignore[type-var] COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) # type: ignore[type-var] -class _UpdateRequestHandler(SocketServer.StreamRequestHandler): +class UpdateRequestHandler(socketserver.StreamRequestHandler): """ This handler will keep polling updates from the same socket until the @@ -293,37 +294,64 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): "The value of the provided token to the AccumulatorServer is not correct." ) - # first we keep polling till we've received the authentication token - poll(authenticate_and_accum_updates) + # Unix Domain Socket does not need the auth. + if auth_token is not None: + # first we keep polling till we've received the authentication token + poll(authenticate_and_accum_updates) + # now we've authenticated, don't need to check for the token anymore poll(accum_updates) -class AccumulatorServer(SocketServer.TCPServer): +class AccumulatorTCPServer(socketserver.TCPServer): + server_shutdown = False + def __init__( self, server_address: Tuple[str, int], RequestHandlerClass: Type["socketserver.BaseRequestHandler"], auth_token: str, ): - SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass) + super().__init__(server_address, RequestHandlerClass) self.auth_token = auth_token - """ - A simple TCP server that intercepts shutdown() in order to interrupt - our continuous polling on the handler. - """ + def shutdown(self) -> None: + self.server_shutdown = True + super().shutdown() + self.server_close() + + +class AccumulatorUnixServer(socketserver.UnixStreamServer): server_shutdown = False + def __init__( + self, socket_path: str, RequestHandlerClass: Type[socketserver.BaseRequestHandler] + ): + super().__init__(socket_path, RequestHandlerClass) + self.auth_token = None + def shutdown(self) -> None: self.server_shutdown = True - SocketServer.TCPServer.shutdown(self) + super().shutdown() self.server_close() + if os.path.exists(self.server_address): # type: ignore[arg-type] + os.remove(self.server_address) # type: ignore[arg-type] + + +def _start_update_server( + auth_token: str, is_unix_domain_sock: bool, socket_path: Optional[str] = None +) -> Union[AccumulatorTCPServer, AccumulatorUnixServer]: + """Start a TCP or Unix Domain Socket server for accumulator updates.""" + if is_unix_domain_sock: + assert socket_path is not None + if os.path.exists(socket_path): + os.remove(socket_path) + server = AccumulatorUnixServer(socket_path, UpdateRequestHandler) + else: + server = AccumulatorTCPServer( + ("localhost", 0), UpdateRequestHandler, auth_token + ) # type: ignore[assignment] - -def _start_update_server(auth_token: str) -> AccumulatorServer: - """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" - server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() diff --git a/python/pyspark/core/context.py b/python/pyspark/core/context.py index 4d5c03fd1900..9098cb3805c3 100644 --- a/python/pyspark/core/context.py +++ b/python/pyspark/core/context.py @@ -15,6 +15,7 @@ # limitations under the License. # +import uuid import os import shutil import signal @@ -305,11 +306,25 @@ class SparkContext: # they will be passed back to us through a TCP server assert self._gateway is not None auth_token = self._gateway.gateway_parameters.auth_token + is_unix_domain_sock = ( + self._conf.get("spark.python.unix.domain.socket.enabled", "false").lower() == "true" + ) + socket_path = None + if is_unix_domain_sock: + socket_dir = self._conf.get("spark.python.unix.domain.socket.dir") + if socket_dir is None: + socket_dir = getattr(self._jvm, "java.lang.System").getProperty("java.io.tmpdir") + socket_path = os.path.join(socket_dir, f".{uuid.uuid4()}.sock") start_update_server = accumulators._start_update_server - self._accumulatorServer = start_update_server(auth_token) - (host, port) = self._accumulatorServer.server_address + self._accumulatorServer = start_update_server(auth_token, is_unix_domain_sock, socket_path) assert self._jvm is not None - self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token) + if is_unix_domain_sock: + self._javaAccumulator = self._jvm.PythonAccumulatorV2( + self._accumulatorServer.server_address + ) + else: + (host, port) = self._accumulatorServer.server_address # type: ignore[misc] + self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token) self._jsc.sc().register(self._javaAccumulator) # If encryption is enabled, we need to setup a server in the jvm to read broadcast --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org