Repository: spark
Updated Branches:
  refs/heads/master 81a305dd0 -> bd66c7302


[SPARK-25771][PYSPARK] Fix improper synchronization in PythonWorkerFactory

## What changes were proposed in this pull request?

Fix the following issues in PythonWorkerFactory
1. MonitorThread.run uses a wrong lock.
2. `createSimpleWorker` misses `synchronized` when updating `simpleWorkers`.

Other changes are just to improve the code style to make the thread-safe 
contract clear.

## How was this patch tested?

Jenkins

Closes #22770 from zsxwing/pwf.

Authored-by: Shixiong Zhu <zsxw...@gmail.com>
Signed-off-by: Shixiong Zhu <zsxw...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bd66c730
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bd66c730
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bd66c730

Branch: refs/heads/master
Commit: bd66c73025c0b947be230178a737fd53812b78dd
Parents: 81a305d
Author: Shixiong Zhu <zsxw...@gmail.com>
Authored: Mon Oct 22 10:07:11 2018 -0700
Committer: Shixiong Zhu <zsxw...@gmail.com>
Committed: Mon Oct 22 10:07:11 2018 -0700

----------------------------------------------------------------------
 .../spark/api/python/PythonWorkerFactory.scala  | 75 +++++++++++---------
 1 file changed, 43 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bd66c730/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 6afa37a..1f2f503 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream, 
EOFException, InputStream, Ou
 import java.net.{InetAddress, ServerSocket, Socket, SocketException}
 import java.nio.charset.StandardCharsets
 import java.util.Arrays
+import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -31,7 +32,7 @@ import org.apache.spark.security.SocketAuthHelper
 import org.apache.spark.util.{RedirectThread, Utils}
 
 private[spark] class PythonWorkerFactory(pythonExec: String, envVars: 
Map[String, String])
-  extends Logging {
+  extends Logging { self =>
 
   import PythonWorkerFactory._
 
@@ -39,7 +40,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, 
envVars: Map[String
   // pyspark/daemon.py (by default) and tell it to fork new workers for our 
tasks. This daemon
   // currently only works on UNIX-based systems now because it uses signals 
for child management,
   // so we can also fall back to launching workers, pyspark/worker.py (by 
default) directly.
-  val useDaemon = {
+  private val useDaemon = {
     val useDaemonEnabled = 
SparkEnv.get.conf.getBoolean("spark.python.use.daemon", true)
 
     // This flag is ignored on Windows as it's unable to fork.
@@ -51,44 +52,52 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
   // as expert-only option, and shouldn't be used before knowing what it means 
exactly.
 
   // This configuration indicates the module to run the daemon to execute its 
Python workers.
-  val daemonModule = 
SparkEnv.get.conf.getOption("spark.python.daemon.module").map { value =>
-    logInfo(
-      s"Python daemon module in PySpark is set to [$value] in 
'spark.python.daemon.module', " +
-      "using this to start the daemon up. Note that this configuration only 
has an effect when " +
-      "'spark.python.use.daemon' is enabled and the platform is not Windows.")
-    value
-  }.getOrElse("pyspark.daemon")
+  private val daemonModule =
+    SparkEnv.get.conf.getOption("spark.python.daemon.module").map { value =>
+      logInfo(
+        s"Python daemon module in PySpark is set to [$value] in 
'spark.python.daemon.module', " +
+        "using this to start the daemon up. Note that this configuration only 
has an effect when " +
+        "'spark.python.use.daemon' is enabled and the platform is not 
Windows.")
+      value
+    }.getOrElse("pyspark.daemon")
 
   // This configuration indicates the module to run each Python worker.
-  val workerModule = 
SparkEnv.get.conf.getOption("spark.python.worker.module").map { value =>
-    logInfo(
-      s"Python worker module in PySpark is set to [$value] in 
'spark.python.worker.module', " +
-      "using this to start the worker up. Note that this configuration only 
has an effect when " +
-      "'spark.python.use.daemon' is disabled or the platform is Windows.")
-    value
-  }.getOrElse("pyspark.worker")
+  private val workerModule =
+    SparkEnv.get.conf.getOption("spark.python.worker.module").map { value =>
+      logInfo(
+        s"Python worker module in PySpark is set to [$value] in 
'spark.python.worker.module', " +
+        "using this to start the worker up. Note that this configuration only 
has an effect when " +
+        "'spark.python.use.daemon' is disabled or the platform is Windows.")
+      value
+    }.getOrElse("pyspark.worker")
 
   private val authHelper = new SocketAuthHelper(SparkEnv.get.conf)
 
-  var daemon: Process = null
+  @GuardedBy("self")
+  private var daemon: Process = null
   val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
-  var daemonPort: Int = 0
-  val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
-  val idleWorkers = new mutable.Queue[Socket]()
-  var lastActivity = 0L
+  @GuardedBy("self")
+  private var daemonPort: Int = 0
+  @GuardedBy("self")
+  private val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+  @GuardedBy("self")
+  private val idleWorkers = new mutable.Queue[Socket]()
+  @GuardedBy("self")
+  private var lastActivity = 0L
   new MonitorThread().start()
 
-  var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
+  @GuardedBy("self")
+  private val simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
 
-  val pythonPath = PythonUtils.mergePythonPaths(
+  private val pythonPath = PythonUtils.mergePythonPaths(
     PythonUtils.sparkPythonPath,
     envVars.getOrElse("PYTHONPATH", ""),
     sys.env.getOrElse("PYTHONPATH", ""))
 
   def create(): Socket = {
     if (useDaemon) {
-      synchronized {
-        if (idleWorkers.size > 0) {
+      self.synchronized {
+        if (idleWorkers.nonEmpty) {
           return idleWorkers.dequeue()
         }
       }
@@ -117,7 +126,7 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
       socket
     }
 
-    synchronized {
+    self.synchronized {
       // Start the daemon if it hasn't been started
       startDaemon()
 
@@ -163,7 +172,9 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
       try {
         val socket = serverSocket.accept()
         authHelper.authClient(socket)
-        simpleWorkers.put(socket, worker)
+        self.synchronized {
+          simpleWorkers.put(socket, worker)
+        }
         return socket
       } catch {
         case e: Exception =>
@@ -178,7 +189,7 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
   }
 
   private def startDaemon() {
-    synchronized {
+    self.synchronized {
       // Is it already running?
       if (daemon != null) {
         return
@@ -278,7 +289,7 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
 
     override def run() {
       while (true) {
-        synchronized {
+        self.synchronized {
           if (lastActivity + IDLE_WORKER_TIMEOUT_MS < 
System.currentTimeMillis()) {
             cleanupIdleWorkers()
             lastActivity = System.currentTimeMillis()
@@ -303,7 +314,7 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
   }
 
   private def stopDaemon() {
-    synchronized {
+    self.synchronized {
       if (useDaemon) {
         cleanupIdleWorkers()
 
@@ -325,7 +336,7 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
   }
 
   def stopWorker(worker: Socket) {
-    synchronized {
+    self.synchronized {
       if (useDaemon) {
         if (daemon != null) {
           daemonWorkers.get(worker).foreach { pid =>
@@ -345,7 +356,7 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
 
   def releaseWorker(worker: Socket) {
     if (useDaemon) {
-      synchronized {
+      self.synchronized {
         lastActivity = System.currentTimeMillis()
         idleWorkers.enqueue(worker)
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to