Repository: spark
Updated Branches:
  refs/heads/branch-2.4 adfd1057d -> eff1c5016


[SPARK-25822][PYSPARK] Fix a race condition when releasing a Python worker

## What changes were proposed in this pull request?

There is a race condition when releasing a Python worker. If 
`ReaderIterator.handleEndOfDataSection` is not running in the task thread, when 
a task is early terminated (such as `take(N)`), the task completion listener 
may close the worker but "handleEndOfDataSection" can still put the worker into 
the worker pool to reuse.

https://github.com/zsxwing/spark/commit/0e07b483d2e7c68f3b5c3c118d0bf58c501041b7
 is a patch to reproduce this issue.

I also found a user reported this in the mail list: 
http://mail-archives.apache.org/mod_mbox/spark-user/201610.mbox/%3CCAAUq=H+YLUEpd23nwvq13Ms5hOStkhX3ao4f4zQV6sgO5zM-xAmail.gmail.com%3E

This PR fixes the issue by using `compareAndSet` to make sure we will never 
return a closed worker to the work pool.

## How was this patch tested?

Jenkins.

Closes #22816 from zsxwing/fix-socket-closed.

Authored-by: Shixiong Zhu <zsxw...@gmail.com>
Signed-off-by: Takuya UESHIN <ues...@databricks.com>
(cherry picked from commit 86d469aeaa492c0642db09b27bb0879ead5d7166)
Signed-off-by: Takuya UESHIN <ues...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eff1c501
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eff1c501
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eff1c501

Branch: refs/heads/branch-2.4
Commit: eff1c5016cc78ad3d26e5e83cd0d72c56c35cf0a
Parents: adfd105
Author: Shixiong Zhu <zsxw...@gmail.com>
Authored: Fri Oct 26 13:53:51 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Fri Oct 26 13:54:14 2018 +0900

----------------------------------------------------------------------
 .../apache/spark/api/python/PythonRunner.scala  | 21 ++++++++++----------
 .../execution/python/ArrowPythonRunner.scala    |  4 ++--
 .../sql/execution/python/PythonUDFRunner.scala  |  4 ++--
 3 files changed, 15 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/eff1c501/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 6e53a04..f73e95e 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -106,15 +106,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
       envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", memoryMb.get.toString)
     }
     val worker: Socket = env.createPythonWorker(pythonExec, 
envVars.asScala.toMap)
-    // Whether is the worker released into idle pool
-    val released = new AtomicBoolean(false)
+    // Whether is the worker released into idle pool or closed. When any codes 
try to release or
+    // close a worker, they should use `releasedOrClosed.compareAndSet` to 
flip the state to make
+    // sure there is only one winner that is going to release or close the 
worker.
+    val releasedOrClosed = new AtomicBoolean(false)
 
     // Start a thread to feed the process input from our parent's iterator
     val writerThread = newWriterThread(env, worker, inputIterator, 
partitionIndex, context)
 
     context.addTaskCompletionListener[Unit] { _ =>
       writerThread.shutdownOnTaskCompletion()
-      if (!reuseWorker || !released.get) {
+      if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) {
         try {
           worker.close()
         } catch {
@@ -131,7 +133,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     val stream = new DataInputStream(new 
BufferedInputStream(worker.getInputStream, bufferSize))
 
     val stdoutIterator = newReaderIterator(
-      stream, writerThread, startTime, env, worker, released, context)
+      stream, writerThread, startTime, env, worker, releasedOrClosed, context)
     new InterruptibleIterator(context, stdoutIterator)
   }
 
@@ -148,7 +150,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
       startTime: Long,
       env: SparkEnv,
       worker: Socket,
-      released: AtomicBoolean,
+      releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[OUT]
 
   /**
@@ -392,7 +394,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
       startTime: Long,
       env: SparkEnv,
       worker: Socket,
-      released: AtomicBoolean,
+      releasedOrClosed: AtomicBoolean,
       context: TaskContext)
     extends Iterator[OUT] {
 
@@ -463,9 +465,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
       }
       // Check whether the worker is ready to be re-used.
       if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
-        if (reuseWorker) {
+        if (reuseWorker && releasedOrClosed.compareAndSet(false, true)) {
           env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
-          released.set(true)
         }
       }
       eos = true
@@ -565,9 +566,9 @@ private[spark] class PythonRunner(funcs: 
Seq[ChainedPythonFunctions])
       startTime: Long,
       env: SparkEnv,
       worker: Socket,
-      released: AtomicBoolean,
+      releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[Array[Byte]] = {
-    new ReaderIterator(stream, writerThread, startTime, env, worker, released, 
context) {
+    new ReaderIterator(stream, writerThread, startTime, env, worker, 
releasedOrClosed, context) {
 
       protected override def read(): Array[Byte] = {
         if (writerThread.exception.isDefined) {

http://git-wip-us.apache.org/repos/asf/spark/blob/eff1c501/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 18992d7..04623b1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -117,9 +117,9 @@ class ArrowPythonRunner(
       startTime: Long,
       env: SparkEnv,
       worker: Socket,
-      released: AtomicBoolean,
+      releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[ColumnarBatch] = {
-    new ReaderIterator(stream, writerThread, startTime, env, worker, released, 
context) {
+    new ReaderIterator(stream, writerThread, startTime, env, worker, 
releasedOrClosed, context) {
 
       private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
         s"stdin reader for $pythonExec", 0, Long.MaxValue)

http://git-wip-us.apache.org/repos/asf/spark/blob/eff1c501/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index cc61faa..752d271 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -59,9 +59,9 @@ class PythonUDFRunner(
       startTime: Long,
       env: SparkEnv,
       worker: Socket,
-      released: AtomicBoolean,
+      releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[Array[Byte]] = {
-    new ReaderIterator(stream, writerThread, startTime, env, worker, released, 
context) {
+    new ReaderIterator(stream, writerThread, startTime, env, worker, 
releasedOrClosed, context) {
 
       protected override def read(): Array[Byte] = {
         if (writerThread.exception.isDefined) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to