This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d3b6dd13e9e9 [SPARK-51688][PYTHON] Use Unix Domain Socket between
Python and JVM communication
d3b6dd13e9e9 is described below
commit d3b6dd13e9e9f1c995f9c1152d8958a29c8ccd54
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Apr 15 10:35:11 2025 +0900
[SPARK-51688][PYTHON] Use Unix Domain Socket between Python and JVM
communication
### What changes were proposed in this pull request?
This PR proposes to use Unix Domain Socket (UDS) in the communication
between Python process and JVM (except Py4J, which does not support UDS).
It adds a new configuration `spark.python.unix.domain.socket.enabled` that
is disabled by default. When enabled, it uses UDS. When disabled, we use TPC/IP
sockets as it is.
When we use UDS, since the data is protected by file permissions, it also
avoid doing the unnecessary authentication we use for TPC/IP sockets.
### Why are the changes needed?
1. UDS is known as faster than TPC/IP, see also
https://www.researchgate.net/figure/Performance-Comparison-of-TCP-vs-Unix-Domain-Sockets-as-a-Function-of-Message-Size_fig3_221461399
2. It does not require network as it avoids TPC/IP layer so we can avoid
network overhead.
### Does this PR introduce _any_ user-facing change?
To the end users, no. This is the implementation level change.
### How was this patch tested?
Manually ran the tests after enabling this configuration.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #50466 from HyukjinKwon/unix-domain-socket.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
R/pkg/R/context.R | 2 +-
.../org/apache/spark/api/python/PythonRDD.scala | 31 +++----
.../org/apache/spark/api/python/PythonRunner.scala | 100 ++++++++++++---------
.../spark/api/python/PythonWorkerFactory.scala | 70 ++++++++++++---
.../spark/api/python/PythonWorkerUtils.scala | 12 ++-
.../spark/api/python/StreamingPythonRunner.scala | 24 +++--
.../scala/org/apache/spark/api/r/RAuthHelper.scala | 1 +
.../main/scala/org/apache/spark/api/r/RRDD.scala | 6 +-
.../org/apache/spark/internal/config/Python.scala | 24 +++++
.../apache/spark/security/SocketAuthHelper.scala | 14 ++-
.../apache/spark/security/SocketAuthServer.scala | 66 ++++++++++----
.../apache/spark/api/python/PythonRDDSuite.scala | 16 ++--
.../spark/security/SocketAuthHelperSuite.scala | 12 ++-
python/pyspark/core/broadcast.py | 8 +-
python/pyspark/core/context.py | 2 +-
python/pyspark/daemon.py | 28 +++++-
.../deepspeed/tests/test_deepspeed_distributor.py | 4 +-
.../streaming/worker/foreach_batch_worker.py | 8 +-
.../connect/streaming/worker/listener_worker.py | 8 +-
.../streaming/python_streaming_source_runner.py | 8 +-
.../sql/streaming/stateful_processor_api_client.py | 32 ++++---
.../transform_with_state_driver_worker.py | 12 ++-
python/pyspark/sql/worker/analyze_udtf.py | 8 +-
.../pyspark/sql/worker/commit_data_source_write.py | 8 +-
python/pyspark/sql/worker/create_data_source.py | 8 +-
.../sql/worker/data_source_pushdown_filters.py | 8 +-
python/pyspark/sql/worker/lookup_data_sources.py | 8 +-
python/pyspark/sql/worker/plan_data_source_read.py | 8 +-
.../sql/worker/python_streaming_sink_runner.py | 8 +-
.../pyspark/sql/worker/write_into_data_source.py | 8 +-
python/pyspark/taskcontext.py | 22 ++---
python/pyspark/tests/test_appsubmit.py | 2 +
python/pyspark/util.py | 35 ++++++--
python/pyspark/worker.py | 19 ++--
python/pyspark/worker_util.py | 10 ++-
python/run-tests.py | 2 +
.../spark/deploy/yarn/YarnClusterSuite.scala | 14 ++-
.../spark/sql/api/python/PythonSQLUtils.scala | 7 +-
.../streaming/PythonStreamingSourceRunner.scala | 5 +-
.../TransformWithStateInPandasPythonRunner.scala | 57 +++++++++---
.../TransformWithStateInPandasStateServer.scala | 16 ++--
...ransformWithStateInPandasStateServerSuite.scala | 4 +-
42 files changed, 515 insertions(+), 230 deletions(-)
diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index eea83aa5ab52..0242e7114978 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -181,7 +181,7 @@ parallelize <- function(sc, coll, numSlices = 1) {
parallelism <- as.integer(numSlices)
jserver <- newJObject("org.apache.spark.api.r.RParallelizeServer", sc,
parallelism)
authSecret <- callJMethod(jserver, "secret")
- port <- callJMethod(jserver, "port")
+ port <- callJMethod(jserver, "connInfo")
conn <- socketConnection(
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
doServerAuth(conn, authSecret)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index d643983ef5df..2152724c4c13 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,6 +19,7 @@ package org.apache.spark.api.python
import java.io._
import java.net._
+import java.nio.channels.{Channels, SocketChannel}
import java.nio.charset.StandardCharsets
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
@@ -231,9 +232,9 @@ private[spark] object PythonRDD extends Logging {
* server object that can be used to join the JVM serving thread in
Python.
*/
def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean =
false): Array[Any] = {
- val handleFunc = (sock: Socket) => {
- val out = new DataOutputStream(sock.getOutputStream)
- val in = new DataInputStream(sock.getInputStream)
+ val handleFunc = (sock: SocketChannel) => {
+ val out = new DataOutputStream(Channels.newOutputStream(sock))
+ val in = new DataInputStream(Channels.newInputStream(sock))
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
@@ -287,7 +288,7 @@ private[spark] object PythonRDD extends Logging {
}
val server = new SocketFuncServer(authHelper, "serve toLocalIterator",
handleFunc)
- Array(server.port, server.secret, server)
+ Array(server.connInfo, server.secret, server)
}
def readRDDFromFile(
@@ -831,21 +832,21 @@ private[spark] class PythonBroadcast(@transient var path:
String) extends Serial
def setupEncryptionServer(): Array[Any] = {
encryptionServer = new SocketAuthServer[Unit]("broadcast-encrypt-server") {
- override def handleConnection(sock: Socket): Unit = {
+ override def handleConnection(sock: SocketChannel): Unit = {
val env = SparkEnv.get
- val in = sock.getInputStream()
+ val in = Channels.newInputStream(sock)
val abspath = new File(path).getAbsolutePath
val out = env.serializerManager.wrapForEncryption(new
FileOutputStream(abspath))
DechunkedInputStream.dechunkAndCopyToOutput(in, out)
}
}
- Array(encryptionServer.port, encryptionServer.secret)
+ Array(encryptionServer.connInfo, encryptionServer.secret)
}
def setupDecryptionServer(): Array[Any] = {
decryptionServer = new
SocketAuthServer[Unit]("broadcast-decrypt-server-for-driver") {
- override def handleConnection(sock: Socket): Unit = {
- val out = new DataOutputStream(new
BufferedOutputStream(sock.getOutputStream()))
+ override def handleConnection(sock: SocketChannel): Unit = {
+ val out = new DataOutputStream(new
BufferedOutputStream(Channels.newOutputStream(sock)))
Utils.tryWithSafeFinally {
val in = SparkEnv.get.serializerManager.wrapForEncryption(new
FileInputStream(path))
Utils.tryWithSafeFinally {
@@ -859,7 +860,7 @@ private[spark] class PythonBroadcast(@transient var path:
String) extends Serial
}
}
}
- Array(decryptionServer.port, decryptionServer.secret)
+ Array(decryptionServer.connInfo, decryptionServer.secret)
}
def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult()
@@ -945,8 +946,8 @@ private[spark] class EncryptedPythonBroadcastServer(
val idsAndFiles: Seq[(Long, String)])
extends SocketAuthServer[Unit]("broadcast-decrypt-server") with Logging {
- override def handleConnection(socket: Socket): Unit = {
- val out = new DataOutputStream(new
BufferedOutputStream(socket.getOutputStream()))
+ override def handleConnection(socket: SocketChannel): Unit = {
+ val out = new DataOutputStream(new
BufferedOutputStream(Channels.newOutputStream(socket)))
var socketIn: InputStream = null
// send the broadcast id, then the decrypted data. We don't need to send
the length, the
// the python pickle module just needs a stream.
@@ -962,7 +963,7 @@ private[spark] class EncryptedPythonBroadcastServer(
}
logTrace("waiting for python to accept broadcast data over socket")
out.flush()
- socketIn = socket.getInputStream()
+ socketIn = Channels.newInputStream(socket)
socketIn.read()
logTrace("done serving broadcast data")
} {
@@ -983,8 +984,8 @@ private[spark] class EncryptedPythonBroadcastServer(
private[spark] abstract class PythonRDDServer
extends
SocketAuthServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
- def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
- val in = sock.getInputStream()
+ def handleConnection(sock: SocketChannel): JavaRDD[Array[Byte]] = {
+ val in = Channels.newInputStream(sock)
val dechunkedInput: InputStream = new DechunkedInputStream(in)
streamToRDD(dechunkedInput)
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 84701ee593c1..c2539ee05f21 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -20,9 +20,9 @@ package org.apache.spark.api.python
import java.io._
import java.net._
import java.nio.ByteBuffer
-import java.nio.channels.SelectionKey
-import java.nio.charset.StandardCharsets.UTF_8
+import java.nio.channels.{AsynchronousCloseException, Channels, SelectionKey,
ServerSocketChannel, SocketChannel}
import java.nio.file.{Files => JavaFiles, Path}
+import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean
@@ -201,9 +201,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
// Python accumulator is always set in production except in tests. See
SPARK-27893
private val maybeAccumulator: Option[PythonAccumulator] = Option(accumulator)
- // Expose a ServerSocket to support method calls via socket from Python
side. Only relevant for
- // for tasks that are a part of barrier stage, refer [[BarrierTaskContext]]
for details.
- private[spark] var serverSocket: Option[ServerSocket] = None
+ // Expose a ServerSocketChannel to support method calls via socket from
Python side.
+ // Only relevant for tasks that are a part of barrier stage, refer
+ // `BarrierTaskContext` for details.
+ private[spark] var serverSocketChannel: Option[ServerSocketChannel] = None
// Authentication helper used when serving method calls via socket from
Python side.
private lazy val authHelper = new SocketAuthHelper(conf)
@@ -347,6 +348,11 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
def writeNextInputToStream(dataOut: DataOutputStream): Boolean
def open(dataOut: DataOutputStream): Unit = Utils.logUncaughtExceptions {
+ val isUnixDomainSock =
authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
+ lazy val sockPath = new File(
+ authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
+ .getOrElse(System.getProperty("java.io.tmpdir")),
+ s".${UUID.randomUUID()}.sock")
try {
// Partition index
dataOut.writeInt(partitionIndex)
@@ -356,27 +362,34 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
// Init a ServerSocket to accept method calls from Python side.
val isBarrier = context.isInstanceOf[BarrierTaskContext]
if (isBarrier) {
- serverSocket = Some(new ServerSocket(/* port */ 0,
- /* backlog */ 1,
- InetAddress.getByName("localhost")))
- // A call to accept() for ServerSocket shall block infinitely.
- serverSocket.foreach(_.setSoTimeout(0))
+ if (isUnixDomainSock) {
+ serverSocketChannel =
Some(ServerSocketChannel.open(StandardProtocolFamily.UNIX))
+ sockPath.deleteOnExit()
+
serverSocketChannel.get.bind(UnixDomainSocketAddress.of(sockPath.getPath))
+ } else {
+ serverSocketChannel = Some(ServerSocketChannel.open())
+ serverSocketChannel.foreach(_.bind(
+ new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1))
+ // A call to accept() for ServerSocket shall block infinitely.
+ serverSocketChannel.foreach(_.socket().setSoTimeout(0))
+ }
+
new Thread("accept-connections") {
setDaemon(true)
override def run(): Unit = {
- while (!serverSocket.get.isClosed()) {
- var sock: Socket = null
+ while (serverSocketChannel.get.isOpen()) {
+ var sock: SocketChannel = null
try {
- sock = serverSocket.get.accept()
+ sock = serverSocketChannel.get.accept()
// Wait for function call from python side.
- sock.setSoTimeout(10000)
+ if (!isUnixDomainSock) sock.socket().setSoTimeout(10000)
authHelper.authClient(sock)
- val input = new DataInputStream(sock.getInputStream())
+ val input = new
DataInputStream(Channels.newInputStream(sock))
val requestMethod = input.readInt()
// The BarrierTaskContext function may wait infinitely,
socket shall not timeout
// before the function finishes.
- sock.setSoTimeout(0)
+ if (!isUnixDomainSock) sock.socket().setSoTimeout(0)
requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
barrierAndServe(requestMethod, sock)
@@ -385,13 +398,14 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
barrierAndServe(requestMethod, sock, message)
case _ =>
val out = new DataOutputStream(new BufferedOutputStream(
- sock.getOutputStream))
+ Channels.newOutputStream(sock)))
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
}
} catch {
- case e: SocketException if e.getMessage.contains("Socket
closed") =>
- // It is possible that the ServerSocket is not closed, but
the native socket
- // has already been closed, we shall catch and silently
ignore this case.
+ case _: AsynchronousCloseException =>
+ // Ignore to make less noisy. These will be closed when
tasks
+ // are finished by listeners.
+ if (isUnixDomainSock) sockPath.delete()
} finally {
if (sock != null) {
sock.close()
@@ -401,33 +415,35 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
}.start()
}
- val secret = if (isBarrier) {
- authHelper.secret
- } else {
- ""
- }
if (isBarrier) {
// Close ServerSocket on task completion.
- serverSocket.foreach { server =>
- context.addTaskCompletionListener[Unit](_ => server.close())
+ serverSocketChannel.foreach { server =>
+ context.addTaskCompletionListener[Unit] { _ =>
+ server.close()
+ if (isUnixDomainSock) sockPath.delete()
+ }
}
- val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
- if (boundPort == -1) {
- val message = "ServerSocket failed to bind to Java side."
- logError(message)
- throw new SparkException(message)
+ if (isUnixDomainSock) {
+ logDebug(s"Started ServerSocket on with Unix Domain Socket
$sockPath.")
+ dataOut.writeBoolean(/* isBarrier = */true)
+ dataOut.writeInt(-1)
+ PythonRDD.writeUTF(sockPath.getPath, dataOut)
+ } else {
+ val boundPort: Int =
serverSocketChannel.map(_.socket().getLocalPort).getOrElse(-1)
+ if (boundPort == -1) {
+ val message = "ServerSocket failed to bind to Java side."
+ logError(message)
+ throw new SparkException(message)
+ }
+ logDebug(s"Started ServerSocket on port $boundPort.")
+ dataOut.writeBoolean(/* isBarrier = */true)
+ dataOut.writeInt(boundPort)
+ PythonRDD.writeUTF(authHelper.secret, dataOut)
}
- logDebug(s"Started ServerSocket on port $boundPort.")
- dataOut.writeBoolean(/* isBarrier = */true)
- dataOut.writeInt(boundPort)
} else {
dataOut.writeBoolean(/* isBarrier = */false)
- dataOut.writeInt(0)
}
// Write out the TaskContextInfo
- val secretBytes = secret.getBytes(UTF_8)
- dataOut.writeInt(secretBytes.length)
- dataOut.write(secretBytes, 0, secretBytes.length)
dataOut.writeInt(context.stageId())
dataOut.writeInt(context.partitionId())
dataOut.writeInt(context.attemptNumber())
@@ -485,12 +501,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
/**
* Gateway to call BarrierTaskContext methods.
*/
- def barrierAndServe(requestMethod: Int, sock: Socket, message: String =
""): Unit = {
+ def barrierAndServe(requestMethod: Int, sock: SocketChannel, message:
String = ""): Unit = {
require(
- serverSocket.isDefined,
+ serverSocketChannel.isDefined,
"No available ServerSocket to redirect the BarrierTaskContext method
call."
)
- val out = new DataOutputStream(new
BufferedOutputStream(sock.getOutputStream))
+ val out = new DataOutputStream(new
BufferedOutputStream(Channels.newOutputStream(sock)))
try {
val messages = requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
diff --git
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 19a067076967..64b29585a0d9 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -18,10 +18,11 @@
package org.apache.spark.api.python
import java.io.{DataInputStream, DataOutputStream, EOFException, File,
InputStream}
-import java.net.{InetAddress, InetSocketAddress, SocketException}
+import java.net.{InetAddress, InetSocketAddress, SocketException,
StandardProtocolFamily, UnixDomainSocketAddress}
import java.net.SocketTimeoutException
import java.nio.channels._
import java.util.Arrays
+import java.util.UUID
import java.util.concurrent.TimeUnit
import javax.annotation.concurrent.GuardedBy
@@ -33,6 +34,7 @@ import org.apache.spark._
import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
+import org.apache.spark.internal.config.Python.{PYTHON_UNIX_DOMAIN_SOCKET_DIR,
PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util.{RedirectThread, Utils}
@@ -97,6 +99,7 @@ private[spark] class PythonWorkerFactory(
}
private val authHelper = new SocketAuthHelper(SparkEnv.get.conf)
+ private val isUnixDomainSock =
authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
@GuardedBy("self")
private var daemon: Process = null
@@ -106,6 +109,8 @@ private[spark] class PythonWorkerFactory(
@GuardedBy("self")
private val daemonWorkers = new mutable.WeakHashMap[PythonWorker,
ProcessHandle]()
@GuardedBy("self")
+ private var daemonSockPath: String = _
+ @GuardedBy("self")
private val idleWorkers = new mutable.Queue[PythonWorker]()
@GuardedBy("self")
private var lastActivityNs = 0L
@@ -152,7 +157,11 @@ private[spark] class PythonWorkerFactory(
private def createThroughDaemon(): (PythonWorker, Option[ProcessHandle]) = {
def createWorker(): (PythonWorker, Option[ProcessHandle]) = {
- val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost,
daemonPort))
+ val socketChannel = if (isUnixDomainSock) {
+ SocketChannel.open(UnixDomainSocketAddress.of(daemonSockPath))
+ } else {
+ SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort))
+ }
// These calls are blocking.
val pid = new
DataInputStream(Channels.newInputStream(socketChannel)).readInt()
if (pid < 0) {
@@ -161,7 +170,7 @@ private[spark] class PythonWorkerFactory(
val processHandle = ProcessHandle.of(pid).orElseThrow(
() => new IllegalStateException("Python daemon failed to launch
worker.")
)
- authHelper.authToServer(socketChannel.socket())
+ authHelper.authToServer(socketChannel)
socketChannel.configureBlocking(false)
val worker = PythonWorker(socketChannel)
daemonWorkers.put(worker, processHandle)
@@ -192,9 +201,19 @@ private[spark] class PythonWorkerFactory(
private[spark] def createSimpleWorker(
blockingMode: Boolean): (PythonWorker, Option[ProcessHandle]) = {
var serverSocketChannel: ServerSocketChannel = null
+ lazy val sockPath = new File(
+ authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
+ .getOrElse(System.getProperty("java.io.tmpdir")),
+ s".${UUID.randomUUID()}.sock")
try {
- serverSocketChannel = ServerSocketChannel.open()
- serverSocketChannel.bind(new
InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1)
+ if (isUnixDomainSock) {
+ serverSocketChannel =
ServerSocketChannel.open(StandardProtocolFamily.UNIX)
+ sockPath.deleteOnExit()
+ serverSocketChannel.bind(UnixDomainSocketAddress.of(sockPath.getPath))
+ } else {
+ serverSocketChannel = ServerSocketChannel.open()
+ serverSocketChannel.bind(new
InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1)
+ }
// Create and start the worker
val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m",
workerModule))
@@ -209,9 +228,14 @@ private[spark] class PythonWorkerFactory(
workerEnv.put("PYTHONPATH", pythonPath)
// This is equivalent to setting the -u flag; we use it because ipython
doesn't support -u:
workerEnv.put("PYTHONUNBUFFERED", "YES")
- workerEnv.put("PYTHON_WORKER_FACTORY_PORT",
serverSocketChannel.socket().getLocalPort
- .toString)
- workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
+ if (isUnixDomainSock) {
+ workerEnv.put("PYTHON_WORKER_FACTORY_SOCK_PATH", sockPath.getPath)
+ workerEnv.put("PYTHON_UNIX_DOMAIN_ENABLED", "True")
+ } else {
+ workerEnv.put("PYTHON_WORKER_FACTORY_PORT",
serverSocketChannel.socket().getLocalPort
+ .toString)
+ workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
+ }
if (Utils.preferIPv6) {
workerEnv.put("SPARK_PREFER_IPV6", "True")
}
@@ -233,7 +257,7 @@ private[spark] class PythonWorkerFactory(
throw new SocketTimeoutException(
"Timed out while waiting for the Python worker to connect back")
}
- authHelper.authClient(socketChannel.socket())
+ authHelper.authClient(socketChannel)
// TODO: When we drop JDK 8, we can just use workerProcess.pid()
val pid = new
DataInputStream(Channels.newInputStream(socketChannel)).readInt()
if (pid < 0) {
@@ -254,6 +278,7 @@ private[spark] class PythonWorkerFactory(
} finally {
if (serverSocketChannel != null) {
serverSocketChannel.close()
+ if (isUnixDomainSock) sockPath.delete()
}
}
}
@@ -278,7 +303,15 @@ private[spark] class PythonWorkerFactory(
val workerEnv = pb.environment()
workerEnv.putAll(envVars.asJava)
workerEnv.put("PYTHONPATH", pythonPath)
- workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
+ if (isUnixDomainSock) {
+ workerEnv.put(
+ "PYTHON_WORKER_FACTORY_SOCK_DIR",
+ authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
+ .getOrElse(System.getProperty("java.io.tmpdir")))
+ workerEnv.put("PYTHON_UNIX_DOMAIN_ENABLED", "True")
+ } else {
+ workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
+ }
if (Utils.preferIPv6) {
workerEnv.put("SPARK_PREFER_IPV6", "True")
}
@@ -288,7 +321,11 @@ private[spark] class PythonWorkerFactory(
val in = new DataInputStream(daemon.getInputStream)
try {
- daemonPort = in.readInt()
+ if (isUnixDomainSock) {
+ daemonSockPath = PythonWorkerUtils.readUTF(in)
+ } else {
+ daemonPort = in.readInt()
+ }
} catch {
case _: EOFException if daemon.isAlive =>
throw SparkCoreErrors.eofExceptionWhileReadPortNumberError(
@@ -301,10 +338,14 @@ private[spark] class PythonWorkerFactory(
// test that the returned port number is within a valid range.
// note: this does not cover the case where the port number
// is arbitrary data but is also coincidentally within range
- if (daemonPort < 1 || daemonPort > 0xffff) {
+ val isMalformedPort = !isUnixDomainSock && (daemonPort < 1 ||
daemonPort > 0xffff)
+ val isMalformedSockPath = isUnixDomainSock && !new
File(daemonSockPath).exists()
+ val errorMsg =
+ if (isUnixDomainSock) daemonSockPath else f"$daemonPort
(0x$daemonPort%08x)"
+ if (isMalformedPort || isMalformedSockPath) {
val exceptionMessage = f"""
- |Bad data in $daemonModule's standard output. Invalid port number:
- | $daemonPort (0x$daemonPort%08x)
+ |Bad data in $daemonModule's standard output. Invalid port
number/socket path:
+ | $errorMsg
|Python command to execute the daemon was:
| ${command.asScala.mkString(" ")}
|Check that you don't have any unexpected modules or libraries in
@@ -407,6 +448,7 @@ private[spark] class PythonWorkerFactory(
daemon = null
daemonPort = 0
+ daemonSockPath = null
} else {
simpleWorkers.values.foreach(_.destroy())
}
diff --git
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
index ae3614445be6..0a6def051a34 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
@@ -117,9 +117,15 @@ private[spark] object PythonWorkerUtils extends Logging {
}
}
val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
- dataOut.writeInt(server.port)
- logTrace(s"broadcast decryption server setup on ${server.port}")
- writeUTF(server.secret, dataOut)
+ server.connInfo match {
+ case portNum: Int =>
+ dataOut.writeInt(portNum)
+ writeUTF(server.secret, dataOut)
+ case sockPath: String =>
+ dataOut.writeInt(-1)
+ writeUTF(sockPath, dataOut)
+ }
+ logTrace(s"broadcast decryption server setup on ${server.connInfo}")
sendBidsToRemove()
idsAndFiles.foreach { case (id, _) =>
// send new broadcast
diff --git
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index 6f9708def2f2..7eba574751b4 100644
---
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -18,6 +18,7 @@
package org.apache.spark.api.python
import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream,
DataOutputStream}
+import java.nio.channels.Channels
import scala.jdk.CollectionConverters._
@@ -25,7 +26,7 @@ import org.apache.spark.{SparkEnv, SparkPythonException}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{PYTHON_WORKER_MODULE,
PYTHON_WORKER_RESPONSE, SESSION_ID}
import org.apache.spark.internal.config.BUFFER_SIZE
-import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT
+import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT,
PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
private[spark] object StreamingPythonRunner {
@@ -45,6 +46,7 @@ private[spark] class StreamingPythonRunner(
sessionId: String,
workerModule: String) extends Logging {
private val conf = SparkEnv.get.conf
+ private val isUnixDomainSock = conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
@@ -78,14 +80,20 @@ private[spark] class StreamingPythonRunner(
pythonWorker = Some(worker)
pythonWorkerFactory = Some(workerFactory)
- val socket = pythonWorker.get.channel.socket()
- val stream = new BufferedOutputStream(socket.getOutputStream, bufferSize)
- val dataIn = new DataInputStream(new
BufferedInputStream(socket.getInputStream, bufferSize))
+ val socketChannel = pythonWorker.get.channel
+ val stream = new
BufferedOutputStream(Channels.newOutputStream(socketChannel), bufferSize)
+ val dataIn = new DataInputStream(
+ new BufferedInputStream(Channels.newInputStream(socketChannel),
bufferSize))
val dataOut = new DataOutputStream(stream)
- val originalTimeout = socket.getSoTimeout()
- // Set timeout to 5 minute during initialization config transmission
- socket.setSoTimeout(5 * 60 * 1000)
+ val originalTimeout = if (!isUnixDomainSock) {
+ val timeout = socketChannel.socket().getSoTimeout()
+ // Set timeout to 5 minute during initialization config transmission
+ socketChannel.socket().setSoTimeout(5 * 60 * 1000)
+ Some(timeout)
+ } else {
+ None
+ }
val resFromPython = try {
PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
@@ -111,7 +119,7 @@ private[spark] class StreamingPythonRunner(
// Set timeout back to the original timeout
// Should be infinity by default
- socket.setSoTimeout(originalTimeout)
+ originalTimeout.foreach(v => socketChannel.socket().setSoTimeout(v))
if (resFromPython != 0) {
val errMessage = PythonWorkerUtils.readUTF(dataIn)
diff --git a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
index ac6826a9ec77..5c45986a8f9a 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
@@ -24,6 +24,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.security.SocketAuthHelper
private[spark] class RAuthHelper(conf: SparkConf) extends
SocketAuthHelper(conf) {
+ override val isUnixDomainSock = false
override protected def readUtf8(s: Socket): String = {
SerDe.readString(new DataInputStream(s.getInputStream()))
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index ff6ed9f86b55..3b309e093970 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -18,7 +18,7 @@
package org.apache.spark.api.r
import java.io.{File, OutputStream}
-import java.net.Socket
+import java.nio.channels.{Channels, SocketChannel}
import java.util.{Map => JMap}
import scala.jdk.CollectionConverters._
@@ -179,8 +179,8 @@ private[spark] class RParallelizeServer(sc:
JavaSparkContext, parallelism: Int)
extends SocketAuthServer[JavaRDD[Array[Byte]]](
new RAuthHelper(SparkEnv.get.conf), "sparkr-parallelize-server") {
- override def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
- val in = sock.getInputStream()
+ override def handleConnection(sock: SocketChannel): JavaRDD[Array[Byte]] = {
+ val in = Channels.newInputStream(sock)
JavaRDD.readRDDFromInputStream(sc.sc, in, parallelism)
}
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala
b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
index 1f827e8dc449..7f9921d58dba 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
@@ -70,6 +70,30 @@ private[spark] object Python {
.booleanConf
.createWithDefault(false)
+ val PYTHON_UNIX_DOMAIN_SOCKET_ENABLED =
ConfigBuilder("spark.python.unix.domain.socket.enabled")
+ .doc("When set to true, the Python driver uses a Unix domain socket for
operations like " +
+ "creating or collecting a DataFrame from local data, using accumulators,
and executing " +
+ "Python functions with PySpark such as Python UDFs. This configuration
only applies " +
+ "to Spark Classic and Spark Connect server.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val PYTHON_UNIX_DOMAIN_SOCKET_DIR =
ConfigBuilder("spark.python.unix.domain.socket.dir")
+ .doc("When specified, it uses the directory to create Unix domain socket
files. " +
+ "Otherwise, it uses the default location of the temporary directory set
in " +
+ s"'java.io.tmpdir' property. This is used when
${PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key} " +
+ "is enabled.")
+ .internal()
+ .version("4.1.0")
+ .stringConf
+ // UDS requires the length of path lower than 104 characters. We use UUID
(36 characters)
+ // and additional prefix "." (1), postfix ".sock" (5), and the path
separator (1).
+ .checkValue(
+ _.length <= (104 - (36 + 1 + 5 + 1)),
+ s"The directory path should be lower than ${(104 - (36 + 1 + 5 + 1))}")
+ .createOptional
+
private val PYTHON_WORKER_IDLE_TIMEOUT_SECONDS_KEY =
"spark.python.worker.idleTimeoutSeconds"
private val PYTHON_WORKER_KILL_ON_IDLE_TIMEOUT_KEY =
"spark.python.worker.killOnIdleTimeout"
diff --git
a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
index f800553c5388..ecebb97ecfc1 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
@@ -19,9 +19,11 @@ package org.apache.spark.security
import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket
+import java.nio.channels.SocketChannel
import java.nio.charset.StandardCharsets.UTF_8
import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.Python.{PYTHON_UNIX_DOMAIN_SOCKET_DIR,
PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.Utils
@@ -35,6 +37,9 @@ import org.apache.spark.util.Utils
* There's no secrecy, so this relies on the sockets being either local or
somehow encrypted.
*/
private[spark] class SocketAuthHelper(val conf: SparkConf) {
+ val isUnixDomainSock: Boolean = conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
+ lazy val sockDir: String =
+
conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR).getOrElse(System.getProperty("java.io.tmpdir"))
val secret = Utils.createSecret(conf)
@@ -47,6 +52,11 @@ private[spark] class SocketAuthHelper(val conf: SparkConf) {
* @param s The client socket.
* @throws IllegalArgumentException If authentication fails.
*/
+ def authClient(socket: SocketChannel): Unit = {
+ if (isUnixDomainSock) return
+ authClient(socket.socket())
+ }
+
def authClient(s: Socket): Unit = {
var shouldClose = true
try {
@@ -80,7 +90,9 @@ private[spark] class SocketAuthHelper(val conf: SparkConf) {
* @param s The socket connected to the server.
* @throws IllegalArgumentException If authentication fails.
*/
- def authToServer(s: Socket): Unit = {
+ def authToServer(socket: SocketChannel): Unit = {
+ if (isUnixDomainSock) return
+ val s = socket.socket()
var shouldClose = true
try {
writeUtf8(secret, s)
diff --git
a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
index 9efe2af5fcc8..b0446a4f2feb 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
@@ -17,8 +17,10 @@
package org.apache.spark.security
-import java.io.{BufferedOutputStream, OutputStream}
-import java.net.{InetAddress, ServerSocket, Socket}
+import java.io.{BufferedOutputStream, File, OutputStream}
+import java.net.{InetAddress, InetSocketAddress, StandardProtocolFamily,
UnixDomainSocketAddress}
+import java.nio.channels.{Channels, ServerSocketChannel, SocketChannel}
+import java.util.UUID
import scala.concurrent.Promise
import scala.concurrent.duration.Duration
@@ -46,44 +48,70 @@ private[spark] abstract class SocketAuthServer[T](
def this(threadName: String) = this(SparkEnv.get, threadName)
private val promise = Promise[T]()
+ private val isUnixDomainSock: Boolean = authHelper.isUnixDomainSock
- private def startServer(): (Int, String) = {
+ private def startServer(): (Any, String) = {
logTrace("Creating listening socket")
- val address = InetAddress.getLoopbackAddress()
- val serverSocket = new ServerSocket(0, 1, address)
+ lazy val sockPath = new File(authHelper.sockDir,
s".${UUID.randomUUID()}.sock")
+
+ val (serverSocketChannel, address) = if (isUnixDomainSock) {
+ val address = UnixDomainSocketAddress.of(sockPath.getPath)
+ val serverChannel = ServerSocketChannel.open(StandardProtocolFamily.UNIX)
+ sockPath.deleteOnExit()
+ serverChannel.bind(address)
+ (serverChannel, address)
+ } else {
+ val address = InetAddress.getLoopbackAddress()
+ val serverChannel = ServerSocketChannel.open()
+ serverChannel.bind(new InetSocketAddress(address, 0), 1)
+ (serverChannel, address)
+ }
+
// Close the socket if no connection in the configured seconds
val timeout = authHelper.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT).toInt
logTrace(s"Setting timeout to $timeout sec")
- serverSocket.setSoTimeout(timeout * 1000)
+ if (!isUnixDomainSock) serverSocketChannel.socket().setSoTimeout(timeout *
1000)
new Thread(threadName) {
setDaemon(true)
override def run(): Unit = {
- var sock: Socket = null
+ var sock: SocketChannel = null
try {
- logTrace(s"Waiting for connection on $address with port
${serverSocket.getLocalPort}")
- sock = serverSocket.accept()
- logTrace(s"Connection accepted from address
${sock.getRemoteSocketAddress}")
+ if (isUnixDomainSock) {
+ logTrace(s"Waiting for connection on $address.")
+ } else {
+ logTrace(
+ s"Waiting for connection on $address with port " +
+ s"${serverSocketChannel.socket().getLocalPort}")
+ }
+ sock = serverSocketChannel.accept()
+ logTrace(s"Connection accepted from address
${sock.getRemoteAddress}")
authHelper.authClient(sock)
logTrace("Client authenticated")
promise.complete(Try(handleConnection(sock)))
} finally {
logTrace("Closing server")
- JavaUtils.closeQuietly(serverSocket)
+ JavaUtils.closeQuietly(serverSocketChannel)
JavaUtils.closeQuietly(sock)
+ if (isUnixDomainSock) sockPath.delete()
}
}
}.start()
- (serverSocket.getLocalPort, authHelper.secret)
+ if (isUnixDomainSock) {
+ (sockPath.getPath, null)
+ } else {
+ (serverSocketChannel.socket().getLocalPort, authHelper.secret)
+ }
}
- val (port, secret) = startServer()
+ // connInfo is either a string (for UDS) or a port number (for TCP/IP).
+ val (connInfo, secret) = startServer()
/**
* Handle a connection which has already been authenticated. Any error from
this function
* will clean up this connection and the entire server, and get propagated
to [[getResult]].
*/
- def handleConnection(sock: Socket): T
+ def handleConnection(sock: SocketChannel): T
/**
* Blocks indefinitely for [[handleConnection]] to finish, and returns that
result. If
@@ -108,9 +136,9 @@ private[spark] abstract class SocketAuthServer[T](
private[spark] class SocketFuncServer(
authHelper: SocketAuthHelper,
threadName: String,
- func: Socket => Unit) extends SocketAuthServer[Unit](authHelper,
threadName) {
+ func: SocketChannel => Unit) extends SocketAuthServer[Unit](authHelper,
threadName) {
- override def handleConnection(sock: Socket): Unit = {
+ override def handleConnection(sock: SocketChannel): Unit = {
func(sock)
}
}
@@ -134,8 +162,8 @@ private[spark] object SocketAuthServer {
def serveToStream(
threadName: String,
authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit):
Array[Any] = {
- val handleFunc = (sock: Socket) => {
- val out = new BufferedOutputStream(sock.getOutputStream())
+ val handleFunc = (sock: SocketChannel) => {
+ val out = new BufferedOutputStream(Channels.newOutputStream(sock))
Utils.tryWithSafeFinally {
writeFunc(out)
} {
@@ -144,6 +172,6 @@ private[spark] object SocketAuthServer {
}
val server = new SocketFuncServer(authHelper, threadName, handleFunc)
- Array(server.port, server.secret, server)
+ Array(server.connInfo, server.secret, server)
}
}
diff --git
a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
index 88ad5b3a7483..4efd2870cccb 100644
--- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
@@ -18,7 +18,8 @@
package org.apache.spark.api.python
import java.io.{ByteArrayOutputStream, DataOutputStream, File}
-import java.net.{InetAddress, Socket}
+import java.net.{InetAddress, InetSocketAddress}
+import java.nio.channels.SocketChannel
import java.nio.charset.StandardCharsets
import java.util
@@ -33,6 +34,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext,
SparkFunSuite}
import org.apache.spark.api.java.JavaSparkContext
+import
org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
import org.apache.spark.rdd.{HadoopRDD, RDD}
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
import org.apache.spark.util.Utils
@@ -76,10 +78,14 @@ class PythonRDDSuite extends SparkFunSuite with
LocalSparkContext {
}
test("python server error handling") {
- val authHelper = new SocketAuthHelper(new SparkConf())
+ val conf = new SparkConf()
+ conf.set(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key, false.toString)
+ val authHelper = new SocketAuthHelper(conf)
val errorServer = new ExceptionPythonServer(authHelper)
- val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
- authHelper.authToServer(client)
+ val socketChannel = SocketChannel.open(
+ new InetSocketAddress(InetAddress.getLoopbackAddress(),
+ errorServer.connInfo.asInstanceOf[Int]))
+ authHelper.authToServer(socketChannel)
val ex = intercept[Exception] { errorServer.getResult(Duration(1,
"second")) }
assert(ex.getCause().getMessage().contains("exception within
handleConnection"))
}
@@ -87,7 +93,7 @@ class PythonRDDSuite extends SparkFunSuite with
LocalSparkContext {
class ExceptionPythonServer(authHelper: SocketAuthHelper)
extends SocketAuthServer[Unit](authHelper, "error-server") {
- override def handleConnection(sock: Socket): Unit = {
+ override def handleConnection(sock: SocketChannel): Unit = {
throw new Exception("exception within handleConnection")
}
}
diff --git
a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
index e57cb701b628..c5a6199cf4c1 100644
--- a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
@@ -18,14 +18,17 @@ package org.apache.spark.security
import java.io.Closeable
import java.net._
+import java.nio.channels.SocketChannel
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.internal.config._
+import
org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
import org.apache.spark.util.Utils
class SocketAuthHelperSuite extends SparkFunSuite {
private val conf = new SparkConf()
+ conf.set(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key, false.toString)
private val authHelper = new SocketAuthHelper(conf)
test("successful auth") {
@@ -43,7 +46,9 @@ class SocketAuthHelperSuite extends SparkFunSuite {
test("failed auth") {
Utils.tryWithResource(new ServerThread()) { server =>
Utils.tryWithResource(server.createClient()) { client =>
- val badHelper = new SocketAuthHelper(new
SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128))
+ val badHelper = new SocketAuthHelper(new SparkConf()
+ .set(AUTH_SECRET_BIT_LENGTH, 128)
+ .set(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key, false.toString))
intercept[IllegalArgumentException] {
badHelper.authToServer(client)
}
@@ -66,8 +71,9 @@ class SocketAuthHelperSuite extends SparkFunSuite {
setDaemon(true)
start()
- def createClient(): Socket = {
- new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort())
+ def createClient(): SocketChannel = {
+ SocketChannel.open(new InetSocketAddress(
+ InetAddress.getLoopbackAddress(), ss.getLocalPort))
}
override def run(): Unit = {
diff --git a/python/pyspark/core/broadcast.py b/python/pyspark/core/broadcast.py
index 69d57c35614d..2d5658284be8 100644
--- a/python/pyspark/core/broadcast.py
+++ b/python/pyspark/core/broadcast.py
@@ -125,8 +125,8 @@ class Broadcast(Generic[T]):
if sc._encryption_enabled:
# with encryption, we ask the jvm to do the encryption for us,
we send it data
# over a socket
- port, auth_secret =
self._python_broadcast.setupEncryptionServer()
- (encryption_sock_file, _) = local_connect_and_auth(port,
auth_secret)
+ conn_info, auth_secret =
self._python_broadcast.setupEncryptionServer()
+ (encryption_sock_file, _) = local_connect_and_auth(conn_info,
auth_secret)
broadcast_out = ChunkedStream(encryption_sock_file, 8192)
else:
# no encryption, we can just write pickled data directly to
the file from python
@@ -270,8 +270,8 @@ class Broadcast(Generic[T]):
# we only need to decrypt it here when encryption is enabled and
# if its on the driver, since executor decryption is handled
already
if self._sc is not None and self._sc._encryption_enabled:
- port, auth_secret =
self._python_broadcast.setupDecryptionServer()
- (decrypted_sock_file, _) = local_connect_and_auth(port,
auth_secret)
+ conn_info, auth_secret =
self._python_broadcast.setupDecryptionServer()
+ (decrypted_sock_file, _) = local_connect_and_auth(conn_info,
auth_secret)
self._python_broadcast.waitTillBroadcastDataSent()
return self.load(decrypted_sock_file)
else:
diff --git a/python/pyspark/core/context.py b/python/pyspark/core/context.py
index 5fcd4ffb0921..4d5c03fd1900 100644
--- a/python/pyspark/core/context.py
+++ b/python/pyspark/core/context.py
@@ -880,7 +880,7 @@ class SparkContext:
if self._encryption_enabled:
# with encryption, we open a server in java and send the data
directly
server = server_func()
- (sock_file, _) = local_connect_and_auth(server.port(),
server.secret())
+ (sock_file, _) = local_connect_and_auth(server.connInfo(),
server.secret())
chunked_out = ChunkedStream(sock_file, 8192)
serializer.dump_stream(data, chunked_out)
chunked_out.close()
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index a23af109ea6d..ca33ce2c39ef 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
+import uuid
import numbers
import os
import signal
@@ -93,8 +93,20 @@ def manager():
# Create a new process group to corral our children
os.setpgid(0, 0)
+ is_unix_domain_sock = os.environ.get("PYTHON_UNIX_DOMAIN_ENABLED",
"false").lower() == "true"
+ socket_path = None
+
# Create a listening socket on the loopback interface
- if os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true":
+ if is_unix_domain_sock:
+ assert "PYTHON_WORKER_FACTORY_SOCK_DIR" in os.environ
+ socket_path = os.path.join(
+ os.environ["PYTHON_WORKER_FACTORY_SOCK_DIR"],
f".{uuid.uuid4()}.sock"
+ )
+ listen_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ listen_sock.bind(socket_path)
+ listen_sock.listen(max(1024, SOMAXCONN))
+ listen_port = socket_path
+ elif os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true":
listen_sock = socket.socket(AF_INET6, SOCK_STREAM)
listen_sock.bind(("::1", 0, 0, 0))
listen_sock.listen(max(1024, SOMAXCONN))
@@ -108,10 +120,15 @@ def manager():
# re-open stdin/stdout in 'wb' mode
stdin_bin = os.fdopen(sys.stdin.fileno(), "rb", 4)
stdout_bin = os.fdopen(sys.stdout.fileno(), "wb", 4)
- write_int(listen_port, stdout_bin)
+ if is_unix_domain_sock:
+ write_with_length(listen_port.encode("utf-8"), stdout_bin)
+ else:
+ write_int(listen_port, stdout_bin)
stdout_bin.flush()
def shutdown(code):
+ if socket_path is not None and os.path.exists(socket_path):
+ os.remove(socket_path)
signal.signal(SIGTERM, SIG_DFL)
# Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP)
@@ -195,7 +212,10 @@ def manager():
write_int(os.getpid(), outfile)
outfile.flush()
outfile.close()
- authenticated = False
+ authenticated = (
+ os.environ.get("PYTHON_UNIX_DOMAIN_ENABLED",
"false").lower() == "true"
+ or False
+ )
while True:
code = worker(sock, authenticated)
if code == 0:
diff --git a/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
b/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
index 66a9b553cc75..e614c347faa9 100644
--- a/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
+++ b/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
@@ -227,7 +227,7 @@ class
DeepspeedTorchDistributorDistributedEndToEnd(unittest.TestCase):
conf = conf.set(k, v)
conf = conf.set(
"spark.worker.resource.gpu.discoveryScript",
cls.gpu_discovery_script_file_name
- )
+ ).set("spark.python.unix.domain.socket.enabled", "false")
sc = SparkContext("local-cluster[2,2,512]", cls.__name__, conf=conf)
cls.spark = SparkSession(sc)
@@ -264,7 +264,7 @@ class
DeepspeedDistributorLocalEndToEndTests(unittest.TestCase):
conf = conf.set(k, v)
conf = conf.set(
"spark.driver.resource.gpu.discoveryScript",
cls.gpu_discovery_script_file_name
- )
+ ).set("spark.python.unix.domain.socket.enabled", "false")
sc = SparkContext("local-cluster[2,2,512]", cls.__name__, conf=conf)
cls.spark = SparkSession(sc)
diff --git
a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
index b471769ad428..b819634adb5a 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
@@ -91,9 +91,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
# There could be a long time between each micro batch.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index a7a5066ca0d7..2c6ce8715994 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -105,9 +105,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
# There could be a long time between each listener event.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index 11aa4e15ab1e..ab988eb714cc 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -204,9 +204,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
# Prevent the socket from timeout error when query trigger interval is
large.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py
b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index 50945198f9c4..e564d7186faa 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -49,22 +49,26 @@ class StatefulProcessorHandleState(Enum):
class StatefulProcessorApiClient:
def __init__(
- self, state_server_port: int, key_schema: StructType, is_driver: bool
= False
+ self, state_server_port: Union[int, str], key_schema: StructType,
is_driver: bool = False
) -> None:
self.key_schema = key_schema
- self._client_socket = socket.socket()
- self._client_socket.connect(("localhost", state_server_port))
-
- # SPARK-51667: We have a pattern of sending messages continuously from
one side
- # (Python -> JVM, and vice versa) before getting response from other
side. Since most
- # messages we are sending are small, this triggers the bad combination
of Nagle's algorithm
- # and delayed ACKs, which can cause a significant delay on the latency.
- # See SPARK-51667 for more details on how this can be a problem.
- #
- # Disabling either would work, but it's more common to disable Nagle's
algorithm; there is
- # lot less reference to disabling delayed ACKs, while there are lots
of resources to
- # disable Nagle's algorithm.
- self._client_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY,
1)
+ if isinstance(state_server_port, str):
+ self._client_socket = socket.socket(socket.AF_UNIX,
socket.SOCK_STREAM)
+ self._client_socket.connect(state_server_port)
+ else:
+ self._client_socket = socket.socket()
+ self._client_socket.connect(("localhost", state_server_port))
+
+ # SPARK-51667: We have a pattern of sending messages continuously
from one side
+ # (Python -> JVM, and vice versa) before getting response from
other side. Since most
+ # messages we are sending are small, this triggers the bad
combination of Nagle's
+ # algorithm and delayed ACKs, which can cause a significant delay
on the latency.
+ # See SPARK-51667 for more details on how this can be a problem.
+ #
+ # Disabling either would work, but it's more common to disable
Nagle's algorithm; there
+ # is lot less reference to disabling delayed ACKs, while there are
lots of resources to
+ # disable Nagle's algorithm.
+ self._client_socket.setsockopt(socket.IPPROTO_TCP,
socket.TCP_NODELAY, 1)
self.sockfile = self._client_socket.makefile(
"rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
diff --git a/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
b/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
index 99d386f07b5b..3fe7f68a99e5 100644
--- a/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
+++ b/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
@@ -72,9 +72,11 @@ def main(infile: IO, outfile: IO) -> None:
# This driver runner will only be used on the first batch of a query,
# and the following code block should be only run once for each query
run
state_server_port = read_int(infile)
+ if state_server_port == -1:
+ state_server_port = utf8_deserializer.loads(infile)
key_schema =
StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
print(
- f"{log_name} received parameters for UDF. State server port:
{state_server_port}, "
+ f"{log_name} received parameters for UDF. State server port/path:
{state_server_port}, "
f"key schema: {key_schema}.\n"
)
@@ -94,9 +96,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/analyze_udtf.py
b/python/pyspark/sql/worker/analyze_udtf.py
index 9247fde78004..1c926f4980a5 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -273,9 +273,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
# TODO: Remove the following two lines and use `Process.pid()` when we
drop JDK 8.
write_int(os.getpid(), sock_file)
sock_file.flush()
diff --git a/python/pyspark/sql/worker/commit_data_source_write.py
b/python/pyspark/sql/worker/commit_data_source_write.py
index c891d9f083cb..d08d65974dfb 100644
--- a/python/pyspark/sql/worker/commit_data_source_write.py
+++ b/python/pyspark/sql/worker/commit_data_source_write.py
@@ -119,9 +119,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/create_data_source.py
b/python/pyspark/sql/worker/create_data_source.py
index 33957616c483..424f07012723 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -184,9 +184,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/data_source_pushdown_filters.py
b/python/pyspark/sql/worker/data_source_pushdown_filters.py
index 9edbaf3a9b72..0415f450fe0f 100644
--- a/python/pyspark/sql/worker/data_source_pushdown_filters.py
+++ b/python/pyspark/sql/worker/data_source_pushdown_filters.py
@@ -269,7 +269,9 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/lookup_data_sources.py
b/python/pyspark/sql/worker/lookup_data_sources.py
index 18737095fa9c..af138ab68965 100644
--- a/python/pyspark/sql/worker/lookup_data_sources.py
+++ b/python/pyspark/sql/worker/lookup_data_sources.py
@@ -104,9 +104,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py
b/python/pyspark/sql/worker/plan_data_source_read.py
index 7f765a377bea..5edc8185adcf 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -409,9 +409,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py
b/python/pyspark/sql/worker/python_streaming_sink_runner.py
index 13b8f4d30786..cf6246b54490 100644
--- a/python/pyspark/sql/worker/python_streaming_sink_runner.py
+++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py
@@ -148,9 +148,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/write_into_data_source.py
b/python/pyspark/sql/worker/write_into_data_source.py
index 235e5c249f69..d6d055f01e54 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -255,9 +255,11 @@ def main(infile: IO, outfile: IO) -> None:
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index 9785664d7a15..957f9d70687b 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -262,8 +262,8 @@ ALL_GATHER_FUNCTION = 2
def _load_from_socket(
- port: Optional[Union[str, int]],
- auth_secret: str,
+ conn_info: Optional[Union[str, int]],
+ auth_secret: Optional[str],
function: int,
all_gather_message: Optional[str] = None,
) -> List[str]:
@@ -271,7 +271,7 @@ def _load_from_socket(
Load data from a given socket, this is a blocking method thus only return
when the socket
connection has been closed.
"""
- (sockfile, sock) = local_connect_and_auth(port, auth_secret)
+ (sockfile, sock) = local_connect_and_auth(conn_info, auth_secret)
# The call may block forever, so no timeout
sock.settimeout(None)
@@ -331,7 +331,7 @@ class BarrierTaskContext(TaskContext):
[1]
"""
- _port: ClassVar[Optional[Union[str, int]]] = None
+ _conn_info: ClassVar[Optional[Union[str, int]]] = None
_secret: ClassVar[Optional[str]] = None
@classmethod
@@ -368,13 +368,13 @@ class BarrierTaskContext(TaskContext):
@classmethod
def _initialize(
- cls: Type["BarrierTaskContext"], port: Optional[Union[str, int]],
secret: str
+ cls: Type["BarrierTaskContext"], conn_info: Optional[Union[str, int]],
secret: Optional[str]
) -> None:
"""
Initialize :class:`BarrierTaskContext`, other methods within
:class:`BarrierTaskContext`
can only be called after BarrierTaskContext is initialized.
"""
- cls._port = port
+ cls._conn_info = conn_info
cls._secret = secret
def barrier(self) -> None:
@@ -393,7 +393,7 @@ class BarrierTaskContext(TaskContext):
calls, in all possible code branches. Otherwise, you may get the job
hanging
or a `SparkException` after timeout.
"""
- if self._port is None or self._secret is None:
+ if self._conn_info is None:
raise PySparkRuntimeError(
errorClass="CALL_BEFORE_INITIALIZE",
messageParameters={
@@ -402,7 +402,7 @@ class BarrierTaskContext(TaskContext):
},
)
else:
- _load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
+ _load_from_socket(self._conn_info, self._secret, BARRIER_FUNCTION)
def allGather(self, message: str = "") -> List[str]:
"""
@@ -422,7 +422,7 @@ class BarrierTaskContext(TaskContext):
"""
if not isinstance(message, str):
raise TypeError("Argument `message` must be of type `str`")
- elif self._port is None or self._secret is None:
+ elif self._conn_info is None:
raise PySparkRuntimeError(
errorClass="CALL_BEFORE_INITIALIZE",
messageParameters={
@@ -431,7 +431,7 @@ class BarrierTaskContext(TaskContext):
},
)
else:
- return _load_from_socket(self._port, self._secret,
ALL_GATHER_FUNCTION, message)
+ return _load_from_socket(self._conn_info, self._secret,
ALL_GATHER_FUNCTION, message)
def getTaskInfos(self) -> List["BarrierTaskInfo"]:
"""
@@ -453,7 +453,7 @@ class BarrierTaskContext(TaskContext):
>>> barrier_info.address
'...:...'
"""
- if self._port is None or self._secret is None:
+ if self._conn_info is None:
raise PySparkRuntimeError(
errorClass="CALL_BEFORE_INITIALIZE",
messageParameters={
diff --git a/python/pyspark/tests/test_appsubmit.py
b/python/pyspark/tests/test_appsubmit.py
index 5f2c8b49d279..909ed0447154 100644
--- a/python/pyspark/tests/test_appsubmit.py
+++ b/python/pyspark/tests/test_appsubmit.py
@@ -36,6 +36,8 @@ class SparkSubmitTests(unittest.TestCase):
"spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
"--conf",
"spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
+ "--conf",
+ "spark.python.unix.domain.socket.enabled=false",
]
def tearDown(self):
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 5a5a8d31e77d..cdfc8d2a4a4f 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -652,9 +652,9 @@ def _create_local_socket(sock_info: "JavaArray") ->
"io.BufferedRWPair":
"""
sockfile: "io.BufferedRWPair"
sock: "socket.socket"
- port: int = sock_info[0]
+ conn_info: int = sock_info[0]
auth_secret: str = sock_info[1]
- sockfile, sock = local_connect_and_auth(port, auth_secret)
+ sockfile, sock = local_connect_and_auth(conn_info, auth_secret)
# The RDD materialization time is unpredictable, if we set a timeout for
socket reading
# operation, it will very possibly fail. See SPARK-18281.
sock.settimeout(None)
@@ -731,7 +731,9 @@ def _local_iterator_from_socket(sock_info: "JavaArray",
serializer: "Serializer"
return iter(PyLocalIterable(sock_info, serializer))
-def local_connect_and_auth(port: Optional[Union[str, int]], auth_secret: str)
-> Tuple:
+def local_connect_and_auth(
+ conn_info: Optional[Union[str, int]], auth_secret: Optional[str]
+) -> Tuple:
"""
Connect to local host, authenticate with it, and return a (sockfile,sock)
for that connection.
Handles IPV4 & IPV6, does some error handling.
@@ -739,26 +741,49 @@ def local_connect_and_auth(port: Optional[Union[str,
int]], auth_secret: str) ->
Parameters
----------
port : str or int, optional
- auth_secret : str
+ auth_secret : str, optional
Returns
-------
tuple
with (sockfile, sock)
"""
+ is_unix_domain_socket = isinstance(conn_info, str) and auth_secret is None
+ if is_unix_domain_socket:
+ sock_path = conn_info
+ assert isinstance(sock_path, str)
+ sock = None
+ try:
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT",
15)))
+ sock.connect(sock_path)
+ sockfile = sock.makefile("rwb",
int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
+ return (sockfile, sock)
+ except socket.error as e:
+ if sock is not None:
+ sock.close()
+ raise PySparkRuntimeError(
+ errorClass="CANNOT_OPEN_SOCKET",
+ messageParameters={
+ "errors": "tried to connect to %s, but an error occurred:
%s"
+ % (sock_path, str(e)),
+ },
+ )
+
sock = None
errors = []
# Support for both IPv4 and IPv6.
addr = "127.0.0.1"
if os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true":
addr = "::1"
- for res in socket.getaddrinfo(addr, port, socket.AF_UNSPEC,
socket.SOCK_STREAM):
+ for res in socket.getaddrinfo(addr, conn_info, socket.AF_UNSPEC,
socket.SOCK_STREAM):
af, socktype, proto, _, sa = res
try:
sock = socket.socket(af, socktype, proto)
sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT",
15)))
sock.connect(sa)
sockfile = sock.makefile("rwb",
int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
+ assert isinstance(auth_secret, str)
_do_server_auth(sockfile, auth_secret)
return (sockfile, sock)
except socket.error as e:
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 44a5d0b91131..0724ad42e566 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -45,7 +45,6 @@ from pyspark.serializers import (
write_long,
read_int,
SpecialLengths,
- UTF8Deserializer,
CPickleSerializer,
BatchedSerializer,
)
@@ -1548,6 +1547,8 @@ def read_udfs(pickleSer, infile, eval_type):
or eval_type ==
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
):
state_server_port = read_int(infile)
+ if state_server_port == -1:
+ state_server_port = utf8_deserializer.loads(infile)
key_schema =
StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
# NOTE: if timezone is set here, that implies respectSessionTimeZone
is True
@@ -1983,8 +1984,6 @@ def main(infile, outfile):
# read inputs only for a barrier task
isBarrier = read_bool(infile)
- boundPort = read_int(infile)
- secret = UTF8Deserializer().loads(infile)
memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB",
"-1"))
setup_memory_limits(memory_limit_mb)
@@ -1992,6 +1991,12 @@ def main(infile, outfile):
# initialize global state
taskContext = None
if isBarrier:
+ boundPort = read_int(infile)
+ secret = None
+ if boundPort == -1:
+ boundPort = utf8_deserializer.loads(infile)
+ else:
+ secret = utf8_deserializer.loads(infile)
taskContext = BarrierTaskContext._getOrCreate()
BarrierTaskContext._initialize(boundPort, secret)
# Set the task context instance here, so we can get it by
TaskContext.get for
@@ -2085,9 +2090,11 @@ def main(infile, outfile):
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the
environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ conn_info = os.environ.get(
+ "PYTHON_WORKER_FACTORY_SOCK_PATH",
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+ )
+ auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+ (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
# TODO: Remove the following two lines and use `Process.pid()` when we
drop JDK 8.
write_int(os.getpid(), sock_file)
sock_file.flush()
diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py
index 5c758d3f83fe..c2f35db8d52d 100644
--- a/python/pyspark/worker_util.py
+++ b/python/pyspark/worker_util.py
@@ -156,9 +156,13 @@ def setup_broadcasts(infile: IO) -> None:
num_broadcast_variables = read_int(infile)
if needs_broadcast_decryption_server:
# read the decrypted data from a server in the jvm
- port = read_int(infile)
- auth_secret = utf8_deserializer.loads(infile)
- (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)
+ conn_info = read_int(infile)
+ auth_secret = None
+ if conn_info == -1:
+ conn_info = utf8_deserializer.loads(infile)
+ else:
+ auth_secret = utf8_deserializer.loads(infile)
+ (broadcast_sock_file, _) = local_connect_and_auth(conn_info,
auth_secret)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
diff --git a/python/run-tests.py b/python/run-tests.py
index 64ac48e210db..8752f264cd75 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -111,6 +111,7 @@ def run_individual_python_test(target_dir, test_name,
pyspark_python, keep_test_
while os.path.isdir(tmp_dir):
tmp_dir = os.path.join(target_dir, str(uuid.uuid4()))
os.mkdir(tmp_dir)
+ sock_dir = os.getenv('TMPDIR') or os.getenv('TEMP') or os.getenv('TMP') or
'/tmp'
env["TMPDIR"] = tmp_dir
metastore_dir = os.path.join(tmp_dir, str(uuid.uuid4()))
while os.path.isdir(metastore_dir):
@@ -124,6 +125,7 @@ def run_individual_python_test(target_dir, test_name,
pyspark_python, keep_test_
"--conf", "spark.driver.extraJavaOptions='{0}'".format(java_options),
"--conf", "spark.executor.extraJavaOptions='{0}'".format(java_options),
"--conf", "spark.sql.warehouse.dir='{0}'".format(metastore_dir),
+ "--conf", "spark.python.unix.domain.socket.dir={0}".format(sock_dir),
"pyspark-shell",
]
diff --git
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 4408817b0426..b3a792bbfc73 100644
---
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -40,6 +40,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
+import
org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
import org.apache.spark.internal.config.UI._
import org.apache.spark.launcher._
import org.apache.spark.scheduler.{SparkListener,
SparkListenerApplicationStart, SparkListenerExecutorAdded}
@@ -268,11 +269,19 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
}
test("run Python application in yarn-client mode") {
- testPySpark(true)
+ testPySpark(
+ true,
+ // User is unknown in this suite.
+ extraConf = Map(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key -> false.toString)
+ )
}
test("run Python application in yarn-cluster mode") {
- testPySpark(false)
+ testPySpark(
+ false,
+ // User is unknown in this suite.
+ extraConf = Map(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key -> false.toString)
+ )
}
test("run Python application with Spark Connect in yarn-client mode") {
@@ -290,6 +299,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
testPySpark(
clientMode = false,
extraConf = Map(
+ PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key -> false.toString, // User is
unknown in this suite.
"spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON"
-> sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", pythonExecutablePath),
"spark.yarn.appMasterEnv.PYSPARK_PYTHON"
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 374d38db371a..40779c66600f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -18,8 +18,7 @@
package org.apache.spark.sql.api.python
import java.io.InputStream
-import java.net.Socket
-import java.nio.channels.Channels
+import java.nio.channels.{Channels, SocketChannel}
import net.razorvine.pickle.{Pickler, Unpickler}
@@ -197,8 +196,8 @@ private[sql] object PythonSQLUtils extends Logging {
private[spark] class ArrowIteratorServer
extends
SocketAuthServer[Iterator[Array[Byte]]]("pyspark-arrow-batches-server") {
- def handleConnection(sock: Socket): Iterator[Array[Byte]] = {
- val in = sock.getInputStream()
+ def handleConnection(sock: SocketChannel): Iterator[Array[Byte]] = {
+ val in = Channels.newInputStream(sock)
val dechunkedInput: InputStream = new DechunkedInputStream(in)
// Create array to consume iterator so that we can safely close the file
ArrowConverters.getBatchesFromStream(Channels.newChannel(dechunkedInput)).toArray.iterator
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
index 89273b7bc80f..3979220618ba 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
@@ -19,6 +19,7 @@
package org.apache.spark.sql.execution.python.streaming
import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream,
DataOutputStream}
+import java.nio.channels.Channels
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
@@ -99,7 +100,7 @@ class PythonStreamingSourceRunner(
pythonWorkerFactory = Some(workerFactory)
val stream = new BufferedOutputStream(
- pythonWorker.get.channel.socket().getOutputStream, bufferSize)
+ Channels.newOutputStream(pythonWorker.get.channel), bufferSize)
dataOut = new DataOutputStream(stream)
PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
@@ -118,7 +119,7 @@ class PythonStreamingSourceRunner(
dataOut.flush()
dataIn = new DataInputStream(
- new
BufferedInputStream(pythonWorker.get.channel.socket().getInputStream,
bufferSize))
+ new
BufferedInputStream(Channels.newInputStream(pythonWorker.get.channel),
bufferSize))
val initStatus = dataIn.readInt()
if (initStatus == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasPythonRunner.scala
index 9b2a2518a7b2..638b2d48ffc4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasPythonRunner.scala
@@ -17,17 +17,20 @@
package org.apache.spark.sql.execution.python.streaming
-import java.io.{DataInputStream, DataOutputStream}
-import java.net.ServerSocket
+import java.io.{DataInputStream, DataOutputStream, File}
+import java.net.{InetAddress, InetSocketAddress, StandardProtocolFamily,
UnixDomainSocketAddress}
+import java.nio.channels.ServerSocketChannel
+import java.util.UUID
import scala.concurrent.ExecutionContext
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
-import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions,
PythonFunction, PythonRDD, PythonWorkerUtils, StreamingPythonRunner}
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.Python.{PYTHON_UNIX_DOMAIN_SOCKET_DIR,
PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.{BasicPythonArrowOutput,
PythonArrowInput, PythonUDFRunner}
@@ -196,8 +199,13 @@ abstract class
TransformWithStateInPandasPythonBaseRunner[I](
override protected def handleMetadataBeforeExec(stream: DataOutputStream):
Unit = {
super.handleMetadataBeforeExec(stream)
- // Also write the port number for state server
- stream.writeInt(stateServerSocketPort)
+ // Also write the port/path number for state server
+ if (isUnixDomainSock) {
+ stream.writeInt(-1)
+ PythonWorkerUtils.writeUTF(stateServerSocketPath, stream)
+ } else {
+ stream.writeInt(stateServerSocketPort)
+ }
PythonRDD.writeUTF(groupingKeySchema.json, stream)
}
@@ -255,14 +263,19 @@ class TransformWithStateInPandasPythonPreInitRunner(
dataOut = result._1
dataIn = result._2
- // start state server, update socket port
+ // start state server, update socket port/path
startStateServer()
(dataOut, dataIn)
}
def process(): Unit = {
- // Also write the port number for state server
- dataOut.writeInt(stateServerSocketPort)
+ // Also write the port/path number for state server
+ if (isUnixDomainSock) {
+ dataOut.writeInt(-1)
+ PythonWorkerUtils.writeUTF(stateServerSocketPath, dataOut)
+ } else {
+ dataOut.writeInt(stateServerSocketPort)
+ }
PythonWorkerUtils.writeUTF(groupingKeySchema.json, dataOut)
dataOut.flush()
@@ -307,14 +320,27 @@ class TransformWithStateInPandasPythonPreInitRunner(
* in a new daemon thread.
*/
trait TransformWithStateInPandasPythonRunnerUtils extends Logging {
- protected var stateServerSocketPort: Int = 0
- protected var stateServerSocket: ServerSocket = null
+ protected val isUnixDomainSock: Boolean =
SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
+ protected var stateServerSocketPort: Int = -1
+ protected var stateServerSocketPath: String = null
+ protected var stateServerSocket: ServerSocketChannel = null
protected def initStateServer(): Unit = {
var failed = false
try {
- stateServerSocket = new ServerSocket(/* port = */ 0,
- /* backlog = */ 1)
- stateServerSocketPort = stateServerSocket.getLocalPort
+ if (isUnixDomainSock) {
+ val sockPath = new File(
+ SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
+ .getOrElse(System.getProperty("java.io.tmpdir")),
+ s".${UUID.randomUUID()}.sock")
+ stateServerSocket =
ServerSocketChannel.open(StandardProtocolFamily.UNIX)
+ stateServerSocket.bind(UnixDomainSocketAddress.of(sockPath.getPath), 1)
+ sockPath.deleteOnExit()
+ stateServerSocketPath = sockPath.getPath
+ } else {
+ stateServerSocket = ServerSocketChannel.open()
+ .bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1)
+ stateServerSocketPort = stateServerSocket.socket().getLocalPort
+ }
} catch {
case e: Throwable =>
failed = true
@@ -326,10 +352,13 @@ trait TransformWithStateInPandasPythonRunnerUtils extends
Logging {
}
}
- protected def closeServerSocketChannelSilently(stateServerSocket:
ServerSocket): Unit = {
+ protected def closeServerSocketChannelSilently(stateServerSocket:
ServerSocketChannel): Unit = {
try {
logInfo(log"closing the state server socket")
stateServerSocket.close()
+ if (stateServerSocketPath != null) {
+ new File(stateServerSocketPath).delete
+ }
} catch {
case e: Exception =>
logError(log"failed to close state server socket", e)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
index f46b66204383..3749fb6b7c50 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.python.streaming
import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream,
DataOutputStream, EOFException}
-import java.net.ServerSocket
+import java.nio.channels.{Channels, ServerSocketChannel}
import java.time.Duration
import scala.collection.mutable
@@ -27,7 +27,9 @@ import com.google.protobuf.ByteString
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.apache.spark.SparkEnv
import org.apache.spark.internal.{Logging, LogKeys, MDC}
+import
org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.InternalRow
@@ -52,7 +54,7 @@ import org.apache.spark.util.Utils
* - Requests for managing state variables (e.g. valueState).
*/
class TransformWithStateInPandasStateServer(
- stateServerSocket: ServerSocket,
+ stateServerSocket: ServerSocketChannel,
statefulProcessorHandle: StatefulProcessorHandleImplBase,
groupingKeySchema: StructType,
timeZoneId: String,
@@ -80,6 +82,10 @@ class TransformWithStateInPandasStateServer(
private var inputStream: DataInputStream = _
private var outputStream: DataOutputStream = outputStreamForTest
+ private val isUnixDomainSock = Option(SparkEnv.get)
+ .map(_.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED))
+ .getOrElse(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.defaultValue.get)
+
/** State variable related class variables */
// A map to store the value state name -> (value state, schema, value row
deserializer) mapping.
private val valueStates = if (valueStateMapForTest != null) {
@@ -148,12 +154,12 @@ class TransformWithStateInPandasStateServer(
// Disabling either would work, but it's more common to disable Nagle's
algorithm; there is
// lot less reference to disabling delayed ACKs, while there are lots of
resources to
// disable Nagle's algorithm.
- listeningSocket.setTcpNoDelay(true)
+ if (!isUnixDomainSock) listeningSocket.socket().setTcpNoDelay(true)
inputStream = new DataInputStream(
- new BufferedInputStream(listeningSocket.getInputStream))
+ new BufferedInputStream(Channels.newInputStream(listeningSocket)))
outputStream = new DataOutputStream(
- new BufferedOutputStream(listeningSocket.getOutputStream)
+ new BufferedOutputStream(Channels.newOutputStream(listeningSocket))
)
while (listeningSocket.isConnected &&
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
index 305a520f6af8..f1e6379a00c8 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.python.streaming
import java.io.DataOutputStream
-import java.net.ServerSocket
+import java.nio.channels.ServerSocketChannel
import scala.collection.mutable
@@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{IntegerType, StructField,
StructType}
class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with
BeforeAndAfterEach {
val stateName = "test"
val iteratorId = "testId"
- val serverSocket: ServerSocket = mock(classOf[ServerSocket])
+ val serverSocket: ServerSocketChannel = mock(classOf[ServerSocketChannel])
val groupingKeySchema: StructType = StructType(Seq())
val stateSchema: StructType = StructType(Array(StructField("value",
IntegerType)))
// Below byte array is a serialized row with a single integer value 1.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]