This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 203942c58318 [SPARK-51688][PYTHON][FOLLOW-UP] Implement UDS in
Accumulators
203942c58318 is described below
commit 203942c583187b6d0a012ff2b2d6aab5c664bd39
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Apr 15 16:52:01 2025 +0900
[SPARK-51688][PYTHON][FOLLOW-UP] Implement UDS in Accumulators
### What changes were proposed in this pull request?
This PR is a followup of https://github.com/apache/spark/pull/50466 that
enables Unix Domain Sockets in Accumulators as well. This will be the last PR
to complete UDS in PySpark.
### Why are the changes needed?
This was not handled in the original PR as it was a bit complicated. Was
separated out. See the original PR.
### Does this PR introduce _any_ user-facing change?
See https://github.com/apache/spark/pull/50466
### How was this patch tested?
CI in this PR with enabling it by default. Also will set up a scheduled
build through https://github.com/apache/spark/pull/50585
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #50587 from HyukjinKwon/SPARK-51688-followup.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../org/apache/spark/api/python/PythonRDD.scala | 47 +++++++++++------
python/pyspark/accumulators.py | 60 ++++++++++++++++------
python/pyspark/core/context.py | 21 ++++++--
3 files changed, 94 insertions(+), 34 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 2152724c4c13..e1f16fe32ebe 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -39,8 +39,9 @@ import
org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.{Logging, MDC}
-import org.apache.spark.internal.LogKeys.{HOST, PORT}
+import org.apache.spark.internal.LogKeys.{HOST, PORT, SOCKET_ADDRESS}
import org.apache.spark.internal.config.BUFFER_SIZE
+import
org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer,
SocketFuncServer}
@@ -717,35 +718,50 @@ class BytesToString extends
org.apache.spark.api.java.function.Function[Array[By
* collects a list of pickled strings that we pass to Python through a socket.
*/
private[spark] class PythonAccumulatorV2(
- @transient private val serverHost: String,
- private val serverPort: Int,
- private val secretToken: String)
+ @transient private val serverHost: Option[String],
+ private val serverPort: Option[Int],
+ private val secretToken: Option[String],
+ @transient private val socketPath: Option[String])
extends CollectionAccumulator[Array[Byte]] with Logging {
- Utils.checkHost(serverHost)
+ // Unix domain socket
+ def this(socketPath: String) = this(None, None, None, Some(socketPath))
+ // TPC socket
+ def this(serverHost: String, serverPort: Int, secretToken: String) = this(
+ Some(serverHost), Some(serverPort), Some(secretToken), None)
+
+ serverHost.foreach(Utils.checkHost)
val bufferSize = SparkEnv.get.conf.get(BUFFER_SIZE)
+ val isUnixDomainSock =
SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
/**
* We try to reuse a single Socket to transfer accumulator updates, as they
are all added
* by the DAGScheduler's single-threaded RpcEndpoint anyway.
*/
- @transient private var socket: Socket = _
+ @transient private var socket: SocketChannel = _
- private def openSocket(): Socket = synchronized {
- if (socket == null || socket.isClosed) {
- socket = new Socket(serverHost, serverPort)
- logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST,
serverHost)}" +
- log" port: ${MDC(PORT, serverPort)}")
+ private def openSocket(): SocketChannel = synchronized {
+ if (socket == null || !socket.isOpen) {
+ if (isUnixDomainSock) {
+ socket = SocketChannel.open(UnixDomainSocketAddress.of(socketPath.get))
+ logInfo(log"Connected to AccumulatorServer at socket:
${MDC(SOCKET_ADDRESS, serverHost)}")
+ } else {
+ socket = SocketChannel.open(new InetSocketAddress(serverHost.get,
serverPort.get))
+ logInfo(log"Connected to AccumulatorServer at host: ${MDC(HOST,
serverHost)}" +
+ log" port: ${MDC(PORT, serverPort)}")
+ }
// send the secret just for the initial authentication when opening a
new connection
-
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
+ secretToken.foreach { token =>
+
Channels.newOutputStream(socket).write(token.getBytes(StandardCharsets.UTF_8))
+ }
}
socket
}
// Need to override so the types match with PythonFunction
override def copyAndReset(): PythonAccumulatorV2 = {
- new PythonAccumulatorV2(serverHost, serverPort, secretToken)
+ new PythonAccumulatorV2(serverHost, serverPort, secretToken, socketPath)
}
override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]):
Unit = synchronized {
@@ -758,8 +774,9 @@ private[spark] class PythonAccumulatorV2(
} else {
// This happens on the master, where we pass the updates to Python
through a socket
val socket = openSocket()
- val in = socket.getInputStream
- val out = new DataOutputStream(new
BufferedOutputStream(socket.getOutputStream, bufferSize))
+ val in = Channels.newInputStream(socket)
+ val out = new DataOutputStream(
+ new BufferedOutputStream(Channels.newOutputStream(socket), bufferSize))
val values = other.value
out.writeInt(values.size)
for (array <- values.asScala) {
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 93a64d8eef10..59f7856688ee 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -15,12 +15,13 @@
# limitations under the License.
#
+import os
import sys
import select
import struct
-import socketserver as SocketServer
+import socketserver
import threading
-from typing import Callable, Dict, Generic, Tuple, Type, TYPE_CHECKING,
TypeVar, Union
+from typing import Callable, Dict, Generic, Tuple, Type, TYPE_CHECKING,
TypeVar, Union, Optional
from pyspark.serializers import read_int, CPickleSerializer
from pyspark.errors import PySparkRuntimeError
@@ -252,7 +253,7 @@ FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) #
type: ignore[type-var]
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) # type:
ignore[type-var]
-class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
+class UpdateRequestHandler(socketserver.StreamRequestHandler):
"""
This handler will keep polling updates from the same socket until the
@@ -293,37 +294,64 @@ class
_UpdateRequestHandler(SocketServer.StreamRequestHandler):
"The value of the provided token to the AccumulatorServer
is not correct."
)
- # first we keep polling till we've received the authentication token
- poll(authenticate_and_accum_updates)
+ # Unix Domain Socket does not need the auth.
+ if auth_token is not None:
+ # first we keep polling till we've received the authentication
token
+ poll(authenticate_and_accum_updates)
+
# now we've authenticated, don't need to check for the token anymore
poll(accum_updates)
-class AccumulatorServer(SocketServer.TCPServer):
+class AccumulatorTCPServer(socketserver.TCPServer):
+ server_shutdown = False
+
def __init__(
self,
server_address: Tuple[str, int],
RequestHandlerClass: Type["socketserver.BaseRequestHandler"],
auth_token: str,
):
- SocketServer.TCPServer.__init__(self, server_address,
RequestHandlerClass)
+ super().__init__(server_address, RequestHandlerClass)
self.auth_token = auth_token
- """
- A simple TCP server that intercepts shutdown() in order to interrupt
- our continuous polling on the handler.
- """
+ def shutdown(self) -> None:
+ self.server_shutdown = True
+ super().shutdown()
+ self.server_close()
+
+
+class AccumulatorUnixServer(socketserver.UnixStreamServer):
server_shutdown = False
+ def __init__(
+ self, socket_path: str, RequestHandlerClass:
Type[socketserver.BaseRequestHandler]
+ ):
+ super().__init__(socket_path, RequestHandlerClass)
+ self.auth_token = None
+
def shutdown(self) -> None:
self.server_shutdown = True
- SocketServer.TCPServer.shutdown(self)
+ super().shutdown()
self.server_close()
+ if os.path.exists(self.server_address): # type: ignore[arg-type]
+ os.remove(self.server_address) # type: ignore[arg-type]
+
+
+def _start_update_server(
+ auth_token: str, is_unix_domain_sock: bool, socket_path: Optional[str] =
None
+) -> Union[AccumulatorTCPServer, AccumulatorUnixServer]:
+ """Start a TCP or Unix Domain Socket server for accumulator updates."""
+ if is_unix_domain_sock:
+ assert socket_path is not None
+ if os.path.exists(socket_path):
+ os.remove(socket_path)
+ server = AccumulatorUnixServer(socket_path, UpdateRequestHandler)
+ else:
+ server = AccumulatorTCPServer(
+ ("localhost", 0), UpdateRequestHandler, auth_token
+ ) # type: ignore[assignment]
-
-def _start_update_server(auth_token: str) -> AccumulatorServer:
- """Start a TCP server to receive accumulator updates in a daemon thread,
and returns it"""
- server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler,
auth_token)
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
thread.start()
diff --git a/python/pyspark/core/context.py b/python/pyspark/core/context.py
index 4d5c03fd1900..9098cb3805c3 100644
--- a/python/pyspark/core/context.py
+++ b/python/pyspark/core/context.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import uuid
import os
import shutil
import signal
@@ -305,11 +306,25 @@ class SparkContext:
# they will be passed back to us through a TCP server
assert self._gateway is not None
auth_token = self._gateway.gateway_parameters.auth_token
+ is_unix_domain_sock = (
+ self._conf.get("spark.python.unix.domain.socket.enabled",
"false").lower() == "true"
+ )
+ socket_path = None
+ if is_unix_domain_sock:
+ socket_dir = self._conf.get("spark.python.unix.domain.socket.dir")
+ if socket_dir is None:
+ socket_dir = getattr(self._jvm,
"java.lang.System").getProperty("java.io.tmpdir")
+ socket_path = os.path.join(socket_dir, f".{uuid.uuid4()}.sock")
start_update_server = accumulators._start_update_server
- self._accumulatorServer = start_update_server(auth_token)
- (host, port) = self._accumulatorServer.server_address
+ self._accumulatorServer = start_update_server(auth_token,
is_unix_domain_sock, socket_path)
assert self._jvm is not None
- self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port,
auth_token)
+ if is_unix_domain_sock:
+ self._javaAccumulator = self._jvm.PythonAccumulatorV2(
+ self._accumulatorServer.server_address
+ )
+ else:
+ (host, port) = self._accumulatorServer.server_address # type:
ignore[misc]
+ self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port,
auth_token)
self._jsc.sc().register(self._javaAccumulator)
# If encryption is enabled, we need to setup a server in the jvm to
read broadcast
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]