This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 203942c58318 [SPARK-51688][PYTHON][FOLLOW-UP] Implement UDS in 
Accumulators
203942c58318 is described below

commit 203942c583187b6d0a012ff2b2d6aab5c664bd39
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Tue Apr 15 16:52:01 2025 +0900

    [SPARK-51688][PYTHON][FOLLOW-UP] Implement UDS in Accumulators
    
    ### What changes were proposed in this pull request?
    
    This PR is a followup of https://github.com/apache/spark/pull/50466 that 
enables Unix Domain Sockets in Accumulators as well. This will be the last PR 
to complete UDS in PySpark.
    
    ### Why are the changes needed?
    
    This was not handled in the original PR as it was a bit complicated. Was 
separated out. See the original PR.
    
    ### Does this PR introduce _any_ user-facing change?
    
    See https://github.com/apache/spark/pull/50466
    
    ### How was this patch tested?
    
    CI in this PR with enabling it by default. Also will set up a scheduled 
build through https://github.com/apache/spark/pull/50585
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50587 from HyukjinKwon/SPARK-51688-followup.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../org/apache/spark/api/python/PythonRDD.scala    | 47 +++++++++++------
 python/pyspark/accumulators.py                     | 60 ++++++++++++++++------
 python/pyspark/core/context.py                     | 21 ++++++--
 3 files changed, 94 insertions(+), 34 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 2152724c4c13..e1f16fe32ebe 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -39,8 +39,9 @@ import 
org.apache.spark.api.python.PythonFunction.PythonAccumulator
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.input.PortableDataStream
 import org.apache.spark.internal.{Logging, MDC}
-import org.apache.spark.internal.LogKeys.{HOST, PORT}
+import org.apache.spark.internal.LogKeys.{HOST, PORT, SOCKET_ADDRESS}
 import org.apache.spark.internal.config.BUFFER_SIZE
+import 
org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, 
SocketFuncServer}
@@ -717,35 +718,50 @@ class BytesToString extends 
org.apache.spark.api.java.function.Function[Array[By
  * collects a list of pickled strings that we pass to Python through a socket.
  */
 private[spark] class PythonAccumulatorV2(
-    @transient private val serverHost: String,
-    private val serverPort: Int,
-    private val secretToken: String)
+    @transient private val serverHost: Option[String],
+    private val serverPort: Option[Int],
+    private val secretToken: Option[String],
+    @transient private val socketPath: Option[String])
   extends CollectionAccumulator[Array[Byte]] with Logging {
 
-  Utils.checkHost(serverHost)
+  // Unix domain socket
+  def this(socketPath: String) = this(None, None, None, Some(socketPath))
+  // TPC socket
+  def this(serverHost: String, serverPort: Int, secretToken: String) = this(
+    Some(serverHost), Some(serverPort), Some(secretToken), None)
+
+  serverHost.foreach(Utils.checkHost)
 
   val bufferSize = SparkEnv.get.conf.get(BUFFER_SIZE)
+  val isUnixDomainSock = 
SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
 
   /**
    * We try to reuse a single Socket to transfer accumulator updates, as they 
are all added
    * by the DAGScheduler's single-threaded RpcEndpoint anyway.
    */
-  @transient private var socket: Socket = _
+  @transient private var socket: SocketChannel = _
 
-  private def openSocket(): Socket = synchronized {
-    if (socket == null || socket.isClosed) {
-      socket = new Socket(serverHost, serverPort)
-      logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, 
serverHost)}" +
-        log" port: ${MDC(PORT, serverPort)}")
+  private def openSocket(): SocketChannel = synchronized {
+    if (socket == null || !socket.isOpen) {
+      if (isUnixDomainSock) {
+        socket = SocketChannel.open(UnixDomainSocketAddress.of(socketPath.get))
+        logInfo(log"Connected to AccumulatorServer at socket: 
${MDC(SOCKET_ADDRESS, serverHost)}")
+      } else {
+        socket = SocketChannel.open(new InetSocketAddress(serverHost.get, 
serverPort.get))
+        logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST, 
serverHost)}" +
+          log" port: ${MDC(PORT, serverPort)}")
+      }
       // send the secret just for the initial authentication when opening a 
new connection
-      
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
+      secretToken.foreach { token =>
+        
Channels.newOutputStream(socket).write(token.getBytes(StandardCharsets.UTF_8))
+      }
     }
     socket
   }
 
   // Need to override so the types match with PythonFunction
   override def copyAndReset(): PythonAccumulatorV2 = {
-    new PythonAccumulatorV2(serverHost, serverPort, secretToken)
+    new PythonAccumulatorV2(serverHost, serverPort, secretToken, socketPath)
   }
 
   override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): 
Unit = synchronized {
@@ -758,8 +774,9 @@ private[spark] class PythonAccumulatorV2(
     } else {
       // This happens on the master, where we pass the updates to Python 
through a socket
       val socket = openSocket()
-      val in = socket.getInputStream
-      val out = new DataOutputStream(new 
BufferedOutputStream(socket.getOutputStream, bufferSize))
+      val in = Channels.newInputStream(socket)
+      val out = new DataOutputStream(
+        new BufferedOutputStream(Channels.newOutputStream(socket), bufferSize))
       val values = other.value
       out.writeInt(values.size)
       for (array <- values.asScala) {
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 93a64d8eef10..59f7856688ee 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -15,12 +15,13 @@
 # limitations under the License.
 #
 
+import os
 import sys
 import select
 import struct
-import socketserver as SocketServer
+import socketserver
 import threading
-from typing import Callable, Dict, Generic, Tuple, Type, TYPE_CHECKING, 
TypeVar, Union
+from typing import Callable, Dict, Generic, Tuple, Type, TYPE_CHECKING, 
TypeVar, Union, Optional
 
 from pyspark.serializers import read_int, CPickleSerializer
 from pyspark.errors import PySparkRuntimeError
@@ -252,7 +253,7 @@ FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)  # 
type: ignore[type-var]
 COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)  # type: 
ignore[type-var]
 
 
-class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
+class UpdateRequestHandler(socketserver.StreamRequestHandler):
 
     """
     This handler will keep polling updates from the same socket until the
@@ -293,37 +294,64 @@ class 
_UpdateRequestHandler(SocketServer.StreamRequestHandler):
                     "The value of the provided token to the AccumulatorServer 
is not correct."
                 )
 
-        # first we keep polling till we've received the authentication token
-        poll(authenticate_and_accum_updates)
+        # Unix Domain Socket does not need the auth.
+        if auth_token is not None:
+            # first we keep polling till we've received the authentication 
token
+            poll(authenticate_and_accum_updates)
+
         # now we've authenticated, don't need to check for the token anymore
         poll(accum_updates)
 
 
-class AccumulatorServer(SocketServer.TCPServer):
+class AccumulatorTCPServer(socketserver.TCPServer):
+    server_shutdown = False
+
     def __init__(
         self,
         server_address: Tuple[str, int],
         RequestHandlerClass: Type["socketserver.BaseRequestHandler"],
         auth_token: str,
     ):
-        SocketServer.TCPServer.__init__(self, server_address, 
RequestHandlerClass)
+        super().__init__(server_address, RequestHandlerClass)
         self.auth_token = auth_token
 
-    """
-    A simple TCP server that intercepts shutdown() in order to interrupt
-    our continuous polling on the handler.
-    """
+    def shutdown(self) -> None:
+        self.server_shutdown = True
+        super().shutdown()
+        self.server_close()
+
+
+class AccumulatorUnixServer(socketserver.UnixStreamServer):
     server_shutdown = False
 
+    def __init__(
+        self, socket_path: str, RequestHandlerClass: 
Type[socketserver.BaseRequestHandler]
+    ):
+        super().__init__(socket_path, RequestHandlerClass)
+        self.auth_token = None
+
     def shutdown(self) -> None:
         self.server_shutdown = True
-        SocketServer.TCPServer.shutdown(self)
+        super().shutdown()
         self.server_close()
+        if os.path.exists(self.server_address):  # type: ignore[arg-type]
+            os.remove(self.server_address)  # type: ignore[arg-type]
+
+
+def _start_update_server(
+    auth_token: str, is_unix_domain_sock: bool, socket_path: Optional[str] = 
None
+) -> Union[AccumulatorTCPServer, AccumulatorUnixServer]:
+    """Start a TCP or Unix Domain Socket server for accumulator updates."""
+    if is_unix_domain_sock:
+        assert socket_path is not None
+        if os.path.exists(socket_path):
+            os.remove(socket_path)
+        server = AccumulatorUnixServer(socket_path, UpdateRequestHandler)
+    else:
+        server = AccumulatorTCPServer(
+            ("localhost", 0), UpdateRequestHandler, auth_token
+        )  # type: ignore[assignment]
 
-
-def _start_update_server(auth_token: str) -> AccumulatorServer:
-    """Start a TCP server to receive accumulator updates in a daemon thread, 
and returns it"""
-    server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, 
auth_token)
     thread = threading.Thread(target=server.serve_forever)
     thread.daemon = True
     thread.start()
diff --git a/python/pyspark/core/context.py b/python/pyspark/core/context.py
index 4d5c03fd1900..9098cb3805c3 100644
--- a/python/pyspark/core/context.py
+++ b/python/pyspark/core/context.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 
+import uuid
 import os
 import shutil
 import signal
@@ -305,11 +306,25 @@ class SparkContext:
         # they will be passed back to us through a TCP server
         assert self._gateway is not None
         auth_token = self._gateway.gateway_parameters.auth_token
+        is_unix_domain_sock = (
+            self._conf.get("spark.python.unix.domain.socket.enabled", 
"false").lower() == "true"
+        )
+        socket_path = None
+        if is_unix_domain_sock:
+            socket_dir = self._conf.get("spark.python.unix.domain.socket.dir")
+            if socket_dir is None:
+                socket_dir = getattr(self._jvm, 
"java.lang.System").getProperty("java.io.tmpdir")
+            socket_path = os.path.join(socket_dir, f".{uuid.uuid4()}.sock")
         start_update_server = accumulators._start_update_server
-        self._accumulatorServer = start_update_server(auth_token)
-        (host, port) = self._accumulatorServer.server_address
+        self._accumulatorServer = start_update_server(auth_token, 
is_unix_domain_sock, socket_path)
         assert self._jvm is not None
-        self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, 
auth_token)
+        if is_unix_domain_sock:
+            self._javaAccumulator = self._jvm.PythonAccumulatorV2(
+                self._accumulatorServer.server_address
+            )
+        else:
+            (host, port) = self._accumulatorServer.server_address  # type: 
ignore[misc]
+            self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, 
auth_token)
         self._jsc.sc().register(self._javaAccumulator)
 
         # If encryption is enabled, we need to setup a server in the jvm to 
read broadcast


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to